mirror of
https://github.com/ollama/ollama-js.git
synced 2025-12-15 13:54:15 -06:00
browser/utils: updates for web search and loading OLLAMA_API_KEY from the environment (#250)
--------- Co-authored-by: jmorganca <jmorganca@gmail.com>
This commit is contained in:
parent
0ce6961552
commit
a6689ac591
@ -1,12 +1,17 @@
|
||||
import { Ollama, type Message, type SearchResponse, type FetchResponse } from 'ollama'
|
||||
import {
|
||||
Ollama,
|
||||
type Message,
|
||||
type WebSearchResponse,
|
||||
type WebFetchResponse,
|
||||
} from 'ollama'
|
||||
|
||||
async function main() {
|
||||
|
||||
if (!process.env.OLLAMA_API_KEY) throw new Error('Set OLLAMA_API_KEY to use web search tools')
|
||||
|
||||
const client = new Ollama({
|
||||
headers: { Authorization: `Bearer ${process.env.OLLAMA_API_KEY}` },
|
||||
})
|
||||
// Set enviornment variable OLLAMA_API_KEY=<YOUR>.<KEY>
|
||||
// or set the header manually
|
||||
// const client = new Ollama({
|
||||
// headers: { Authorization: `Bearer ${process.env.OLLAMA_API_KEY}` },
|
||||
// })
|
||||
const client = new Ollama()
|
||||
|
||||
// Tool schemas
|
||||
const webSearchTool = {
|
||||
@ -20,7 +25,7 @@ async function main() {
|
||||
query: { type: 'string', description: 'Search query string.' },
|
||||
max_results: {
|
||||
type: 'number',
|
||||
description: 'The maximum number of results to return per query (default 5, max 10).',
|
||||
description: 'The maximum number of results to return per query (default 3).',
|
||||
},
|
||||
},
|
||||
required: ['query'],
|
||||
@ -43,28 +48,32 @@ async function main() {
|
||||
},
|
||||
}
|
||||
|
||||
const availableTools = {
|
||||
webSearch: async (args: { query: string; max_results?: number }): Promise<SearchResponse> => {
|
||||
const res = await client.webSearch(args)
|
||||
return res as SearchResponse
|
||||
},
|
||||
webFetch: async (args: { url: string }): Promise<FetchResponse> => {
|
||||
const res = await client.webFetch(args)
|
||||
return res as FetchResponse
|
||||
},
|
||||
}
|
||||
const availableTools = {
|
||||
webSearch: async (args: {
|
||||
query: string
|
||||
max_results?: number
|
||||
}): Promise<WebSearchResponse> => {
|
||||
const res = await client.webSearch(args)
|
||||
return res as WebSearchResponse
|
||||
},
|
||||
webFetch: async (args: { url: string }): Promise<WebFetchResponse> => {
|
||||
const res = await client.webFetch(args)
|
||||
return res as WebFetchResponse
|
||||
},
|
||||
}
|
||||
|
||||
const query = 'What is Ollama?'
|
||||
console.log('Prompt:', query, '\n')
|
||||
|
||||
const messages: Message[] = [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'What is Ollama?',
|
||||
content: query,
|
||||
},
|
||||
]
|
||||
|
||||
console.log('----- Prompt:', messages.find((m) => m.role === 'user')?.content, '\n')
|
||||
|
||||
while (true) {
|
||||
const response = await client.chat({
|
||||
const response = await client.chat({
|
||||
model: 'qwen3',
|
||||
messages: messages,
|
||||
tools: [webSearchTool, webFetchTool],
|
||||
@ -76,7 +85,6 @@ async function main() {
|
||||
var content = ''
|
||||
var thinking = ''
|
||||
for await (const chunk of response) {
|
||||
|
||||
if (chunk.message.thinking) {
|
||||
thinking += chunk.message.thinking
|
||||
}
|
||||
@ -97,14 +105,19 @@ async function main() {
|
||||
const functionToCall = availableTools[toolCall.function.name]
|
||||
if (functionToCall) {
|
||||
const args = toolCall.function.arguments as any
|
||||
console.log('\nCalling function:', toolCall.function.name, 'with arguments:', args)
|
||||
console.log(
|
||||
'\nCalling function:',
|
||||
toolCall.function.name,
|
||||
'with arguments:',
|
||||
args,
|
||||
)
|
||||
const output = await functionToCall(args)
|
||||
console.log('Function result:', JSON.stringify(output).slice(0, 200), '\n')
|
||||
|
||||
|
||||
messages.push(chunk.message)
|
||||
messages.push({
|
||||
role: 'tool',
|
||||
content: JSON.stringify(output),
|
||||
content: JSON.stringify(output).slice(0, 2000 * 4), // cap at ~2000 tokens
|
||||
tool_name: toolCall.function.name,
|
||||
})
|
||||
}
|
||||
@ -116,9 +129,7 @@ async function main() {
|
||||
process.stdout.write('\n')
|
||||
break
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
main().catch(console.error)
|
||||
main().catch(console.error)
|
||||
|
||||
@ -24,10 +24,10 @@ import type {
|
||||
ShowRequest,
|
||||
ShowResponse,
|
||||
StatusResponse,
|
||||
SearchRequest,
|
||||
SearchResponse,
|
||||
FetchRequest,
|
||||
FetchResponse,
|
||||
WebSearchRequest,
|
||||
WebSearchResponse,
|
||||
WebFetchRequest,
|
||||
WebFetchResponse,
|
||||
} from './interfaces.js'
|
||||
import { defaultHost } from './constant.js'
|
||||
|
||||
@ -49,6 +49,8 @@ export class Ollama {
|
||||
this.fetch = config?.fetch ?? fetch
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Abort any ongoing streamed requests to Ollama
|
||||
public abort() {
|
||||
for (const request of this.ongoingStreamedRequests) {
|
||||
@ -327,32 +329,32 @@ async encodeImage(image: Uint8Array | string): Promise<string> {
|
||||
|
||||
/**
|
||||
* Performs web search using the Ollama web search API
|
||||
* @param request {SearchRequest} - The search request containing query and options
|
||||
* @returns {Promise<SearchResponse>} - The search results
|
||||
* @param request {WebSearchRequest} - The search request containing query and options
|
||||
* @returns {Promise<WebSearchResponse>} - The search results
|
||||
* @throws {Error} - If the request is invalid or the server returns an error
|
||||
*/
|
||||
async webSearch(request: SearchRequest): Promise<SearchResponse> {
|
||||
async webSearch(request: WebSearchRequest): Promise<WebSearchResponse> {
|
||||
if (!request.query || request.query.length === 0) {
|
||||
throw new Error('Query is required')
|
||||
}
|
||||
const response = await utils.post(this.fetch, `https://ollama.com/api/web_search`, { ...request }, {
|
||||
headers: this.config.headers
|
||||
})
|
||||
return (await response.json()) as SearchResponse
|
||||
return (await response.json()) as WebSearchResponse
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches a single page using the Ollama web fetch API
|
||||
* @param request {FetchRequest} - The fetch request containing a URL
|
||||
* @returns {Promise<FetchResponse>} - The fetch result
|
||||
* @param request {WebFetchRequest} - The fetch request containing a URL
|
||||
* @returns {Promise<WebFetchResponse>} - The fetch result
|
||||
* @throws {Error} - If the request is invalid or the server returns an error
|
||||
*/
|
||||
async webFetch(request: FetchRequest): Promise<FetchResponse> {
|
||||
async webFetch(request: WebFetchRequest): Promise<WebFetchResponse> {
|
||||
if (!request.url || request.url.length === 0) {
|
||||
throw new Error('URL is required')
|
||||
}
|
||||
const response = await utils.post(this.fetch, `https://ollama.com/api/web_fetch`, { ...request }, { headers: this.config.headers })
|
||||
return (await response.json()) as FetchResponse
|
||||
return (await response.json()) as WebFetchResponse
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -270,26 +270,25 @@ export interface StatusResponse {
|
||||
status: string
|
||||
}
|
||||
|
||||
// Web Search types
|
||||
export interface SearchRequest {
|
||||
export interface WebSearchRequest {
|
||||
query: string
|
||||
max_results?: number
|
||||
maxResults?: number
|
||||
}
|
||||
|
||||
export interface SearchResult {
|
||||
export interface WebSearchResult {
|
||||
content: string
|
||||
}
|
||||
|
||||
export interface SearchResponse {
|
||||
results: SearchResult[]
|
||||
export interface WebSearchResponse {
|
||||
results: WebSearchResult[]
|
||||
}
|
||||
|
||||
// Fetch types
|
||||
export interface FetchRequest {
|
||||
export interface WebFetchRequest {
|
||||
url: string
|
||||
}
|
||||
|
||||
export interface FetchResponse {
|
||||
export interface WebFetchResponse {
|
||||
title: string
|
||||
url: string
|
||||
content: string
|
||||
|
||||
91
src/utils.ts
91
src/utils.ts
@ -28,7 +28,11 @@ export class AbortableAsyncIterator<T extends object> {
|
||||
private readonly itr: AsyncGenerator<T | ErrorResponse>
|
||||
private readonly doneCallback: () => void
|
||||
|
||||
constructor(abortController: AbortController, itr: AsyncGenerator<T | ErrorResponse>, doneCallback: () => void) {
|
||||
constructor(
|
||||
abortController: AbortController,
|
||||
itr: AsyncGenerator<T | ErrorResponse>,
|
||||
doneCallback: () => void,
|
||||
) {
|
||||
this.abortController = abortController
|
||||
this.itr = itr
|
||||
this.doneCallback = doneCallback
|
||||
@ -119,23 +123,27 @@ function getPlatform(): string {
|
||||
* - An array of key-value pairs representing headers.
|
||||
* @returns {Record<string,string>} - A plain object representing the normalized headers.
|
||||
*/
|
||||
function normalizeHeaders(headers?: HeadersInit | undefined): Record<string,string> {
|
||||
function normalizeHeaders(headers?: HeadersInit | undefined): Record<string, string> {
|
||||
if (headers instanceof Headers) {
|
||||
// If headers are an instance of Headers, convert it to an object
|
||||
const obj: Record<string, string> = {};
|
||||
headers.forEach((value, key) => {
|
||||
obj[key] = value;
|
||||
});
|
||||
return obj;
|
||||
// If headers are an instance of Headers, convert it to an object
|
||||
const obj: Record<string, string> = {}
|
||||
headers.forEach((value, key) => {
|
||||
obj[key] = value
|
||||
})
|
||||
return obj
|
||||
} else if (Array.isArray(headers)) {
|
||||
// If headers are in array format, convert them to an object
|
||||
return Object.fromEntries(headers);
|
||||
// If headers are in array format, convert them to an object
|
||||
return Object.fromEntries(headers)
|
||||
} else {
|
||||
// Otherwise assume it's already a plain object
|
||||
return headers || {};
|
||||
// Otherwise assume it's already a plain object
|
||||
return headers || {}
|
||||
}
|
||||
}
|
||||
|
||||
const readEnvVar = (obj: object, key: string): string | undefined => {
|
||||
return obj[key]
|
||||
}
|
||||
|
||||
/**
|
||||
* A wrapper around fetch that adds default headers.
|
||||
* @param fetch {Fetch} - The fetch function to use
|
||||
@ -155,16 +163,41 @@ const fetchWithHeaders = async (
|
||||
} as HeadersInit
|
||||
|
||||
// Normalizes headers into a plain object format.
|
||||
options.headers = normalizeHeaders(options.headers);
|
||||
|
||||
// Filter out default headers from custom headers
|
||||
options.headers = normalizeHeaders(options.headers)
|
||||
|
||||
// Automatically add the API key to the headers if the URL is https://ollama.com
|
||||
try {
|
||||
const parsed = new URL(url)
|
||||
if (parsed.protocol === 'https:' && parsed.hostname === 'ollama.com') {
|
||||
const apiKey =
|
||||
typeof process === 'object' &&
|
||||
process !== null &&
|
||||
typeof process.env === 'object' &&
|
||||
process.env !== null
|
||||
? readEnvVar(process.env, 'OLLAMA_API_KEY')
|
||||
: undefined
|
||||
const authorization =
|
||||
options.headers['authorization'] || options.headers['Authorization']
|
||||
if (!authorization && apiKey) {
|
||||
options.headers['Authorization'] = `Bearer ${apiKey}`
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('error parsing url', error)
|
||||
}
|
||||
|
||||
const customHeaders = Object.fromEntries(
|
||||
Object.entries(options.headers).filter(([key]) => !Object.keys(defaultHeaders).some(defaultKey => defaultKey.toLowerCase() === key.toLowerCase()))
|
||||
Object.entries(options.headers).filter(
|
||||
([key]) =>
|
||||
!Object.keys(defaultHeaders).some(
|
||||
(defaultKey) => defaultKey.toLowerCase() === key.toLowerCase(),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
options.headers = {
|
||||
...defaultHeaders,
|
||||
...customHeaders
|
||||
...customHeaders,
|
||||
}
|
||||
|
||||
return fetch(url, options)
|
||||
@ -176,9 +209,13 @@ const fetchWithHeaders = async (
|
||||
* @param host {string} - The host to fetch
|
||||
* @returns {Promise<Response>} - The fetch response
|
||||
*/
|
||||
export const get = async (fetch: Fetch, host: string, options?: { headers?: HeadersInit }): Promise<Response> => {
|
||||
export const get = async (
|
||||
fetch: Fetch,
|
||||
host: string,
|
||||
options?: { headers?: HeadersInit },
|
||||
): Promise<Response> => {
|
||||
const response = await fetchWithHeaders(fetch, host, {
|
||||
headers: options?.headers
|
||||
headers: options?.headers,
|
||||
})
|
||||
|
||||
await checkOk(response)
|
||||
@ -212,7 +249,7 @@ export const post = async (
|
||||
fetch: Fetch,
|
||||
host: string,
|
||||
data?: Record<string, unknown> | BodyInit,
|
||||
options?: { signal?: AbortSignal, headers?: HeadersInit },
|
||||
options?: { signal?: AbortSignal; headers?: HeadersInit },
|
||||
): Promise<Response> => {
|
||||
const isRecord = (input: any): input is Record<string, unknown> => {
|
||||
return input !== null && typeof input === 'object' && !Array.isArray(input)
|
||||
@ -224,7 +261,7 @@ export const post = async (
|
||||
method: 'POST',
|
||||
body: formattedData,
|
||||
signal: options?.signal,
|
||||
headers: options?.headers
|
||||
headers: options?.headers,
|
||||
})
|
||||
|
||||
await checkOk(response)
|
||||
@ -247,7 +284,7 @@ export const del = async (
|
||||
const response = await fetchWithHeaders(fetch, host, {
|
||||
method: 'DELETE',
|
||||
body: JSON.stringify(data),
|
||||
headers: options?.headers
|
||||
headers: options?.headers,
|
||||
})
|
||||
|
||||
await checkOk(response)
|
||||
@ -332,16 +369,16 @@ export const formatHost = (host: string): string => {
|
||||
}
|
||||
|
||||
// Build basic auth part if present
|
||||
let auth = '';
|
||||
let auth = ''
|
||||
if (url.username) {
|
||||
auth = url.username;
|
||||
auth = url.username
|
||||
if (url.password) {
|
||||
auth += `:${url.password}`;
|
||||
auth += `:${url.password}`
|
||||
}
|
||||
auth += '@';
|
||||
auth += '@'
|
||||
}
|
||||
|
||||
let formattedHost = `${url.protocol}//${auth}${url.hostname}:${port}${url.pathname}`;
|
||||
let formattedHost = `${url.protocol}//${auth}${url.hostname}:${port}${url.pathname}`
|
||||
// remove trailing slashes
|
||||
if (formattedHost.endsWith('/')) {
|
||||
formattedHost = formattedHost.slice(0, -1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user