diff --git a/ollama/_client.py b/ollama/_client.py
index 4b62765..079eda7 100644
--- a/ollama/_client.py
+++ b/ollama/_client.py
@@ -1,5 +1,4 @@
import os
-import io
import json
import platform
import ipaddress
@@ -19,6 +18,8 @@
TypeVar,
Union,
overload,
+ Dict,
+ List,
)
import sys
@@ -62,7 +63,6 @@
ProgressResponse,
PullRequest,
PushRequest,
- RequestError,
ResponseError,
ShowRequest,
ShowResponse,
@@ -476,10 +476,16 @@ def push(
def create(
self,
model: str,
- path: Optional[Union[str, PathLike]] = None,
- modelfile: Optional[str] = None,
- *,
quantize: Optional[str] = None,
+ from_: Optional[str] = None,
+ files: Optional[Dict[str, str]] = None,
+ adapters: Optional[Dict[str, str]] = None,
+ template: Optional[str] = None,
+ license: Optional[Union[str, List[str]]] = None,
+ system: Optional[str] = None,
+ parameters: Optional[Union[Mapping[str, Any], Options]] = None,
+ messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
+ *,
stream: Literal[False] = False,
) -> ProgressResponse: ...
@@ -487,20 +493,32 @@ def create(
def create(
self,
model: str,
- path: Optional[Union[str, PathLike]] = None,
- modelfile: Optional[str] = None,
- *,
quantize: Optional[str] = None,
+ from_: Optional[str] = None,
+ files: Optional[Dict[str, str]] = None,
+ adapters: Optional[Dict[str, str]] = None,
+ template: Optional[str] = None,
+ license: Optional[Union[str, List[str]]] = None,
+ system: Optional[str] = None,
+ parameters: Optional[Union[Mapping[str, Any], Options]] = None,
+ messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
+ *,
stream: Literal[True] = True,
) -> Iterator[ProgressResponse]: ...
def create(
self,
model: str,
- path: Optional[Union[str, PathLike]] = None,
- modelfile: Optional[str] = None,
- *,
quantize: Optional[str] = None,
+ from_: Optional[str] = None,
+ files: Optional[Dict[str, str]] = None,
+ adapters: Optional[Dict[str, str]] = None,
+ template: Optional[str] = None,
+ license: Optional[Union[str, List[str]]] = None,
+ system: Optional[str] = None,
+ parameters: Optional[Union[Mapping[str, Any], Options]] = None,
+ messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
+ *,
stream: bool = False,
) -> Union[ProgressResponse, Iterator[ProgressResponse]]:
"""
@@ -508,45 +526,27 @@ def create(
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
- if (realpath := _as_path(path)) and realpath.exists():
- modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
- elif modelfile:
- modelfile = self._parse_modelfile(modelfile)
- else:
- raise RequestError('must provide either path or modelfile')
-
return self._request(
ProgressResponse,
'POST',
'/api/create',
json=CreateRequest(
model=model,
- modelfile=modelfile,
stream=stream,
quantize=quantize,
+ from_=from_,
+ files=files,
+ adapters=adapters,
+ license=license,
+ template=template,
+ system=system,
+ parameters=parameters,
+ messages=messages,
).model_dump(exclude_none=True),
stream=stream,
)
- def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
- base = Path.cwd() if base is None else base
-
- out = io.StringIO()
- for line in io.StringIO(modelfile):
- command, _, args = line.partition(' ')
- if command.upper() not in ['FROM', 'ADAPTER']:
- print(line, end='', file=out)
- continue
-
- path = Path(args.strip()).expanduser()
- path = path if path.is_absolute() else base / path
- if path.exists():
- args = f'@{self._create_blob(path)}\n'
- print(command, args, end='', file=out)
-
- return out.getvalue()
-
- def _create_blob(self, path: Union[str, Path]) -> str:
+ def create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
@@ -978,31 +978,49 @@ async def push(
async def create(
self,
model: str,
- path: Optional[Union[str, PathLike]] = None,
- modelfile: Optional[str] = None,
- *,
quantize: Optional[str] = None,
- stream: Literal[False] = False,
+ from_: Optional[str] = None,
+ files: Optional[Dict[str, str]] = None,
+ adapters: Optional[Dict[str, str]] = None,
+ template: Optional[str] = None,
+ license: Optional[Union[str, List[str]]] = None,
+ system: Optional[str] = None,
+ parameters: Optional[Union[Mapping[str, Any], Options]] = None,
+ messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
+ *,
+ stream: Literal[True] = True,
) -> ProgressResponse: ...
@overload
async def create(
self,
model: str,
- path: Optional[Union[str, PathLike]] = None,
- modelfile: Optional[str] = None,
- *,
quantize: Optional[str] = None,
+ from_: Optional[str] = None,
+ files: Optional[Dict[str, str]] = None,
+ adapters: Optional[Dict[str, str]] = None,
+ template: Optional[str] = None,
+ license: Optional[Union[str, List[str]]] = None,
+ system: Optional[str] = None,
+ parameters: Optional[Union[Mapping[str, Any], Options]] = None,
+ messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
+ *,
stream: Literal[True] = True,
) -> AsyncIterator[ProgressResponse]: ...
async def create(
self,
model: str,
- path: Optional[Union[str, PathLike]] = None,
- modelfile: Optional[str] = None,
- *,
quantize: Optional[str] = None,
+ from_: Optional[str] = None,
+ files: Optional[Dict[str, str]] = None,
+ adapters: Optional[Dict[str, str]] = None,
+ template: Optional[str] = None,
+ license: Optional[Union[str, List[str]]] = None,
+ system: Optional[str] = None,
+ parameters: Optional[Union[Mapping[str, Any], Options]] = None,
+ messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
+ *,
stream: bool = False,
) -> Union[ProgressResponse, AsyncIterator[ProgressResponse]]:
"""
@@ -1010,12 +1028,6 @@ async def create(
Returns `ProgressResponse` if `stream` is `False`, otherwise returns a `ProgressResponse` generator.
"""
- if (realpath := _as_path(path)) and realpath.exists():
- modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
- elif modelfile:
- modelfile = await self._parse_modelfile(modelfile)
- else:
- raise RequestError('must provide either path or modelfile')
return await self._request(
ProgressResponse,
@@ -1023,32 +1035,21 @@ async def create(
'/api/create',
json=CreateRequest(
model=model,
- modelfile=modelfile,
stream=stream,
quantize=quantize,
+ from_=from_,
+ files=files,
+ adapters=adapters,
+ license=license,
+ template=template,
+ system=system,
+ parameters=parameters,
+ messages=messages,
).model_dump(exclude_none=True),
stream=stream,
)
- async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
- base = Path.cwd() if base is None else base
-
- out = io.StringIO()
- for line in io.StringIO(modelfile):
- command, _, args = line.partition(' ')
- if command.upper() not in ['FROM', 'ADAPTER']:
- print(line, end='', file=out)
- continue
-
- path = Path(args.strip()).expanduser()
- path = path if path.is_absolute() else base / path
- if path.exists():
- args = f'@{await self._create_blob(path)}\n'
- print(command, args, end='', file=out)
-
- return out.getvalue()
-
- async def _create_blob(self, path: Union[str, Path]) -> str:
+ async def create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
diff --git a/ollama/_types.py b/ollama/_types.py
index 3be80a7..995db14 100644
--- a/ollama/_types.py
+++ b/ollama/_types.py
@@ -2,7 +2,7 @@
from base64 import b64decode, b64encode
from pathlib import Path
from datetime import datetime
-from typing import Any, Mapping, Optional, Union, Sequence
+from typing import Any, Mapping, Optional, Union, Sequence, Dict, List
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import Annotated, Literal
@@ -401,13 +401,25 @@ class PushRequest(BaseStreamableRequest):
class CreateRequest(BaseStreamableRequest):
+ @model_serializer(mode='wrap')
+ def serialize_model(self, nxt):
+ output = nxt(self)
+ if 'from_' in output:
+ output['from'] = output.pop('from_')
+ return output
+
"""
Request to create a new model.
"""
-
- modelfile: Optional[str] = None
-
quantize: Optional[str] = None
+ from_: Optional[str] = None
+ files: Optional[Dict[str, str]] = None
+ adapters: Optional[Dict[str, str]] = None
+ template: Optional[str] = None
+ license: Optional[Union[str, List[str]]] = None
+ system: Optional[str] = None
+ parameters: Optional[Union[Mapping[str, Any], Options]] = None
+ messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None
class ModelDetails(SubscriptableBaseModel):
diff --git a/poetry.lock b/poetry.lock
index 2697430..732e589 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
+# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand.
[[package]]
name = "annotated-types"
@@ -6,6 +6,7 @@ version = "0.7.0"
description = "Reusable constraint types to use with typing.Annotated"
optional = false
python-versions = ">=3.8"
+groups = ["main"]
files = [
{file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"},
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
@@ -20,6 +21,7 @@ version = "4.5.2"
description = "High level compatibility layer for multiple asynchronous event loop implementations"
optional = false
python-versions = ">=3.8"
+groups = ["main"]
files = [
{file = "anyio-4.5.2-py3-none-any.whl", hash = "sha256:c011ee36bc1e8ba40e5a81cb9df91925c218fe9b778554e0b56a21e1b5d4716f"},
{file = "anyio-4.5.2.tar.gz", hash = "sha256:23009af4ed04ce05991845451e11ef02fc7c5ed29179ac9a420e5ad0ac7ddc5b"},
@@ -42,6 +44,7 @@ version = "2024.8.30"
description = "Python package for providing Mozilla's CA Bundle."
optional = false
python-versions = ">=3.6"
+groups = ["main"]
files = [
{file = "certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8"},
{file = "certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9"},
@@ -53,6 +56,8 @@ version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
+groups = ["dev"]
+markers = "sys_platform == \"win32\""
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
@@ -64,6 +69,7 @@ version = "7.6.1"
description = "Code coverage measurement for Python"
optional = false
python-versions = ">=3.8"
+groups = ["dev"]
files = [
{file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"},
{file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"},
@@ -151,6 +157,8 @@ version = "1.2.2"
description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
+groups = ["main", "dev"]
+markers = "python_version < \"3.11\""
files = [
{file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"},
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
@@ -165,6 +173,7 @@ version = "0.14.0"
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
optional = false
python-versions = ">=3.7"
+groups = ["main"]
files = [
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
@@ -176,6 +185,7 @@ version = "1.0.6"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
+groups = ["main"]
files = [
{file = "httpcore-1.0.6-py3-none-any.whl", hash = "sha256:27b59625743b85577a8c0e10e55b50b5368a4f2cfe8cc7bcfa9cf00829c2682f"},
{file = "httpcore-1.0.6.tar.gz", hash = "sha256:73f6dbd6eb8c21bbf7ef8efad555481853f5f6acdeaff1edb0694289269ee17f"},
@@ -197,6 +207,7 @@ version = "0.27.2"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.8"
+groups = ["main"]
files = [
{file = "httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0"},
{file = "httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2"},
@@ -222,6 +233,7 @@ version = "3.10"
description = "Internationalized Domain Names in Applications (IDNA)"
optional = false
python-versions = ">=3.6"
+groups = ["main"]
files = [
{file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"},
{file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"},
@@ -236,6 +248,7 @@ version = "2.0.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.7"
+groups = ["dev"]
files = [
{file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"},
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
@@ -247,6 +260,7 @@ version = "2.1.5"
description = "Safely add untrusted strings to HTML/XML markup."
optional = false
python-versions = ">=3.7"
+groups = ["dev"]
files = [
{file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"},
{file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"},
@@ -316,6 +330,7 @@ version = "24.1"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
+groups = ["dev"]
files = [
{file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"},
{file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"},
@@ -327,6 +342,7 @@ version = "1.5.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.8"
+groups = ["dev"]
files = [
{file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"},
{file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"},
@@ -342,6 +358,7 @@ version = "2.9.2"
description = "Data validation using Python type hints"
optional = false
python-versions = ">=3.8"
+groups = ["main"]
files = [
{file = "pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12"},
{file = "pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f"},
@@ -365,6 +382,7 @@ version = "2.23.4"
description = "Core functionality for Pydantic validation and serialization"
optional = false
python-versions = ">=3.8"
+groups = ["main"]
files = [
{file = "pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b"},
{file = "pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166"},
@@ -466,6 +484,7 @@ version = "8.3.4"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.8"
+groups = ["dev"]
files = [
{file = "pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6"},
{file = "pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761"},
@@ -488,6 +507,7 @@ version = "0.24.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.8"
+groups = ["dev"]
files = [
{file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"},
{file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"},
@@ -506,6 +526,7 @@ version = "5.0.0"
description = "Pytest plugin for measuring coverage."
optional = false
python-versions = ">=3.8"
+groups = ["dev"]
files = [
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
@@ -524,6 +545,7 @@ version = "1.1.0"
description = "pytest-httpserver is a httpserver for pytest"
optional = false
python-versions = ">=3.8"
+groups = ["dev"]
files = [
{file = "pytest_httpserver-1.1.0-py3-none-any.whl", hash = "sha256:7ef88be8ed3354b6784daa3daa75a422370327c634053cefb124903fa8d73a41"},
{file = "pytest_httpserver-1.1.0.tar.gz", hash = "sha256:6b1cb0199e2ed551b1b94d43f096863bbf6ae5bcd7c75c2c06845e5ce2dc8701"},
@@ -538,6 +560,7 @@ version = "0.7.4"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
+groups = ["dev"]
files = [
{file = "ruff-0.7.4-py3-none-linux_armv6l.whl", hash = "sha256:a4919925e7684a3f18e18243cd6bea7cfb8e968a6eaa8437971f681b7ec51478"},
{file = "ruff-0.7.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfb365c135b830778dda8c04fb7d4280ed0b984e1aec27f574445231e20d6c63"},
@@ -565,6 +588,7 @@ version = "1.3.1"
description = "Sniff out which async library your code is running under"
optional = false
python-versions = ">=3.7"
+groups = ["main"]
files = [
{file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"},
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
@@ -576,6 +600,8 @@ version = "2.0.2"
description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
+groups = ["dev"]
+markers = "python_full_version <= \"3.11.0a6\""
files = [
{file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"},
{file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"},
@@ -587,6 +613,7 @@ version = "4.12.2"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
+groups = ["main"]
files = [
{file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
@@ -598,6 +625,7 @@ version = "3.0.6"
description = "The comprehensive WSGI web application library."
optional = false
python-versions = ">=3.8"
+groups = ["dev"]
files = [
{file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"},
{file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"},
@@ -610,6 +638,6 @@ MarkupSafe = ">=2.1.1"
watchdog = ["watchdog (>=2.3)"]
[metadata]
-lock-version = "2.0"
+lock-version = "2.1"
python-versions = "^3.8"
content-hash = "8e93767305535b0a02f0d724edf1249fd928ff1021644eb9dc26dbfa191f6971"
diff --git a/pyproject.toml b/pyproject.toml
index 3a4e14e..2735ac9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -13,6 +13,9 @@ python = "^3.8"
httpx = "^0.27.0"
pydantic = "^2.9.0"
+[tool.poetry.requires-plugins]
+poetry-plugin-export = ">=1.8"
+
[tool.poetry.group.dev.dependencies]
pytest = ">=7.4.3,<9.0.0"
pytest-asyncio = ">=0.23.2,<0.25.0"
diff --git a/tests/test_client.py b/tests/test_client.py
index d837a1a..eb18a19 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -536,52 +536,6 @@ def generate():
assert part['status'] == next(it)
-def test_client_create_path(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
- httpserver.expect_ordered_request(
- '/api/create',
- method='POST',
- json={
- 'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
- 'stream': False,
- },
- ).respond_with_json({'status': 'success'})
-
- client = Client(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as modelfile:
- with tempfile.NamedTemporaryFile() as blob:
- modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
- modelfile.flush()
-
- response = client.create('dummy', path=modelfile.name)
- assert response['status'] == 'success'
-
-
-def test_client_create_path_relative(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
- httpserver.expect_ordered_request(
- '/api/create',
- method='POST',
- json={
- 'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
- 'stream': False,
- },
- ).respond_with_json({'status': 'success'})
-
- client = Client(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as modelfile:
- with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
- modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
- modelfile.flush()
-
- response = client.create('dummy', path=modelfile.name)
- assert response['status'] == 'success'
-
-
@pytest.fixture
def userhomedir():
with tempfile.TemporaryDirectory() as temp:
@@ -591,92 +545,56 @@ def userhomedir():
os.environ['HOME'] = home
-def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
+def test_client_create_with_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
+ 'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
- with tempfile.NamedTemporaryFile() as modelfile:
- with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
- modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
- modelfile.flush()
-
- response = client.create('dummy', path=modelfile.name)
- assert response['status'] == 'success'
-
-
-def test_client_create_modelfile(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
- httpserver.expect_ordered_request(
- '/api/create',
- method='POST',
- json={
- 'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
- 'stream': False,
- },
- ).respond_with_json({'status': 'success'})
-
- client = Client(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as blob:
- response = client.create('dummy', modelfile=f'FROM {blob.name}')
+ with tempfile.NamedTemporaryFile():
+ response = client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
assert response['status'] == 'success'
-def test_client_create_modelfile_roundtrip(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
+def test_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
- 'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
-TEMPLATE """[INST] <>{{.System}}<>
-{{.Prompt}} [/INST]"""
-SYSTEM """
-Use
-multiline
-strings.
-"""
-PARAMETER stop [INST]
-PARAMETER stop [/INST]
-PARAMETER stop <>
-PARAMETER stop <>''',
+ 'quantize': 'q4_k_m',
+ 'from': 'mymodel',
+ 'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
+ 'template': '[INST] <>{{.System}}<>\n{{.Prompt}} [/INST]',
+ 'license': 'this is my license',
+ 'system': '\nUse\nmultiline\nstrings.\n',
+ 'parameters': {'stop': ['[INST]', '[/INST]', '<>', '<>'], 'pi': 3.14159},
+ 'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
- with tempfile.NamedTemporaryFile() as blob:
+ with tempfile.NamedTemporaryFile():
response = client.create(
'dummy',
- modelfile='\n'.join(
- [
- f'FROM {blob.name}',
- 'TEMPLATE """[INST] <>{{.System}}<>',
- '{{.Prompt}} [/INST]"""',
- 'SYSTEM """',
- 'Use',
- 'multiline',
- 'strings.',
- '"""',
- 'PARAMETER stop [INST]',
- 'PARAMETER stop [/INST]',
- 'PARAMETER stop <>',
- 'PARAMETER stop <>',
- ]
- ),
+ quantize='q4_k_m',
+ from_='mymodel',
+ adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
+ template='[INST] <>{{.System}}<>\n{{.Prompt}} [/INST]',
+ license='this is my license',
+ system='\nUse\nmultiline\nstrings.\n',
+ parameters={'stop': ['[INST]', '[/INST]', '<>', '<>'], 'pi': 3.14159},
+ messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
+ stream=False,
)
assert response['status'] == 'success'
@@ -687,14 +605,14 @@ def test_client_create_from_library(httpserver: HTTPServer):
method='POST',
json={
'model': 'dummy',
- 'modelfile': 'FROM llama2',
+ 'from': 'llama2',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = Client(httpserver.url_for('/'))
- response = client.create('dummy', modelfile='FROM llama2')
+ response = client.create('dummy', from_='llama2')
assert response['status'] == 'success'
@@ -704,7 +622,7 @@ def test_client_create_blob(httpserver: HTTPServer):
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
- response = client._create_blob(blob.name)
+ response = client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
@@ -714,7 +632,7 @@ def test_client_create_blob_exists(httpserver: HTTPServer):
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
- response = client._create_blob(blob.name)
+ response = client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
@@ -1015,142 +933,57 @@ def generate():
@pytest.mark.asyncio
-async def test_async_client_create_path(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
+async def test_async_client_create_with_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
+ 'files': {'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
'stream': False,
},
).respond_with_json({'status': 'success'})
client = AsyncClient(httpserver.url_for('/'))
- with tempfile.NamedTemporaryFile() as modelfile:
- with tempfile.NamedTemporaryFile() as blob:
- modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
- modelfile.flush()
-
- response = await client.create('dummy', path=modelfile.name)
- assert response['status'] == 'success'
-
-
-@pytest.mark.asyncio
-async def test_async_client_create_path_relative(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
- httpserver.expect_ordered_request(
- '/api/create',
- method='POST',
- json={
- 'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
- 'stream': False,
- },
- ).respond_with_json({'status': 'success'})
-
- client = AsyncClient(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as modelfile:
- with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
- modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
- modelfile.flush()
-
- response = await client.create('dummy', path=modelfile.name)
- assert response['status'] == 'success'
-
-
-@pytest.mark.asyncio
-async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
- httpserver.expect_ordered_request(
- '/api/create',
- method='POST',
- json={
- 'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
- 'stream': False,
- },
- ).respond_with_json({'status': 'success'})
-
- client = AsyncClient(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as modelfile:
- with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
- modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
- modelfile.flush()
-
- response = await client.create('dummy', path=modelfile.name)
- assert response['status'] == 'success'
-
-
-@pytest.mark.asyncio
-async def test_async_client_create_modelfile(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
- httpserver.expect_ordered_request(
- '/api/create',
- method='POST',
- json={
- 'model': 'dummy',
- 'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
- 'stream': False,
- },
- ).respond_with_json({'status': 'success'})
-
- client = AsyncClient(httpserver.url_for('/'))
-
- with tempfile.NamedTemporaryFile() as blob:
- response = await client.create('dummy', modelfile=f'FROM {blob.name}')
+ with tempfile.NamedTemporaryFile():
+ response = await client.create('dummy', files={'test.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'})
assert response['status'] == 'success'
@pytest.mark.asyncio
-async def test_async_client_create_modelfile_roundtrip(httpserver: HTTPServer):
- httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='POST').respond_with_response(Response(status=200))
+async def test_async_client_create_with_parameters_roundtrip(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
- 'modelfile': '''FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
-TEMPLATE """[INST] <>{{.System}}<>
-{{.Prompt}} [/INST]"""
-SYSTEM """
-Use
-multiline
-strings.
-"""
-PARAMETER stop [INST]
-PARAMETER stop [/INST]
-PARAMETER stop <>
-PARAMETER stop <>''',
+ 'quantize': 'q4_k_m',
+ 'from': 'mymodel',
+ 'adapters': {'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
+ 'template': '[INST] <>{{.System}}<>\n{{.Prompt}} [/INST]',
+ 'license': 'this is my license',
+ 'system': '\nUse\nmultiline\nstrings.\n',
+ 'parameters': {'stop': ['[INST]', '[/INST]', '<>', '<>'], 'pi': 3.14159},
+ 'messages': [{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
'stream': False,
},
).respond_with_json({'status': 'success'})
client = AsyncClient(httpserver.url_for('/'))
- with tempfile.NamedTemporaryFile() as blob:
+ with tempfile.NamedTemporaryFile():
response = await client.create(
'dummy',
- modelfile='\n'.join(
- [
- f'FROM {blob.name}',
- 'TEMPLATE """[INST] <>{{.System}}<>',
- '{{.Prompt}} [/INST]"""',
- 'SYSTEM """',
- 'Use',
- 'multiline',
- 'strings.',
- '"""',
- 'PARAMETER stop [INST]',
- 'PARAMETER stop [/INST]',
- 'PARAMETER stop <>',
- 'PARAMETER stop <>',
- ]
- ),
+ quantize='q4_k_m',
+ from_='mymodel',
+ adapters={'someadapter.gguf': 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'},
+ template='[INST] <>{{.System}}<>\n{{.Prompt}} [/INST]',
+ license='this is my license',
+ system='\nUse\nmultiline\nstrings.\n',
+ parameters={'stop': ['[INST]', '[/INST]', '<>', '<>'], 'pi': 3.14159},
+ messages=[{'role': 'user', 'content': 'Hello there!'}, {'role': 'assistant', 'content': 'Hello there yourself!'}],
+ stream=False,
)
assert response['status'] == 'success'
@@ -1162,14 +995,14 @@ async def test_async_client_create_from_library(httpserver: HTTPServer):
method='POST',
json={
'model': 'dummy',
- 'modelfile': 'FROM llama2',
+ 'from': 'llama2',
'stream': False,
},
).respond_with_json({'status': 'success'})
client = AsyncClient(httpserver.url_for('/'))
- response = await client.create('dummy', modelfile='FROM llama2')
+ response = await client.create('dummy', from_='llama2')
assert response['status'] == 'success'
@@ -1180,7 +1013,7 @@ async def test_async_client_create_blob(httpserver: HTTPServer):
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
- response = await client._create_blob(blob.name)
+ response = await client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
@@ -1191,7 +1024,7 @@ async def test_async_client_create_blob_exists(httpserver: HTTPServer):
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
- response = await client._create_blob(blob.name)
+ response = await client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
diff --git a/tests/test_type_serialization.py b/tests/test_type_serialization.py
index 8200ce3..1ecbe08 100644
--- a/tests/test_type_serialization.py
+++ b/tests/test_type_serialization.py
@@ -2,7 +2,7 @@
from pathlib import Path
import pytest
-from ollama._types import Image
+from ollama._types import CreateRequest, Image
import tempfile
@@ -52,3 +52,42 @@ def test_image_serialization_string_path():
with pytest.raises(ValueError):
img = Image(value='not an image')
img.model_dump()
+
+
+def test_create_request_serialization():
+ request = CreateRequest(model='test-model', from_='base-model', quantize='q4_0', files={'file1': 'content1'}, adapters={'adapter1': 'content1'}, template='test template', license='MIT', system='test system', parameters={'param1': 'value1'})
+
+ serialized = request.model_dump()
+ assert serialized['from'] == 'base-model'
+ assert 'from_' not in serialized
+ assert serialized['quantize'] == 'q4_0'
+ assert serialized['files'] == {'file1': 'content1'}
+ assert serialized['adapters'] == {'adapter1': 'content1'}
+ assert serialized['template'] == 'test template'
+ assert serialized['license'] == 'MIT'
+ assert serialized['system'] == 'test system'
+ assert serialized['parameters'] == {'param1': 'value1'}
+
+
+def test_create_request_serialization_exclude_none_true():
+ request = CreateRequest(model='test-model', from_=None, quantize=None)
+ serialized = request.model_dump(exclude_none=True)
+ assert serialized == {'model': 'test-model'}
+ assert 'from' not in serialized
+ assert 'from_' not in serialized
+ assert 'quantize' not in serialized
+
+
+def test_create_request_serialization_exclude_none_false():
+ request = CreateRequest(model='test-model', from_=None, quantize=None)
+ serialized = request.model_dump(exclude_none=False)
+ assert 'from' in serialized
+ assert 'quantize' in serialized
+ assert 'adapters' in serialized
+ assert 'from_' not in serialized
+
+
+def test_create_request_serialization_license_list():
+ request = CreateRequest(model='test-model', license=['MIT', 'Apache-2.0'])
+ serialized = request.model_dump()
+ assert serialized['license'] == ['MIT', 'Apache-2.0']