Merge pull request #1 from jmorganca/mxyng/tests

add ci tests/lint
This commit is contained in:
Michael Yang 2023-12-21 15:06:14 -08:00 committed by GitHub
commit 187bd29b0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1654 additions and 481 deletions

24
.github/workflows/publish.yaml vendored Normal file
View File

@ -0,0 +1,24 @@
name: publish
on:
release:
types:
- created
jobs:
publish:
runs-on: ubuntu-latest
environment: release
permissions:
id-token: write
steps:
- uses: actions/checkout@v4
- run: pipx install poetry
- uses: actions/setup-python@v5
with:
cache: poetry
- run: |
poetry version -- ${GIT_REF_NAME#v}
poetry build
- uses: pypa/gh-action-pypi-publish@release/v1
- run: gh release upload $GIT_REF_NAME dist/*

32
.github/workflows/test.yaml vendored Normal file
View File

@ -0,0 +1,32 @@
name: test
on:
pull_request:
jobs:
test:
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: pipx install poetry
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: poetry
- run: poetry install --with=dev
- run: poetry run ruff --output-format=github .
- run: poetry run pytest . --junitxml=junit/test-results-${{ matrix.python-version }}.xml --cov=ollama --cov-report=xml --cov-report=html
- name: check poetry.lock is up-to-date
run: poetry check --lock
- name: check requirements.txt is up-to-date
run: |
poetry export >requirements.txt
git diff --exit-code requirements.txt
- uses: actions/upload-artifact@v3
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
if: ${{ always() }}

160
.gitignore vendored Normal file
View File

@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

View File

@ -0,0 +1,70 @@
# Ollama Python Library
The Ollama Python library provides the easiest way to integrate your Python 3 project with [Ollama](https://github.com/jmorganca/ollama).
## Getting Started
Requires Python 3.8 or higher.
```sh
pip install ollama
```
A global default client is provided for convenience and can be used in the same way as the synchronous client.
```python
import ollama
response = ollama.chat(model='llama2', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
```
```python
import ollama
message = {'role': 'user', 'content': 'Why is the sky blue?'}
for part in ollama.chat(model='llama2', messages=[message], stream=True):
print(part['message']['content'], end='', flush=True)
```
### Using the Synchronous Client
```python
from ollama import Client
message = {'role': 'user', 'content': 'Why is the sky blue?'}
response = Client().chat(model='llama2', messages=[message])
```
Response streaming can be enabled by setting `stream=True`. This modifies the function to return a Python generator where each part is an object in the stream.
```python
from ollama import Client
message = {'role': 'user', 'content': 'Why is the sky blue?'}
for part in Client().chat(model='llama2', messages=[message], stream=True):
print(part['message']['content'], end='', flush=True)
```
### Using the Asynchronous Client
```python
import asyncio
from ollama import AsyncClient
async def chat():
message = {'role': 'user', 'content': 'Why is the sky blue?'}
response = await AsyncClient().chat(model='llama2', messages=[message])
asyncio.run(chat())
```
Similar to the synchronous client, setting `stream=True` modifies the function to return a Python asynchronous generator.
```python
import asyncio
from ollama import AsyncClient
async def chat():
message = {'role': 'user', 'content': 'Why is the sky blue?'}
async for part in await AsyncClient().chat(model='llama2', messages=[message], stream=True):
print(part['message']['content'], end='', flush=True)
asyncio.run(chat())
```

View File

@ -0,0 +1,22 @@
from ollama import generate
prefix = '''def remove_non_ascii(s: str) -> str:
""" '''
suffix = """
return result
"""
response = generate(
model='codellama:7b-code',
prompt=f'<PRE> {prefix} <SUF>{suffix} <MID>',
options={
'num_predict': 128,
'temperature': 0,
'top_p': 0.9,
'stop': ['<EOT>'],
},
)
print(response['response'])

View File

@ -1,30 +1,56 @@
from ollama.client import Client
from ollama._client import Client, AsyncClient, Message, Options
__all__ = [
'Client',
'AsyncClient',
'Message',
'Options',
'generate',
'chat',
'pull',
'push',
'create',
'delete',
'list',
'copy',
'show',
]
_default_client = Client()
def generate(*args, **kwargs):
return _default_client.generate(*args, **kwargs)
def chat(*args, **kwargs):
return _default_client.chat(*args, **kwargs)
def pull(*args, **kwargs):
return _default_client.pull(*args, **kwargs)
def push(*args, **kwargs):
return _default_client.push(*args, **kwargs)
def create(*args, **kwargs):
return _default_client.create(*args, **kwargs)
def delete(*args, **kwargs):
return _default_client.delete(*args, **kwargs)
def list(*args, **kwargs):
return _default_client.list(*args, **kwargs)
def copy(*args, **kwargs):
return _default_client.copy(*args, **kwargs)
def show(*args, **kwargs):
return _default_client.show(*args, **kwargs)

458
ollama/_client.py Normal file
View File

@ -0,0 +1,458 @@
import io
import json
import httpx
from os import PathLike
from pathlib import Path
from hashlib import sha256
from base64 import b64encode
from typing import Any, AnyStr, Union, Optional, List, Mapping
import sys
if sys.version_info < (3, 9):
from typing import Iterator, AsyncIterator
else:
from collections.abc import Iterator, AsyncIterator
from ollama._types import Message, Options
class BaseClient:
def __init__(self, client, base_url='http://127.0.0.1:11434') -> None:
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
class Client(BaseClient):
def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.Client, base)
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = self._client.request(method, url, **kwargs)
response.raise_for_status()
return response
def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
return self._request(method, url, **kwargs).json()
def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
with self._client.stream(method, url, **kwargs) as r:
for line in r.iter_lines():
part = json.loads(line)
if e := part.get('error'):
raise Exception(e)
yield part
def generate(
self,
model: str = '',
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[List[int]] = None,
stream: bool = False,
raw: bool = False,
format: str = '',
images: Optional[List[AnyStr]] = None,
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')
fn = self._stream if stream else self._request_json
return fn(
'POST',
'/api/generate',
json={
'model': model,
'prompt': prompt,
'system': system,
'template': template,
'context': context or [],
'stream': stream,
'raw': raw,
'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
},
)
def chat(
self,
model: str = '',
messages: Optional[List[Message]] = None,
stream: bool = False,
format: str = '',
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
if not message.get('content'):
raise Exception('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
fn = self._stream if stream else self._request_json
return fn(
'POST',
'/api/chat',
json={
'model': model,
'messages': messages,
'stream': stream,
'format': format,
'options': options or {},
},
)
def pull(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return fn(
'POST',
'/api/pull',
json={
'model': model,
'insecure': insecure,
'stream': stream,
},
)
def push(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return fn(
'POST',
'/api/push',
json={
'model': model,
'insecure': insecure,
'stream': stream,
},
)
def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
stream: bool = False,
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
if (realpath := _as_path(path)) and realpath.exists():
modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = self._parse_modelfile(modelfile)
else:
raise Exception('must provide either path or modelfile')
fn = self._stream if stream else self._request_json
return fn(
'POST',
'/api/create',
json={
'model': model,
'modelfile': modelfile,
'stream': stream,
},
)
def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base
out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() in ['FROM', 'ADAPTER']:
path = Path(args).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{self._create_blob(path)}'
print(command, args, file=out)
return out.getvalue()
def _create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
chunk = r.read(32 * 1024)
if not chunk:
break
sha256sum.update(chunk)
digest = f'sha256:{sha256sum.hexdigest()}'
try:
self._request('HEAD', f'/api/blobs/{digest}')
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise
with open(path, 'rb') as r:
self._request('PUT', f'/api/blobs/{digest}', content=r)
return digest
def delete(self, model: str) -> Mapping[str, Any]:
response = self._request('DELETE', '/api/delete', json={'model': model})
return {'status': 'success' if response.status_code == 200 else 'error'}
def list(self) -> Mapping[str, Any]:
return self._request_json('GET', '/api/tags').get('models', [])
def copy(self, source: str, target: str) -> Mapping[str, Any]:
response = self._request('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}
def show(self, model: str) -> Mapping[str, Any]:
return self._request_json('GET', '/api/show', json={'model': model})
class AsyncClient(BaseClient):
def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.AsyncClient, base)
async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = await self._client.request(method, url, **kwargs)
response.raise_for_status()
return response
async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
response = await self._request(method, url, **kwargs)
return response.json()
async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]:
async def inner():
async with self._client.stream(method, url, **kwargs) as r:
async for line in r.aiter_lines():
part = json.loads(line)
if e := part.get('error'):
raise Exception(e)
yield part
return inner()
async def generate(
self,
model: str = '',
prompt: str = '',
system: str = '',
template: str = '',
context: Optional[List[int]] = None,
stream: bool = False,
raw: bool = False,
format: str = '',
images: Optional[List[AnyStr]] = None,
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')
fn = self._stream if stream else self._request_json
return await fn(
'POST',
'/api/generate',
json={
'model': model,
'prompt': prompt,
'system': system,
'template': template,
'context': context or [],
'stream': stream,
'raw': raw,
'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
},
)
async def chat(
self,
model: str = '',
messages: Optional[List[Message]] = None,
stream: bool = False,
format: str = '',
options: Optional[Options] = None,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if not model:
raise Exception('must provide a model')
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
if not message.get('content'):
raise Exception('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
fn = self._stream if stream else self._request_json
return await fn(
'POST',
'/api/chat',
json={
'model': model,
'messages': messages,
'stream': stream,
'format': format,
'options': options or {},
},
)
async def pull(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return await fn(
'POST',
'/api/pull',
json={
'model': model,
'insecure': insecure,
'stream': stream,
},
)
async def push(
self,
model: str,
insecure: bool = False,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
fn = self._stream if stream else self._request_json
return await fn(
'POST',
'/api/push',
json={
'model': model,
'insecure': insecure,
'stream': stream,
},
)
async def create(
self,
model: str,
path: Optional[Union[str, PathLike]] = None,
modelfile: Optional[str] = None,
stream: bool = False,
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
if (realpath := _as_path(path)) and realpath.exists():
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
elif modelfile:
modelfile = await self._parse_modelfile(modelfile)
else:
raise Exception('must provide either path or modelfile')
fn = self._stream if stream else self._request_json
return await fn(
'POST',
'/api/create',
json={
'model': model,
'modelfile': modelfile,
'stream': stream,
},
)
async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
base = Path.cwd() if base is None else base
out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() in ['FROM', 'ADAPTER']:
path = Path(args).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{await self._create_blob(path)}'
print(command, args, file=out)
return out.getvalue()
async def _create_blob(self, path: Union[str, Path]) -> str:
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
chunk = r.read(32 * 1024)
if not chunk:
break
sha256sum.update(chunk)
digest = f'sha256:{sha256sum.hexdigest()}'
try:
await self._request('HEAD', f'/api/blobs/{digest}')
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise
async def upload_bytes():
with open(path, 'rb') as r:
while True:
chunk = r.read(32 * 1024)
if not chunk:
break
yield chunk
await self._request('PUT', f'/api/blobs/{digest}', content=upload_bytes())
return digest
async def delete(self, model: str) -> Mapping[str, Any]:
response = await self._request('DELETE', '/api/delete', json={'model': model})
return {'status': 'success' if response.status_code == 200 else 'error'}
async def list(self) -> Mapping[str, Any]:
response = await self._request_json('GET', '/api/tags')
return response.get('models', [])
async def copy(self, source: str, target: str) -> Mapping[str, Any]:
response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}
async def show(self, model: str) -> Mapping[str, Any]:
return await self._request_json('GET', '/api/show', json={'model': model})
def _encode_image(image) -> str:
if p := _as_path(image):
b64 = b64encode(p.read_bytes())
elif b := _as_bytesio(image):
b64 = b64encode(b.read())
else:
raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
return b64.decode('utf-8')
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
if isinstance(s, str) or isinstance(s, Path):
return Path(s)
return None
def _as_bytesio(s: Any) -> Union[io.BytesIO, None]:
if isinstance(s, io.BytesIO):
return s
elif isinstance(s, bytes):
return io.BytesIO(s)
return None

53
ollama/_types.py Normal file
View File

@ -0,0 +1,53 @@
from typing import Any, TypedDict, List
import sys
if sys.version_info < (3, 11):
from typing_extensions import NotRequired
else:
from typing import NotRequired
class Message(TypedDict):
role: str
content: str
images: NotRequired[List[Any]]
class Options(TypedDict, total=False):
# load time options
numa: bool
num_ctx: int
num_batch: int
num_gqa: int
num_gpu: int
main_gpu: int
low_vram: bool
f16_kv: bool
logits_all: bool
vocab_only: bool
use_mmap: bool
use_mlock: bool
embedding_only: bool
rope_frequency_base: float
rope_frequency_scale: float
num_thread: int
# runtime options
num_keep: int
seed: int
num_predict: int
top_k: int
top_p: float
tfs_z: float
typical_p: float
repeat_last_n: int
temperature: float
repeat_penalty: float
presence_penalty: float
frequency_penalty: float
mirostat: int
mirostat_tau: float
mirostat_eta: float
penalize_newline: bool
stop: List[str]

View File

@ -1,182 +0,0 @@
import io
import json
import httpx
from pathlib import Path
from hashlib import sha256
from base64 import b64encode
class BaseClient:
def __init__(self, client, base_url='http://127.0.0.1:11434'):
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
class Client(BaseClient):
def __init__(self, base='http://localhost:11434'):
super().__init__(httpx.Client, base)
def _request(self, method, url, **kwargs):
response = self._client.request(method, url, **kwargs)
response.raise_for_status()
return response
def _request_json(self, method, url, **kwargs):
return self._request(method, url, **kwargs).json()
def stream(self, method, url, **kwargs):
with self._client.stream(method, url, **kwargs) as r:
for line in r.iter_lines():
part = json.loads(line)
if e := part.get('error'):
raise Exception(e)
yield part
def generate(self, model, prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
fn = self.stream if stream else self._request_json
return fn('POST', '/api/generate', json={
'model': model,
'prompt': prompt,
'system': system,
'template': template,
'context': context or [],
'stream': stream,
'raw': raw,
'images': [_encode_image(image) for image in images or []],
'format': format,
'options': options or {},
})
def chat(self, model, messages=None, stream=False, format='', options=None):
for message in messages or []:
if not isinstance(message, dict):
raise TypeError('messages must be a list of strings')
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
if not message.get('content'):
raise Exception('messages must contain content')
if images := message.get('images'):
message['images'] = [_encode_image(image) for image in images]
fn = self.stream if stream else self._request_json
return fn('POST', '/api/chat', json={
'model': model,
'messages': messages,
'stream': stream,
'format': format,
'options': options or {},
})
def pull(self, model, insecure=False, stream=False):
fn = self.stream if stream else self._request_json
return fn('POST', '/api/pull', json={
'model': model,
'insecure': insecure,
'stream': stream,
})
def push(self, model, insecure=False, stream=False):
fn = self.stream if stream else self._request_json
return fn('POST', '/api/push', json={
'model': model,
'insecure': insecure,
'stream': stream,
})
def create(self, model, path=None, modelfile=None, stream=False):
if (path := _as_path(path)) and path.exists():
modelfile = _parse_modelfile(path.read_text(), self.create_blob, base=path.parent)
elif modelfile:
modelfile = _parse_modelfile(modelfile, self.create_blob)
else:
raise Exception('must provide either path or modelfile')
fn = self.stream if stream else self._request_json
return fn('POST', '/api/create', json={
'model': model,
'modelfile': modelfile,
'stream': stream,
})
def create_blob(self, path):
sha256sum = sha256()
with open(path, 'rb') as r:
while True:
chunk = r.read(32*1024)
if not chunk:
break
sha256sum.update(chunk)
digest = f'sha256:{sha256sum.hexdigest()}'
try:
self._request('HEAD', f'/api/blobs/{digest}')
except httpx.HTTPError:
with open(path, 'rb') as r:
self._request('PUT', f'/api/blobs/{digest}', content=r)
return digest
def delete(self, model):
response = self._request_json('DELETE', '/api/delete', json={'model': model})
return {'status': 'success' if response.status_code == 200 else 'error'}
def list(self):
return self._request_json('GET', '/api/tags').get('models', [])
def copy(self, source, target):
response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
return {'status': 'success' if response.status_code == 200 else 'error'}
def show(self, model):
return self._request_json('GET', '/api/show', json={'model': model}).json()
def _encode_image(image):
'''
_encode_images takes a list of images and returns a generator of base64 encoded images.
if the image is a bytes object, it is assumed to be the raw bytes of an image.
if the image is a string, it is assumed to be a path to a file.
if the image is a Path object, it is assumed to be a path to a file.
if the image is a file-like object, it is assumed to be a container to the raw bytes of an image.
'''
if p := _as_path(image):
b64 = b64encode(p.read_bytes())
elif b := _as_bytesio(image):
b64 = b64encode(b.read())
else:
raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
return b64.decode('utf-8')
def _parse_modelfile(modelfile, cb, base=None):
base = Path.cwd() if base is None else base
out = io.StringIO()
for line in io.StringIO(modelfile):
command, _, args = line.partition(' ')
if command.upper() in ['FROM', 'ADAPTER']:
path = Path(args).expanduser()
path = path if path.is_absolute() else base / path
if path.exists():
args = f'@{cb(path)}'
print(command, args, file=out)
return out.getvalue()
def _as_path(s):
if isinstance(s, str) or isinstance(s, Path):
return Path(s)
return None
def _as_bytesio(s):
if isinstance(s, io.BytesIO):
return s
elif isinstance(s, bytes):
return io.BytesIO(s)
return None

View File

@ -1,292 +0,0 @@
import pytest
import os
import io
import types
import tempfile
from pathlib import Path
from ollama.client import Client
from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Response
from PIL import Image
class PrefixPattern(URIPattern):
def __init__(self, prefix: str):
self.prefix = prefix
def match(self, uri):
return uri.startswith(self.prefix)
def test_client_chat(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': False,
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
assert isinstance(response, dict)
def test_client_chat_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': True,
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_chat_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/chat', method='POST', json={
'model': 'dummy',
'messages': [
{
'role': 'user',
'content': 'Why is the sky blue?',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'stream': False,
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with io.BytesIO() as b:
Image.new('RGB', (1, 1)).save(b, 'PNG')
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
assert isinstance(response, dict)
def test_client_generate(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': [],
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?')
assert isinstance(response, dict)
def test_client_generate_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': True,
'raw': False,
'images': [],
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_generate_images(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/generate', method='POST', json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as temp:
Image.new('RGB', (1, 1)).save(temp, 'PNG')
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
assert isinstance(response, dict)
def test_client_pull(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/pull', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.pull('dummy')
assert isinstance(response, dict)
def test_client_pull_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/pull', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': True,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.pull('dummy', stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_push(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.push('dummy')
assert isinstance(response, dict)
def test_client_push_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/push', method='POST', json={
'model': 'dummy',
'insecure': False,
'stream': True,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.push('dummy', stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile() as blob:
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
def test_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
@pytest.fixture
def userhomedir():
with tempfile.TemporaryDirectory() as temp:
home = os.getenv('HOME', '')
os.environ['HOME'] = temp
yield Path(temp)
os.environ['HOME'] = home
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
def test_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client.create('dummy', modelfile=f'FROM {blob.name}')
assert isinstance(response, dict)
def test_client_create_from_library(httpserver: HTTPServer):
httpserver.expect_ordered_request('/api/create', method='POST', json={
'model': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
}).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.create('dummy', modelfile='FROM llama2')
assert isinstance(response, dict)
def test_client_create_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
def test_client_create_blob_exists(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client.create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'

20
poetry.lock generated
View File

@ -387,6 +387,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "0.23.2"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.8"
files = [
{file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"},
{file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"},
]
[package.dependencies]
pytest = ">=7.0.0"
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[[package]]
name = "pytest-cov"
version = "4.1.0"
@ -498,4 +516,4 @@ watchdog = ["watchdog (>=2.3)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.8"
content-hash = "b9f64e1a5795a417d2dbff7286360f8d3f8f10fdfa9580411940d144c2561e92"
content-hash = "9416a897c95d3c80cf1bfd3cc61cd19f0143c9bd0bc7c219fcb31ee27c497c9d"

View File

@ -1,9 +1,12 @@
[tool.poetry]
name = "ollama"
version = "0.1.0"
version = "0.0.0"
description = "The official Python client for Ollama."
authors = ["Ollama <hello@ollama.com>"]
license = "MIT"
readme = "README.md"
homepage = "https://ollama.ai"
repository = "https://github.com/jmorganca/ollama-python"
[tool.poetry.dependencies]
python = "^3.8"
@ -11,12 +14,18 @@ httpx = "^0.25.2"
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
pytest-asyncio = "^0.23.2"
pytest-cov = "^4.1.0"
pytest-httpserver = "^1.0.8"
pillow = "^10.1.0"
ruff = "^0.1.8"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.ruff]
line-length = 999
indent-width = 2
[tool.ruff.format]
@ -26,7 +35,3 @@ indent-style = "space"
[tool.ruff.lint]
select = ["E", "F", "B"]
ignore = ["E501"]
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

779
tests/test_client.py Normal file
View File

@ -0,0 +1,779 @@
import os
import io
import json
import types
import pytest
import tempfile
from pathlib import Path
from pytest_httpserver import HTTPServer, URIPattern
from werkzeug.wrappers import Request, Response
from PIL import Image
from ollama._client import Client, AsyncClient
class PrefixPattern(URIPattern):
def __init__(self, prefix: str):
self.prefix = prefix
def match(self, uri):
return uri.startswith(self.prefix)
def test_client_chat(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': False,
'format': '',
'options': {},
},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
},
}
)
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."
def test_client_chat_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():
for message in ['I ', "don't ", 'know.']:
yield (
json.dumps(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': message,
},
}
)
+ '\n'
)
return Response(generate())
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': True,
'format': '',
'options': {},
},
).respond_with_handler(stream_handler)
client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
for part in response:
assert part['message']['role'] in 'assistant'
assert part['message']['content'] in ['I ', "don't ", 'know.']
def test_client_chat_images(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [
{
'role': 'user',
'content': 'Why is the sky blue?',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'stream': False,
'format': '',
'options': {},
},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': "I don't know.",
},
}
)
client = Client(httpserver.url_for('/'))
with io.BytesIO() as b:
Image.new('RGB', (1, 1)).save(b, 'PNG')
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
assert response['model'] == 'dummy'
assert response['message']['role'] == 'assistant'
assert response['message']['content'] == "I don't know."
def test_client_generate(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': [],
'format': '',
'options': {},
},
).respond_with_json(
{
'model': 'dummy',
'response': 'Because it is.',
}
)
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?')
assert response['model'] == 'dummy'
assert response['response'] == 'Because it is.'
def test_client_generate_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():
for message in ['Because ', 'it ', 'is.']:
yield (
json.dumps(
{
'model': 'dummy',
'response': message,
}
)
+ '\n'
)
return Response(generate())
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': True,
'raw': False,
'images': [],
'format': '',
'options': {},
},
).respond_with_handler(stream_handler)
client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
for part in response:
assert part['model'] == 'dummy'
assert part['response'] in ['Because ', 'it ', 'is.']
def test_client_generate_images(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
},
).respond_with_json(
{
'model': 'dummy',
'response': 'Because it is.',
}
)
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as temp:
Image.new('RGB', (1, 1)).save(temp, 'PNG')
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
assert response['model'] == 'dummy'
assert response['response'] == 'Because it is.'
def test_client_pull(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/pull',
method='POST',
json={
'model': 'dummy',
'insecure': False,
'stream': False,
},
).respond_with_json(
{
'status': 'success',
}
)
client = Client(httpserver.url_for('/'))
response = client.pull('dummy')
assert response['status'] == 'success'
def test_client_pull_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():
yield json.dumps({'status': 'pulling manifest'}) + '\n'
yield json.dumps({'status': 'verifying sha256 digest'}) + '\n'
yield json.dumps({'status': 'writing manifest'}) + '\n'
yield json.dumps({'status': 'removing any unused layers'}) + '\n'
yield json.dumps({'status': 'success'}) + '\n'
return Response(generate())
httpserver.expect_ordered_request(
'/api/pull',
method='POST',
json={
'model': 'dummy',
'insecure': False,
'stream': True,
},
).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.pull('dummy', stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_push(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/push',
method='POST',
json={
'model': 'dummy',
'insecure': False,
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.push('dummy')
assert isinstance(response, dict)
def test_client_push_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/push',
method='POST',
json={
'model': 'dummy',
'insecure': False,
'stream': True,
},
).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.push('dummy', stream=True)
assert isinstance(response, types.GeneratorType)
def test_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile() as blob:
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
def test_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
@pytest.fixture
def userhomedir():
with tempfile.TemporaryDirectory() as temp:
home = os.getenv('HOME', '')
os.environ['HOME'] = temp
yield Path(temp)
os.environ['HOME'] = home
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
def test_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client.create('dummy', modelfile=f'FROM {blob.name}')
assert isinstance(response, dict)
def test_client_create_from_library(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
},
).respond_with_json({})
client = Client(httpserver.url_for('/'))
response = client.create('dummy', modelfile='FROM llama2')
assert isinstance(response, dict)
def test_client_create_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client._create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
def test_client_create_blob_exists(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
client = Client(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = client._create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
@pytest.mark.asyncio
async def test_async_client_chat(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': False,
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_chat_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
'stream': True,
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
assert isinstance(response, types.AsyncGeneratorType)
@pytest.mark.asyncio
async def test_async_client_chat_images(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [
{
'role': 'user',
'content': 'Why is the sky blue?',
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
},
],
'stream': False,
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with io.BytesIO() as b:
Image.new('RGB', (1, 1)).save(b, 'PNG')
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_generate(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': [],
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.generate('dummy', 'Why is the sky blue?')
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_generate_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': True,
'raw': False,
'images': [],
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.generate('dummy', 'Why is the sky blue?', stream=True)
assert isinstance(response, types.AsyncGeneratorType)
@pytest.mark.asyncio
async def test_async_client_generate_images(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why is the sky blue?',
'system': '',
'template': '',
'context': [],
'stream': False,
'raw': False,
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
'format': '',
'options': {},
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as temp:
Image.new('RGB', (1, 1)).save(temp, 'PNG')
response = await client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_pull(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/pull',
method='POST',
json={
'model': 'dummy',
'insecure': False,
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.pull('dummy')
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_pull_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/pull',
method='POST',
json={
'model': 'dummy',
'insecure': False,
'stream': True,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.pull('dummy', stream=True)
assert isinstance(response, types.AsyncGeneratorType)
@pytest.mark.asyncio
async def test_async_client_push(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/push',
method='POST',
json={
'model': 'dummy',
'insecure': False,
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.push('dummy')
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_push_stream(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/push',
method='POST',
json={
'model': 'dummy',
'insecure': False,
'stream': True,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.push('dummy', stream=True)
assert isinstance(response, types.AsyncGeneratorType)
@pytest.mark.asyncio
async def test_async_client_create_path(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile() as blob:
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
modelfile.flush()
response = await client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_path_relative(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = await client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as modelfile:
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
modelfile.flush()
response = await client.create('dummy', path=modelfile.name)
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_modelfile(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = await client.create('dummy', modelfile=f'FROM {blob.name}')
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_from_library(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/create',
method='POST',
json={
'model': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
},
).respond_with_json({})
client = AsyncClient(httpserver.url_for('/'))
response = await client.create('dummy', modelfile='FROM llama2')
assert isinstance(response, dict)
@pytest.mark.asyncio
async def test_async_client_create_blob(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = await client._create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
@pytest.mark.asyncio
async def test_async_client_create_blob_exists(httpserver: HTTPServer):
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
client = AsyncClient(httpserver.url_for('/'))
with tempfile.NamedTemporaryFile() as blob:
response = await client._create_blob(blob.name)
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'