Merge pull request #2 from jmorganca/mxyng/fix

fix endpoints
This commit is contained in:
Michael Yang 2023-12-22 16:32:51 -08:00 committed by GitHub
commit 349d9c3023
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 49 additions and 41 deletions

View File

@ -25,7 +25,7 @@ for part in ollama.chat(model='llama2', messages=[message], stream=True):
```
### Using the Synchronous Client
## Using the Synchronous Client
```python
from ollama import Client
@ -42,7 +42,7 @@ for part in Client().chat(model='llama2', messages=[message], stream=True):
print(part['message']['content'], end='', flush=True)
```
### Using the Asynchronous Client
## Using the Asynchronous Client
```python
import asyncio

View File

@ -8,10 +8,8 @@ messages = [
},
]
for message in chat('mistral', messages=messages, stream=True):
if message := message.get('message'):
if message.get('role') == 'assistant':
print(message.get('content', ''), end='', flush=True)
for part in chat('mistral', messages=messages, stream=True):
print(part['message']['content'], end='', flush=True)
# end with a newline
print()

View File

@ -9,4 +9,4 @@ messages = [
]
response = chat('mistral', messages=messages)
print(response['message'])
print(response['message']['content'])

View File

@ -0,0 +1,5 @@
from ollama import generate
for part in generate('mistral', 'Why is the sky blue?', stream=True):
print(part['response'], end='', flush=True)

View File

@ -0,0 +1,5 @@
from ollama import generate
response = generate('mistral', 'Why is the sky blue?')
print(response['response'])

View File

@ -24,6 +24,6 @@ raw = httpx.get(comic.json().get('img'))
raw.raise_for_status()
for response in generate('llava', 'explain this comic:', images=[raw.content], stream=True):
print(response.get('response'), end='', flush=True)
print(response['response'], end='', flush=True)
print()

View File

@ -19,13 +19,13 @@ from ollama._types import Message, Options
class BaseClient:
def __init__(self, client, base_url='http://127.0.0.1:11434') -> None:
def __init__(self, client, base_url: str = 'http://127.0.0.1:11434') -> None:
self._client = client(base_url=base_url, follow_redirects=True, timeout=None)
class Client(BaseClient):
def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.Client, base)
def __init__(self, base_url: str = 'http://localhost:11434') -> None:
super().__init__(httpx.Client, base_url)
def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = self._client.request(method, url, **kwargs)
@ -122,7 +122,7 @@ class Client(BaseClient):
'POST',
'/api/pull',
json={
'model': model,
'name': model,
'insecure': insecure,
'stream': stream,
},
@ -139,7 +139,7 @@ class Client(BaseClient):
'POST',
'/api/push',
json={
'model': model,
'name': model,
'insecure': insecure,
'stream': stream,
},
@ -164,7 +164,7 @@ class Client(BaseClient):
'POST',
'/api/create',
json={
'model': model,
'name': model,
'modelfile': modelfile,
'stream': stream,
},
@ -208,7 +208,7 @@ class Client(BaseClient):
return digest
def delete(self, model: str) -> Mapping[str, Any]:
response = self._request('DELETE', '/api/delete', json={'model': model})
response = self._request('DELETE', '/api/delete', json={'name': model})
return {'status': 'success' if response.status_code == 200 else 'error'}
def list(self) -> Mapping[str, Any]:
@ -219,12 +219,12 @@ class Client(BaseClient):
return {'status': 'success' if response.status_code == 200 else 'error'}
def show(self, model: str) -> Mapping[str, Any]:
return self._request_json('GET', '/api/show', json={'model': model})
return self._request_json('GET', '/api/show', json={'name': model})
class AsyncClient(BaseClient):
def __init__(self, base='http://localhost:11434') -> None:
super().__init__(httpx.AsyncClient, base)
def __init__(self, base_url: str = 'http://localhost:11434') -> None:
super().__init__(httpx.AsyncClient, base_url)
async def _request(self, method: str, url: str, **kwargs) -> httpx.Response:
response = await self._client.request(method, url, **kwargs)
@ -325,7 +325,7 @@ class AsyncClient(BaseClient):
'POST',
'/api/pull',
json={
'model': model,
'name': model,
'insecure': insecure,
'stream': stream,
},
@ -342,7 +342,7 @@ class AsyncClient(BaseClient):
'POST',
'/api/push',
json={
'model': model,
'name': model,
'insecure': insecure,
'stream': stream,
},
@ -367,7 +367,7 @@ class AsyncClient(BaseClient):
'POST',
'/api/create',
json={
'model': model,
'name': model,
'modelfile': modelfile,
'stream': stream,
},
@ -418,7 +418,7 @@ class AsyncClient(BaseClient):
return digest
async def delete(self, model: str) -> Mapping[str, Any]:
response = await self._request('DELETE', '/api/delete', json={'model': model})
response = await self._request('DELETE', '/api/delete', json={'name': model})
return {'status': 'success' if response.status_code == 200 else 'error'}
async def list(self) -> Mapping[str, Any]:
@ -430,7 +430,7 @@ class AsyncClient(BaseClient):
return {'status': 'success' if response.status_code == 200 else 'error'}
async def show(self, model: str) -> Mapping[str, Any]:
return await self._request_json('GET', '/api/show', json={'model': model})
return await self._request_json('GET', '/api/show', json={'name': model})
def _encode_image(image) -> str:

View File

@ -229,7 +229,7 @@ def test_client_pull(httpserver: HTTPServer):
'/api/pull',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': False,
},
@ -259,7 +259,7 @@ def test_client_pull_stream(httpserver: HTTPServer):
'/api/pull',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': True,
},
@ -275,7 +275,7 @@ def test_client_push(httpserver: HTTPServer):
'/api/push',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': False,
},
@ -291,7 +291,7 @@ def test_client_push_stream(httpserver: HTTPServer):
'/api/push',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': True,
},
@ -308,7 +308,7 @@ def test_client_create_path(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
@ -331,7 +331,7 @@ def test_client_create_path_relative(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
@ -363,7 +363,7 @@ def test_client_create_path_user_home(httpserver: HTTPServer, userhomedir):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
@ -386,7 +386,7 @@ def test_client_create_modelfile(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
@ -404,7 +404,7 @@ def test_client_create_from_library(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
},
@ -584,7 +584,7 @@ async def test_async_client_pull(httpserver: HTTPServer):
'/api/pull',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': False,
},
@ -601,7 +601,7 @@ async def test_async_client_pull_stream(httpserver: HTTPServer):
'/api/pull',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': True,
},
@ -618,7 +618,7 @@ async def test_async_client_push(httpserver: HTTPServer):
'/api/push',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': False,
},
@ -635,7 +635,7 @@ async def test_async_client_push_stream(httpserver: HTTPServer):
'/api/push',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'insecure': False,
'stream': True,
},
@ -653,7 +653,7 @@ async def test_async_client_create_path(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
@ -677,7 +677,7 @@ async def test_async_client_create_path_relative(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
@ -701,7 +701,7 @@ async def test_async_client_create_path_user_home(httpserver: HTTPServer, userho
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
@ -725,7 +725,7 @@ async def test_async_client_create_modelfile(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM @sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855\n',
'stream': False,
},
@ -744,7 +744,7 @@ async def test_async_client_create_from_library(httpserver: HTTPServer):
'/api/create',
method='POST',
json={
'model': 'dummy',
'name': 'dummy',
'modelfile': 'FROM llama2\n',
'stream': False,
},