mirror of
https://github.com/ollama/ollama-python.git
synced 2026-02-03 18:08:43 -06:00
Add image generation support (#616)
This commit is contained in:
parent
60e7b2f9ce
commit
dbccf192ac
@ -250,7 +250,6 @@ ollama.embed(model='gemma3', input=['The sky is blue because of rayleigh scatter
|
|||||||
ollama.ps()
|
ollama.ps()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Errors
|
## Errors
|
||||||
|
|
||||||
Errors are raised if requests return an error status or if an error is detected while streaming.
|
Errors are raised if requests return an error status or if an error is detected while streaming.
|
||||||
|
|||||||
@ -78,6 +78,12 @@ Configuration to use with an MCP client:
|
|||||||
- [multimodal-chat.py](multimodal-chat.py)
|
- [multimodal-chat.py](multimodal-chat.py)
|
||||||
- [multimodal-generate.py](multimodal-generate.py)
|
- [multimodal-generate.py](multimodal-generate.py)
|
||||||
|
|
||||||
|
### Image Generation (Experimental) - Generate images with a model
|
||||||
|
|
||||||
|
> **Note:** Image generation is experimental and currently only available on macOS.
|
||||||
|
|
||||||
|
- [generate-image.py](generate-image.py)
|
||||||
|
|
||||||
### Structured Outputs - Generate structured outputs with a model
|
### Structured Outputs - Generate structured outputs with a model
|
||||||
|
|
||||||
- [structured-outputs.py](structured-outputs.py)
|
- [structured-outputs.py](structured-outputs.py)
|
||||||
|
|||||||
18
examples/generate-image.py
Normal file
18
examples/generate-image.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Image generation is experimental and currently only available on macOS
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from ollama import generate
|
||||||
|
|
||||||
|
prompt = 'a sunset over mountains'
|
||||||
|
print(f'Prompt: {prompt}')
|
||||||
|
|
||||||
|
for response in generate(model='x/z-image-turbo', prompt=prompt, stream=True):
|
||||||
|
if response.image:
|
||||||
|
# Final response contains the image
|
||||||
|
with open('output.png', 'wb') as f:
|
||||||
|
f.write(base64.b64decode(response.image))
|
||||||
|
print('\nImage saved to output.png')
|
||||||
|
elif response.total:
|
||||||
|
# Progress update
|
||||||
|
print(f'Progress: {response.completed or 0}/{response.total}', end='\r')
|
||||||
@ -217,6 +217,9 @@ class Client(BaseClient):
|
|||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
steps: Optional[int] = None,
|
||||||
) -> GenerateResponse: ...
|
) -> GenerateResponse: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -238,6 +241,9 @@ class Client(BaseClient):
|
|||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
steps: Optional[int] = None,
|
||||||
) -> Iterator[GenerateResponse]: ...
|
) -> Iterator[GenerateResponse]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@ -258,6 +264,9 @@ class Client(BaseClient):
|
|||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
steps: Optional[int] = None,
|
||||||
) -> Union[GenerateResponse, Iterator[GenerateResponse]]:
|
) -> Union[GenerateResponse, Iterator[GenerateResponse]]:
|
||||||
"""
|
"""
|
||||||
Create a response using the requested model.
|
Create a response using the requested model.
|
||||||
@ -289,6 +298,9 @@ class Client(BaseClient):
|
|||||||
images=list(_copy_images(images)) if images else None,
|
images=list(_copy_images(images)) if images else None,
|
||||||
options=options,
|
options=options,
|
||||||
keep_alive=keep_alive,
|
keep_alive=keep_alive,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
steps=steps,
|
||||||
).model_dump(exclude_none=True),
|
).model_dump(exclude_none=True),
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
@ -838,6 +850,9 @@ class AsyncClient(BaseClient):
|
|||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
steps: Optional[int] = None,
|
||||||
) -> GenerateResponse: ...
|
) -> GenerateResponse: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -859,6 +874,9 @@ class AsyncClient(BaseClient):
|
|||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
steps: Optional[int] = None,
|
||||||
) -> AsyncIterator[GenerateResponse]: ...
|
) -> AsyncIterator[GenerateResponse]: ...
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
@ -879,6 +897,9 @@ class AsyncClient(BaseClient):
|
|||||||
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
|
||||||
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
options: Optional[Union[Mapping[str, Any], Options]] = None,
|
||||||
keep_alive: Optional[Union[float, str]] = None,
|
keep_alive: Optional[Union[float, str]] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
steps: Optional[int] = None,
|
||||||
) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]:
|
) -> Union[GenerateResponse, AsyncIterator[GenerateResponse]]:
|
||||||
"""
|
"""
|
||||||
Create a response using the requested model.
|
Create a response using the requested model.
|
||||||
@ -909,6 +930,9 @@ class AsyncClient(BaseClient):
|
|||||||
images=list(_copy_images(images)) if images else None,
|
images=list(_copy_images(images)) if images else None,
|
||||||
options=options,
|
options=options,
|
||||||
keep_alive=keep_alive,
|
keep_alive=keep_alive,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
steps=steps,
|
||||||
).model_dump(exclude_none=True),
|
).model_dump(exclude_none=True),
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -216,6 +216,16 @@ class GenerateRequest(BaseGenerateRequest):
|
|||||||
top_logprobs: Optional[int] = None
|
top_logprobs: Optional[int] = None
|
||||||
'Number of alternative tokens and log probabilities to include per position (0-20).'
|
'Number of alternative tokens and log probabilities to include per position (0-20).'
|
||||||
|
|
||||||
|
# Experimental image generation parameters
|
||||||
|
width: Optional[int] = None
|
||||||
|
'Width of the generated image in pixels (for image generation models).'
|
||||||
|
|
||||||
|
height: Optional[int] = None
|
||||||
|
'Height of the generated image in pixels (for image generation models).'
|
||||||
|
|
||||||
|
steps: Optional[int] = None
|
||||||
|
'Number of diffusion steps (for image generation models).'
|
||||||
|
|
||||||
|
|
||||||
class BaseGenerateResponse(SubscriptableBaseModel):
|
class BaseGenerateResponse(SubscriptableBaseModel):
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
@ -267,7 +277,7 @@ class GenerateResponse(BaseGenerateResponse):
|
|||||||
Response returned by generate requests.
|
Response returned by generate requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response: str
|
response: Optional[str] = None
|
||||||
'Response content. When streaming, this contains a fragment of the response.'
|
'Response content. When streaming, this contains a fragment of the response.'
|
||||||
|
|
||||||
thinking: Optional[str] = None
|
thinking: Optional[str] = None
|
||||||
@ -279,6 +289,17 @@ class GenerateResponse(BaseGenerateResponse):
|
|||||||
logprobs: Optional[Sequence[Logprob]] = None
|
logprobs: Optional[Sequence[Logprob]] = None
|
||||||
'Log probabilities for generated tokens.'
|
'Log probabilities for generated tokens.'
|
||||||
|
|
||||||
|
# Image generation response fields
|
||||||
|
image: Optional[str] = None
|
||||||
|
'Base64-encoded generated image data (for image generation models).'
|
||||||
|
|
||||||
|
# Streaming progress fields (for image generation)
|
||||||
|
completed: Optional[int] = None
|
||||||
|
'Number of completed steps (for image generation streaming).'
|
||||||
|
|
||||||
|
total: Optional[int] = None
|
||||||
|
'Total number of steps (for image generation streaming).'
|
||||||
|
|
||||||
|
|
||||||
class Message(SubscriptableBaseModel):
|
class Message(SubscriptableBaseModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -568,6 +568,115 @@ async def test_async_client_generate_format_pydantic(httpserver: HTTPServer):
|
|||||||
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
assert response['response'] == '{"answer": "Because of Rayleigh scattering", "confidence": 0.95}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_generate_image(httpserver: HTTPServer):
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/generate',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy-image',
|
||||||
|
'prompt': 'a sunset over mountains',
|
||||||
|
'stream': False,
|
||||||
|
'width': 1024,
|
||||||
|
'height': 768,
|
||||||
|
'steps': 20,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy-image',
|
||||||
|
'image': PNG_BASE64,
|
||||||
|
'done': True,
|
||||||
|
'done_reason': 'stop',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = Client(httpserver.url_for('/'))
|
||||||
|
response = client.generate('dummy-image', 'a sunset over mountains', width=1024, height=768, steps=20)
|
||||||
|
assert response['model'] == 'dummy-image'
|
||||||
|
assert response['image'] == PNG_BASE64
|
||||||
|
assert response['done'] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_generate_image_stream(httpserver: HTTPServer):
|
||||||
|
def stream_handler(_: Request):
|
||||||
|
def generate():
|
||||||
|
# Progress updates
|
||||||
|
for i in range(1, 4):
|
||||||
|
yield (
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
'model': 'dummy-image',
|
||||||
|
'completed': i,
|
||||||
|
'total': 3,
|
||||||
|
'done': False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ '\n'
|
||||||
|
)
|
||||||
|
# Final response with image
|
||||||
|
yield (
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
'model': 'dummy-image',
|
||||||
|
'image': PNG_BASE64,
|
||||||
|
'done': True,
|
||||||
|
'done_reason': 'stop',
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ '\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(generate())
|
||||||
|
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/generate',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy-image',
|
||||||
|
'prompt': 'a sunset over mountains',
|
||||||
|
'stream': True,
|
||||||
|
'width': 512,
|
||||||
|
'height': 512,
|
||||||
|
},
|
||||||
|
).respond_with_handler(stream_handler)
|
||||||
|
|
||||||
|
client = Client(httpserver.url_for('/'))
|
||||||
|
response = client.generate('dummy-image', 'a sunset over mountains', stream=True, width=512, height=512)
|
||||||
|
|
||||||
|
parts = list(response)
|
||||||
|
# Check progress updates
|
||||||
|
assert parts[0]['completed'] == 1
|
||||||
|
assert parts[0]['total'] == 3
|
||||||
|
assert parts[0]['done'] is False
|
||||||
|
# Check final response
|
||||||
|
assert parts[-1]['image'] == PNG_BASE64
|
||||||
|
assert parts[-1]['done'] is True
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_client_generate_image(httpserver: HTTPServer):
|
||||||
|
httpserver.expect_ordered_request(
|
||||||
|
'/api/generate',
|
||||||
|
method='POST',
|
||||||
|
json={
|
||||||
|
'model': 'dummy-image',
|
||||||
|
'prompt': 'a robot painting',
|
||||||
|
'stream': False,
|
||||||
|
'width': 1024,
|
||||||
|
'height': 1024,
|
||||||
|
},
|
||||||
|
).respond_with_json(
|
||||||
|
{
|
||||||
|
'model': 'dummy-image',
|
||||||
|
'image': PNG_BASE64,
|
||||||
|
'done': True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
client = AsyncClient(httpserver.url_for('/'))
|
||||||
|
response = await client.generate('dummy-image', 'a robot painting', width=1024, height=1024)
|
||||||
|
assert response['model'] == 'dummy-image'
|
||||||
|
assert response['image'] == PNG_BASE64
|
||||||
|
|
||||||
|
|
||||||
def test_client_pull(httpserver: HTTPServer):
|
def test_client_pull(httpserver: HTTPServer):
|
||||||
httpserver.expect_ordered_request(
|
httpserver.expect_ordered_request(
|
||||||
'/api/pull',
|
'/api/pull',
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user