From 99f613b1c029f4ba7379fdf7b1c5ea653f5021b0 Mon Sep 17 00:00:00 2001 From: Shawn Yang Date: Mon, 16 Dec 2024 17:04:44 -0800 Subject: [PATCH] feat: Support `stream_query` in LangChain Agent Templates in the Python Reasoning Engine Client PiperOrigin-RevId: 706882105 --- ...st_reasoning_engine_templates_langchain.py | 10 ++++++ .../reasoning_engines/templates/langchain.py | 31 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py b/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py index 307ecd045f..f88581c51d 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py +++ b/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py @@ -221,6 +221,16 @@ def test_query(self, langchain_dump_mock): [mock.call.invoke.invoke(input={"input": "test query"}, config=None)] ) + def test_stream_query(self, langchain_dump_mock): + agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL) + agent._runnable = mock.Mock() + agent._runnable.stream.return_value = [] + list(agent.stream_query(input="test stream query")) + agent._runnable.stream.assert_called_once_with( + input={"input": "test stream query"}, + config=None, + ) + @pytest.mark.usefixtures("caplog") def test_enable_tracing( self, diff --git a/vertexai/preview/reasoning_engines/templates/langchain.py b/vertexai/preview/reasoning_engines/templates/langchain.py index 0282b98047..b037fa40f1 100644 --- a/vertexai/preview/reasoning_engines/templates/langchain.py +++ b/vertexai/preview/reasoning_engines/templates/langchain.py @@ -18,6 +18,7 @@ Any, Callable, Dict, + Iterable, Mapping, Optional, Sequence, @@ -609,3 +610,33 @@ def query( return langchain_load_dump.dumpd( self._runnable.invoke(input=input, config=config, **kwargs) ) + + def stream_query( + self, + *, + input: Union[str, Mapping[str, Any]], + config: Optional["RunnableConfig"] = None, + **kwargs, + ) -> Iterable[Any]: + """Stream queries the Agent with the given input and config. + + Args: + input (Union[str, Mapping[str, Any]]): + Required. The input to be passed to the Agent. + config (langchain_core.runnables.RunnableConfig): + Optional. The config (if any) to be used for invoking the Agent. + **kwargs: + Optional. Any additional keyword arguments to be passed to the + `.invoke()` method of the corresponding AgentExecutor. + + Yields: + The output of querying the Agent with the given input and config. + """ + from langchain.load import dump as langchain_load_dump + + if isinstance(input, str): + input = {"input": input} + if not self._runnable: + self.set_up() + for chunk in self._runnable.stream(input=input, config=config, **kwargs): + yield langchain_load_dump.dumpd(chunk)