Skip to content

Commit

Permalink
feat: Support stream_query in LangChain Agent Templates in the Pyth…
Browse files Browse the repository at this point in the history
…on Reasoning Engine Client

PiperOrigin-RevId: 706882105
  • Loading branch information
shawn-yang-google authored and copybara-github committed Dec 17, 2024
1 parent 25622f8 commit 99f613b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,16 @@ def test_query(self, langchain_dump_mock):
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
)

def test_stream_query(self, langchain_dump_mock):
agent = reasoning_engines.LangchainAgent(model=_TEST_MODEL)
agent._runnable = mock.Mock()
agent._runnable.stream.return_value = []
list(agent.stream_query(input="test stream query"))
agent._runnable.stream.assert_called_once_with(
input={"input": "test stream query"},
config=None,
)

@pytest.mark.usefixtures("caplog")
def test_enable_tracing(
self,
Expand Down
31 changes: 31 additions & 0 deletions vertexai/preview/reasoning_engines/templates/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -609,3 +610,33 @@ def query(
return langchain_load_dump.dumpd(
self._runnable.invoke(input=input, config=config, **kwargs)
)

def stream_query(
self,
*,
input: Union[str, Mapping[str, Any]],
config: Optional["RunnableConfig"] = None,
**kwargs,
) -> Iterable[Any]:
"""Stream queries the Agent with the given input and config.
Args:
input (Union[str, Mapping[str, Any]]):
Required. The input to be passed to the Agent.
config (langchain_core.runnables.RunnableConfig):
Optional. The config (if any) to be used for invoking the Agent.
**kwargs:
Optional. Any additional keyword arguments to be passed to the
`.invoke()` method of the corresponding AgentExecutor.
Yields:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump

if isinstance(input, str):
input = {"input": input}
if not self._runnable:
self.set_up()
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
yield langchain_load_dump.dumpd(chunk)

0 comments on commit 99f613b

Please sign in to comment.