From fbb6553e03304b5b83029be4e96e0102a30bdb12 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 26 Jan 2024 11:19:52 -0800 Subject: [PATCH] add keep_alive --- ollama/_client.py | 26 ++++++++++++++++++++++++-- tests/test_client.py | 12 ++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index 5908dfc..22f1565 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -92,6 +92,7 @@ def generate( format: Literal['', 'json'] = '', images: Optional[Sequence[AnyStr]] = None, options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: """ Create a response using the requested model. @@ -120,6 +121,7 @@ def generate( 'images': [_encode_image(image) for image in images or []], 'format': format, 'options': options or {}, + 'keep_alive': keep_alive, }, stream=stream, ) @@ -131,6 +133,7 @@ def chat( stream: bool = False, format: Literal['', 'json'] = '', options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: """ Create a chat response using the requested model. @@ -164,11 +167,18 @@ def chat( 'stream': stream, 'format': format, 'options': options or {}, + 'keep_alive': keep_alive, }, stream=stream, ) - def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]: + def embeddings( + self, + model: str = '', + prompt: str = '', + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Sequence[float]: return self._request( 'POST', '/api/embeddings', @@ -176,6 +186,7 @@ def embeddings(self, model: str = '', prompt: str = '', options: Optional[Option 'model': model, 'prompt': prompt, 'options': options or {}, + 'keep_alive': keep_alive, }, ).json() @@ -360,6 +371,7 @@ async def generate( format: Literal['', 'json'] = '', images: Optional[Sequence[AnyStr]] = None, options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: """ Create a response using the requested model. @@ -387,6 +399,7 @@ async def generate( 'images': [_encode_image(image) for image in images or []], 'format': format, 'options': options or {}, + 'keep_alive': keep_alive, }, stream=stream, ) @@ -398,6 +411,7 @@ async def chat( stream: bool = False, format: Literal['', 'json'] = '', options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: """ Create a chat response using the requested model. @@ -430,11 +444,18 @@ async def chat( 'stream': stream, 'format': format, 'options': options or {}, + 'keep_alive': keep_alive, }, stream=stream, ) - async def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]: + async def embeddings( + self, + model: str = '', + prompt: str = '', + options: Optional[Options] = None, + keep_alive: Optional[Union[float, str]] = None, + ) -> Sequence[float]: response = await self._request( 'POST', '/api/embeddings', @@ -442,6 +463,7 @@ async def embeddings(self, model: str = '', prompt: str = '', options: Optional[ 'model': model, 'prompt': prompt, 'options': options or {}, + 'keep_alive': keep_alive, }, ) diff --git a/tests/test_client.py b/tests/test_client.py index 6afbe70..859db7a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -29,6 +29,7 @@ def test_client_chat(httpserver: HTTPServer): 'stream': False, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json( { @@ -75,6 +76,7 @@ def generate(): 'stream': True, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -103,6 +105,7 @@ def test_client_chat_images(httpserver: HTTPServer): 'stream': False, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json( { @@ -139,6 +142,7 @@ def test_client_generate(httpserver: HTTPServer): 'images': [], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json( { @@ -183,6 +187,7 @@ def generate(): 'images': [], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -210,6 +215,7 @@ def test_client_generate_images(httpserver: HTTPServer): 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json( { @@ -465,6 +471,7 @@ async def test_async_client_chat(httpserver: HTTPServer): 'stream': False, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json({}) @@ -502,6 +509,7 @@ def generate(): 'stream': True, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -531,6 +539,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer): 'stream': False, 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json({}) @@ -558,6 +567,7 @@ async def test_async_client_generate(httpserver: HTTPServer): 'images': [], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json({}) @@ -597,6 +607,7 @@ def generate(): 'images': [], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_handler(stream_handler) @@ -625,6 +636,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer): 'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'], 'format': '', 'options': {}, + 'keep_alive': None, }, ).respond_with_json({})