diff --git a/ollama/_client.py b/ollama/_client.py index 87fa881..4647784 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -115,11 +115,13 @@ def __init__(self, host: Optional[str] = None, **kwargs) -> None: super().__init__(httpx.Client, host, **kwargs) def _request_raw(self, *args, **kwargs): - r = self._client.request(*args, **kwargs) try: + r = self._client.request(*args, **kwargs) r.raise_for_status() except httpx.HTTPStatusError as e: raise ResponseError(e.response.text, e.response.status_code) from None + except httpx.ConnectError: + raise ResponseError("error connecting to ollama server: have you checked that it's running?", 500) return r @overload @@ -617,11 +619,13 @@ def __init__(self, host: Optional[str] = None, **kwargs) -> None: super().__init__(httpx.AsyncClient, host, **kwargs) async def _request_raw(self, *args, **kwargs): - r = await self._client.request(*args, **kwargs) try: + r = await self._client.request(*args, **kwargs) r.raise_for_status() except httpx.HTTPStatusError as e: raise ResponseError(e.response.text, e.response.status_code) from None + except httpx.ConnectError: + raise ResponseError("error connecting to ollama server: have you checked that it's running?", 500) return r @overload diff --git a/tests/test_client.py b/tests/test_client.py index aab2f2e..41e6be3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,8 @@ import os import io import json + +import httpx from pydantic import ValidationError, BaseModel import pytest import tempfile @@ -10,6 +12,7 @@ from PIL import Image from ollama._client import Client, AsyncClient, _copy_tools +from ollama._types import ResponseError class PrefixPattern(URIPattern): @@ -182,6 +185,29 @@ class ResponseFormat(BaseModel): assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}' +def test_client_gracefully_handles_ollama_server_not_running(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], + 'format': 'json', + 'stream': False, + }, + ).respond_with_handler(lambda _: Response()) + + def _monkey_patched_request_func(*args, **kwargs): + raise httpx.ConnectError("[Errno 111] Connection refused") + + client = Client(httpserver.url_for('/')) + client._client.request = _monkey_patched_request_func + + with pytest.raises(ResponseError): + client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json') + + @pytest.mark.asyncio async def test_async_client_chat_format_json(httpserver: HTTPServer): httpserver.expect_ordered_request( @@ -244,6 +270,30 @@ class ResponseFormat(BaseModel): assert response['message']['content'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}' +@pytest.mark.asyncio +async def test_async_client_chat_format_json(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}], + 'tools': [], + 'format': 'json', + 'stream': False, + }, + ).respond_with_handler(lambda _: Response()) + + async def _monkey_patched_request_func(*args, **kwargs): + raise httpx.ConnectError("[Errno 111] Connection refused") + + client = AsyncClient(httpserver.url_for('/')) + client._client.request = _monkey_patched_request_func + + with pytest.raises(ResponseError): + await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], format='json') + + def test_client_generate(httpserver: HTTPServer): httpserver.expect_ordered_request( '/api/generate',