11import base64
22import json
3+ import os
34import re
45import sys
56import warnings
67from typing import Any , Optional
7- import os
8- from dotenv import load_dotenv
98
109import click
1110import litellm
1211import traitlets
13- from typing import Optional
12+ from dotenv import load_dotenv
1413from IPython .core .magic import Magics , line_cell_magic , magics_class
1514from IPython .display import HTML , JSON , Markdown , Math
1615from jupyter_ai .model_providers .model_list import CHAT_MODELS
3332# Load the .env file from the workspace root
3433dotenv_path = os .path .join (os .getcwd (), ".env" )
3534
35+
3636class TextOrMarkdown :
3737 def __init__ (self , text , markdown ):
3838 self .text = text
@@ -128,12 +128,14 @@ class AiMagics(Magics):
128128 # This should only set the "starting set" of aliases
129129 initial_aliases = traitlets .Dict (
130130 default_value = {},
131- value_trait = traitlets .Unicode (),
131+ value_trait = traitlets .Dict (),
132132 key_trait = traitlets .Unicode (),
133133 help = """Aliases for model identifiers.
134134
135- Keys define aliases, values define the provider and the model to use.
136- The values should include identifiers in in the `provider:model` format.
135+ Keys define aliases, values define a dictionary containing:
136+ - target: The provider and model to use in the `provider:model` format
137+ - api_base: Optional base URL for the API endpoint
138+ - api_key_name: Optional name of the environment variable containing the API key
137139 """ ,
138140 config = True ,
139141 )
@@ -183,8 +185,11 @@ def __init__(self, shell):
183185 # This is useful for users to know that they can set API keys in the JupyterLab
184186 # UI, but it is not always required to run the extension.
185187 if not os .path .isfile (dotenv_path ):
186- print (f"No `.env` file containing provider API keys found at { dotenv_path } . \
187- You can add API keys to the `.env` file via the AI Settings in the JupyterLab UI." , file = sys .stderr )
188+ print (
189+ f"No `.env` file containing provider API keys found at { dotenv_path } . \
190+ You can add API keys to the `.env` file via the AI Settings in the JupyterLab UI." ,
191+ file = sys .stderr ,
192+ )
188193
189194 # TODO: use LiteLLM aliases to provide this
190195 # https://docs.litellm.ai/docs/completion/model_alias
@@ -240,7 +245,10 @@ def ai(self, line: str, cell: Optional[str] = None) -> Any:
240245 print (error_msg , file = sys .stderr )
241246 return
242247 if not args :
243- print ("No valid %ai magics arguments given, run `%ai help` for all options." , file = sys .stderr )
248+ print (
249+ "No valid %ai magics arguments given, run `%ai help` for all options." ,
250+ file = sys .stderr ,
251+ )
244252 return
245253 raise e
246254
@@ -306,21 +314,23 @@ def run_ai_cell(self, args: CellArgs, prompt: str):
306314
307315 # Resolve model_id: check if it's in CHAT_MODELS or an alias
308316 model_id = args .model_id
309- if model_id not in CHAT_MODELS :
310- # Check if it's an alias
311- if model_id in self .aliases :
312- model_id = self .aliases [model_id ]
313- else :
314- error_msg = f"Model ID '{ model_id } ' is not a known model or alias. Run '%ai list' to see available models and aliases."
315- print (error_msg , file = sys .stderr ) # Log to stderr
316- return
317+ # Check if model_id is an alias and get stored configuration
318+ alias_config = None
319+ if model_id not in CHAT_MODELS and model_id in self .aliases :
320+ alias_config = self .aliases [model_id ]
321+ model_id = alias_config ["target" ]
322+ # Use stored api_base and api_key_name if not provided in current call
323+ if not args .api_base and alias_config ["api_base" ]:
324+ args .api_base = alias_config ["api_base" ]
325+ if not args .api_key_name and alias_config ["api_key_name" ]:
326+ args .api_key_name = alias_config ["api_key_name" ]
327+ elif model_id not in CHAT_MODELS :
328+ error_msg = f"Model ID '{ model_id } ' is not a known model or alias. Run '%ai list' to see available models and aliases."
329+ print (error_msg , file = sys .stderr ) # Log to stderr
330+ return
317331 try :
318332 # Prepare litellm completion arguments
319- completion_args = {
320- "model" : model_id ,
321- "messages" : messages ,
322- "stream" : False
323- }
333+ completion_args = {"model" : model_id , "messages" : messages , "stream" : False }
324334
325335 # Add api_base if provided
326336 if args .api_base :
@@ -493,8 +503,12 @@ def handle_alias(self, args: RegisterArgs) -> TextOrMarkdown:
493503 if args .name in AI_COMMANDS :
494504 raise ValueError (f"The name { args .name } is reserved for a command" )
495505
496- # Store the alias
497- self .aliases [args .name ] = args .target
506+ # Store the alias with its configuration
507+ self .aliases [args .name ] = {
508+ "target" : args .target ,
509+ "api_base" : args .api_base ,
510+ "api_key_name" : args .api_key_name ,
511+ }
498512
499513 output = f"Registered new alias `{ args .name } `"
500514 return TextOrMarkdown (output , output )
@@ -508,7 +522,7 @@ def handle_version(self, args: VersionArgs) -> str:
508522
509523 def handle_list (self , args : ListArgs ):
510524 """
511- Handles `%ai list`.
525+ Handles `%ai list`.
512526 - `%ai list` shows all providers by default, and ask the user to run %ai list <provider-name>.
513527 - `%ai list <provider-name>` shows all models available from one provider. It should also note that the list is not comprehensive, and include a reference to the upstream LiteLLM docs.
514528 - `%ai list all` should list all models.
@@ -517,12 +531,12 @@ def handle_list(self, args: ListArgs):
517531 models = CHAT_MODELS
518532
519533 # If provider_id is None, only return provider IDs
520- if getattr (args , ' provider_id' , None ) is None :
534+ if getattr (args , " provider_id" , None ) is None :
521535 # Extract unique provider IDs from model IDs
522536 provider_ids = set ()
523537 for model in models :
524- if '/' in model :
525- provider_ids .add (model .split ('/' )[0 ])
538+ if "/" in model :
539+ provider_ids .add (model .split ("/" )[0 ])
526540
527541 # Format output for both text and markdown
528542 text_output = "Available providers\n \n (Run `%ai list <provider_name>` to see models for a specific provider)\n \n "
@@ -533,9 +547,9 @@ def handle_list(self, args: ListArgs):
533547 markdown_output += f"* `{ provider_id } `\n "
534548
535549 return TextOrMarkdown (text_output , markdown_output )
536-
537- elif getattr (args , ' provider_id' , None ) == ' all' :
538- # Otherwise show all models and aliases
550+
551+ elif getattr (args , " provider_id" , None ) == " all" :
552+ # Otherwise show all models and aliases
539553 text_output = "All available models\n \n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n \n "
540554 markdown_output = "## All available models \n \n (The list is not comprehensive, a list of models is available at https://docs.litellm.ai/docs/providers)\n \n "
541555
@@ -547,12 +561,25 @@ def handle_list(self, args: ListArgs):
547561 if len (self .aliases ) > 0 :
548562 text_output += "\n Aliases:\n "
549563 markdown_output += "\n ### Aliases\n \n "
550- for alias , target in self .aliases .items ():
551- text_output += f"* { alias } -> { target } \n "
552- markdown_output += f"* `{ alias } ` -> `{ target } `\n "
564+ for alias , config in self .aliases .items ():
565+ text_output += f"* { alias } :\n "
566+ text_output += f" - target: { config ['target' ]} \n "
567+ if config ["api_base" ]:
568+ text_output += f" - api_base: { config ['api_base' ]} \n "
569+ if config ["api_key_name" ]:
570+ text_output += f" - api_key_name: { config ['api_key_name' ]} \n "
571+
572+ markdown_output += f"* `{ alias } `:\n "
573+ markdown_output += f" - target: `{ config ['target' ]} `\n "
574+ if config ["api_base" ]:
575+ markdown_output += f" - api_base: `{ config ['api_base' ]} `\n "
576+ if config ["api_key_name" ]:
577+ markdown_output += (
578+ f" - api_key_name: `{ config ['api_key_name' ]} `\n "
579+ )
553580
554581 return TextOrMarkdown (text_output , markdown_output )
555-
582+
556583 else :
557584 # If a specific provider_id is given, filter models by that provider
558585 provider_id = args .provider_id
@@ -575,10 +602,24 @@ def handle_list(self, args: ListArgs):
575602 if len (self .aliases ) > 0 :
576603 text_output += "\n Aliases:\n "
577604 markdown_output += "\n ### Aliases\n \n "
578- for alias , target in self .aliases .items ():
579- if target .startswith (provider_id + "/" ):
580- text_output += f"* { alias } -> { target } \n "
581- markdown_output += f"* `{ alias } ` -> `{ target } `\n "
582-
605+ for alias , config in self .aliases .items ():
606+ if config ["target" ].startswith (provider_id + "/" ):
607+ text_output += f"* { alias } :\n "
608+ text_output += f" - target: { config ['target' ]} \n "
609+ if config ["api_base" ]:
610+ text_output += f" - api_base: { config ['api_base' ]} \n "
611+ if config ["api_key_name" ]:
612+ text_output += (
613+ f" - api_key_name: { config ['api_key_name' ]} \n "
614+ )
615+
616+ markdown_output += f"* `{ alias } `:\n "
617+ markdown_output += f" - target: `{ config ['target' ]} `\n "
618+ if config ["api_base" ]:
619+ markdown_output += f" - api_base: `{ config ['api_base' ]} `\n "
620+ if config ["api_key_name" ]:
621+ markdown_output += (
622+ f" - api_key_name: `{ config ['api_key_name' ]} `\n "
623+ )
583624
584625 return TextOrMarkdown (text_output , markdown_output )
0 commit comments