From c67ef1ae340e49ef4fbef7e3c5a1cce80a951da2 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 9 Jan 2024 15:24:30 -0800 Subject: [PATCH] fix: type hints --- ollama/_client.py | 78 ++++++++++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/ollama/_client.py b/ollama/_client.py index 9110212..19ec847 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -34,9 +34,6 @@ class Client(BaseClient): response.raise_for_status() return response - def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: - return self._request(method, url, **kwargs).json() - def _stream(self, method: str, url: str, **kwargs) -> Iterator[Mapping[str, Any]]: with self._client.stream(method, url, **kwargs) as r: for line in r.iter_lines(): @@ -45,6 +42,14 @@ class Client(BaseClient): raise Exception(e) yield partial + def _request_stream( + self, + *args, + stream: bool = False, + **kwargs, + ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: + return self._stream(*args, **kwargs) if stream else self._request(*args, **kwargs).json() + def generate( self, model: str = '', @@ -61,8 +66,7 @@ class Client(BaseClient): if not model: raise Exception('must provide a model') - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/generate', json={ @@ -77,6 +81,7 @@ class Client(BaseClient): 'format': format, 'options': options or {}, }, + stream=stream, ) def chat( @@ -100,8 +105,7 @@ class Client(BaseClient): if images := message.get('images'): message['images'] = [_encode_image(image) for image in images] - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/chat', json={ @@ -111,6 +115,7 @@ class Client(BaseClient): 'format': format, 'options': options or {}, }, + stream=stream, ) def pull( @@ -119,8 +124,7 @@ class Client(BaseClient): insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/pull', json={ @@ -128,6 +132,7 @@ class Client(BaseClient): 'insecure': insecure, 'stream': stream, }, + stream=stream, ) def push( @@ -136,8 +141,7 @@ class Client(BaseClient): insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], Iterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/push', json={ @@ -145,6 +149,7 @@ class Client(BaseClient): 'insecure': insecure, 'stream': stream, }, + stream=stream, ) def create( @@ -161,8 +166,7 @@ class Client(BaseClient): else: raise Exception('must provide either path or modelfile') - fn = self._stream if stream else self._request_json - return fn( + return self._request_stream( 'POST', '/api/create', json={ @@ -170,6 +174,7 @@ class Client(BaseClient): 'modelfile': modelfile, 'stream': stream, }, + stream=stream, ) def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: @@ -214,14 +219,14 @@ class Client(BaseClient): return {'status': 'success' if response.status_code == 200 else 'error'} def list(self) -> Mapping[str, Any]: - return self._request_json('GET', '/api/tags').get('models', []) + return self._request('GET', '/api/tags').json().get('models', []) def copy(self, source: str, target: str) -> Mapping[str, Any]: response = self._request('POST', '/api/copy', json={'source': source, 'destination': target}) return {'status': 'success' if response.status_code == 200 else 'error'} def show(self, model: str) -> Mapping[str, Any]: - return self._request_json('GET', '/api/show', json={'name': model}) + return self._request('GET', '/api/show', json={'name': model}).json() class AsyncClient(BaseClient): @@ -233,10 +238,6 @@ class AsyncClient(BaseClient): response.raise_for_status() return response - async def _request_json(self, method: str, url: str, **kwargs) -> Mapping[str, Any]: - response = await self._request(method, url, **kwargs) - return response.json() - async def _stream(self, method: str, url: str, **kwargs) -> AsyncIterator[Mapping[str, Any]]: async def inner(): async with self._client.stream(method, url, **kwargs) as r: @@ -248,6 +249,18 @@ class AsyncClient(BaseClient): return inner() + async def _request_stream( + self, + *args, + stream: bool = False, + **kwargs, + ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: + if stream: + return await self._stream(*args, **kwargs) + + response = await self._request(*args, **kwargs) + return response.json() + async def generate( self, model: str = '', @@ -264,8 +277,7 @@ class AsyncClient(BaseClient): if not model: raise Exception('must provide a model') - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/generate', json={ @@ -280,6 +292,7 @@ class AsyncClient(BaseClient): 'format': format, 'options': options or {}, }, + stream=stream, ) async def chat( @@ -303,8 +316,7 @@ class AsyncClient(BaseClient): if images := message.get('images'): message['images'] = [_encode_image(image) for image in images] - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/chat', json={ @@ -314,6 +326,7 @@ class AsyncClient(BaseClient): 'format': format, 'options': options or {}, }, + stream=stream, ) async def pull( @@ -322,8 +335,7 @@ class AsyncClient(BaseClient): insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/pull', json={ @@ -331,6 +343,7 @@ class AsyncClient(BaseClient): 'insecure': insecure, 'stream': stream, }, + stream=stream, ) async def push( @@ -339,8 +352,7 @@ class AsyncClient(BaseClient): insecure: bool = False, stream: bool = False, ) -> Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]: - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/push', json={ @@ -348,6 +360,7 @@ class AsyncClient(BaseClient): 'insecure': insecure, 'stream': stream, }, + stream=stream, ) async def create( @@ -364,8 +377,7 @@ class AsyncClient(BaseClient): else: raise Exception('must provide either path or modelfile') - fn = self._stream if stream else self._request_json - return await fn( + return await self._request_stream( 'POST', '/api/create', json={ @@ -373,6 +385,7 @@ class AsyncClient(BaseClient): 'modelfile': modelfile, 'stream': stream, }, + stream=stream, ) async def _parse_modelfile(self, modelfile: str, base: Optional[Path] = None) -> str: @@ -424,15 +437,16 @@ class AsyncClient(BaseClient): return {'status': 'success' if response.status_code == 200 else 'error'} async def list(self) -> Mapping[str, Any]: - response = await self._request_json('GET', '/api/tags') - return response.get('models', []) + response = await self._request('GET', '/api/tags') + return response.json().get('models', []) async def copy(self, source: str, target: str) -> Mapping[str, Any]: response = await self._request('POST', '/api/copy', json={'source': source, 'destination': target}) return {'status': 'success' if response.status_code == 200 else 'error'} async def show(self, model: str) -> Mapping[str, Any]: - return await self._request_json('GET', '/api/show', json={'name': model}) + response = await self._request('GET', '/api/show', json={'name': model}) + return response.json() def _encode_image(image) -> str: