From 4e1c9a2e81f9b39bc77baab0aec43679949ab25c Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 13 Apr 2025 13:14:22 -0700 Subject: [PATCH 1/3] Allow specification of maximum event contents in prompt to help mitigate error compounding --- src/google/adk/agents/llm_agent.py | 8 ++++++++ src/google/adk/flows/llm_flows/contents.py | 8 +++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index a140997228..c6c427b275 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -133,6 +133,14 @@ class LlmAgent(BaseAgent): user messages, tool results, etc. """ + max_contents: Optional[int] = None + """The maximum number of contents to include in the model request. + Hallucinations can lead to compounding error if allowed to persist in the + system prompt indefinitely. Recommend setting this large enough to allow + relevant content to stay in memory but short enough that older irrelevant + content can be forgotten. + """ + # Controlled input/output configurations - Start input_schema: Optional[type[BaseModel]] = None """The input schema when agent is used as a tool.""" diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index f2847554c8..3faacbb5e5 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -48,6 +48,7 @@ async def run_async( invocation_context.branch, invocation_context.session.events, agent.name, + agent.max_contents, ) # Maintain async generator behavior @@ -186,7 +187,10 @@ def _rearrange_events_for_latest_function_response( def _get_contents( - current_branch: Optional[str], events: list[Event], agent_name: str = '' + current_branch: Optional[str], + events: list[Event], + agent_name: str = '', + max_contents: Optional[int] = None ) -> list[types.Content]: """Get the contents for the LLM request. @@ -224,6 +228,8 @@ def _get_contents( result_events = _rearrange_events_for_async_function_responses_in_history( result_events ) + if max_contents: + result_events = result_events[-max_contents:] contents = [] for event in result_events: content = copy.deepcopy(event.content) From 798aab044adb4fafa55f1a386d869f17c49076e1 Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 13 Apr 2025 13:46:49 -0700 Subject: [PATCH 2/3] Allow malformed function calls to be fixed on the spot --- src/google/adk/models/google_llm.py | 42 +++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 29988dfc91..c2b8beb8ec 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -40,6 +40,46 @@ _EXCLUDED_PART_FIELD = {'inline_data': {'data'}} +def _extract_function_call(input_string): + """ Use regex to convert a function call string into a FunctionCall object. """ + pattern = r"Malformed function call:\s*(\w+)\((.*)\)" + match = re.search(pattern, input_string) + if match: + func_name = match.group(1) + args_str = match.group(2).strip() + + # Create a dummy function call with the captured arguments. + # This will allow us to use ast to parse the function call. + dummy_call = f"dummy({args_str})" + + # Parse the dummy function call. + tree = ast.parse(dummy_call, mode='eval') + call_node = tree.body + + # Extract keyword arguments (if there are any). + args_dict = {kw.arg: ast.literal_eval(kw.value) for kw in call_node.keywords} + + return types.FunctionCall(args=args_dict, name=func_name) + else: + return None + + +def _fix_malformed_function_calls(response): + """ Check if there's malformed error, create FunctionCall object using args. + Then remove the error and insert the FunctionCall into the response. + """ + for candidate in response.candidates: + if candidate.finish_reason == types.FinishReason.MALFORMED_FUNCTION_CALL: + function_call = _extract_function_call(candidate.finish_message) + if function_call is None: + logging.warning("could not parse function call: %s", candidate.finish_message) + continue + logging.warning("malformed function call caught and overwritten: %s", candidate.finish_message) + candidate.content = types.Content(parts=[types.Part(function_call=function_call)], role="model") + candidate.finish_message = None + candidate.finish_reason = types.FinishReason.STOP + + class Gemini(BaseLlm): """Integration for Gemini models. @@ -102,6 +142,7 @@ async def generate_content_async( # previous partial content. The only difference is bidi rely on # complete_turn flag to detect end while sse depends on finish_reason. async for response in responses: + _fix_malformed_function_calls(response) logger.info(_build_response_log(response)) llm_response = LlmResponse.create(response) if ( @@ -142,6 +183,7 @@ async def generate_content_async( contents=llm_request.contents, config=llm_request.config, ) + _fix_malformed_function_calls(response) logger.info(_build_response_log(response)) yield LlmResponse.create(response) From ff1d4eb42fd35bbf33349fa60c1e09eba664b352 Mon Sep 17 00:00:00 2001 From: Alex Braylan Date: Thu, 22 May 2025 19:19:13 -0500 Subject: [PATCH 3/3] fix missing imports --- src/google/adk/models/google_llm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index c2b8beb8ec..3eefd5beca 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -13,6 +13,8 @@ # limitations under the License. from __future__ import annotations +import re +import ast import contextlib from functools import cached_property import logging