From 5a71cce0b5bbf776dcde178f7105940a81748597 Mon Sep 17 00:00:00 2001 From: Sanjiv Das Date: Tue, 2 Sep 2025 12:54:25 -0700 Subject: [PATCH 1/2] [magics] Enhances the `alias` option to include the base API url and key --- .../jupyter_ai_magics/magics.py | 74 ++++++++++++++----- .../jupyter_ai_magics/parsers.py | 48 ++++++++++-- 2 files changed, 97 insertions(+), 25 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index fd932d115..12ce3458a 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -128,12 +128,14 @@ class AiMagics(Magics): # This should only set the "starting set" of aliases initial_aliases = traitlets.Dict( default_value={}, - value_trait=traitlets.Unicode(), + value_trait=traitlets.Dict(), key_trait=traitlets.Unicode(), help="""Aliases for model identifiers. - Keys define aliases, values define the provider and the model to use. - The values should include identifiers in in the `provider:model` format. + Keys define aliases, values define a dictionary containing: + - target: The provider and model to use in the `provider:model` format + - api_base: Optional base URL for the API endpoint + - api_key_name: Optional name of the environment variable containing the API key """, config=True, ) @@ -306,14 +308,20 @@ def run_ai_cell(self, args: CellArgs, prompt: str): # Resolve model_id: check if it's in CHAT_MODELS or an alias model_id = args.model_id - if model_id not in CHAT_MODELS: - # Check if it's an alias - if model_id in self.aliases: - model_id = self.aliases[model_id] - else: - error_msg = f"Model ID '{model_id}' is not a known model or alias. Run '%ai list' to see available models and aliases." - print(error_msg, file=sys.stderr) # Log to stderr - return + # Check if model_id is an alias and get stored configuration + alias_config = None + if model_id not in CHAT_MODELS and model_id in self.aliases: + alias_config = self.aliases[model_id] + model_id = alias_config["target"] + # Use stored api_base and api_key_name if not provided in current call + if not args.api_base and alias_config["api_base"]: + args.api_base = alias_config["api_base"] + if not args.api_key_name and alias_config["api_key_name"]: + args.api_key_name = alias_config["api_key_name"] + elif model_id not in CHAT_MODELS: + error_msg = f"Model ID '{model_id}' is not a known model or alias. Run '%ai list' to see available models and aliases." + print(error_msg, file=sys.stderr) # Log to stderr + return try: # Prepare litellm completion arguments completion_args = { @@ -493,8 +501,12 @@ def handle_alias(self, args: RegisterArgs) -> TextOrMarkdown: if args.name in AI_COMMANDS: raise ValueError(f"The name {args.name} is reserved for a command") - # Store the alias - self.aliases[args.name] = args.target + # Store the alias with its configuration + self.aliases[args.name] = { + "target": args.target, + "api_base": args.api_base, + "api_key_name": args.api_key_name + } output = f"Registered new alias `{args.name}`" return TextOrMarkdown(output, output) @@ -547,9 +559,20 @@ def handle_list(self, args: ListArgs): if len(self.aliases) > 0: text_output += "\nAliases:\n" markdown_output += "\n### Aliases\n\n" - for alias, target in self.aliases.items(): - text_output += f"* {alias} -> {target}\n" - markdown_output += f"* `{alias}` -> `{target}`\n" + for alias, config in self.aliases.items(): + text_output += f"* {alias}:\n" + text_output += f" - target: {config['target']}\n" + if config['api_base']: + text_output += f" - api_base: {config['api_base']}\n" + if config['api_key_name']: + text_output += f" - api_key_name: {config['api_key_name']}\n" + + markdown_output += f"* `{alias}`:\n" + markdown_output += f" - target: `{config['target']}`\n" + if config['api_base']: + markdown_output += f" - api_base: `{config['api_base']}`\n" + if config['api_key_name']: + markdown_output += f" - api_key_name: `{config['api_key_name']}`\n" return TextOrMarkdown(text_output, markdown_output) @@ -575,10 +598,21 @@ def handle_list(self, args: ListArgs): if len(self.aliases) > 0: text_output += "\nAliases:\n" markdown_output += "\n### Aliases\n\n" - for alias, target in self.aliases.items(): - if target.startswith(provider_id + "/"): - text_output += f"* {alias} -> {target}\n" - markdown_output += f"* `{alias}` -> `{target}`\n" + for alias, config in self.aliases.items(): + if config['target'].startswith(provider_id + "/"): + text_output += f"* {alias}:\n" + text_output += f" - target: {config['target']}\n" + if config['api_base']: + text_output += f" - api_base: {config['api_base']}\n" + if config['api_key_name']: + text_output += f" - api_key_name: {config['api_key_name']}\n" + + markdown_output += f"* `{alias}`:\n" + markdown_output += f" - target: `{config['target']}`\n" + if config['api_base']: + markdown_output += f" - api_base: `{config['api_base']}`\n" + if config['api_key_name']: + markdown_output += f" - api_key_name: `{config['api_key_name']}`\n" return TextOrMarkdown(text_output, markdown_output) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index 015b3e647..fa6614075 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -88,6 +88,8 @@ class RegisterArgs(BaseModel): type: Literal["alias"] = "alias" name: str target: str + api_base: Optional[str] = None + api_key_name: Optional[str] = None class DeleteArgs(BaseModel): @@ -99,6 +101,8 @@ class UpdateArgs(BaseModel): type: Literal["update"] = "update" name: str target: str + api_base: Optional[str] = None + api_key_name: Optional[str] = None class ResetArgs(BaseModel): @@ -292,8 +296,23 @@ def list_subparser(**kwargs): ) @click.argument("name") @click.argument("target") +@click.option( + "--api-base", + required=False, + help="Base URL for the API endpoint.", +) +@click.option( + "--api-key-name", + required=False, + help="Name of the environment variable containing the API key.", +) def register_subparser(**kwargs): - """Register a new alias called NAME for the model or chain named TARGET.""" + """Register a new alias called NAME for the model or chain named TARGET. + + Optional parameters: + --api-base: Base URL for the API endpoint + --api-key-name: Name of the environment variable containing the API key + """ return RegisterArgs(**kwargs) @@ -301,15 +320,34 @@ def register_subparser(**kwargs): name="dealias", short_help="Delete an alias. See `%ai dealias --help` for options." ) @click.argument("name") -def register_subparser(**kwargs): +def dealias_subparser(**kwargs): """Delete an alias called NAME.""" return DeleteArgs(**kwargs) +@line_magic_parser.command( + name="update", + short_help="Update an alias. See `%ai update --help` for options.", +) @click.argument("name") @click.argument("target") -def register_subparser(**kwargs): - """Update an alias called NAME to refer to the model or chain named TARGET.""" +@click.option( + "--api-base", + required=False, + help="Base URL for the API endpoint.", +) +@click.option( + "--api-key-name", + required=False, + help="Name of the environment variable containing the API key.", +) +def update_subparser(**kwargs): + """Update an alias called NAME to refer to the model or chain named TARGET. + + Optional parameters: + --api-base: Base URL for the API endpoint + --api-key-name: Name of the environment variable containing the API key + """ return UpdateArgs(**kwargs) @@ -317,6 +355,6 @@ def register_subparser(**kwargs): name="reset", short_help="Clear the conversation transcript.", ) -def register_subparser(**kwargs): +def reset_subparser(**kwargs): """Clear the conversation transcript.""" return ResetArgs() From 978550440ba5581454c50e02ab8d0141348e22a2 Mon Sep 17 00:00:00 2001 From: Sanjiv Das Date: Tue, 2 Sep 2025 22:07:57 -0700 Subject: [PATCH 2/2] update tests --- .../jupyter_ai_magics/magics.py | 77 ++++++++++--------- .../jupyter_ai_magics/parsers.py | 4 +- .../jupyter_ai_magics/tests/test_magics.py | 16 +++- 3 files changed, 56 insertions(+), 41 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 12ce3458a..3a4203de4 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -1,16 +1,15 @@ import base64 import json +import os import re import sys import warnings from typing import Any, Optional -import os -from dotenv import load_dotenv import click import litellm import traitlets -from typing import Optional +from dotenv import load_dotenv from IPython.core.magic import Magics, line_cell_magic, magics_class from IPython.display import HTML, JSON, Markdown, Math from jupyter_ai.model_providers.model_list import CHAT_MODELS @@ -33,6 +32,7 @@ # Load the .env file from the workspace root dotenv_path = os.path.join(os.getcwd(), ".env") + class TextOrMarkdown: def __init__(self, text, markdown): self.text = text @@ -185,8 +185,11 @@ def __init__(self, shell): # This is useful for users to know that they can set API keys in the JupyterLab # UI, but it is not always required to run the extension. if not os.path.isfile(dotenv_path): - print(f"No `.env` file containing provider API keys found at {dotenv_path}. \ - You can add API keys to the `.env` file via the AI Settings in the JupyterLab UI.", file=sys.stderr) + print( + f"No `.env` file containing provider API keys found at {dotenv_path}. \ + You can add API keys to the `.env` file via the AI Settings in the JupyterLab UI.", + file=sys.stderr, + ) # TODO: use LiteLLM aliases to provide this # https://docs.litellm.ai/docs/completion/model_alias @@ -242,7 +245,10 @@ def ai(self, line: str, cell: Optional[str] = None) -> Any: print(error_msg, file=sys.stderr) return if not args: - print("No valid %ai magics arguments given, run `%ai help` for all options.", file=sys.stderr) + print( + "No valid %ai magics arguments given, run `%ai help` for all options.", + file=sys.stderr, + ) return raise e @@ -324,11 +330,7 @@ def run_ai_cell(self, args: CellArgs, prompt: str): return try: # Prepare litellm completion arguments - completion_args = { - "model": model_id, - "messages": messages, - "stream": False - } + completion_args = {"model": model_id, "messages": messages, "stream": False} # Add api_base if provided if args.api_base: @@ -505,7 +507,7 @@ def handle_alias(self, args: RegisterArgs) -> TextOrMarkdown: self.aliases[args.name] = { "target": args.target, "api_base": args.api_base, - "api_key_name": args.api_key_name + "api_key_name": args.api_key_name, } output = f"Registered new alias `{args.name}`" @@ -520,7 +522,7 @@ def handle_version(self, args: VersionArgs) -> str: def handle_list(self, args: ListArgs): """ - Handles `%ai list`. + Handles `%ai list`. - `%ai list` shows all providers by default, and ask the user to run %ai list . - `%ai list ` shows all models available from one provider. It should also note that the list is not comprehensive, and include a reference to the upstream LiteLLM docs. - `%ai list all` should list all models. @@ -529,12 +531,12 @@ def handle_list(self, args: ListArgs): models = CHAT_MODELS # If provider_id is None, only return provider IDs - if getattr(args, 'provider_id', None) is None: + if getattr(args, "provider_id", None) is None: # Extract unique provider IDs from model IDs provider_ids = set() for model in models: - if '/' in model: - provider_ids.add(model.split('/')[0]) + if "/" in model: + provider_ids.add(model.split("/")[0]) # Format output for both text and markdown text_output = "Available providers\n\n (Run `%ai list ` to see models for a specific provider)\n\n" @@ -545,9 +547,9 @@ def handle_list(self, args: ListArgs): markdown_output += f"* `{provider_id}`\n" return TextOrMarkdown(text_output, markdown_output) - - elif getattr(args, 'provider_id', None) == 'all': - # Otherwise show all models and aliases + + elif getattr(args, "provider_id", None) == "all": + # Otherwise show all models and aliases text_output = "All available models\n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n\n" markdown_output = "## All available models \n\n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n\n" @@ -562,20 +564,22 @@ def handle_list(self, args: ListArgs): for alias, config in self.aliases.items(): text_output += f"* {alias}:\n" text_output += f" - target: {config['target']}\n" - if config['api_base']: + if config["api_base"]: text_output += f" - api_base: {config['api_base']}\n" - if config['api_key_name']: + if config["api_key_name"]: text_output += f" - api_key_name: {config['api_key_name']}\n" - + markdown_output += f"* `{alias}`:\n" markdown_output += f" - target: `{config['target']}`\n" - if config['api_base']: + if config["api_base"]: markdown_output += f" - api_base: `{config['api_base']}`\n" - if config['api_key_name']: - markdown_output += f" - api_key_name: `{config['api_key_name']}`\n" + if config["api_key_name"]: + markdown_output += ( + f" - api_key_name: `{config['api_key_name']}`\n" + ) return TextOrMarkdown(text_output, markdown_output) - + else: # If a specific provider_id is given, filter models by that provider provider_id = args.provider_id @@ -599,20 +603,23 @@ def handle_list(self, args: ListArgs): text_output += "\nAliases:\n" markdown_output += "\n### Aliases\n\n" for alias, config in self.aliases.items(): - if config['target'].startswith(provider_id + "/"): + if config["target"].startswith(provider_id + "/"): text_output += f"* {alias}:\n" text_output += f" - target: {config['target']}\n" - if config['api_base']: + if config["api_base"]: text_output += f" - api_base: {config['api_base']}\n" - if config['api_key_name']: - text_output += f" - api_key_name: {config['api_key_name']}\n" - + if config["api_key_name"]: + text_output += ( + f" - api_key_name: {config['api_key_name']}\n" + ) + markdown_output += f"* `{alias}`:\n" markdown_output += f" - target: `{config['target']}`\n" - if config['api_base']: + if config["api_base"]: markdown_output += f" - api_base: `{config['api_base']}`\n" - if config['api_key_name']: - markdown_output += f" - api_key_name: `{config['api_key_name']}`\n" - + if config["api_key_name"]: + markdown_output += ( + f" - api_key_name: `{config['api_key_name']}`\n" + ) return TextOrMarkdown(text_output, markdown_output) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py index fa6614075..5e08672db 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/parsers.py @@ -308,7 +308,7 @@ def list_subparser(**kwargs): ) def register_subparser(**kwargs): """Register a new alias called NAME for the model or chain named TARGET. - + Optional parameters: --api-base: Base URL for the API endpoint --api-key-name: Name of the environment variable containing the API key @@ -343,7 +343,7 @@ def dealias_subparser(**kwargs): ) def update_subparser(**kwargs): """Update an alias called NAME to refer to the model or chain named TARGET. - + Optional parameters: --api-base: Base URL for the API endpoint --api-key-name: Name of the environment variable containing the API key diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py index 34b850d1a..01aa650b0 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/tests/test_magics.py @@ -14,12 +14,20 @@ def ip() -> InteractiveShell: def test_aliases_config(ip): - ip.config.AiMagics.initial_aliases = {"my_custom_alias": "my_provider:my_model"} + ip.config.AiMagics.initial_aliases = { + "my_custom_alias": { + "target": "my_provider:my_model", + "api_base": None, + "api_key_name": None + } + } ip.extension_manager.load_extension("jupyter_ai_magics") # Use 'list all' to see all models and aliases - providers_list = ip.run_line_magic("ai", "list all").text - # Check that alias appears in the output - assert "my_custom_alias -> my_provider:my_model" in providers_list + providers_list = ip.run_line_magic("ai", "list all") + # Check that alias appears in the markdown output with correct format + assert "### Aliases" in providers_list.markdown + assert "* `my_custom_alias`:" in providers_list.markdown + assert " - target: `my_provider:my_model`" in providers_list.markdown def test_default_model_cell(ip):