Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 81 additions & 40 deletions packages/jupyter-ai-magics/jupyter_ai_magics/magics.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import base64
import json
import os
import re
import sys
import warnings
from typing import Any, Optional
import os
from dotenv import load_dotenv

import click
import litellm
import traitlets
from typing import Optional
from dotenv import load_dotenv
from IPython.core.magic import Magics, line_cell_magic, magics_class
from IPython.display import HTML, JSON, Markdown, Math
from jupyter_ai.model_providers.model_list import CHAT_MODELS
Expand All @@ -33,6 +32,7 @@
# Load the .env file from the workspace root
dotenv_path = os.path.join(os.getcwd(), ".env")


class TextOrMarkdown:
def __init__(self, text, markdown):
self.text = text
Expand Down Expand Up @@ -128,12 +128,14 @@ class AiMagics(Magics):
# This should only set the "starting set" of aliases
initial_aliases = traitlets.Dict(
default_value={},
value_trait=traitlets.Unicode(),
value_trait=traitlets.Dict(),
key_trait=traitlets.Unicode(),
help="""Aliases for model identifiers.

Keys define aliases, values define the provider and the model to use.
The values should include identifiers in in the `provider:model` format.
Keys define aliases, values define a dictionary containing:
- target: The provider and model to use in the `provider:model` format
- api_base: Optional base URL for the API endpoint
- api_key_name: Optional name of the environment variable containing the API key
""",
config=True,
)
Expand Down Expand Up @@ -183,8 +185,11 @@ def __init__(self, shell):
# This is useful for users to know that they can set API keys in the JupyterLab
# UI, but it is not always required to run the extension.
if not os.path.isfile(dotenv_path):
print(f"No `.env` file containing provider API keys found at {dotenv_path}. \
You can add API keys to the `.env` file via the AI Settings in the JupyterLab UI.", file=sys.stderr)
print(
f"No `.env` file containing provider API keys found at {dotenv_path}. \
You can add API keys to the `.env` file via the AI Settings in the JupyterLab UI.",
file=sys.stderr,
)

# TODO: use LiteLLM aliases to provide this
# https://docs.litellm.ai/docs/completion/model_alias
Expand Down Expand Up @@ -240,7 +245,10 @@ def ai(self, line: str, cell: Optional[str] = None) -> Any:
print(error_msg, file=sys.stderr)
return
if not args:
print("No valid %ai magics arguments given, run `%ai help` for all options.", file=sys.stderr)
print(
"No valid %ai magics arguments given, run `%ai help` for all options.",
file=sys.stderr,
)
return
raise e

Expand Down Expand Up @@ -306,21 +314,23 @@ def run_ai_cell(self, args: CellArgs, prompt: str):

# Resolve model_id: check if it's in CHAT_MODELS or an alias
model_id = args.model_id
if model_id not in CHAT_MODELS:
# Check if it's an alias
if model_id in self.aliases:
model_id = self.aliases[model_id]
else:
error_msg = f"Model ID '{model_id}' is not a known model or alias. Run '%ai list' to see available models and aliases."
print(error_msg, file=sys.stderr) # Log to stderr
return
# Check if model_id is an alias and get stored configuration
alias_config = None
if model_id not in CHAT_MODELS and model_id in self.aliases:
alias_config = self.aliases[model_id]
model_id = alias_config["target"]
# Use stored api_base and api_key_name if not provided in current call
if not args.api_base and alias_config["api_base"]:
args.api_base = alias_config["api_base"]
if not args.api_key_name and alias_config["api_key_name"]:
args.api_key_name = alias_config["api_key_name"]
elif model_id not in CHAT_MODELS:
error_msg = f"Model ID '{model_id}' is not a known model or alias. Run '%ai list' to see available models and aliases."
print(error_msg, file=sys.stderr) # Log to stderr
return
try:
# Prepare litellm completion arguments
completion_args = {
"model": model_id,
"messages": messages,
"stream": False
}
completion_args = {"model": model_id, "messages": messages, "stream": False}

# Add api_base if provided
if args.api_base:
Expand Down Expand Up @@ -493,8 +503,12 @@ def handle_alias(self, args: RegisterArgs) -> TextOrMarkdown:
if args.name in AI_COMMANDS:
raise ValueError(f"The name {args.name} is reserved for a command")

# Store the alias
self.aliases[args.name] = args.target
# Store the alias with its configuration
self.aliases[args.name] = {
"target": args.target,
"api_base": args.api_base,
"api_key_name": args.api_key_name,
}

output = f"Registered new alias `{args.name}`"
return TextOrMarkdown(output, output)
Expand All @@ -508,7 +522,7 @@ def handle_version(self, args: VersionArgs) -> str:

def handle_list(self, args: ListArgs):
"""
Handles `%ai list`.
Handles `%ai list`.
- `%ai list` shows all providers by default, and ask the user to run %ai list <provider-name>.
- `%ai list <provider-name>` shows all models available from one provider. It should also note that the list is not comprehensive, and include a reference to the upstream LiteLLM docs.
- `%ai list all` should list all models.
Expand All @@ -517,12 +531,12 @@ def handle_list(self, args: ListArgs):
models = CHAT_MODELS

# If provider_id is None, only return provider IDs
if getattr(args, 'provider_id', None) is None:
if getattr(args, "provider_id", None) is None:
# Extract unique provider IDs from model IDs
provider_ids = set()
for model in models:
if '/' in model:
provider_ids.add(model.split('/')[0])
if "/" in model:
provider_ids.add(model.split("/")[0])

# Format output for both text and markdown
text_output = "Available providers\n\n (Run `%ai list <provider_name>` to see models for a specific provider)\n\n"
Expand All @@ -533,9 +547,9 @@ def handle_list(self, args: ListArgs):
markdown_output += f"* `{provider_id}`\n"

return TextOrMarkdown(text_output, markdown_output)
elif getattr(args, 'provider_id', None) == 'all':
# Otherwise show all models and aliases

elif getattr(args, "provider_id", None) == "all":
# Otherwise show all models and aliases
text_output = "All available models\n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n\n"
markdown_output = "## All available models \n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n\n"

Expand All @@ -547,12 +561,25 @@ def handle_list(self, args: ListArgs):
if len(self.aliases) > 0:
text_output += "\nAliases:\n"
markdown_output += "\n### Aliases\n\n"
for alias, target in self.aliases.items():
text_output += f"* {alias} -> {target}\n"
markdown_output += f"* `{alias}` -> `{target}`\n"
for alias, config in self.aliases.items():
text_output += f"* {alias}:\n"
text_output += f" - target: {config['target']}\n"
if config["api_base"]:
text_output += f" - api_base: {config['api_base']}\n"
if config["api_key_name"]:
text_output += f" - api_key_name: {config['api_key_name']}\n"

markdown_output += f"* `{alias}`:\n"
markdown_output += f" - target: `{config['target']}`\n"
if config["api_base"]:
markdown_output += f" - api_base: `{config['api_base']}`\n"
if config["api_key_name"]:
markdown_output += (
f" - api_key_name: `{config['api_key_name']}`\n"
)

return TextOrMarkdown(text_output, markdown_output)

else:
# If a specific provider_id is given, filter models by that provider
provider_id = args.provider_id
Expand All @@ -575,10 +602,24 @@ def handle_list(self, args: ListArgs):
if len(self.aliases) > 0:
text_output += "\nAliases:\n"
markdown_output += "\n### Aliases\n\n"
for alias, target in self.aliases.items():
if target.startswith(provider_id + "/"):
text_output += f"* {alias} -> {target}\n"
markdown_output += f"* `{alias}` -> `{target}`\n"

for alias, config in self.aliases.items():
if config["target"].startswith(provider_id + "/"):
text_output += f"* {alias}:\n"
text_output += f" - target: {config['target']}\n"
if config["api_base"]:
text_output += f" - api_base: {config['api_base']}\n"
if config["api_key_name"]:
text_output += (
f" - api_key_name: {config['api_key_name']}\n"
)

markdown_output += f"* `{alias}`:\n"
markdown_output += f" - target: `{config['target']}`\n"
if config["api_base"]:
markdown_output += f" - api_base: `{config['api_base']}`\n"
if config["api_key_name"]:
markdown_output += (
f" - api_key_name: `{config['api_key_name']}`\n"
)

return TextOrMarkdown(text_output, markdown_output)
48 changes: 43 additions & 5 deletions packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class RegisterArgs(BaseModel):
type: Literal["alias"] = "alias"
name: str
target: str
api_base: Optional[str] = None
api_key_name: Optional[str] = None


class DeleteArgs(BaseModel):
Expand All @@ -99,6 +101,8 @@ class UpdateArgs(BaseModel):
type: Literal["update"] = "update"
name: str
target: str
api_base: Optional[str] = None
api_key_name: Optional[str] = None


class ResetArgs(BaseModel):
Expand Down Expand Up @@ -292,31 +296,65 @@ def list_subparser(**kwargs):
)
@click.argument("name")
@click.argument("target")
@click.option(
"--api-base",
required=False,
help="Base URL for the API endpoint.",
)
@click.option(
"--api-key-name",
required=False,
help="Name of the environment variable containing the API key.",
)
def register_subparser(**kwargs):
"""Register a new alias called NAME for the model or chain named TARGET."""
"""Register a new alias called NAME for the model or chain named TARGET.

Optional parameters:
--api-base: Base URL for the API endpoint
--api-key-name: Name of the environment variable containing the API key
"""
return RegisterArgs(**kwargs)


@line_magic_parser.command(
name="dealias", short_help="Delete an alias. See `%ai dealias --help` for options."
)
@click.argument("name")
def register_subparser(**kwargs):
def dealias_subparser(**kwargs):
"""Delete an alias called NAME."""
return DeleteArgs(**kwargs)


@line_magic_parser.command(
name="update",
short_help="Update an alias. See `%ai update --help` for options.",
)
@click.argument("name")
@click.argument("target")
def register_subparser(**kwargs):
"""Update an alias called NAME to refer to the model or chain named TARGET."""
@click.option(
"--api-base",
required=False,
help="Base URL for the API endpoint.",
)
@click.option(
"--api-key-name",
required=False,
help="Name of the environment variable containing the API key.",
)
def update_subparser(**kwargs):
"""Update an alias called NAME to refer to the model or chain named TARGET.

Optional parameters:
--api-base: Base URL for the API endpoint
--api-key-name: Name of the environment variable containing the API key
"""
return UpdateArgs(**kwargs)


@line_magic_parser.command(
name="reset",
short_help="Clear the conversation transcript.",
)
def register_subparser(**kwargs):
def reset_subparser(**kwargs):
"""Clear the conversation transcript."""
return ResetArgs()
16 changes: 12 additions & 4 deletions packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,20 @@ def ip() -> InteractiveShell:


def test_aliases_config(ip):
ip.config.AiMagics.initial_aliases = {"my_custom_alias": "my_provider:my_model"}
ip.config.AiMagics.initial_aliases = {
"my_custom_alias": {
"target": "my_provider:my_model",
"api_base": None,
"api_key_name": None
}
}
ip.extension_manager.load_extension("jupyter_ai_magics")
# Use 'list all' to see all models and aliases
providers_list = ip.run_line_magic("ai", "list all").text
# Check that alias appears in the output
assert "my_custom_alias -> my_provider:my_model" in providers_list
providers_list = ip.run_line_magic("ai", "list all")
# Check that alias appears in the markdown output with correct format
assert "### Aliases" in providers_list.markdown
assert "* `my_custom_alias`:" in providers_list.markdown
assert " - target: `my_provider:my_model`" in providers_list.markdown


def test_default_model_cell(ip):
Expand Down
Loading