type hints

This commit is contained in:
Michael Yang 2023-12-21 09:46:05 -08:00
parent 6f5565914f
commit dabcca6a1f
3 changed files with 179 additions and 43 deletions

View File

@ -1,8 +1,10 @@
from ._client import Client, AsyncClient
from ollama._client import Client, AsyncClient, Message, Options
__all__ = [
'Client',
'AsyncClient',
'Message',
'Options',
'generate',
'chat',
'pull',

View File

@ -1,31 +1,42 @@
import io
import json
import httpx
from os import PathLike
from pathlib import Path
from hashlib import sha256
from base64 import b64encode
from typing import Any, AnyStr, Union, Optional, List, Mapping
import sys
if sys.version_info < (3, 9):
from typing import Iterator, AsyncIterator
else:
from collections.abc import Iterator, AsyncIterator
from ollama._types import Message, Options
class BaseClient:
def __init__(self, client, base_url='http://127.0.0.1:11434'):
def __init__(self, client, base_url='http://127.0.0.1:11434') -> None:
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
class Client(BaseClient):
def __init__(self, base='http://localhost:11434'):
def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.Client, base)
def _request(self, method, url, **kwargs):
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = self._client.request(method, url, **kwargs)
response.raise_for_status()
return response
def _request_json(self, method, url, **kwargs):
def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
return self._request(method, url, **kwargs).json()
def _stream(self, method, url, **kwargs):
def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
with self._client.stream(method, url, **kwargs) as r:
for line in r.iter_lines():
part = json.loads(line)
@ -33,7 +44,19 @@ class Client(BaseClient):
raise Exception(e)
yield part
def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
def generate(
self,
model: str = '',
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[List[int]] = None,
stream: bool = False,
raw: bool = False,
format: str = '',
images: Optional[List[AnyStr]] = None,
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')
@ -51,7 +74,14 @@ class Client(BaseClient):
'options': options or {},
})
def chat(self, model='', messages=None, stream=False, format='', options=None):
def chat(
self,
model: str = '',
messages: Optional[List[Message]] = None,
stream: bool = False,
format: str = '',
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')
@ -74,7 +104,12 @@ class Client(BaseClient):
'options': options or {},
})
def pull(self, model, insecure=False, stream=False):
def pull(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return fn('POST', '/api/pull', json={
'model': model,
@ -82,7 +117,12 @@ class Client(BaseClient):
'stream': stream,
})
def push(self, model, insecure=False, stream=False):
def push(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return fn('POST', '/api/push', json={
'model': model,
@ -90,9 +130,15 @@ class Client(BaseClient):
'stream': stream,
})
def create(self, model, path=None, modelfile=None, stream=False):
if (path := _as_path(path)) and path.exists():
modelfile = self._parse_modelfile(path.read_text(), base=path.parent)
def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if (realpath := _as_path(path)) and realpath.exists():
modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = self._parse_modelfile(modelfile)
else:
@ -105,7 +151,7 @@ class Client(BaseClient):
'stream': stream,
})
def _parse_modelfile(self, modelfile, base=None):
def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base
out = io.StringIO()
@ -120,7 +166,7 @@ class Client(BaseClient):
print(command, args, file=out)
return out.getvalue()
def _create_blob(self, path):
def _create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
@ -142,36 +188,36 @@ class Client(BaseClient):
return digest
def delete(self, model):
response = self._request_json('DELETE', '/api/delete', json={'model': model})
def delete(self, model: str) -> Mapping[str, Any]:
response = self._request('DELETE', '/api/delete', json={'model': model})
return {'status': 'success' if response.status_code == 200 else 'error'}
def list(self):
def list(self) -> Mapping[str, Any]:
return self._request_json('GET', '/api/tags').get('models', [])
def copy(self, source, target):
response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
def copy(self, source: str, target: str) -> Mapping[str, Any]:
response = self._request('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}
def show(self, model):
def show(self, model: str) -> Mapping[str, Any]:
return self._request_json('GET', '/api/show', json={'model': model})
class AsyncClient(BaseClient):
def __init__(self, base='http://localhost:11434'):
def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.AsyncClient, base)
async def _request(self, method, url, **kwargs):
async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = await self._client.request(method, url, **kwargs)
response.raise_for_status()
return response
async def _request_json(self, method, url, **kwargs):
async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
response = await self._request(method, url, **kwargs)
return response.json()
async def _stream(self, method, url, **kwargs):
async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]:
async def inner():
async with self._client.stream(method, url, **kwargs) as r:
async for line in r.aiter_lines():
@ -181,7 +227,19 @@ class AsyncClient(BaseClient):
yield part
return inner()
async def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
async def generate(
self,
model: str = '',
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[List[int]] = None,
stream: bool = False,
raw: bool = False,
format: str = '',
images: Optional[List[AnyStr]] = None,
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')
@ -199,7 +257,14 @@ class AsyncClient(BaseClient):
'options': options or {},
})
async def chat(self, model='', messages=None, stream=False, format='', options=None):
async def chat(
self,
model: str = '',
messages: Optional[List[Message]] = None,
stream: bool = False,
format: str = '',
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')
@ -222,7 +287,12 @@ class AsyncClient(BaseClient):
'options': options or {},
})
async def pull(self, model, insecure=False, stream=False):
async def pull(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return await fn('POST', '/api/pull', json={
'model': model,
@ -230,7 +300,12 @@ class AsyncClient(BaseClient):
'stream': stream,
})
async def push(self, model, insecure=False, stream=False):
async def push(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return await fn('POST', '/api/push', json={
'model': model,
@ -238,9 +313,15 @@ class AsyncClient(BaseClient):
'stream': stream,
})
async def create(self, model, path=None, modelfile=None, stream=False):
if (path := _as_path(path)) and path.exists():
modelfile = await self._parse_modelfile(path.read_text(), base=path.parent)
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if (realpath := _as_path(path)) and realpath.exists():
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = await self._parse_modelfile(modelfile)
else:
@ -253,7 +334,7 @@ class AsyncClient(BaseClient):
'stream': stream,
})
async def _parse_modelfile(self, modelfile, base=None):
async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base
out = io.StringIO()
@ -268,7 +349,7 @@ class AsyncClient(BaseClient):
print(command, args, file=out)
return out.getvalue()
async def _create_blob(self, path):
async def _create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
@ -297,23 +378,23 @@ class AsyncClient(BaseClient):
return digest
async def delete(self, model):
response = await self._request_json('DELETE', '/api/delete', json={'model': model})
async def delete(self, model: str) -> Mapping[str, Any]:
response = await self._request('DELETE', '/api/delete', json={'model': model})
return {'status': 'success' if response.status_code == 200 else 'error'}
async def list(self):
async def list(self) -> Mapping[str, Any]:
response = await self._request_json('GET', '/api/tags')
return response.get('models', [])
async def copy(self, source, target):
response = await self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
async def copy(self, source: str, target: str) -> Mapping[str, Any]:
response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}
async def show(self, model):
async def show(self, model: str) -> Mapping[str, Any]:
return await self._request_json('GET', '/api/show', json={'model': model})
def _encode_image(image):
def _encode_image(image) -> str:
if p := _as_path(image):
b64 = b64encode(p.read_bytes())
elif b := _as_bytesio(image):
@ -324,12 +405,13 @@ def _encode_image(image):
return b64.decode('utf-8')
def _as_path(s):
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
if isinstance(s, str) or isinstance(s, Path):
return Path(s)
return None
def _as_bytesio(s):
def _as_bytesio(s: Any) -> Union[io.BytesIO, None]:
if isinstance(s, io.BytesIO):
return s
elif isinstance(s, bytes):

52
ollama/_types.py Normal file
View File

@ -0,0 +1,52 @@
from typing import Any, TypedDict, List
import sys
if sys.version_info < (3, 11):
from typing_extensions import NotRequired
else:
from typing import NotRequired
class Message(TypedDict):
role: str
content: str
images: NotRequired[List[Any]]
class Options(TypedDict, total=False):
# load time options
numa: bool
num_ctx: int
num_batch: int
num_gqa: int
num_gpu: int
main_gpu: int
low_vram: bool
f16_kv: bool
logits_all: bool
vocab_only: bool
use_mmap: bool
use_mlock: bool
embedding_only: bool
rope_frequency_base: float
rope_frequency_scale: float
num_thread: int
# runtime options
num_keep: int
seed: int
num_predict: int
top_k: int
top_p: float
tfs_z: float
typical_p: float
repeat_last_n: int
temperature: float
repeat_penalty: float
presence_penalty: float
frequency_penalty: float
mirostat: int
mirostat_tau: float
mirostat_eta: float
penalize_newline: bool
stop: List[str]