Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions ollama/_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import base64
import os
import time
from pathlib import Path
from typing import Optional

from cryptography.hazmat.primitives import serialization


class OllamaAuth:
def __init__(self, key_path: Optional[str] = None):
"""Initialize the OllamaAuth class.

Args:
key_path: Optional path to the private key file. If not provided,
defaults to ~/.ollama/id_ed25519
"""
if key_path is None:
home = str(Path.home())
self.key_path = os.path.join(home, '.ollama', 'id_ed25519')
else:
# Expand ~ and environment variables in the path
self.key_path = os.path.expanduser(os.path.expandvars(key_path))

def load_private_key(self):
"""Read and load the private key.

Returns:
The loaded Ed25519 private key.

Raises:
FileNotFoundError: If the key file doesn't exist
ValueError: If the key file is invalid
"""
try:
with open(self.key_path, 'rb') as f:
private_key_data = f.read()

private_key = serialization.load_ssh_private_key(
private_key_data,
password=None,
)
return private_key
except FileNotFoundError:
raise FileNotFoundError(f"Could not find Ollama private key at {self.key_path}. Please generate one using: ssh-keygen -t ed25519 -f ~/.ollama/id_ed25519 -N ''") from None
except Exception as e:
raise ValueError(f'Invalid private key at {self.key_path}: {e!s}') from e

def get_public_key_b64(self, private_key):
"""Get the base64 encoded public key.

Args:
private_key: The Ed25519 private key

Returns:
Base64 encoded public key string
"""
# Get the public key in OpenSSH format and extract the second field (base64-encoded key)
public_key = private_key.public_key()
openssh_pub = (
public_key.public_bytes(
encoding=serialization.Encoding.OpenSSH,
format=serialization.PublicFormat.OpenSSH,
)
.decode('utf-8')
.strip()
)
parts = openssh_pub.split(' ')
if len(parts) < 2:
raise ValueError('Malformed OpenSSH public key')
public_key_b64 = parts[1]
return public_key_b64

def sign_request(self, method: str, path: str):
"""Sign an HTTP request.

Args:
method: The HTTP method (e.g. 'GET', 'POST')
path: The request path (e.g. '/api/chat')

Returns:
A tuple of (auth_token, timestamp) where auth_token is the
authorization header value and timestamp is the request timestamp.

Raises:
FileNotFoundError: If the key file doesn't exist
ValueError: If the key file is invalid
"""
timestamp = str(int(time.time()))
path_with_ts = f'{path}&ts={timestamp}' if '?' in path else f'{path}?ts={timestamp}'
challenge = f'{method},{path_with_ts}'

private_key = self.load_private_key()
signature = private_key.sign(challenge.encode())

public_key_b64 = self.get_public_key_b64(private_key)

auth_token = f'{public_key_b64}:{base64.b64encode(signature).decode("utf-8")}'

return auth_token, timestamp
44 changes: 37 additions & 7 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import anyio
from pydantic.json_schema import JsonSchemaValue

from ollama._auth import OllamaAuth
from ollama._utils import convert_function_to_tool

if sys.version_info < (3, 9):
Expand Down Expand Up @@ -80,16 +81,18 @@ def __init__(
follow_redirects: bool = True,
timeout: Any = None,
headers: Optional[Mapping[str, str]] = None,
auth_key_path: Optional[str] = None,
**kwargs,
) -> None:
"""
Creates a httpx client. Default parameters are the same as those defined in httpx
except for the following:
- `follow_redirects`: True
- `timeout`: None
- `auth_key_path`: Optional path to the ed25519 private key for authentication
`kwargs` are passed to the httpx client.
"""

self._auth = OllamaAuth(auth_key_path)
self._client = client(
base_url=_parse_host(host or os.getenv('OLLAMA_HOST')),
follow_redirects=follow_redirects,
Expand All @@ -107,6 +110,24 @@ def __init__(
**kwargs,
)

def _prepare_request(self, method: str, path: str, **kwargs) -> Dict[str, Any]:
if self._auth:
url = str(self._client.build_request(method, path).url)
parsed = urllib.parse.urlparse(url)
full_path = parsed.path
if parsed.query:
full_path = f'{full_path}?{parsed.query}'

auth_token, timestamp = self._auth.sign_request(method, full_path)

if 'headers' not in kwargs:
kwargs['headers'] = {}
kwargs['headers']['Authorization'] = auth_token

path = f'{path}&ts={timestamp}' if '?' in path else f'{path}?ts={timestamp}'

return {'method': method, 'url': path, **kwargs}


CONNECTION_ERROR_MESSAGE = 'Failed to connect to Ollama. Please check that Ollama is downloaded, running and accessible. https://ollama.com/download'

Expand Down Expand Up @@ -155,14 +176,18 @@ def _request(
def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
*,
stream: bool = False,
**kwargs,
) -> Union[T, Iterator[T]]:
request_params = self._prepare_request(method, path, **kwargs)

if stream:

def inner():
with self._client.stream(*args, **kwargs) as r:
with self._client.stream(**request_params) as r:
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -177,7 +202,7 @@ def inner():

return inner()

return cls(**self._request_raw(*args, **kwargs).json())
return cls(**self._request_raw(**request_params).json())

@overload
def generate(
Expand Down Expand Up @@ -669,14 +694,19 @@ async def _request(
async def _request(
self,
cls: Type[T],
*args,
method: str,
path: str,
*,
stream: bool = False,
**kwargs,
) -> Union[T, AsyncIterator[T]]:
"""Make a request with optional authentication."""
request_params = self._prepare_request(method, path, **kwargs)

if stream:

async def inner():
async with self._client.stream(*args, **kwargs) as r:
async with self._client.stream(**request_params) as r:
try:
r.raise_for_status()
except httpx.HTTPStatusError as e:
Expand All @@ -691,7 +721,7 @@ async def inner():

return inner()

return cls(**(await self._request_raw(*args, **kwargs)).json())
return cls(**(await self._request_raw(**request_params)).json())

@overload
async def generate(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ requires-python = '>=3.8'
dependencies = [
'httpx>=0.27',
'pydantic>=2.9',
'cryptography>=46.0.1',
]
dynamic = [ 'version' ]
license = "MIT"
Expand Down
Loading
Loading