mirror of
https://github.com/ollama/ollama-python.git
synced 2026-02-04 02:54:08 -06:00
type hints
This commit is contained in:
parent
6f5565914f
commit
dabcca6a1f
@ -1,8 +1,10 @@
|
||||
from ._client import Client, AsyncClient
|
||||
from ollama._client import Client, AsyncClient, Message, Options
|
||||
|
||||
__all__ = [
|
||||
'Client',
|
||||
'AsyncClient',
|
||||
'Message',
|
||||
'Options',
|
||||
'generate',
|
||||
'chat',
|
||||
'pull',
|
||||
|
||||
@ -1,31 +1,42 @@
|
||||
import io
|
||||
import json
|
||||
import httpx
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
from base64 import b64encode
|
||||
|
||||
from typing import Any, AnyStr, Union, Optional, List, Mapping
|
||||
|
||||
import sys
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import Iterator, AsyncIterator
|
||||
else:
|
||||
from collections.abc import Iterator, AsyncIterator
|
||||
|
||||
from ollama._types import Message, Options
|
||||
|
||||
|
||||
class BaseClient:
|
||||
|
||||
def __init__(self, client, base_url='http://127.0.0.1:11434'):
|
||||
def __init__(self, client, base_url='http://127.0.0.1:11434') -> None:
|
||||
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
|
||||
|
||||
|
||||
class Client(BaseClient):
|
||||
|
||||
def __init__(self, base='http://localhost:11434'):
|
||||
def __init__(self, base='http://localhost:11434') -> None:
|
||||
super().__init__(httpx.Client, base)
|
||||
|
||||
def _request(self, method, url, **kwargs):
|
||||
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
response = self._client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _request_json(self, method, url, **kwargs):
|
||||
def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
|
||||
return self._request(method, url, **kwargs).json()
|
||||
|
||||
def _stream(self, method, url, **kwargs):
|
||||
def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
|
||||
with self._client.stream(method, url, **kwargs) as r:
|
||||
for line in r.iter_lines():
|
||||
part = json.loads(line)
|
||||
@ -33,7 +44,19 @@ class Client(BaseClient):
|
||||
raise Exception(e)
|
||||
yield part
|
||||
|
||||
def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
|
||||
def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[List[int]] = None,
|
||||
stream: bool = False,
|
||||
raw: bool = False,
|
||||
format: str = '',
|
||||
images: Optional[List[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
@ -51,7 +74,14 @@ class Client(BaseClient):
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
def chat(self, model='', messages=None, stream=False, format='', options=None):
|
||||
def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[List[Message]] = None,
|
||||
stream: bool = False,
|
||||
format: str = '',
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
@ -74,7 +104,12 @@ class Client(BaseClient):
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
def pull(self, model, insecure=False, stream=False):
|
||||
def pull(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn('POST', '/api/pull', json={
|
||||
'model': model,
|
||||
@ -82,7 +117,12 @@ class Client(BaseClient):
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
def push(self, model, insecure=False, stream=False):
|
||||
def push(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn('POST', '/api/push', json={
|
||||
'model': model,
|
||||
@ -90,9 +130,15 @@ class Client(BaseClient):
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
def create(self, model, path=None, modelfile=None, stream=False):
|
||||
if (path := _as_path(path)) and path.exists():
|
||||
modelfile = self._parse_modelfile(path.read_text(), base=path.parent)
|
||||
def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
if (realpath := _as_path(path)) and realpath.exists():
|
||||
modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
|
||||
elif modelfile:
|
||||
modelfile = self._parse_modelfile(modelfile)
|
||||
else:
|
||||
@ -105,7 +151,7 @@ class Client(BaseClient):
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
def _parse_modelfile(self, modelfile, base=None):
|
||||
def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
|
||||
base = Path.cwd() if base is None else base
|
||||
|
||||
out = io.StringIO()
|
||||
@ -120,7 +166,7 @@ class Client(BaseClient):
|
||||
print(command, args, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
def _create_blob(self, path):
|
||||
def _create_blob(self, path: Union[str, Path]) -> str:
|
||||
sha256sum = sha256()
|
||||
with open(path, 'rb') as r:
|
||||
while True:
|
||||
@ -142,36 +188,36 @@ class Client(BaseClient):
|
||||
|
||||
return digest
|
||||
|
||||
def delete(self, model):
|
||||
response = self._request_json('DELETE', '/api/delete', json={'model': model})
|
||||
def delete(self, model: str) -> Mapping[str, Any]:
|
||||
response = self._request('DELETE', '/api/delete', json={'model': model})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def list(self):
|
||||
def list(self) -> Mapping[str, Any]:
|
||||
return self._request_json('GET', '/api/tags').get('models', [])
|
||||
|
||||
def copy(self, source, target):
|
||||
response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
def copy(self, source: str, target: str) -> Mapping[str, Any]:
|
||||
response = self._request('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def show(self, model):
|
||||
def show(self, model: str) -> Mapping[str, Any]:
|
||||
return self._request_json('GET', '/api/show', json={'model': model})
|
||||
|
||||
|
||||
class AsyncClient(BaseClient):
|
||||
|
||||
def __init__(self, base='http://localhost:11434'):
|
||||
def __init__(self, base='http://localhost:11434') -> None:
|
||||
super().__init__(httpx.AsyncClient, base)
|
||||
|
||||
async def _request(self, method, url, **kwargs):
|
||||
async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
response = await self._client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
async def _request_json(self, method, url, **kwargs):
|
||||
async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
|
||||
response = await self._request(method, url, **kwargs)
|
||||
return response.json()
|
||||
|
||||
async def _stream(self, method, url, **kwargs):
|
||||
async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]:
|
||||
async def inner():
|
||||
async with self._client.stream(method, url, **kwargs) as r:
|
||||
async for line in r.aiter_lines():
|
||||
@ -181,7 +227,19 @@ class AsyncClient(BaseClient):
|
||||
yield part
|
||||
return inner()
|
||||
|
||||
async def generate(self, model='', prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
|
||||
async def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[List[int]] = None,
|
||||
stream: bool = False,
|
||||
raw: bool = False,
|
||||
format: str = '',
|
||||
images: Optional[List[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
@ -199,7 +257,14 @@ class AsyncClient(BaseClient):
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
async def chat(self, model='', messages=None, stream=False, format='', options=None):
|
||||
async def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[List[Message]] = None,
|
||||
stream: bool = False,
|
||||
format: str = '',
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
@ -222,7 +287,12 @@ class AsyncClient(BaseClient):
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
async def pull(self, model, insecure=False, stream=False):
|
||||
async def pull(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn('POST', '/api/pull', json={
|
||||
'model': model,
|
||||
@ -230,7 +300,12 @@ class AsyncClient(BaseClient):
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
async def push(self, model, insecure=False, stream=False):
|
||||
async def push(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn('POST', '/api/push', json={
|
||||
'model': model,
|
||||
@ -238,9 +313,15 @@ class AsyncClient(BaseClient):
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
async def create(self, model, path=None, modelfile=None, stream=False):
|
||||
if (path := _as_path(path)) and path.exists():
|
||||
modelfile = await self._parse_modelfile(path.read_text(), base=path.parent)
|
||||
async def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
if (realpath := _as_path(path)) and realpath.exists():
|
||||
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
|
||||
elif modelfile:
|
||||
modelfile = await self._parse_modelfile(modelfile)
|
||||
else:
|
||||
@ -253,7 +334,7 @@ class AsyncClient(BaseClient):
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
async def _parse_modelfile(self, modelfile, base=None):
|
||||
async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
|
||||
base = Path.cwd() if base is None else base
|
||||
|
||||
out = io.StringIO()
|
||||
@ -268,7 +349,7 @@ class AsyncClient(BaseClient):
|
||||
print(command, args, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
async def _create_blob(self, path):
|
||||
async def _create_blob(self, path: Union[str, Path]) -> str:
|
||||
sha256sum = sha256()
|
||||
with open(path, 'rb') as r:
|
||||
while True:
|
||||
@ -297,23 +378,23 @@ class AsyncClient(BaseClient):
|
||||
|
||||
return digest
|
||||
|
||||
async def delete(self, model):
|
||||
response = await self._request_json('DELETE', '/api/delete', json={'model': model})
|
||||
async def delete(self, model: str) -> Mapping[str, Any]:
|
||||
response = await self._request('DELETE', '/api/delete', json={'model': model})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
async def list(self):
|
||||
async def list(self) -> Mapping[str, Any]:
|
||||
response = await self._request_json('GET', '/api/tags')
|
||||
return response.get('models', [])
|
||||
|
||||
async def copy(self, source, target):
|
||||
response = await self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
async def copy(self, source: str, target: str) -> Mapping[str, Any]:
|
||||
response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
async def show(self, model):
|
||||
async def show(self, model: str) -> Mapping[str, Any]:
|
||||
return await self._request_json('GET', '/api/show', json={'model': model})
|
||||
|
||||
|
||||
def _encode_image(image):
|
||||
def _encode_image(image) -> str:
|
||||
if p := _as_path(image):
|
||||
b64 = b64encode(p.read_bytes())
|
||||
elif b := _as_bytesio(image):
|
||||
@ -324,12 +405,13 @@ def _encode_image(image):
|
||||
return b64.decode('utf-8')
|
||||
|
||||
|
||||
def _as_path(s):
|
||||
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
|
||||
if isinstance(s, str) or isinstance(s, Path):
|
||||
return Path(s)
|
||||
return None
|
||||
|
||||
def _as_bytesio(s):
|
||||
|
||||
def _as_bytesio(s: Any) -> Union[io.BytesIO, None]:
|
||||
if isinstance(s, io.BytesIO):
|
||||
return s
|
||||
elif isinstance(s, bytes):
|
||||
|
||||
52
ollama/_types.py
Normal file
52
ollama/_types.py
Normal file
@ -0,0 +1,52 @@
|
||||
from typing import Any, TypedDict, List
|
||||
|
||||
import sys
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import NotRequired
|
||||
else:
|
||||
from typing import NotRequired
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
images: NotRequired[List[Any]]
|
||||
|
||||
|
||||
class Options(TypedDict, total=False):
|
||||
# load time options
|
||||
numa: bool
|
||||
num_ctx: int
|
||||
num_batch: int
|
||||
num_gqa: int
|
||||
num_gpu: int
|
||||
main_gpu: int
|
||||
low_vram: bool
|
||||
f16_kv: bool
|
||||
logits_all: bool
|
||||
vocab_only: bool
|
||||
use_mmap: bool
|
||||
use_mlock: bool
|
||||
embedding_only: bool
|
||||
rope_frequency_base: float
|
||||
rope_frequency_scale: float
|
||||
num_thread: int
|
||||
|
||||
# runtime options
|
||||
num_keep: int
|
||||
seed: int
|
||||
num_predict: int
|
||||
top_k: int
|
||||
top_p: float
|
||||
tfs_z: float
|
||||
typical_p: float
|
||||
repeat_last_n: int
|
||||
temperature: float
|
||||
repeat_penalty: float
|
||||
presence_penalty: float
|
||||
frequency_penalty: float
|
||||
mirostat: int
|
||||
mirostat_tau: float
|
||||
mirostat_eta: float
|
||||
penalize_newline: bool
|
||||
stop: List[str]
|
||||
Loading…
x
Reference in New Issue
Block a user