Skip to content

Commit

Permalink
Merge pull request #31 from ollama/keepalive
Browse files Browse the repository at this point in the history
add keep_alive
  • Loading branch information
mxyng authored Feb 2, 2024
2 parents 4a81fa4 + fbb6553 commit cdec2ad
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
26 changes: 24 additions & 2 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def generate(
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Create a response using the requested model.
Expand Down Expand Up @@ -121,6 +122,7 @@ def generate(
'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
'keep_alive': keep_alive,
},
stream=stream,
)
Expand All @@ -132,6 +134,7 @@ def chat(
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
"""
Create a chat response using the requested model.
Expand Down Expand Up @@ -165,18 +168,26 @@ def chat(
'stream': stream,
'format': format,
'options': options or {},
'keep_alive': keep_alive,
},
stream=stream,
)

def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]:
def embeddings(
self,
model: str = '',
prompt: str = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Sequence[float]:
return self._request(
'POST',
'/api/embeddings',
json={
'model': model,
'prompt': prompt,
'options': options or {},
'keep_alive': keep_alive,
},
).json()

Expand Down Expand Up @@ -364,6 +375,7 @@ async def generate(
format: Literal['', 'json'] = '',
images: Optional[Sequence[AnyStr]] = None,
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Create a response using the requested model.
Expand Down Expand Up @@ -391,6 +403,7 @@ async def generate(
'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
'keep_alive': keep_alive,
},
stream=stream,
)
Expand All @@ -402,6 +415,7 @@ async def chat(
stream: bool = False,
format: Literal['', 'json'] = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
"""
Create a chat response using the requested model.
Expand Down Expand Up @@ -434,18 +448,26 @@ async def chat(
'stream': stream,
'format': format,
'options': options or {},
'keep_alive': keep_alive,
},
stream=stream,
)

async def embeddings(self, model: str = '', prompt: str = '', options: Optional[Options] = None) -> Sequence[float]:
async def embeddings(
self,
model: str = '',
prompt: str = '',
options: Optional[Options] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Sequence[float]:
response = await self._request(
'POST',
'/api/embeddings',
json={
'model': model,
'prompt': prompt,
'options': options or {},
'keep_alive': keep_alive,
},
)

Expand Down
12 changes: 12 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_client_chat(httpserver: HTTPServer):
'stream': False,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json(
{
Expand Down Expand Up @@ -75,6 +76,7 @@ def generate():
'stream': True,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_handler(stream_handler)

Expand Down Expand Up @@ -103,6 +105,7 @@ def test_client_chat_images(httpserver: HTTPServer):
'stream': False,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json(
{
Expand Down Expand Up @@ -139,6 +142,7 @@ def test_client_generate(httpserver: HTTPServer):
'images': [],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json(
{
Expand Down Expand Up @@ -183,6 +187,7 @@ def generate():
'images': [],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_handler(stream_handler)

Expand Down Expand Up @@ -210,6 +215,7 @@ def test_client_generate_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json(
{
Expand Down Expand Up @@ -513,6 +519,7 @@ async def test_async_client_chat(httpserver: HTTPServer):
'stream': False,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json({})

Expand Down Expand Up @@ -550,6 +557,7 @@ def generate():
'stream': True,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_handler(stream_handler)

Expand Down Expand Up @@ -579,6 +587,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
'stream': False,
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json({})

Expand Down Expand Up @@ -606,6 +615,7 @@ async def test_async_client_generate(httpserver: HTTPServer):
'images': [],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json({})

Expand Down Expand Up @@ -645,6 +655,7 @@ def generate():
'images': [],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_handler(stream_handler)

Expand Down Expand Up @@ -673,6 +684,7 @@ async def test_async_client_generate_images(httpserver: HTTPServer):
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
'keep_alive': None,
},
).respond_with_json({})

Expand Down

0 comments on commit cdec2ad

Please sign in to comment.