diff --git a/docetl/builder.py b/docetl/builder.py index 4d3c8d30..e25980d2 100644 --- a/docetl/builder.py +++ b/docetl/builder.py @@ -13,15 +13,14 @@ from rich.status import Status from rich.traceback import install -from docetl.dataset import Dataset, create_parsing_tool_map +from docetl.dataset import Dataset from docetl.operations import get_operation from docetl.operations.base import BaseOperation -from docetl.operations.utils import flush_cache +from docetl.helper.cache import flush_cache from docetl.optimizers.join_optimizer import JoinOptimizer from docetl.optimizers.map_optimizer import MapOptimizer from docetl.optimizers.reduce_optimizer import ReduceOptimizer from docetl.optimizers.utils import LLMClient -from docetl.config_wrapper import ConfigWrapper install(show_locals=True) diff --git a/docetl/cli.py b/docetl/cli.py index dbbe09fe..7a0dde81 100644 --- a/docetl/cli.py +++ b/docetl/cli.py @@ -4,7 +4,7 @@ import os import typer -from docetl.operations.utils import clear_cache as cc +from docetl.helper.cache import clear_cache as cc from docetl.runner import DSLRunner from dotenv import load_dotenv diff --git a/docetl/config_wrapper.py b/docetl/config_wrapper.py index c650c8b5..92cd9bd7 100644 --- a/docetl/config_wrapper.py +++ b/docetl/config_wrapper.py @@ -2,8 +2,8 @@ import os from docetl.console import get_console from docetl.utils import load_config -from typing import Any, Dict, List, Optional, Tuple, Union -from docetl.operations.utils import APIWrapper +from typing import Dict, Optional +from docetl.helper.api_wrapper import APIWrapper import pyrate_limiter from inspect import isawaitable import math diff --git a/docetl/console.py b/docetl/console.py index 4a07f35d..0f6aa64e 100644 --- a/docetl/console.py +++ b/docetl/console.py @@ -1,9 +1,13 @@ import os -from typing import Any, Optional -from rich.console import Console -from io import StringIO import threading -import queue +from io import StringIO +from multiprocessing.util import DEFAULT_LOGGING_FORMAT +from typing import override, Optional, Union, Any + +from rich.console import Console, JustifyMethod +from rich.style import Style + +from docetl.helper.database import DatabaseUtil class ThreadSafeConsole(Console): @@ -13,12 +17,14 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.input_event = threading.Event() self.input_value = None + self.conn: Optional[DatabaseUtil] = None + self.is_write_to_db = False def print(self, *args, **kwargs): super().print(*args, **kwargs) def input( - self, prompt="", *, markup: bool = True, emoji: bool = True, **kwargs + self, prompt="", *, markup: bool = True, emoji: bool = True, **kwargs ) -> str: if prompt: self.print(prompt, markup=markup, emoji=emoji, end="") @@ -36,6 +42,68 @@ def post_input(self, value: str): self.input_value = value self.input_event.set() + def with_db_logging_enabled(self, conn: DatabaseUtil, table_name: str, schema : dict) -> "ThreadSafeConsole": + self.conn = conn + self.is_write_to_db = True + self.schema = schema + self.table_name = table_name + return self + + @override(Console.log) + def log(self, + *objects: Any, + sep: str = " ", + end: str = "\n", + style: Optional[Union[str, Style]] = None, + justify: Optional[JustifyMethod] = None, + emoji: Optional[bool] = None, + markup: Optional[bool] = None, + highlight: Optional[bool] = None, + log_locals: bool = False, + _stack_offset: int = 1, + ): + # call super method + super().log(*objects, sep=sep, end=end, style=style, justify=justify, emoji=emoji, markup=markup, + highlight=highlight, log_locals=log_locals, _stack_offset=_stack_offset) + if self.is_write_to_db: + self.conn.log_to_db(log_data=str(*objects), schema=self.schema, table_name=self.table_name) + + +class DocETLLog(Console): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conn: Optional[DatabaseUtil] = None + self.is_write_to_db = False + + def with_db_logging_enabled(self, conn: DatabaseUtil, table_name: str, schema: dict) -> "DocETLLog": + self.conn = conn + self.is_write_to_db = True + self.schema = schema + self.table_name = table_name + return self + + @override(Console.log) + def log( + self, + *objects: Any, + sep: str = " ", + end: str = "\n", + style: Optional[Union[str, Style]] = None, + justify: Optional[JustifyMethod] = None, + emoji: Optional[bool] = None, + markup: Optional[bool] = None, + highlight: Optional[bool] = None, + log_locals: bool = False, + _stack_offset: int = 1, + ): + # call super method + super().log(*objects, sep=sep, end=end, style=style, justify=justify, emoji=emoji, markup=markup, + highlight=highlight, log_locals=log_locals, _stack_offset=_stack_offset) + if self.is_write_to_db: + # this needs to be dictionary of the schema type + # user defined schema is causing troubles, strict schema then ? + DatabaseUtil.DEFAULT_LOG_SCHEMA() + self.conn.log_to_db(log_data=str(*objects), schema=self.schema, table_name=self.table_name) def get_console(): # Check if we're running with a frontend @@ -50,4 +118,7 @@ def get_console(): return Console() +# override log function to take in a sqlite database object and writes logs to the database + + DOCETL_CONSOLE = get_console() diff --git a/docetl/helper/__init__.py b/docetl/helper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/docetl/operations/utils.py b/docetl/helper/api_wrapper.py similarity index 61% rename from docetl/operations/utils.py rename to docetl/helper/api_wrapper.py index a385c5ce..961ca31a 100644 --- a/docetl/operations/utils.py +++ b/docetl/helper/api_wrapper.py @@ -1,381 +1,18 @@ import ast -import functools import hashlib import json -import os -import shutil -import threading -from concurrent.futures import as_completed -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union - -import litellm -import tiktoken -from asteval import Interpreter -from diskcache import Cache -from dotenv import load_dotenv -from frozendict import frozendict -from jinja2 import Template -from litellm import completion, embedding, model_cost, RateLimitError -from rich import print as rprint -from rich.console import Console -from rich.prompt import Prompt -from tqdm import tqdm -from pydantic import BaseModel - -from docetl.console import DOCETL_CONSOLE -from docetl.utils import completion_cost, count_tokens import time -from litellm.utils import ModelResponse - -aeval = Interpreter() - -load_dotenv() -# litellm.set_verbose = True -DOCETL_HOME_DIR = os.path.expanduser("~/.docetl") - -CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "cache") -LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm_cache") -cache = Cache(LLM_CACHE_DIR) -cache.close() - - -class LLMResult(BaseModel): - response: Any - total_cost: float - validated: bool - - -def freezeargs(func): - """ - Decorator to convert mutable dictionary arguments into immutable. - - This decorator is useful for making functions compatible with caching mechanisms - that require immutable arguments. - - Args: - func (callable): The function to be wrapped. - - Returns: - callable: The wrapped function with immutable dictionary arguments. - """ - - @functools.wraps(func) - def wrapped(*args, **kwargs): - args = tuple( - ( - frozendict(arg) - if isinstance(arg, dict) - else json.dumps(arg) if isinstance(arg, list) else arg - ) - for arg in args - ) - kwargs = { - k: ( - frozendict(v) - if isinstance(v, dict) - else json.dumps(v) if isinstance(v, list) else v - ) - for k, v in kwargs.items() - } - return func(*args, **kwargs) - - return wrapped - - -def flush_cache(console: Console = DOCETL_CONSOLE): - """ - Flush the cache to disk. - """ - console.log("[bold green]Flushing cache to disk...[/bold green]") - cache.close() - console.log("[bold green]Cache flushed to disk.[/bold green]") +from typing import List, Dict, Optional, Any +from jinja2 import Template +from litellm import embedding, completion, RateLimitError +from litellm.types.utils import ModelResponse +from rich import print as rprint, Console -def clear_cache(console: Console = DOCETL_CONSOLE): - """ - Clear the LLM cache stored on disk. - - This function removes all cached items from the disk-based cache, - effectively clearing the LLM's response history. - - Args: - console (Console, optional): A Rich console object for logging. - Defaults to a new Console instance. - """ - console.log("[bold yellow]Clearing LLM cache...[/bold yellow]") - try: - with cache as c: - c.clear() - # Remove all files in the cache directory - cache_dir = CACHE_DIR - if not os.path.exists(cache_dir): - os.makedirs(cache_dir) - for filename in os.listdir(cache_dir): - file_path = os.path.join(cache_dir, filename) - try: - if os.path.isfile(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - console.log( - f"[bold red]Error deleting {file_path}: {str(e)}[/bold red]" - ) - console.log("[bold green]Cache cleared successfully.[/bold green]") - except Exception as e: - console.log(f"[bold red]Error clearing cache: {str(e)}[/bold red]") - -def convert_dict_schema_to_list_schema(schema: Dict[str, Any]) -> Dict[str, Any]: - schema_str = "{" + ", ".join([f"{k}: {v}" for k, v in schema.items()]) + "}" - return {"results": f"list[{schema_str}]"} - -def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]: - """ - Convert a string representation of a type to a dictionary representation. - - This function takes a string value representing a data type and converts it - into a dictionary format suitable for JSON schema. - - Args: - value (Any): A string representing a data type. - model (str): The model being used. Defaults to "gpt-4o-mini". - - Returns: - Dict[str, Any]: A dictionary representing the type in JSON schema format. - - Raises: - ValueError: If the input value is not a supported type or is improperly formatted. - """ - value = value.strip().lower() - if value in ["str", "text", "string", "varchar"]: - return {"type": "string"} - elif value in ["int", "integer"]: - return {"type": "integer"} - elif value in ["float", "decimal", "number"]: - return {"type": "number"} - elif value in ["bool", "boolean"]: - return {"type": "boolean"} - elif value.startswith("list["): - inner_type = value[5:-1].strip() - return {"type": "array", "items": convert_val(inner_type, model)} - elif value == "list": - raise ValueError("List type must specify its elements, e.g., 'list[str]'") - elif value.startswith("{") and value.endswith("}"): - # Handle dictionary type - properties = {} - for item in value[1:-1].split(","): - key, val = item.strip().split(":") - properties[key.strip()] = convert_val(val.strip(), model) - result = { - "type": "object", - "properties": properties, - "required": list(properties.keys()), - } - # TODO: this is a hack to get around the fact that gemini doesn't support additionalProperties - if "gemini" not in model: - result["additionalProperties"] = False - return result - else: - raise ValueError(f"Unsupported value type: {value}") - - -def cache_key( - model: str, - op_type: str, - messages: List[Dict[str, str]], - output_schema: Dict[str, str], - scratchpad: Optional[str] = None, -) -> str: - """ - Generate a unique cache key based on function arguments. - - This function creates a hash-based key using the input parameters, which can - be used for caching purposes. - - Args: - model (str): The model name. - op_type (str): The operation type. - messages (List[Dict[str, str]]): The messages to send to the LLM. - output_schema (Dict[str, str]): The output schema dictionary. - scratchpad (Optional[str]): The scratchpad to use for the operation. - - Returns: - str: A unique hash string representing the cache key. - """ - # Ensure no non-serializable objects are included - key_dict = { - "model": model, - "op_type": op_type, - "messages": json.dumps(messages, sort_keys=True), - "output_schema": json.dumps(output_schema, sort_keys=True), - "scratchpad": scratchpad, - } - return hashlib.md5(json.dumps(key_dict, sort_keys=True).encode()).hexdigest() - - -def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]: - """ - Prompt the user for input for each key in the schema using Rich, - then parse the input values with json.loads(). - - Args: - schema (Dict[str, Any]): The schema dictionary. - - Returns: - Dict[str, Any]: A dictionary with user inputs parsed according to the schema. - """ - user_input = {} - - for key, value_type in schema.items(): - prompt_text = f"Enter value for '{key}' ({value_type}): " - user_value = Prompt.ask(prompt_text) - - try: - # Parse the input value using json.loads() - parsed_value = json.loads(user_value) - - # Check if the parsed value matches the expected type - if isinstance(parsed_value, eval(value_type)): - user_input[key] = parsed_value - else: - rprint( - f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}." - ) - return get_user_input_for_schema(schema) # Recursive call to retry - - except json.JSONDecodeError: - rprint( - f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again." - ) - return get_user_input_for_schema(schema) # Recursive call to retry - - return user_input - - -class InvalidOutputError(Exception): - """ - Custom exception raised when the LLM output is invalid or cannot be parsed. - - Attributes: - message (str): Explanation of the error. - output (str): The invalid output that caused the exception. - expected_schema (Dict[str, Any]): The expected schema for the output. - messages (List[Dict[str, str]]): The messages sent to the LLM. - tools (Optional[List[Dict[str, str]]]): The tool calls generated by the LLM. - """ - - def __init__( - self, - message: str, - output: str, - expected_schema: Dict[str, Any], - messages: List[Dict[str, str]], - tools: Optional[List[Dict[str, str]]] = None, - ): - self.message = message - self.output = output - self.expected_schema = expected_schema - self.messages = messages - self.tools = tools - super().__init__(self.message) - - def __str__(self): - return ( - f"{self.message}\n" - f"Invalid output: {self.output}\n" - f"Expected schema: {self.expected_schema}\n" - f"Messages sent to LLM: {self.messages}\n" - f"Tool calls generated by LLM: {self.tools}" - ) - - -def timeout(seconds): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - result = [TimeoutError("Function call timed out")] - - def target(): - try: - result[0] = func(*args, **kwargs) - except Exception as e: - result[0] = e - - thread = threading.Thread(target=target) - thread.start() - thread.join(seconds) - if isinstance(result[0], Exception): - raise result[0] - return result[0] - - return wrapper - - return decorator - - -def truncate_messages( - messages: List[Dict[str, str]], model: str, from_agent: bool = False -) -> List[Dict[str, str]]: - """ - Truncate the messages to fit the model's context length. - """ - model_input_context_length = model_cost.get(model.split("/")[-1], {}).get( - "max_input_tokens", 8192 - ) - total_tokens = sum(count_tokens(json.dumps(msg), model) for msg in messages) - - if total_tokens <= model_input_context_length - 100: - return messages - - truncated_messages = messages.copy() - longest_message = max(truncated_messages, key=lambda x: len(x["content"])) - content = longest_message["content"] - excess_tokens = total_tokens - model_input_context_length + 200 # 200 token buffer - - try: - encoder = tiktoken.encoding_for_model(model.split("/")[-1]) - except Exception: - encoder = tiktoken.encoding_for_model("gpt-4o") - encoded_content = encoder.encode(content) - tokens_to_remove = min(len(encoded_content), excess_tokens) - mid_point = len(encoded_content) // 2 - truncated_encoded = ( - encoded_content[: mid_point - tokens_to_remove // 2] - + encoder.encode(f" ... [{tokens_to_remove} tokens truncated] ... ") - + encoded_content[mid_point + tokens_to_remove // 2 :] - ) - truncated_content = encoder.decode(truncated_encoded) - # Calculate the total number of tokens in the original content - total_tokens = len(encoded_content) - - # Print the warning message using rprint - warning_type = "User" if not from_agent else "Agent" - rprint( - f"[yellow]{warning_type} Warning:[/yellow] Cutting {tokens_to_remove} tokens from a prompt with {total_tokens} tokens..." - ) - - longest_message["content"] = truncated_content - - return truncated_messages - - -def safe_eval(expression: str, output: Dict) -> bool: - """ - Safely evaluate an expression with a given output dictionary. - Uses asteval to evaluate the expression. - https://lmfit.github.io/asteval/index.html - """ - try: - # Add the output dictionary to the symbol table - aeval.symtable["output"] = output - # Safely evaluate the expression - return bool(aeval(expression)) - except Exception: - # try to evaluate with python eval - try: - return bool(eval(expression, locals={"output": output})) - except Exception: - return False +from docetl.utils import completion_cost +from docetl.helper.cache import cache_key +from docetl.helper.generic import freezeargs, cache, LLMResult, convert_dict_schema_to_list_schema, truncate_messages, \ + timeout, convert_val, InvalidOutputError, get_user_input_for_schema, safe_eval class APIWrapper(object): @@ -423,7 +60,7 @@ def gen_embedding(self, model: str, input: List[str]) -> List[float]: c.set(key, result) return result - + def call_llm_batch( self, model: str, @@ -437,10 +74,10 @@ def call_llm_batch( ) -> LLMResult: # Turn the output schema into a list of schemas output_schema = convert_dict_schema_to_list_schema(output_schema) - + # Invoke the LLM call return self.call_llm(model, op_type,messages, output_schema, verbose=verbose, timeout_seconds=timeout_seconds, max_retries_per_timeout=max_retries_per_timeout, bypass_cache=bypass_cache) - + def _cached_call_llm( self, @@ -708,7 +345,7 @@ def call_llm( initial_result=initial_result, ) except RateLimitError: - # TODO: this is a really hacky way to handle rate limits + # TODO: this is a really hacky way to handle rate limits # we should implement a more robust retry mechanism backoff_time = 4 * (2**rate_limited_attempt) # Exponential backoff max_backoff = 120 # Maximum backoff time of 60 seconds @@ -1026,136 +663,3 @@ def validate_output(self, operation: Dict, output: Dict, console: Console) -> bo console.log(f"[yellow]Output:[/yellow] {output}") return False return True - - -class RichLoopBar: - """ - A progress bar class that integrates with Rich console. - - This class provides a wrapper around tqdm to create progress bars that work - with Rich console output. - - Args: - iterable (Optional[Union[Iterable, range]]): An iterable to track progress. - total (Optional[int]): The total number of iterations. - desc (Optional[str]): Description to be displayed alongside the progress bar. - leave (bool): Whether to leave the progress bar on screen after completion. - console: The Rich console object to use for output. - """ - - def __init__( - self, - iterable: Optional[Union[Iterable, range]] = None, - total: Optional[int] = None, - desc: Optional[str] = None, - leave: bool = True, - console=None, - ): - if console is None: - raise ValueError("Console must be provided") - self.console = console - self.iterable = iterable - self.total = self._get_total(iterable, total) - self.description = desc - self.leave = leave - self.tqdm = None - - def _get_total(self, iterable, total): - """ - Determine the total number of iterations for the progress bar. - - Args: - iterable: The iterable to be processed. - total: The explicitly specified total, if any. - - Returns: - int or None: The total number of iterations, or None if it can't be determined. - """ - if total is not None: - return total - if isinstance(iterable, range): - return len(iterable) - try: - return len(iterable) - except TypeError: - return None - - def __iter__(self): - """ - Create and return an iterator with a progress bar. - - Returns: - Iterator: An iterator that yields items from the wrapped iterable. - """ - self.tqdm = tqdm( - self.iterable, - total=self.total, - desc=self.description, - file=self.console.file, - ) - for item in self.tqdm: - yield item - - def __enter__(self): - """ - Enter the context manager, initializing the progress bar. - - Returns: - RichLoopBar: The RichLoopBar instance. - """ - self.tqdm = tqdm( - total=self.total, - desc=self.description, - leave=self.leave, - file=self.console.file, - ) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """ - Exit the context manager, closing the progress bar. - - Args: - exc_type: The type of the exception that caused the context to be exited. - exc_val: The instance of the exception that caused the context to be exited. - exc_tb: A traceback object encoding the stack trace. - """ - self.tqdm.close() - - def update(self, n=1): - """ - Update the progress bar. - - Args: - n (int): The number of iterations to increment the progress bar by. - """ - if self.tqdm: - self.tqdm.update(n) - - -def rich_as_completed(futures, total=None, desc=None, leave=True, console=None): - """ - Yield completed futures with a Rich progress bar. - - This function wraps concurrent.futures.as_completed with a Rich progress bar. - - Args: - futures: An iterable of Future objects to monitor. - total (Optional[int]): The total number of futures. - desc (Optional[str]): Description for the progress bar. - leave (bool): Whether to leave the progress bar on screen after completion. - console: The Rich console object to use for output. - - Yields: - Future: Completed future objects. - - Raises: - ValueError: If no console object is provided. - """ - if console is None: - raise ValueError("Console must be provided") - - with RichLoopBar(total=total, desc=desc, leave=leave, console=console) as pbar: - for future in as_completed(futures): - yield future - pbar.update() diff --git a/docetl/helper/cache.py b/docetl/helper/cache.py new file mode 100644 index 00000000..da3594f7 --- /dev/null +++ b/docetl/helper/cache.py @@ -0,0 +1,88 @@ +import hashlib +import json +import os +import shutil +from typing import List, Dict, Optional + +from rich import Console + +from docetl.console import DOCETL_CONSOLE +from docetl.helper.generic import cache, CACHE_DIR + + +def flush_cache(console: Console = DOCETL_CONSOLE): + """ + Flush the cache to disk. + """ + console.log("[bold green]Flushing cache to disk...[/bold green]") + cache.close() + console.log("[bold green]Cache flushed to disk.[/bold green]") + + +def clear_cache(console: Console = DOCETL_CONSOLE): + """ + Clear the LLM cache stored on disk. + + This function removes all cached items from the disk-based cache, + effectively clearing the LLM's response history. + + Args: + console (Console, optional): A Rich console object for logging. + Defaults to a new Console instance. + """ + console.log("[bold yellow]Clearing LLM cache...[/bold yellow]") + try: + with cache as c: + c.clear() + # Remove all files in the cache directory + cache_dir = CACHE_DIR + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + for filename in os.listdir(cache_dir): + file_path = os.path.join(cache_dir, filename) + try: + if os.path.isfile(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + console.log( + f"[bold red]Error deleting {file_path}: {str(e)}[/bold red]" + ) + console.log("[bold green]Cache cleared successfully.[/bold green]") + except Exception as e: + console.log(f"[bold red]Error clearing cache: {str(e)}[/bold red]") + + +def cache_key( + model: str, + op_type: str, + messages: List[Dict[str, str]], + output_schema: Dict[str, str], + scratchpad: Optional[str] = None, +) -> str: + """ + Generate a unique cache key based on function arguments. + + This function creates a hash-based key using the input parameters, which can + be used for caching purposes. + + Args: + model (str): The model name. + op_type (str): The operation type. + messages (List[Dict[str, str]]): The messages to send to the LLM. + output_schema (Dict[str, str]): The output schema dictionary. + scratchpad (Optional[str]): The scratchpad to use for the operation. + + Returns: + str: A unique hash string representing the cache key. + """ + # Ensure no non-serializable objects are included + key_dict = { + "model": model, + "op_type": op_type, + "messages": json.dumps(messages, sort_keys=True), + "output_schema": json.dumps(output_schema, sort_keys=True), + "scratchpad": scratchpad, + } + return hashlib.md5(json.dumps(key_dict, sort_keys=True).encode()).hexdigest() diff --git a/docetl/helper/database.py b/docetl/helper/database.py new file mode 100644 index 00000000..a07ebaad --- /dev/null +++ b/docetl/helper/database.py @@ -0,0 +1,140 @@ +from abc import ABCMeta, abstractmethod +from typing import Any, Optional, Tuple + +from rich import Console + +from docetl.utils import classproperty + + +class DatabaseUtil(ABCMeta): + """ + A utility class for connecting to databases. + """ + + def __init__(self, host: str, port: int, username: str, password: str, database: str, + console: Optional[Console] = None): + self.client = self.new_connection(host=host, port=port, username=username, password=password, database=database, + console=console) + self.console = console + + def __enter__(self): + self.console.log("[bold green]Connected to database successfully.[/bold green]") + + def __exit__(self, exc_type, exc_val, exc_tb): + self.console.log("[bold green]Closing connection to database.[/bold green]") + + @abstractmethod + def new_connection( + self) -> Any: + """ + Connect to a database using the provided parameters. + + This function connects to a database using the specified parameters + and returns a connection object. + + Returns: + Any: A connection object for interacting with the database. + """ + pass + + @abstractmethod + def close_connection(self): + """ + generic function to close the connection + :return: + """ + pass + + @abstractmethod + def get_query(self, query: str) -> Any: + """ + Get a query from a file. + + This function reads a query from a file and returns it as a string. + + Args: + query (str): The path to the query file. + + Returns: + Any: The query as a string. + """ + pass + + @abstractmethod + def execute_query(self, query: str, columns: Optional[Tuple[Any, ...]]) -> Any: + """ + Perform a query on a SQLite database. + + This function executes a query on a SQLite database and returns the results. + + Args: + query (str): The SQL query to execute. + columns (Tuple, optional): The columns to return. Defaults to None. + + Returns: + Any: The results of the query. + """ + pass + + @abstractmethod + def execute_transaction(self, queries: list): + """ + Execute a list of queries in a transaction + + Args: + queries: list of queries to execute + """ + pass + + @abstractmethod + def log_to_db(self, table_name: str, schema: dict, log_data: str): + """ + Log data to a database. + :param table_name: + :param schema: + :param log_data: + :return: + """ + pass + + @abstractmethod + def create_log_table(self, table_name: str, schema: dict) -> Any: + """ + Create a log table in a database. + :param table_name: + :param schema: + :return: + """ + pass + + # create a log_schema object that returns key value callable pairs + class DEFAULT_LOG_SCHEMA(): + def __init__(self, process_id: str = "TEXT", operation: str = "TEXT", log_message: str = "TEXT", + table_name: str = "docETL_log"): + self.process_id: str = process_id + self.operation: str = operation + self.log_message: str = log_message + self.table_name: str = table_name + + @property + def get_default_table_name(self) -> str: + """ + Get the default table name for a SQLite database. + :return: table name + """ + return "docETL_log" + + @property + def get_default_log_schema(self) -> dict: + """ + Get the default log schema for a SQLite database. + :return: json schema + """ + return { + "id": "INTEGER PRIMARY KEY AUTOINCREMENT", + "process_id": "TEXT", + "operation": "TEXT", + "log_message": "TEXT", + "timestamp": "DATETIME DEFAULT CURRENT_TIMESTAMP" + + } diff --git a/docetl/helper/generic.py b/docetl/helper/generic.py new file mode 100644 index 00000000..587d9c63 --- /dev/null +++ b/docetl/helper/generic.py @@ -0,0 +1,292 @@ +import functools +import json +import os +import threading +from typing import Any, Dict, List, Optional + +import tiktoken +from asteval import Interpreter +from diskcache import Cache +from dotenv import load_dotenv +from frozendict import frozendict +from litellm import model_cost +from rich import print as rprint +from rich.prompt import Prompt +from pydantic import BaseModel + +from docetl.utils import count_tokens + +aeval = Interpreter() + +load_dotenv() +# litellm.set_verbose = True +DOCETL_HOME_DIR = os.path.expanduser("~/.docetl") + +CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "cache") +LLM_CACHE_DIR = os.path.join(DOCETL_HOME_DIR, "llm_cache") +cache = Cache(LLM_CACHE_DIR) +cache.close() + + +class LLMResult(BaseModel): + response: Any + total_cost: float + validated: bool + + +def freezeargs(func): + """ + Decorator to convert mutable dictionary arguments into immutable. + + This decorator is useful for making functions compatible with caching mechanisms + that require immutable arguments. + + Args: + func (callable): The function to be wrapped. + + Returns: + callable: The wrapped function with immutable dictionary arguments. + """ + + @functools.wraps(func) + def wrapped(*args, **kwargs): + args = tuple( + ( + frozendict(arg) + if isinstance(arg, dict) + else json.dumps(arg) if isinstance(arg, list) else arg + ) + for arg in args + ) + kwargs = { + k: ( + frozendict(v) + if isinstance(v, dict) + else json.dumps(v) if isinstance(v, list) else v + ) + for k, v in kwargs.items() + } + return func(*args, **kwargs) + + return wrapped + + +def convert_dict_schema_to_list_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + schema_str = "{" + ", ".join([f"{k}: {v}" for k, v in schema.items()]) + "}" + return {"results": f"list[{schema_str}]"} + +def convert_val(value: Any, model: str = "gpt-4o-mini") -> Dict[str, Any]: + """ + Convert a string representation of a type to a dictionary representation. + + This function takes a string value representing a data type and converts it + into a dictionary format suitable for JSON schema. + + Args: + value (Any): A string representing a data type. + model (str): The model being used. Defaults to "gpt-4o-mini". + + Returns: + Dict[str, Any]: A dictionary representing the type in JSON schema format. + + Raises: + ValueError: If the input value is not a supported type or is improperly formatted. + """ + value = value.strip().lower() + if value in ["str", "text", "string", "varchar"]: + return {"type": "string"} + elif value in ["int", "integer"]: + return {"type": "integer"} + elif value in ["float", "decimal", "number"]: + return {"type": "number"} + elif value in ["bool", "boolean"]: + return {"type": "boolean"} + elif value.startswith("list["): + inner_type = value[5:-1].strip() + return {"type": "array", "items": convert_val(inner_type, model)} + elif value == "list": + raise ValueError("List type must specify its elements, e.g., 'list[str]'") + elif value.startswith("{") and value.endswith("}"): + # Handle dictionary type + properties = {} + for item in value[1:-1].split(","): + key, val = item.strip().split(":") + properties[key.strip()] = convert_val(val.strip(), model) + result = { + "type": "object", + "properties": properties, + "required": list(properties.keys()), + } + # TODO: this is a hack to get around the fact that gemini doesn't support additionalProperties + if "gemini" not in model: + result["additionalProperties"] = False + return result + else: + raise ValueError(f"Unsupported value type: {value}") + + +def get_user_input_for_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Prompt the user for input for each key in the schema using Rich, + then parse the input values with json.loads(). + + Args: + schema (Dict[str, Any]): The schema dictionary. + + Returns: + Dict[str, Any]: A dictionary with user inputs parsed according to the schema. + """ + user_input = {} + + for key, value_type in schema.items(): + prompt_text = f"Enter value for '{key}' ({value_type}): " + user_value = Prompt.ask(prompt_text) + + try: + # Parse the input value using json.loads() + parsed_value = json.loads(user_value) + + # Check if the parsed value matches the expected type + if isinstance(parsed_value, eval(value_type)): + user_input[key] = parsed_value + else: + rprint( + f"[bold red]Error:[/bold red] Input for '{key}' does not match the expected type {value_type}." + ) + return get_user_input_for_schema(schema) # Recursive call to retry + + except json.JSONDecodeError: + rprint( + f"[bold red]Error:[/bold red] Invalid JSON input for '{key}'. Please try again." + ) + return get_user_input_for_schema(schema) # Recursive call to retry + + return user_input + + +class InvalidOutputError(Exception): + """ + Custom exception raised when the LLM output is invalid or cannot be parsed. + + Attributes: + message (str): Explanation of the error. + output (str): The invalid output that caused the exception. + expected_schema (Dict[str, Any]): The expected schema for the output. + messages (List[Dict[str, str]]): The messages sent to the LLM. + tools (Optional[List[Dict[str, str]]]): The tool calls generated by the LLM. + """ + + def __init__( + self, + message: str, + output: str, + expected_schema: Dict[str, Any], + messages: List[Dict[str, str]], + tools: Optional[List[Dict[str, str]]] = None, + ): + self.message = message + self.output = output + self.expected_schema = expected_schema + self.messages = messages + self.tools = tools + super().__init__(self.message) + + def __str__(self): + return ( + f"{self.message}\n" + f"Invalid output: {self.output}\n" + f"Expected schema: {self.expected_schema}\n" + f"Messages sent to LLM: {self.messages}\n" + f"Tool calls generated by LLM: {self.tools}" + ) + + +def timeout(seconds): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = [TimeoutError("Function call timed out")] + + def target(): + try: + result[0] = func(*args, **kwargs) + except Exception as e: + result[0] = e + + thread = threading.Thread(target=target) + thread.start() + thread.join(seconds) + if isinstance(result[0], Exception): + raise result[0] + return result[0] + + return wrapper + + return decorator + + +def truncate_messages( + messages: List[Dict[str, str]], model: str, from_agent: bool = False +) -> List[Dict[str, str]]: + """ + Truncate the messages to fit the model's context length. + """ + model_input_context_length = model_cost.get(model.split("/")[-1], {}).get( + "max_input_tokens", 8192 + ) + total_tokens = sum(count_tokens(json.dumps(msg), model) for msg in messages) + + if total_tokens <= model_input_context_length - 100: + return messages + + truncated_messages = messages.copy() + longest_message = max(truncated_messages, key=lambda x: len(x["content"])) + content = longest_message["content"] + excess_tokens = total_tokens - model_input_context_length + 200 # 200 token buffer + + try: + encoder = tiktoken.encoding_for_model(model.split("/")[-1]) + except Exception: + encoder = tiktoken.encoding_for_model("gpt-4o") + encoded_content = encoder.encode(content) + tokens_to_remove = min(len(encoded_content), excess_tokens) + mid_point = len(encoded_content) // 2 + truncated_encoded = ( + encoded_content[: mid_point - tokens_to_remove // 2] + + encoder.encode(f" ... [{tokens_to_remove} tokens truncated] ... ") + + encoded_content[mid_point + tokens_to_remove // 2 :] + ) + truncated_content = encoder.decode(truncated_encoded) + # Calculate the total number of tokens in the original content + total_tokens = len(encoded_content) + + # Print the warning message using rprint + warning_type = "User" if not from_agent else "Agent" + rprint( + f"[yellow]{warning_type} Warning:[/yellow] Cutting {tokens_to_remove} tokens from a prompt with {total_tokens} tokens..." + ) + + longest_message["content"] = truncated_content + + return truncated_messages + + +def safe_eval(expression: str, output: Dict) -> bool: + """ + Safely evaluate an expression with a given output dictionary. + Uses asteval to evaluate the expression. + https://lmfit.github.io/asteval/index.html + """ + try: + # Add the output dictionary to the symbol table + aeval.symtable["output"] = output + # Safely evaluate the expression + return bool(aeval(expression)) + except Exception: + # try to evaluate with python eval + try: + return bool(eval(expression, locals={"output": output})) + except Exception: + return False + + diff --git a/docetl/helper/progress_bar.py b/docetl/helper/progress_bar.py new file mode 100644 index 00000000..beeef2b1 --- /dev/null +++ b/docetl/helper/progress_bar.py @@ -0,0 +1,137 @@ +from concurrent.futures import as_completed +from typing import Optional, Union, Iterable + +from tqdm import tqdm + + +class RichLoopBar: + """ + A progress bar class that integrates with Rich console. + + This class provides a wrapper around tqdm to create progress bars that work + with Rich console output. + + Args: + iterable (Optional[Union[Iterable, range]]): An iterable to track progress. + total (Optional[int]): The total number of iterations. + desc (Optional[str]): Description to be displayed alongside the progress bar. + leave (bool): Whether to leave the progress bar on screen after completion. + console: The Rich console object to use for output. + """ + + def __init__( + self, + iterable: Optional[Union[Iterable, range]] = None, + total: Optional[int] = None, + desc: Optional[str] = None, + leave: bool = True, + console=None, + ): + if console is None: + raise ValueError("Console must be provided") + self.console = console + self.iterable = iterable + self.total = self._get_total(iterable, total) + self.description = desc + self.leave = leave + self.tqdm = None + + def _get_total(self, iterable, total): + """ + Determine the total number of iterations for the progress bar. + + Args: + iterable: The iterable to be processed. + total: The explicitly specified total, if any. + + Returns: + int or None: The total number of iterations, or None if it can't be determined. + """ + if total is not None: + return total + if isinstance(iterable, range): + return len(iterable) + try: + return len(iterable) + except TypeError: + return None + + def __iter__(self): + """ + Create and return an iterator with a progress bar. + + Returns: + Iterator: An iterator that yields items from the wrapped iterable. + """ + self.tqdm = tqdm( + self.iterable, + total=self.total, + desc=self.description, + file=self.console.file, + ) + for item in self.tqdm: + yield item + + def __enter__(self): + """ + Enter the context manager, initializing the progress bar. + + Returns: + RichLoopBar: The RichLoopBar instance. + """ + self.tqdm = tqdm( + total=self.total, + desc=self.description, + leave=self.leave, + file=self.console.file, + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit the context manager, closing the progress bar. + + Args: + exc_type: The type of the exception that caused the context to be exited. + exc_val: The instance of the exception that caused the context to be exited. + exc_tb: A traceback object encoding the stack trace. + """ + self.tqdm.close() + + def update(self, n=1): + """ + Update the progress bar. + + Args: + n (int): The number of iterations to increment the progress bar by. + """ + if self.tqdm: + self.tqdm.update(n) + + +def rich_as_completed(futures, total=None, desc=None, leave=True, console=None): + """ + Yield completed futures with a Rich progress bar. + + This function wraps concurrent.futures.as_completed with a Rich progress bar. + + Args: + futures: An iterable of Future objects to monitor. + total (Optional[int]): The total number of futures. + desc (Optional[str]): Description for the progress bar. + leave (bool): Whether to leave the progress bar on screen after completion. + console: The Rich console object to use for output. + + Yields: + Future: Completed future objects. + + Raises: + ValueError: If no console object is provided. + """ + if console is None: + raise ValueError("Console must be provided") + + with RichLoopBar(total=total, desc=desc, leave=leave, console=console) as pbar: + for future in as_completed(futures): + yield future + pbar.update() diff --git a/docetl/helper/sqlite.py b/docetl/helper/sqlite.py new file mode 100644 index 00000000..48acaae7 --- /dev/null +++ b/docetl/helper/sqlite.py @@ -0,0 +1,123 @@ +import sqlite3 +from typing import Optional, Tuple, Any + +import rich +from rich import Console +from typing_extensions import override + +from docetl.helper.database import DatabaseUtil + + +class SqliteUtil(DatabaseUtil): + def __init__(self, db_path): + self.db_path = db_path + self.conn = None + self.cursor = None + self.console = rich.Console() + + @override + def new_connection(self): + if self.conn is None: + try: + self.conn = sqlite3.connect(self.db_path) + self.cursor = self.conn.cursor() + self.console.log("[bold green]Connected to SQLite database successfully.[/bold green]") + except Exception as e: + self.console.log(f"[bold red]Error connecting to SQLite database: {str(e)}[/bold red]") + raise ConnectionError(f"Error connecting to SQLite database: {str(e)}") + + return self.conn + + @override + def close_connection(self): + if self.conn is not None: + try: + self.conn.close() + self.conn = None + self.cursor = None + self.console.log("[bold green]Disconnected from SQLite database successfully.[/bold green]") + except Exception as e: + self.console.log(f"[bold red]Error disconnecting from SQLite database: {str(e)}[/bold red]") + raise ConnectionError(f"Error disconnecting from SQLite database: {str(e)}") + + + @override + def get_query(self, query: str) -> Any: + if self.conn is None: + raise ConnectionError("No connection to database") + try: + self.console.log(f"Executing query: {query}") + self.cursor.execute(query) + return self.cursor.fetchall() + except Exception as e: + raise ValueError(f"Error executing query: {str(e)}") + + @override + def execute_query(self, query: str, columns: Optional[Tuple[Any, ...]]) -> Any: + if self.conn is None: + raise ConnectionError("No connection to database") + try: + self.console.log(f"Executing query: {query}") + self.cursor.execute(query, columns) + return self.conn.commit() + except Exception as e: + raise ValueError(f"Error executing query: {str(e)}") + + @override + def execute_transaction(self, queries: list): + if self.conn is None: + raise ConnectionError("No connection to database") + try: + for query in queries: + self.cursor.execute(query) + self.conn.commit() + self.console.log("[bold green]Transaction executed successfully.[/bold green]") + except Exception as e: + self.conn.rollback() + self.console.log(f"[bold red]Error executing transaction: {str(e)}[/bold red]") + raise ValueError(f"Error executing transaction: {str(e)}") + + @override + def log_to_db(self, table_name: str, schema: dict, log_data: str): + """ + Log data to a SQLite database. + + This function logs data to a SQLite database. + + Args: + table_name (str): The name of the table to log to. + schema (dict): The schema of the table. + log_data (str): The data to log in schema format. + """ + for key in schema.keys(): + if key not in DatabaseUtil.DEFAULT_LOG_SCHEMA.get_default_log_schema.keys(): + raise ValueError(f"Column '{key}' is not defined in the schema.") + + columns = ', '.join(schema.keys()) # Get the column names + placeholders = ', '.join('?' * len(schema)) # Create placeholders for the values + insert_sql = f"INSERT INTO {table_name} ({columns}) VALUES ({placeholders});" + + return self.execute_query(insert_sql, tuple(log_data.values())) + + + @override + def create_log_table(self, table_name: str, schema: dict) -> Any: + """ + Create a log table in a SQLite database. + + This function creates a log table in a SQLite database. + + Args: + table_name (str): The name of the table to create. + schema (dict): The schema of the table. + """ + if not schema: + schema = DatabaseUtil.DEFAULT_LOG_SCHEMA.get_default_log_schema + + if not table_name: + table_name = DatabaseUtil.DEFAULT_LOG_SCHEMA.get_default_table_name + + columns = ", ".join([f"{column} {datatype}" for column, datatype in schema.items()]) + + create_table_sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns});" + self.execute_query(query=create_table_sql) diff --git a/docetl/operations/base.py b/docetl/operations/base.py index 88377077..b0b59895 100644 --- a/docetl/operations/base.py +++ b/docetl/operations/base.py @@ -5,12 +5,12 @@ from abc import ABC, ABCMeta, abstractmethod from typing import Dict, List, Optional, Tuple -from docetl.operations.utils import APIWrapper -from docetl.console import DOCETL_CONSOLE -from rich.console import Console -from rich.status import Status import jsonschema from pydantic import BaseModel +from rich.console import Console +from rich.status import Status + +from docetl.console import DOCETL_CONSOLE # FIXME: This should probably live in some utils module? @@ -143,3 +143,4 @@ def gleaning_check(self) -> None: raise ValueError( "'validation_prompt' in 'gleaning' configuration cannot be empty" ) + diff --git a/docetl/operations/cluster.py b/docetl/operations/cluster.py index 033e3bf4..cf2243cc 100644 --- a/docetl/operations/cluster.py +++ b/docetl/operations/cluster.py @@ -1,9 +1,9 @@ import numpy as np -from jinja2 import Environment, Template +from jinja2 import Template from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple from .base import BaseOperation -from .utils import RichLoopBar +from ..helper.progress_bar import RichLoopBar from .clustering_utils import get_embeddings_for_clustering diff --git a/docetl/operations/clustering_utils.py b/docetl/operations/clustering_utils.py index 7663b892..ec09801e 100644 --- a/docetl/operations/clustering_utils.py +++ b/docetl/operations/clustering_utils.py @@ -6,7 +6,7 @@ from typing import Dict, List, Tuple -from docetl.operations.utils import APIWrapper +from docetl.helper.api_wrapper import APIWrapper from docetl.utils import completion_cost diff --git a/docetl/operations/code_operations.py b/docetl/operations/code_operations.py index 09a62c9a..dc597f33 100644 --- a/docetl/operations/code_operations.py +++ b/docetl/operations/code_operations.py @@ -1,7 +1,8 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple from concurrent.futures import ThreadPoolExecutor from docetl.operations.base import BaseOperation -from docetl.operations.utils import RichLoopBar +from docetl.helper.progress_bar import RichLoopBar + class CodeMapOperation(BaseOperation): class schema(BaseOperation.schema): diff --git a/docetl/operations/equijoin.py b/docetl/operations/equijoin.py index 884f0113..c8ea8b94 100644 --- a/docetl/operations/equijoin.py +++ b/docetl/operations/equijoin.py @@ -15,9 +15,7 @@ from rich.prompt import Confirm from docetl.operations.base import BaseOperation -from docetl.operations.utils import ( - rich_as_completed, -) +from docetl.helper.progress_bar import rich_as_completed from docetl.utils import completion_cost # Global variables to store shared data diff --git a/docetl/operations/link_resolve.py b/docetl/operations/link_resolve.py index 2669ac66..47ac6ea9 100644 --- a/docetl/operations/link_resolve.py +++ b/docetl/operations/link_resolve.py @@ -1,17 +1,15 @@ import random -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Tuple -import jinja2 from jinja2 import Template from rich.prompt import Confirm from docetl.operations.base import BaseOperation -from docetl.operations.utils import RichLoopBar, rich_as_completed -from docetl.utils import completion_cost, extract_jinja_variables +from ..helper.progress_bar import RichLoopBar, rich_as_completed from .clustering_utils import get_embeddings_for_clustering from sklearn.metrics.pairwise import cosine_similarity -import numpy as np + class LinkResolveOperation(BaseOperation): def syntax_check(self) -> None: diff --git a/docetl/operations/map.py b/docetl/operations/map.py index 91c2916f..53ff2fd4 100644 --- a/docetl/operations/map.py +++ b/docetl/operations/map.py @@ -9,9 +9,8 @@ from tqdm import tqdm from docetl.operations.base import BaseOperation -from docetl.operations.utils import RichLoopBar +from docetl.helper.progress_bar import RichLoopBar from docetl.base_schemas import Tool, ToolFunction -from docetl.utils import completion_cost from pydantic import Field, field_validator from litellm.utils import ModelResponse diff --git a/docetl/operations/reduce.py b/docetl/operations/reduce.py index e1ed4d32..34d6ac2b 100644 --- a/docetl/operations/reduce.py +++ b/docetl/operations/reduce.py @@ -23,7 +23,7 @@ cluster_documents, get_embeddings_for_clustering, ) -from docetl.operations.utils import rich_as_completed +from docetl.helper.progress_bar import rich_as_completed from docetl.utils import completion_cost diff --git a/docetl/operations/resolve.py b/docetl/operations/resolve.py index 0f896261..1dc7cffd 100644 --- a/docetl/operations/resolve.py +++ b/docetl/operations/resolve.py @@ -3,17 +3,15 @@ """ import random -import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict, List, Tuple, Optional import jinja2 from jinja2 import Template from rich.prompt import Confirm -import math from docetl.operations.base import BaseOperation -from docetl.operations.utils import RichLoopBar, rich_as_completed +from docetl.helper.progress_bar import RichLoopBar, rich_as_completed from docetl.utils import completion_cost, extract_jinja_variables diff --git a/docetl/optimizers/reduce_optimizer.py b/docetl/optimizers/reduce_optimizer.py index feee3f8b..b9eca492 100644 --- a/docetl/optimizers/reduce_optimizer.py +++ b/docetl/optimizers/reduce_optimizer.py @@ -12,7 +12,7 @@ from rich.status import Status from docetl.operations.base import BaseOperation -from docetl.operations.utils import truncate_messages +from docetl.helper.generic import truncate_messages from docetl.optimizers.join_optimizer import JoinOptimizer from docetl.optimizers.utils import LLMClient from docetl.utils import count_tokens, extract_jinja_variables diff --git a/docetl/optimizers/utils.py b/docetl/optimizers/utils.py index 53520c8b..0d50673f 100644 --- a/docetl/optimizers/utils.py +++ b/docetl/optimizers/utils.py @@ -2,7 +2,7 @@ from litellm import completion, completion_cost -from docetl.operations.utils import truncate_messages +from docetl.helper.generic import truncate_messages from docetl.utils import completion_cost diff --git a/docetl/runner.py b/docetl/runner.py index 5acf84cf..4eef5438 100644 --- a/docetl/runner.py +++ b/docetl/runner.py @@ -6,17 +6,15 @@ import functools from typing import Any, Dict, List, Optional, Tuple, Union from docetl.builder import Optimizer -from docetl.console import get_console from pydantic import BaseModel from dotenv import load_dotenv import hashlib from rich.console import Console -from rich.prompt import Confirm from docetl.dataset import Dataset, create_parsing_tool_map from docetl.operations import get_operation, get_operations -from docetl.operations.utils import flush_cache +from docetl.helper.cache import flush_cache from docetl.config_wrapper import ConfigWrapper from . import schemas from .utils import classproperty diff --git a/tests/basic/test_basic_filter_split_gather.py b/tests/basic/test_basic_filter_split_gather.py index 1c4a3ca0..4db3e512 100644 --- a/tests/basic/test_basic_filter_split_gather.py +++ b/tests/basic/test_basic_filter_split_gather.py @@ -4,8 +4,6 @@ from docetl.operations.equijoin import EquijoinOperation from docetl.operations.split import SplitOperation from docetl.operations.gather import GatherOperation -from docetl.operations.utils import APIWrapper -from docetl.config_wrapper import ConfigWrapper from dotenv import load_dotenv from tests.conftest import api_wrapper diff --git a/tests/basic/test_pipeline_with_parsing.py b/tests/basic/test_pipeline_with_parsing.py index 3487a5de..410654c5 100644 --- a/tests/basic/test_pipeline_with_parsing.py +++ b/tests/basic/test_pipeline_with_parsing.py @@ -4,7 +4,7 @@ import os import tempfile from docetl.runner import DSLRunner -from docetl.utils import load_config +from docetl.helper import load_config import yaml from docetl.api import ( Pipeline,