From dabcca6a1f69e888eb055a44876144f04a2e0ffe Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 21 Dec 2023 09:46:05 -0800 Subject: [PATCH] type hints --- ollama/__init__.py | 4 +- ollama/_client.py | 166 +++++++++++++++++++++++++++++++++------------ ollama/_types.py | 52 ++++++++++++++ 3 files changed, 179 insertions(+), 43 deletions(-) create mode 100644 ollama/_types.py diff --git a/ollama/__init__.py b/ollama/__init__.py index c6fe6d4..a66f1d0 100644 --- a/ollama/__init__.py +++ b/ollama/__init__.py @@ -1,8 +1,10 @@ -from ._client import Client, AsyncClient +from ollama._client import Client, AsyncClient, Message, Options __all__ = [ 'Client', 'AsyncClient', + 'Message', + 'Options', 'generate', 'chat', 'pull', diff --git a/ollama/_client.py b/ollama/_client.py index 3882ac3..9e5ee32 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -1,31 +1,42 @@ import io import json import httpx +from os import PathLike from pathlib import Path from hashlib import sha256 from base64 import b64encode +from typing import Any, AnyStr, Union, Optional, List, Mapping + +import sys +if sys.version_info < (3, 9): + from typing import Iterator, AsyncIterator +else: + from collections.abc import Iterator, AsyncIterator + +from ollama._types import Message, Options + class BaseClient: - def __init__(self, client, base_url='http://127.0.0.1:11434'): + def __init__(self, client, base_url='http://127.0.0.1:11434') -> None: self._client = client(base_url=base_url, follow_redirects=True, timeout=None) class Client(BaseClient): - def __init__(self, base='http://localhost:11434'): + def __init__(self, base='http://localhost:11434') -> None: super().__init__(httpx.Client, base) - def _request(self, method, url, **kwargs): + def _request(self, method: str, url: str, **kwargs) -> httpx.Response: response = self._client.request(method, url, **kwargs) response.raise_for_status() return response - def _request_json(self, method, url, **kwargs): + def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: return self._request(method, url, **kwargs).json() - def _stream(self, method, url, **kwargs): + def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]: with self._client.stream(method, url, **kwargs) as r: for line in r.iter_lines(): part = json.loads(line) @@ -33,7 +44,19 @@ class Client(BaseClient): raise Exception(e) yield part - def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None): + def generate( + self, + model: str = '', + prompt: str = '', + system: str = '', + template: str = '', + context: Optional[List[int]] = None, + stream: bool = False, + raw: bool = False, + format: str = '', + images: Optional[List[AnyStr]] = None, + options: Optional[Options] = None, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: if not model: raise Exception('must provide a model') @@ -51,7 +74,14 @@ class Client(BaseClient): 'options': options or {}, }) - def chat(self, model='', messages=None, stream=False, format='', options=None): + def chat( + self, + model: str = '', + messages: Optional[List[Message]] = None, + stream: bool = False, + format: str = '', + options: Optional[Options] = None, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: if not model: raise Exception('must provide a model') @@ -74,7 +104,12 @@ class Client(BaseClient): 'options': options or {}, }) - def pull(self, model, insecure=False, stream=False): + def pull( + self, + model: str, + insecure: bool = False, + stream: bool = False, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: fn = self._stream if stream else self._request_json return fn('POST', '/api/pull', json={ 'model': model, @@ -82,7 +117,12 @@ class Client(BaseClient): 'stream': stream, }) - def push(self, model, insecure=False, stream=False): + def push( + self, + model: str, + insecure: bool = False, + stream: bool = False, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: fn = self._stream if stream else self._request_json return fn('POST', '/api/push', json={ 'model': model, @@ -90,9 +130,15 @@ class Client(BaseClient): 'stream': stream, }) - def create(self, model, path=None, modelfile=None, stream=False): - if (path := _as_path(path)) and path.exists(): - modelfile = self._parse_modelfile(path.read_text(), base=path.parent) + def create( + self, + model: str, + path: Optional[Union[str, PathLike]] = None, + modelfile: Optional[str] = None, + stream: bool = False, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + if (realpath := _as_path(path)) and realpath.exists(): + modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent) elif modelfile: modelfile = self._parse_modelfile(modelfile) else: @@ -105,7 +151,7 @@ class Client(BaseClient): 'stream': stream, }) - def _parse_modelfile(self, modelfile, base=None): + def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: base = Path.cwd() if base is None else base out = io.StringIO() @@ -120,7 +166,7 @@ class Client(BaseClient): print(command, args, file=out) return out.getvalue() - def _create_blob(self, path): + def _create_blob(self, path: Union[str, Path]) -> str: sha256sum = sha256() with open(path, 'rb') as r: while True: @@ -142,36 +188,36 @@ class Client(BaseClient): return digest - def delete(self, model): - response = self._request_json('DELETE', '/api/delete', json={'model': model}) + def delete(self, model: str) -> Mapping[str, Any]: + response = self._request('DELETE', '/api/delete', json={'model': model}) return {'status': 'success' if response.status_code == 200 else 'error'} - def list(self): + def list(self) -> Mapping[str, Any]: return self._request_json('GET', '/api/tags').get('models', []) - def copy(self, source, target): - response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target}) + def copy(self, source: str, target: str) -> Mapping[str, Any]: + response = self._request('POST', '/api/copy', json={'source': source, 'destination': target}) return {'status': 'success' if response.status_code == 200 else 'error'} - def show(self, model): + def show(self, model: str) -> Mapping[str, Any]: return self._request_json('GET', '/api/show', json={'model': model}) class AsyncClient(BaseClient): - def __init__(self, base='http://localhost:11434'): + def __init__(self, base='http://localhost:11434') -> None: super().__init__(httpx.AsyncClient, base) - async def _request(self, method, url, **kwargs): + async def _request(self, method: str, url: str, **kwargs) -> httpx.Response: response = await self._client.request(method, url, **kwargs) response.raise_for_status() return response - async def _request_json(self, method, url, **kwargs): + async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: response = await self._request(method, url, **kwargs) return response.json() - async def _stream(self, method, url, **kwargs): + async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]: async def inner(): async with self._client.stream(method, url, **kwargs) as r: async for line in r.aiter_lines(): @@ -181,7 +227,19 @@ class AsyncClient(BaseClient): yield part return inner() - async def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None): + async def generate( + self, + model: str = '', + prompt: str = '', + system: str = '', + template: str = '', + context: Optional[List[int]] = None, + stream: bool = False, + raw: bool = False, + format: str = '', + images: Optional[List[AnyStr]] = None, + options: Optional[Options] = None, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: if not model: raise Exception('must provide a model') @@ -199,7 +257,14 @@ class AsyncClient(BaseClient): 'options': options or {}, }) - async def chat(self, model='', messages=None, stream=False, format='', options=None): + async def chat( + self, + model: str = '', + messages: Optional[List[Message]] = None, + stream: bool = False, + format: str = '', + options: Optional[Options] = None, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: if not model: raise Exception('must provide a model') @@ -222,7 +287,12 @@ class AsyncClient(BaseClient): 'options': options or {}, }) - async def pull(self, model, insecure=False, stream=False): + async def pull( + self, + model: str, + insecure: bool = False, + stream: bool = False, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: fn = self._stream if stream else self._request_json return await fn('POST', '/api/pull', json={ 'model': model, @@ -230,7 +300,12 @@ class AsyncClient(BaseClient): 'stream': stream, }) - async def push(self, model, insecure=False, stream=False): + async def push( + self, + model: str, + insecure: bool = False, + stream: bool = False, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: fn = self._stream if stream else self._request_json return await fn('POST', '/api/push', json={ 'model': model, @@ -238,9 +313,15 @@ class AsyncClient(BaseClient): 'stream': stream, }) - async def create(self, model, path=None, modelfile=None, stream=False): - if (path := _as_path(path)) and path.exists(): - modelfile = await self._parse_modelfile(path.read_text(), base=path.parent) + async def create( + self, + model: str, + path: Optional[Union[str, PathLike]] = None, + modelfile: Optional[str] = None, + stream: bool = False, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + if (realpath := _as_path(path)) and realpath.exists(): + modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent) elif modelfile: modelfile = await self._parse_modelfile(modelfile) else: @@ -253,7 +334,7 @@ class AsyncClient(BaseClient): 'stream': stream, }) - async def _parse_modelfile(self, modelfile, base=None): + async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: base = Path.cwd() if base is None else base out = io.StringIO() @@ -268,7 +349,7 @@ class AsyncClient(BaseClient): print(command, args, file=out) return out.getvalue() - async def _create_blob(self, path): + async def _create_blob(self, path: Union[str, Path]) -> str: sha256sum = sha256() with open(path, 'rb') as r: while True: @@ -297,23 +378,23 @@ class AsyncClient(BaseClient): return digest - async def delete(self, model): - response = await self._request_json('DELETE', '/api/delete', json={'model': model}) + async def delete(self, model: str) -> Mapping[str, Any]: + response = await self._request('DELETE', '/api/delete', json={'model': model}) return {'status': 'success' if response.status_code == 200 else 'error'} - async def list(self): + async def list(self) -> Mapping[str, Any]: response = await self._request_json('GET', '/api/tags') return response.get('models', []) - async def copy(self, source, target): - response = await self._request_json('POST', '/api/copy', json={'source': source, 'destination': target}) + async def copy(self, source: str, target: str) -> Mapping[str, Any]: + response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target}) return {'status': 'success' if response.status_code == 200 else 'error'} - async def show(self, model): + async def show(self, model: str) -> Mapping[str, Any]: return await self._request_json('GET', '/api/show', json={'model': model}) -def _encode_image(image): +def _encode_image(image) -> str: if p := _as_path(image): b64 = b64encode(p.read_bytes()) elif b := _as_bytesio(image): @@ -324,12 +405,13 @@ def _encode_image(image): return b64.decode('utf-8') -def _as_path(s): +def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]: if isinstance(s, str) or isinstance(s, Path): return Path(s) return None -def _as_bytesio(s): + +def _as_bytesio(s: Any) -> Union[io.BytesIO, None]: if isinstance(s, io.BytesIO): return s elif isinstance(s, bytes): diff --git a/ollama/_types.py b/ollama/_types.py new file mode 100644 index 0000000..7fe3bf0 --- /dev/null +++ b/ollama/_types.py @@ -0,0 +1,52 @@ +from typing import Any, TypedDict, List + +import sys +if sys.version_info < (3, 11): + from typing_extensions import NotRequired +else: + from typing import NotRequired + + +class Message(TypedDict): + role: str + content: str + images: NotRequired[List[Any]] + + +class Options(TypedDict, total=False): + # load time options + numa: bool + num_ctx: int + num_batch: int + num_gqa: int + num_gpu: int + main_gpu: int + low_vram: bool + f16_kv: bool + logits_all: bool + vocab_only: bool + use_mmap: bool + use_mlock: bool + embedding_only: bool + rope_frequency_base: float + rope_frequency_scale: float + num_thread: int + + # runtime options + num_keep: int + seed: int + num_predict: int + top_k: int + top_p: float + tfs_z: float + typical_p: float + repeat_last_n: int + temperature: float + repeat_penalty: float + presence_penalty: float + frequency_penalty: float + mirostat: int + mirostat_tau: float + mirostat_eta: float + penalize_newline: bool + stop: List[str]