mirror of
https://github.com/ollama/ollama-python.git
synced 2026-02-04 02:54:08 -06:00
commit
187bd29b0e
24
.github/workflows/publish.yaml
vendored
Normal file
24
.github/workflows/publish.yaml
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
name: publish
|
||||
|
||||
on:
|
||||
release:
|
||||
types:
|
||||
- created
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
environment: release
|
||||
permissions:
|
||||
id-token: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: pipx install poetry
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
cache: poetry
|
||||
- run: |
|
||||
poetry version -- ${GIT_REF_NAME#v}
|
||||
poetry build
|
||||
- uses: pypa/gh-action-pypi-publish@release/v1
|
||||
- run: gh release upload $GIT_REF_NAME dist/*
|
||||
32
.github/workflows/test.yaml
vendored
Normal file
32
.github/workflows/test.yaml
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
name: test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: pipx install poetry
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: poetry
|
||||
- run: poetry install --with=dev
|
||||
- run: poetry run ruff --output-format=github .
|
||||
- run: poetry run pytest . --junitxml=junit/test-results-${{ matrix.python-version }}.xml --cov=ollama --cov-report=xml --cov-report=html
|
||||
- name: check poetry.lock is up-to-date
|
||||
run: poetry check --lock
|
||||
- name: check requirements.txt is up-to-date
|
||||
run: |
|
||||
poetry export >requirements.txt
|
||||
git diff --exit-code requirements.txt
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: pytest-results-${{ matrix.python-version }}
|
||||
path: junit/test-results-${{ matrix.python-version }}.xml
|
||||
if: ${{ always() }}
|
||||
160
.gitignore
vendored
Normal file
160
.gitignore
vendored
Normal file
@ -0,0 +1,160 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
70
README.md
70
README.md
@ -0,0 +1,70 @@
|
||||
# Ollama Python Library
|
||||
|
||||
The Ollama Python library provides the easiest way to integrate your Python 3 project with [Ollama](https://github.com/jmorganca/ollama).
|
||||
|
||||
## Getting Started
|
||||
|
||||
Requires Python 3.8 or higher.
|
||||
|
||||
```sh
|
||||
pip install ollama
|
||||
```
|
||||
|
||||
A global default client is provided for convenience and can be used in the same way as the synchronous client.
|
||||
|
||||
```python
|
||||
import ollama
|
||||
response = ollama.chat(model='llama2', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
|
||||
```
|
||||
|
||||
```python
|
||||
import ollama
|
||||
message = {'role': 'user', 'content': 'Why is the sky blue?'}
|
||||
for part in ollama.chat(model='llama2', messages=[message], stream=True):
|
||||
print(part['message']['content'], end='', flush=True)
|
||||
```
|
||||
|
||||
|
||||
### Using the Synchronous Client
|
||||
|
||||
```python
|
||||
from ollama import Client
|
||||
message = {'role': 'user', 'content': 'Why is the sky blue?'}
|
||||
response = Client().chat(model='llama2', messages=[message])
|
||||
```
|
||||
|
||||
Response streaming can be enabled by setting `stream=True`. This modifies the function to return a Python generator where each part is an object in the stream.
|
||||
|
||||
```python
|
||||
from ollama import Client
|
||||
message = {'role': 'user', 'content': 'Why is the sky blue?'}
|
||||
for part in Client().chat(model='llama2', messages=[message], stream=True):
|
||||
print(part['message']['content'], end='', flush=True)
|
||||
```
|
||||
|
||||
### Using the Asynchronous Client
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from ollama import AsyncClient
|
||||
|
||||
async def chat():
|
||||
message = {'role': 'user', 'content': 'Why is the sky blue?'}
|
||||
response = await AsyncClient().chat(model='llama2', messages=[message])
|
||||
|
||||
asyncio.run(chat())
|
||||
```
|
||||
|
||||
Similar to the synchronous client, setting `stream=True` modifies the function to return a Python asynchronous generator.
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from ollama import AsyncClient
|
||||
|
||||
async def chat():
|
||||
message = {'role': 'user', 'content': 'Why is the sky blue?'}
|
||||
async for part in await AsyncClient().chat(model='llama2', messages=[message], stream=True):
|
||||
print(part['message']['content'], end='', flush=True)
|
||||
|
||||
asyncio.run(chat())
|
||||
```
|
||||
22
examples/simple-fill-in-middle/main.py
Normal file
22
examples/simple-fill-in-middle/main.py
Normal file
@ -0,0 +1,22 @@
|
||||
from ollama import generate
|
||||
|
||||
prefix = '''def remove_non_ascii(s: str) -> str:
|
||||
""" '''
|
||||
|
||||
suffix = """
|
||||
return result
|
||||
"""
|
||||
|
||||
|
||||
response = generate(
|
||||
model='codellama:7b-code',
|
||||
prompt=f'<PRE> {prefix} <SUF>{suffix} <MID>',
|
||||
options={
|
||||
'num_predict': 128,
|
||||
'temperature': 0,
|
||||
'top_p': 0.9,
|
||||
'stop': ['<EOT>'],
|
||||
},
|
||||
)
|
||||
|
||||
print(response['response'])
|
||||
@ -1,30 +1,56 @@
|
||||
from ollama.client import Client
|
||||
from ollama._client import Client, AsyncClient, Message, Options
|
||||
|
||||
__all__ = [
|
||||
'Client',
|
||||
'AsyncClient',
|
||||
'Message',
|
||||
'Options',
|
||||
'generate',
|
||||
'chat',
|
||||
'pull',
|
||||
'push',
|
||||
'create',
|
||||
'delete',
|
||||
'list',
|
||||
'copy',
|
||||
'show',
|
||||
]
|
||||
|
||||
|
||||
_default_client = Client()
|
||||
|
||||
|
||||
def generate(*args, **kwargs):
|
||||
return _default_client.generate(*args, **kwargs)
|
||||
|
||||
|
||||
def chat(*args, **kwargs):
|
||||
return _default_client.chat(*args, **kwargs)
|
||||
|
||||
|
||||
def pull(*args, **kwargs):
|
||||
return _default_client.pull(*args, **kwargs)
|
||||
|
||||
|
||||
def push(*args, **kwargs):
|
||||
return _default_client.push(*args, **kwargs)
|
||||
|
||||
|
||||
def create(*args, **kwargs):
|
||||
return _default_client.create(*args, **kwargs)
|
||||
|
||||
|
||||
def delete(*args, **kwargs):
|
||||
return _default_client.delete(*args, **kwargs)
|
||||
|
||||
|
||||
def list(*args, **kwargs):
|
||||
return _default_client.list(*args, **kwargs)
|
||||
|
||||
|
||||
def copy(*args, **kwargs):
|
||||
return _default_client.copy(*args, **kwargs)
|
||||
|
||||
|
||||
def show(*args, **kwargs):
|
||||
return _default_client.show(*args, **kwargs)
|
||||
|
||||
458
ollama/_client.py
Normal file
458
ollama/_client.py
Normal file
@ -0,0 +1,458 @@
|
||||
import io
|
||||
import json
|
||||
import httpx
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
from base64 import b64encode
|
||||
|
||||
from typing import Any, AnyStr, Union, Optional, List, Mapping
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
from typing import Iterator, AsyncIterator
|
||||
else:
|
||||
from collections.abc import Iterator, AsyncIterator
|
||||
|
||||
from ollama._types import Message, Options
|
||||
|
||||
|
||||
class BaseClient:
|
||||
def __init__(self, client, base_url='http://127.0.0.1:11434') -> None:
|
||||
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
|
||||
|
||||
|
||||
class Client(BaseClient):
|
||||
def __init__(self, base='http://localhost:11434') -> None:
|
||||
super().__init__(httpx.Client, base)
|
||||
|
||||
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
response = self._client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
|
||||
return self._request(method, url, **kwargs).json()
|
||||
|
||||
def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]:
|
||||
with self._client.stream(method, url, **kwargs) as r:
|
||||
for line in r.iter_lines():
|
||||
part = json.loads(line)
|
||||
if e := part.get('error'):
|
||||
raise Exception(e)
|
||||
yield part
|
||||
|
||||
def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[List[int]] = None,
|
||||
stream: bool = False,
|
||||
raw: bool = False,
|
||||
format: str = '',
|
||||
images: Optional[List[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn(
|
||||
'POST',
|
||||
'/api/generate',
|
||||
json={
|
||||
'model': model,
|
||||
'prompt': prompt,
|
||||
'system': system,
|
||||
'template': template,
|
||||
'context': context or [],
|
||||
'stream': stream,
|
||||
'raw': raw,
|
||||
'images': [_encode_image(image) for image in images or []],
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
},
|
||||
)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[List[Message]] = None,
|
||||
stream: bool = False,
|
||||
format: str = '',
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
raise TypeError('messages must be a list of strings')
|
||||
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
|
||||
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
|
||||
if not message.get('content'):
|
||||
raise Exception('messages must contain content')
|
||||
if images := message.get('images'):
|
||||
message['images'] = [_encode_image(image) for image in images]
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn(
|
||||
'POST',
|
||||
'/api/chat',
|
||||
json={
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'stream': stream,
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
},
|
||||
)
|
||||
|
||||
def pull(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn(
|
||||
'POST',
|
||||
'/api/pull',
|
||||
json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
},
|
||||
)
|
||||
|
||||
def push(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn(
|
||||
'POST',
|
||||
'/api/push',
|
||||
json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
},
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]:
|
||||
if (realpath := _as_path(path)) and realpath.exists():
|
||||
modelfile = self._parse_modelfile(realpath.read_text(), base=realpath.parent)
|
||||
elif modelfile:
|
||||
modelfile = self._parse_modelfile(modelfile)
|
||||
else:
|
||||
raise Exception('must provide either path or modelfile')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return fn(
|
||||
'POST',
|
||||
'/api/create',
|
||||
json={
|
||||
'model': model,
|
||||
'modelfile': modelfile,
|
||||
'stream': stream,
|
||||
},
|
||||
)
|
||||
|
||||
def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
|
||||
base = Path.cwd() if base is None else base
|
||||
|
||||
out = io.StringIO()
|
||||
for line in io.StringIO(modelfile):
|
||||
command, _, args = line.partition(' ')
|
||||
if command.upper() in ['FROM', 'ADAPTER']:
|
||||
path = Path(args).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{self._create_blob(path)}'
|
||||
|
||||
print(command, args, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
def _create_blob(self, path: Union[str, Path]) -> str:
|
||||
sha256sum = sha256()
|
||||
with open(path, 'rb') as r:
|
||||
while True:
|
||||
chunk = r.read(32 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
sha256sum.update(chunk)
|
||||
|
||||
digest = f'sha256:{sha256sum.hexdigest()}'
|
||||
|
||||
try:
|
||||
self._request('HEAD', f'/api/blobs/{digest}')
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
|
||||
with open(path, 'rb') as r:
|
||||
self._request('PUT', f'/api/blobs/{digest}', content=r)
|
||||
|
||||
return digest
|
||||
|
||||
def delete(self, model: str) -> Mapping[str, Any]:
|
||||
response = self._request('DELETE', '/api/delete', json={'model': model})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def list(self) -> Mapping[str, Any]:
|
||||
return self._request_json('GET', '/api/tags').get('models', [])
|
||||
|
||||
def copy(self, source: str, target: str) -> Mapping[str, Any]:
|
||||
response = self._request('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def show(self, model: str) -> Mapping[str, Any]:
|
||||
return self._request_json('GET', '/api/show', json={'model': model})
|
||||
|
||||
|
||||
class AsyncClient(BaseClient):
|
||||
def __init__(self, base='http://localhost:11434') -> None:
|
||||
super().__init__(httpx.AsyncClient, base)
|
||||
|
||||
async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
|
||||
response = await self._client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]:
|
||||
response = await self._request(method, url, **kwargs)
|
||||
return response.json()
|
||||
|
||||
async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]:
|
||||
async def inner():
|
||||
async with self._client.stream(method, url, **kwargs) as r:
|
||||
async for line in r.aiter_lines():
|
||||
part = json.loads(line)
|
||||
if e := part.get('error'):
|
||||
raise Exception(e)
|
||||
yield part
|
||||
|
||||
return inner()
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model: str = '',
|
||||
prompt: str = '',
|
||||
system: str = '',
|
||||
template: str = '',
|
||||
context: Optional[List[int]] = None,
|
||||
stream: bool = False,
|
||||
raw: bool = False,
|
||||
format: str = '',
|
||||
images: Optional[List[AnyStr]] = None,
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn(
|
||||
'POST',
|
||||
'/api/generate',
|
||||
json={
|
||||
'model': model,
|
||||
'prompt': prompt,
|
||||
'system': system,
|
||||
'template': template,
|
||||
'context': context or [],
|
||||
'stream': stream,
|
||||
'raw': raw,
|
||||
'images': [_encode_image(image) for image in images or []],
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
},
|
||||
)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
model: str = '',
|
||||
messages: Optional[List[Message]] = None,
|
||||
stream: bool = False,
|
||||
format: str = '',
|
||||
options: Optional[Options] = None,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
if not model:
|
||||
raise Exception('must provide a model')
|
||||
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
raise TypeError('messages must be a list of strings')
|
||||
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
|
||||
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
|
||||
if not message.get('content'):
|
||||
raise Exception('messages must contain content')
|
||||
if images := message.get('images'):
|
||||
message['images'] = [_encode_image(image) for image in images]
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn(
|
||||
'POST',
|
||||
'/api/chat',
|
||||
json={
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'stream': stream,
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
},
|
||||
)
|
||||
|
||||
async def pull(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn(
|
||||
'POST',
|
||||
'/api/pull',
|
||||
json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
},
|
||||
)
|
||||
|
||||
async def push(
|
||||
self,
|
||||
model: str,
|
||||
insecure: bool = False,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn(
|
||||
'POST',
|
||||
'/api/push',
|
||||
json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
},
|
||||
)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
model: str,
|
||||
path: Optional[Union[str, PathLike]] = None,
|
||||
modelfile: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]:
|
||||
if (realpath := _as_path(path)) and realpath.exists():
|
||||
modelfile = await self._parse_modelfile(realpath.read_text(), base=realpath.parent)
|
||||
elif modelfile:
|
||||
modelfile = await self._parse_modelfile(modelfile)
|
||||
else:
|
||||
raise Exception('must provide either path or modelfile')
|
||||
|
||||
fn = self._stream if stream else self._request_json
|
||||
return await fn(
|
||||
'POST',
|
||||
'/api/create',
|
||||
json={
|
||||
'model': model,
|
||||
'modelfile': modelfile,
|
||||
'stream': stream,
|
||||
},
|
||||
)
|
||||
|
||||
async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str:
|
||||
base = Path.cwd() if base is None else base
|
||||
|
||||
out = io.StringIO()
|
||||
for line in io.StringIO(modelfile):
|
||||
command, _, args = line.partition(' ')
|
||||
if command.upper() in ['FROM', 'ADAPTER']:
|
||||
path = Path(args).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{await self._create_blob(path)}'
|
||||
|
||||
print(command, args, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
async def _create_blob(self, path: Union[str, Path]) -> str:
|
||||
sha256sum = sha256()
|
||||
with open(path, 'rb') as r:
|
||||
while True:
|
||||
chunk = r.read(32 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
sha256sum.update(chunk)
|
||||
|
||||
digest = f'sha256:{sha256sum.hexdigest()}'
|
||||
|
||||
try:
|
||||
await self._request('HEAD', f'/api/blobs/{digest}')
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
|
||||
async def upload_bytes():
|
||||
with open(path, 'rb') as r:
|
||||
while True:
|
||||
chunk = r.read(32 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
await self._request('PUT', f'/api/blobs/{digest}', content=upload_bytes())
|
||||
|
||||
return digest
|
||||
|
||||
async def delete(self, model: str) -> Mapping[str, Any]:
|
||||
response = await self._request('DELETE', '/api/delete', json={'model': model})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
async def list(self) -> Mapping[str, Any]:
|
||||
response = await self._request_json('GET', '/api/tags')
|
||||
return response.get('models', [])
|
||||
|
||||
async def copy(self, source: str, target: str) -> Mapping[str, Any]:
|
||||
response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
async def show(self, model: str) -> Mapping[str, Any]:
|
||||
return await self._request_json('GET', '/api/show', json={'model': model})
|
||||
|
||||
|
||||
def _encode_image(image) -> str:
|
||||
if p := _as_path(image):
|
||||
b64 = b64encode(p.read_bytes())
|
||||
elif b := _as_bytesio(image):
|
||||
b64 = b64encode(b.read())
|
||||
else:
|
||||
raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
|
||||
|
||||
return b64.decode('utf-8')
|
||||
|
||||
|
||||
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:
|
||||
if isinstance(s, str) or isinstance(s, Path):
|
||||
return Path(s)
|
||||
return None
|
||||
|
||||
|
||||
def _as_bytesio(s: Any) -> Union[io.BytesIO, None]:
|
||||
if isinstance(s, io.BytesIO):
|
||||
return s
|
||||
elif isinstance(s, bytes):
|
||||
return io.BytesIO(s)
|
||||
return None
|
||||
53
ollama/_types.py
Normal file
53
ollama/_types.py
Normal file
@ -0,0 +1,53 @@
|
||||
from typing import Any, TypedDict, List
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import NotRequired
|
||||
else:
|
||||
from typing import NotRequired
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
images: NotRequired[List[Any]]
|
||||
|
||||
|
||||
class Options(TypedDict, total=False):
|
||||
# load time options
|
||||
numa: bool
|
||||
num_ctx: int
|
||||
num_batch: int
|
||||
num_gqa: int
|
||||
num_gpu: int
|
||||
main_gpu: int
|
||||
low_vram: bool
|
||||
f16_kv: bool
|
||||
logits_all: bool
|
||||
vocab_only: bool
|
||||
use_mmap: bool
|
||||
use_mlock: bool
|
||||
embedding_only: bool
|
||||
rope_frequency_base: float
|
||||
rope_frequency_scale: float
|
||||
num_thread: int
|
||||
|
||||
# runtime options
|
||||
num_keep: int
|
||||
seed: int
|
||||
num_predict: int
|
||||
top_k: int
|
||||
top_p: float
|
||||
tfs_z: float
|
||||
typical_p: float
|
||||
repeat_last_n: int
|
||||
temperature: float
|
||||
repeat_penalty: float
|
||||
presence_penalty: float
|
||||
frequency_penalty: float
|
||||
mirostat: int
|
||||
mirostat_tau: float
|
||||
mirostat_eta: float
|
||||
penalize_newline: bool
|
||||
stop: List[str]
|
||||
182
ollama/client.py
182
ollama/client.py
@ -1,182 +0,0 @@
|
||||
import io
|
||||
import json
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from hashlib import sha256
|
||||
from base64 import b64encode
|
||||
|
||||
|
||||
class BaseClient:
|
||||
|
||||
def __init__(self, client, base_url='http://127.0.0.1:11434'):
|
||||
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
|
||||
|
||||
|
||||
class Client(BaseClient):
|
||||
|
||||
def __init__(self, base='http://localhost:11434'):
|
||||
super().__init__(httpx.Client, base)
|
||||
|
||||
def _request(self, method, url, **kwargs):
|
||||
response = self._client.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def _request_json(self, method, url, **kwargs):
|
||||
return self._request(method, url, **kwargs).json()
|
||||
|
||||
def stream(self, method, url, **kwargs):
|
||||
with self._client.stream(method, url, **kwargs) as r:
|
||||
for line in r.iter_lines():
|
||||
part = json.loads(line)
|
||||
if e := part.get('error'):
|
||||
raise Exception(e)
|
||||
yield part
|
||||
|
||||
def generate(self, model, prompt='', system='', template='', context=None, stream=False, raw=False, format='', images=None, options=None):
|
||||
fn = self.stream if stream else self._request_json
|
||||
return fn('POST', '/api/generate', json={
|
||||
'model': model,
|
||||
'prompt': prompt,
|
||||
'system': system,
|
||||
'template': template,
|
||||
'context': context or [],
|
||||
'stream': stream,
|
||||
'raw': raw,
|
||||
'images': [_encode_image(image) for image in images or []],
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
def chat(self, model, messages=None, stream=False, format='', options=None):
|
||||
for message in messages or []:
|
||||
if not isinstance(message, dict):
|
||||
raise TypeError('messages must be a list of strings')
|
||||
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
|
||||
raise Exception('messages must contain a role and it must be one of "system", "user", or "assistant"')
|
||||
if not message.get('content'):
|
||||
raise Exception('messages must contain content')
|
||||
if images := message.get('images'):
|
||||
message['images'] = [_encode_image(image) for image in images]
|
||||
|
||||
fn = self.stream if stream else self._request_json
|
||||
return fn('POST', '/api/chat', json={
|
||||
'model': model,
|
||||
'messages': messages,
|
||||
'stream': stream,
|
||||
'format': format,
|
||||
'options': options or {},
|
||||
})
|
||||
|
||||
def pull(self, model, insecure=False, stream=False):
|
||||
fn = self.stream if stream else self._request_json
|
||||
return fn('POST', '/api/pull', json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
def push(self, model, insecure=False, stream=False):
|
||||
fn = self.stream if stream else self._request_json
|
||||
return fn('POST', '/api/push', json={
|
||||
'model': model,
|
||||
'insecure': insecure,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
def create(self, model, path=None, modelfile=None, stream=False):
|
||||
if (path := _as_path(path)) and path.exists():
|
||||
modelfile = _parse_modelfile(path.read_text(), self.create_blob, base=path.parent)
|
||||
elif modelfile:
|
||||
modelfile = _parse_modelfile(modelfile, self.create_blob)
|
||||
else:
|
||||
raise Exception('must provide either path or modelfile')
|
||||
|
||||
fn = self.stream if stream else self._request_json
|
||||
return fn('POST', '/api/create', json={
|
||||
'model': model,
|
||||
'modelfile': modelfile,
|
||||
'stream': stream,
|
||||
})
|
||||
|
||||
|
||||
def create_blob(self, path):
|
||||
sha256sum = sha256()
|
||||
with open(path, 'rb') as r:
|
||||
while True:
|
||||
chunk = r.read(32*1024)
|
||||
if not chunk:
|
||||
break
|
||||
sha256sum.update(chunk)
|
||||
|
||||
digest = f'sha256:{sha256sum.hexdigest()}'
|
||||
|
||||
try:
|
||||
self._request('HEAD', f'/api/blobs/{digest}')
|
||||
except httpx.HTTPError:
|
||||
with open(path, 'rb') as r:
|
||||
self._request('PUT', f'/api/blobs/{digest}', content=r)
|
||||
|
||||
return digest
|
||||
|
||||
def delete(self, model):
|
||||
response = self._request_json('DELETE', '/api/delete', json={'model': model})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def list(self):
|
||||
return self._request_json('GET', '/api/tags').get('models', [])
|
||||
|
||||
def copy(self, source, target):
|
||||
response = self._request_json('POST', '/api/copy', json={'source': source, 'destination': target})
|
||||
return {'status': 'success' if response.status_code == 200 else 'error'}
|
||||
|
||||
def show(self, model):
|
||||
return self._request_json('GET', '/api/show', json={'model': model}).json()
|
||||
|
||||
|
||||
def _encode_image(image):
|
||||
'''
|
||||
_encode_images takes a list of images and returns a generator of base64 encoded images.
|
||||
if the image is a bytes object, it is assumed to be the raw bytes of an image.
|
||||
if the image is a string, it is assumed to be a path to a file.
|
||||
if the image is a Path object, it is assumed to be a path to a file.
|
||||
if the image is a file-like object, it is assumed to be a container to the raw bytes of an image.
|
||||
'''
|
||||
|
||||
if p := _as_path(image):
|
||||
b64 = b64encode(p.read_bytes())
|
||||
elif b := _as_bytesio(image):
|
||||
b64 = b64encode(b.read())
|
||||
else:
|
||||
raise Exception('images must be a list of bytes, path-like objects, or file-like objects')
|
||||
|
||||
return b64.decode('utf-8')
|
||||
|
||||
|
||||
def _parse_modelfile(modelfile, cb, base=None):
|
||||
base = Path.cwd() if base is None else base
|
||||
|
||||
out = io.StringIO()
|
||||
for line in io.StringIO(modelfile):
|
||||
command, _, args = line.partition(' ')
|
||||
if command.upper() in ['FROM', 'ADAPTER']:
|
||||
path = Path(args).expanduser()
|
||||
path = path if path.is_absolute() else base / path
|
||||
if path.exists():
|
||||
args = f'@{cb(path)}'
|
||||
|
||||
print(command, args, file=out)
|
||||
return out.getvalue()
|
||||
|
||||
|
||||
def _as_path(s):
|
||||
if isinstance(s, str) or isinstance(s, Path):
|
||||
return Path(s)
|
||||
return None
|
||||
|
||||
def _as_bytesio(s):
|
||||
if isinstance(s, io.BytesIO):
|
||||
return s
|
||||
elif isinstance(s, bytes):
|
||||
return io.BytesIO(s)
|
||||
return None
|
||||
@ -1,292 +0,0 @@
|
||||
import pytest
|
||||
import os
|
||||
import io
|
||||
import types
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from ollama.client import Client
|
||||
from pytest_httpserver import HTTPServer, URIPattern
|
||||
from werkzeug.wrappers import Response
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class PrefixPattern(URIPattern):
|
||||
def __init__(self, prefix: str):
|
||||
self.prefix = prefix
|
||||
|
||||
def match(self, uri):
|
||||
return uri.startswith(self.prefix)
|
||||
|
||||
|
||||
def test_client_chat(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_chat_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': True,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_chat_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/chat', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'messages': [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
},
|
||||
],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with io.BytesIO() as b:
|
||||
Image.new('RGB', (1, 1)).save(b, 'PNG')
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_generate(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_generate_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': True,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_generate_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/generate', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'format': '',
|
||||
'options': {},
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
Image.new('RGB', (1, 1)).save(temp, 'PNG')
|
||||
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_pull(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/pull', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.pull('dummy')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_pull_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/pull', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.pull('dummy', stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_push(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/push', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.push('dummy')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_push_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/push', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.push('dummy', stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_create_path(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_path_relative(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
|
||||
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def userhomedir():
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
home = os.getenv('HOME', '')
|
||||
os.environ['HOME'] = temp
|
||||
yield Path(temp)
|
||||
os.environ['HOME'] = home
|
||||
|
||||
|
||||
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
|
||||
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_modelfile(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client.create('dummy', modelfile=f'FROM {blob.name}')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_from_library(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request('/api/create', method='POST', json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM llama2\n',
|
||||
'stream': False,
|
||||
}).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
response = client.create('dummy', modelfile='FROM llama2')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_blob(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client.create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
|
||||
|
||||
def test_client_create_blob_exists(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client.create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
20
poetry.lock
generated
20
poetry.lock
generated
@ -387,6 +387,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.23.2"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"},
|
||||
{file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=7.0.0"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "4.1.0"
|
||||
@ -498,4 +516,4 @@ watchdog = ["watchdog (>=2.3)"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.8"
|
||||
content-hash = "b9f64e1a5795a417d2dbff7286360f8d3f8f10fdfa9580411940d144c2561e92"
|
||||
content-hash = "9416a897c95d3c80cf1bfd3cc61cd19f0143c9bd0bc7c219fcb31ee27c497c9d"
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
[tool.poetry]
|
||||
name = "ollama"
|
||||
version = "0.1.0"
|
||||
version = "0.0.0"
|
||||
description = "The official Python client for Ollama."
|
||||
authors = ["Ollama <hello@ollama.com>"]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
homepage = "https://ollama.ai"
|
||||
repository = "https://github.com/jmorganca/ollama-python"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8"
|
||||
@ -11,12 +14,18 @@ httpx = "^0.25.2"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^7.4.3"
|
||||
pytest-asyncio = "^0.23.2"
|
||||
pytest-cov = "^4.1.0"
|
||||
pytest-httpserver = "^1.0.8"
|
||||
pillow = "^10.1.0"
|
||||
ruff = "^0.1.8"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 999
|
||||
indent-width = 2
|
||||
|
||||
[tool.ruff.format]
|
||||
@ -26,7 +35,3 @@ indent-style = "space"
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "B"]
|
||||
ignore = ["E501"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
779
tests/test_client.py
Normal file
779
tests/test_client.py
Normal file
@ -0,0 +1,779 @@
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import types
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from pytest_httpserver import HTTPServer, URIPattern
|
||||
from werkzeug.wrappers import Request, Response
|
||||
from PIL import Image
|
||||
|
||||
from ollama._client import Client, AsyncClient
|
||||
|
||||
|
||||
class PrefixPattern(URIPattern):
|
||||
def __init__(self, prefix: str):
|
||||
self.prefix = prefix
|
||||
|
||||
def match(self, uri):
|
||||
return uri.startswith(self.prefix)
|
||||
|
||||
|
||||
def test_client_chat(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': "I don't know.",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
def test_client_chat_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
for message in ['I ', "don't ", 'know.']:
|
||||
yield (
|
||||
json.dumps(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': message,
|
||||
},
|
||||
}
|
||||
)
|
||||
+ '\n'
|
||||
)
|
||||
|
||||
return Response(generate())
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': True,
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_handler(stream_handler)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
|
||||
for part in response:
|
||||
assert part['message']['role'] in 'assistant'
|
||||
assert part['message']['content'] in ['I ', "don't ", 'know.']
|
||||
|
||||
|
||||
def test_client_chat_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
},
|
||||
],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': "I don't know.",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with io.BytesIO() as b:
|
||||
Image.new('RGB', (1, 1)).save(b, 'PNG')
|
||||
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['message']['role'] == 'assistant'
|
||||
assert response['message']['content'] == "I don't know."
|
||||
|
||||
|
||||
def test_client_generate(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': 'Because it is.',
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?')
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
|
||||
def test_client_generate_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
for message in ['Because ', 'it ', 'is.']:
|
||||
yield (
|
||||
json.dumps(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': message,
|
||||
}
|
||||
)
|
||||
+ '\n'
|
||||
)
|
||||
|
||||
return Response(generate())
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': True,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_handler(stream_handler)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
|
||||
for part in response:
|
||||
assert part['model'] == 'dummy'
|
||||
assert part['response'] in ['Because ', 'it ', 'is.']
|
||||
|
||||
|
||||
def test_client_generate_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'model': 'dummy',
|
||||
'response': 'Because it is.',
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
Image.new('RGB', (1, 1)).save(temp, 'PNG')
|
||||
response = client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
|
||||
assert response['model'] == 'dummy'
|
||||
assert response['response'] == 'Because it is.'
|
||||
|
||||
|
||||
def test_client_pull(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/pull',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json(
|
||||
{
|
||||
'status': 'success',
|
||||
}
|
||||
)
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.pull('dummy')
|
||||
assert response['status'] == 'success'
|
||||
|
||||
|
||||
def test_client_pull_stream(httpserver: HTTPServer):
|
||||
def stream_handler(_: Request):
|
||||
def generate():
|
||||
yield json.dumps({'status': 'pulling manifest'}) + '\n'
|
||||
yield json.dumps({'status': 'verifying sha256 digest'}) + '\n'
|
||||
yield json.dumps({'status': 'writing manifest'}) + '\n'
|
||||
yield json.dumps({'status': 'removing any unused layers'}) + '\n'
|
||||
yield json.dumps({'status': 'success'}) + '\n'
|
||||
|
||||
return Response(generate())
|
||||
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/pull',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.pull('dummy', stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_push(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/push',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.push('dummy')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_push_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/push',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
response = client.push('dummy', stream=True)
|
||||
assert isinstance(response, types.GeneratorType)
|
||||
|
||||
|
||||
def test_client_create_path(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_path_relative(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
|
||||
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def userhomedir():
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
home = os.getenv('HOME', '')
|
||||
os.environ['HOME'] = temp
|
||||
yield Path(temp)
|
||||
os.environ['HOME'] = home
|
||||
|
||||
|
||||
def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
|
||||
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_modelfile(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client.create('dummy', modelfile=f'FROM {blob.name}')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_from_library(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM llama2\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
response = client.create('dummy', modelfile='FROM llama2')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
def test_client_create_blob(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client._create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
|
||||
|
||||
def test_client_create_blob_exists(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
|
||||
client = Client(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = client._create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_chat(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_chat_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
|
||||
'stream': True,
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
|
||||
assert isinstance(response, types.AsyncGeneratorType)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_chat_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/chat',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'messages': [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Why is the sky blue?',
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
},
|
||||
],
|
||||
'stream': False,
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with io.BytesIO() as b:
|
||||
Image.new('RGB', (1, 1)).save(b, 'PNG')
|
||||
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?', 'images': [b.getvalue()]}])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_generate(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.generate('dummy', 'Why is the sky blue?')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_generate_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': True,
|
||||
'raw': False,
|
||||
'images': [],
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.generate('dummy', 'Why is the sky blue?', stream=True)
|
||||
assert isinstance(response, types.AsyncGeneratorType)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_generate_images(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/generate',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'prompt': 'Why is the sky blue?',
|
||||
'system': '',
|
||||
'template': '',
|
||||
'context': [],
|
||||
'stream': False,
|
||||
'raw': False,
|
||||
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
|
||||
'format': '',
|
||||
'options': {},
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as temp:
|
||||
Image.new('RGB', (1, 1)).save(temp, 'PNG')
|
||||
response = await client.generate('dummy', 'Why is the sky blue?', images=[temp.name])
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_pull(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/pull',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.pull('dummy')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_pull_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/pull',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.pull('dummy', stream=True)
|
||||
assert isinstance(response, types.AsyncGeneratorType)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_push(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/push',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.push('dummy')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_push_stream(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/push',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'insecure': False,
|
||||
'stream': True,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
response = await client.push('dummy', stream=True)
|
||||
assert isinstance(response, types.AsyncGeneratorType)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_path(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
modelfile.write(f'FROM {blob.name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = await client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_path_relative(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=Path(modelfile.name).parent) as blob:
|
||||
modelfile.write(f'FROM {Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = await client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as modelfile:
|
||||
with tempfile.NamedTemporaryFile(dir=userhomedir) as blob:
|
||||
modelfile.write(f'FROM ~/{Path(blob.name).name}'.encode('utf-8'))
|
||||
modelfile.flush()
|
||||
|
||||
response = await client.create('dummy', path=modelfile.name)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_modelfile(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = await client.create('dummy', modelfile=f'FROM {blob.name}')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_from_library(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(
|
||||
'/api/create',
|
||||
method='POST',
|
||||
json={
|
||||
'model': 'dummy',
|
||||
'modelfile': 'FROM llama2\n',
|
||||
'stream': False,
|
||||
},
|
||||
).respond_with_json({})
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
response = await client.create('dummy', modelfile='FROM llama2')
|
||||
assert isinstance(response, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_blob(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=404))
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='PUT').respond_with_response(Response(status=201))
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = await client._create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_client_create_blob_exists(httpserver: HTTPServer):
|
||||
httpserver.expect_ordered_request(PrefixPattern('/api/blobs/'), method='HEAD').respond_with_response(Response(status=200))
|
||||
|
||||
client = AsyncClient(httpserver.url_for('/'))
|
||||
|
||||
with tempfile.NamedTemporaryFile() as blob:
|
||||
response = await client._create_blob(blob.name)
|
||||
assert response == 'sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'
|
||||
Loading…
x
Reference in New Issue
Block a user