diff --git a/README.md b/README.md index 190131a32..d61b86aae 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,28 @@ uv pip install -e . uv run pytest # Run tests ``` +### Database Migrations (Alembic) + +Our Alembic setup lives under `api/alembic`. Use the API Conda/uv environment before running any commands. + +- **Autogenerate a migration** + + ```bash + cd api + alembic revision --autogenerate -m "describe change" + ``` + + The autogenerator respects the exclusions configured in `api/alembic/env.py` (e.g., exclusion of tables `workflows`, `workflow_runs`). Always review the generated file before committing. + +- **Run migrations locally for testing** + + ```bash + cd api + alembic upgrade head + ``` + + To roll back the most recent migration while iterating, run `alembic downgrade -1`. + ## License diff --git a/api/alembic.ini b/api/alembic.ini new file mode 100644 index 000000000..079181673 --- /dev/null +++ b/api/alembic.ini @@ -0,0 +1,147 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/api/alembic/README b/api/alembic/README new file mode 100644 index 000000000..98e4f9c44 --- /dev/null +++ b/api/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/api/alembic/env.py b/api/alembic/env.py new file mode 100644 index 000000000..ea63560ec --- /dev/null +++ b/api/alembic/env.py @@ -0,0 +1,94 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# Import all models to ensure they're registered with Base.metadata +from transformerlab.shared.models.models import Base + +# Override sqlalchemy.url from environment or use the one from constants +from transformerlab.db.constants import DATABASE_URL + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Set target_metadata to Base.metadata for autogenerate support +target_metadata = Base.metadata + +EXCLUDED_TABLES = {"workflows", "workflow_runs"} + + +def include_object(object, name, type_, reflected, compare_to): + """Skip objects Alembic should not track.""" + if type_ == "table" and name in EXCLUDED_TABLES: + return False + return True + + +# Remove the sqlite+aiosqlite:// prefix and use sqlite:// for Alembic +# Alembic needs a synchronous connection URL (uses sqlite3, not aiosqlite) +sync_url = DATABASE_URL.replace("sqlite+aiosqlite:///", "sqlite:///") +config.set_main_option("sqlalchemy.url", sync_url) + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + include_object=include_object, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + include_object=include_object, + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/api/alembic/script.py.mako b/api/alembic/script.py.mako new file mode 100644 index 000000000..11016301e --- /dev/null +++ b/api/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/api/alembic/versions/f7661070ec23_initial_migration_create_all_tables.py b/api/alembic/versions/f7661070ec23_initial_migration_create_all_tables.py new file mode 100644 index 000000000..782f0de53 --- /dev/null +++ b/api/alembic/versions/f7661070ec23_initial_migration_create_all_tables.py @@ -0,0 +1,244 @@ +"""Initial migration - create all tables + +Revision ID: f7661070ec23 +Revises: +Create Date: 2025-11-21 15:04:59.420186 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "f7661070ec23" +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create all initial tables.""" + connection = op.get_bind() + + # Helper function to check if table exists + def table_exists(table_name: str) -> bool: + result = connection.execute( + sa.text("SELECT name FROM sqlite_master WHERE type='table' AND name=:name"), {"name": table_name} + ) + return result.fetchone() is not None + + # Config table + if not table_exists("config"): + op.create_table( + "config", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("key", sa.String(), nullable=False), + sa.Column("value", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("key"), + ) + op.create_index(op.f("ix_config_key"), "config", ["key"], unique=True) + + # Plugin table + if table_exists("plugins"): + # Drop all indexes on the table + op.drop_index(op.f("ix_plugins_name"), table_name="plugins") + op.drop_index(op.f("ix_plugins_type"), table_name="plugins") + # Drop the table + op.drop_table("plugins") + # Create the table again + # op.create_table( + # "plugins", + # sa.Column("id", sa.Integer(), nullable=False), + # sa.Column("name", sa.String(), nullable=False), + # sa.Column("type", sa.String(), nullable=False), + # sa.PrimaryKeyConstraint("id"), + # sa.UniqueConstraint("name"), + # ) + # op.create_index(op.f("ix_plugins_name"), "plugins", ["name"], unique=True) + # op.create_index(op.f("ix_plugins_type"), "plugins", ["type"], unique=False) + + # TrainingTemplate table + if table_exists("training_template"): + # Drop all indexes on the table + op.drop_index(op.f("ix_training_template_name"), table_name="training_template") + op.drop_index(op.f("ix_training_template_created_at"), table_name="training_template") + op.drop_index(op.f("ix_training_template_type"), table_name="training_template") + op.drop_index(op.f("ix_training_template_updated_at"), table_name="training_template") + # Drop the table + op.drop_table("training_template") + # Create the table again + # op.create_table( + # "training_template", + # sa.Column("id", sa.Integer(), nullable=False), + # sa.Column("name", sa.String(), nullable=False), + # sa.Column("description", sa.String(), nullable=True), + # sa.Column("type", sa.String(), nullable=True), + # sa.Column("datasets", sa.String(), nullable=True), + # sa.Column("config", sa.String(), nullable=True), + # sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + # sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + # sa.PrimaryKeyConstraint("id"), + # sa.UniqueConstraint("name"), + # ) + # op.create_index(op.f("ix_training_template_name"), "training_template", ["name"], unique=True) + # op.create_index(op.f("ix_training_template_created_at"), "training_template", ["created_at"], unique=False) + # op.create_index(op.f("ix_training_template_type"), "training_template", ["type"], unique=False) + # op.create_index(op.f("ix_training_template_updated_at"), "training_template", ["updated_at"], unique=False) + + # Workflow table + if not table_exists("workflows"): + op.create_table( + "workflows", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=True), + sa.Column("config", sa.JSON(), nullable=True), + sa.Column("status", sa.String(), nullable=True), + sa.Column("experiment_id", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_workflows_status"), "workflows", ["status"], unique=False) + op.create_index("idx_workflow_id_experiment", "workflows", ["id", "experiment_id"], unique=False) + + # WorkflowRun table + if not table_exists("workflow_runs"): + op.create_table( + "workflow_runs", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("workflow_id", sa.Integer(), nullable=True), + sa.Column("workflow_name", sa.String(), nullable=True), + sa.Column("job_ids", sa.JSON(), nullable=True), + sa.Column("node_ids", sa.JSON(), nullable=True), + sa.Column("status", sa.String(), nullable=True), + sa.Column("current_tasks", sa.JSON(), nullable=True), + sa.Column("current_job_ids", sa.JSON(), nullable=True), + sa.Column("experiment_id", sa.Integer(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_workflow_runs_status"), "workflow_runs", ["status"], unique=False) + + # Team table + if not table_exists("teams"): + op.create_table( + "teams", + sa.Column("id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + + # UserTeam table + if not table_exists("users_teams"): + op.create_table( + "users_teams", + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("team_id", sa.String(), nullable=False), + sa.Column("role", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("user_id", "team_id"), + ) + + # TeamInvitation table + if not table_exists("team_invitations"): + op.create_table( + "team_invitations", + sa.Column("id", sa.String(), nullable=False), + sa.Column("token", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column("team_id", sa.String(), nullable=False), + sa.Column("invited_by_user_id", sa.String(), nullable=False), + sa.Column("role", sa.String(), nullable=False), + sa.Column("status", sa.String(), nullable=False), + sa.Column("expires_at", sa.DateTime(), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("(CURRENT_TIMESTAMP)"), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("token"), + ) + op.create_index(op.f("ix_team_invitations_email"), "team_invitations", ["email"], unique=False) + op.create_index(op.f("ix_team_invitations_status"), "team_invitations", ["status"], unique=False) + op.create_index(op.f("ix_team_invitations_team_id"), "team_invitations", ["team_id"], unique=False) + op.create_index(op.f("ix_team_invitations_token"), "team_invitations", ["token"], unique=True) + + # User table (from fastapi-users) + # Check if table exists first to avoid errors on existing databases + if not table_exists("user"): + # Create new user table with correct schema + op.create_table( + "user", + sa.Column("id", sa.CHAR(length=36), nullable=False), + sa.Column("email", sa.String(length=320), nullable=False), + sa.Column("hashed_password", sa.String(length=1024), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("1")), + sa.Column("is_superuser", sa.Boolean(), nullable=False, server_default=sa.text("0")), + sa.Column("is_verified", sa.Boolean(), nullable=False, server_default=sa.text("0")), + sa.Column("first_name", sa.String(length=100), nullable=True), + sa.Column("last_name", sa.String(length=100), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_user_email"), "user", ["email"], unique=True) + else: + # Table exists - check the schema + result = connection.execute(sa.text("PRAGMA table_info(user)")) + existing_columns = [row[1] for row in result.fetchall()] + + # Check if it's the old schema with 'name' column instead of 'first_name'/'last_name' + has_old_schema = "name" in existing_columns and ( + "first_name" not in existing_columns or "last_name" not in existing_columns + ) + + if has_old_schema: + # Drop the old table and create a new one with correct schema + # Note: This will lose user data, but the schema is incompatible + op.drop_index(op.f("ix_user_email"), table_name="user", if_exists=True) + op.drop_table("user") + + # Create new user table with correct schema + op.create_table( + "user", + sa.Column("id", sa.CHAR(length=36), nullable=False), + sa.Column("email", sa.String(length=320), nullable=False), + sa.Column("hashed_password", sa.String(length=1024), nullable=False), + sa.Column("is_active", sa.Boolean(), nullable=False, server_default=sa.text("1")), + sa.Column("is_superuser", sa.Boolean(), nullable=False, server_default=sa.text("0")), + sa.Column("is_verified", sa.Boolean(), nullable=False, server_default=sa.text("0")), + sa.Column("first_name", sa.String(length=100), nullable=True), + sa.Column("last_name", sa.String(length=100), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_user_email"), "user", ["email"], unique=True) + else: + # Schema is compatible - just add missing columns if needed + if "first_name" not in existing_columns: + op.add_column("user", sa.Column("first_name", sa.String(length=100), nullable=True)) + if "last_name" not in existing_columns: + op.add_column("user", sa.Column("last_name", sa.String(length=100), nullable=True)) + + +def downgrade() -> None: + """Drop all tables.""" + op.drop_index(op.f("ix_team_invitations_token"), table_name="team_invitations") + op.drop_index(op.f("ix_team_invitations_team_id"), table_name="team_invitations") + op.drop_index(op.f("ix_team_invitations_status"), table_name="team_invitations") + op.drop_index(op.f("ix_team_invitations_email"), table_name="team_invitations") + op.drop_table("team_invitations") + op.drop_table("users_teams") + op.drop_table("teams") + op.drop_index(op.f("ix_workflow_runs_status"), table_name="workflow_runs") + op.drop_table("workflow_runs") + op.drop_index("idx_workflow_id_experiment", table_name="workflows") + op.drop_index(op.f("ix_workflows_status"), table_name="workflows") + op.drop_table("workflows") + + op.drop_index(op.f("ix_config_key"), table_name="config") + op.drop_table("config") + # User table - only drop if it was created by this migration + try: + op.drop_index(op.f("ix_user_email"), table_name="user") + op.drop_table("user") + except Exception: + pass diff --git a/api/api.py b/api/api.py index b4145b98d..f7e8ee68d 100644 --- a/api/api.py +++ b/api/api.py @@ -82,7 +82,6 @@ from lab.dirs import set_organization_id as lab_set_org_id from lab import storage -from transformerlab.shared.models.user_model import create_db_and_tables from dotenv import load_dotenv @@ -113,8 +112,8 @@ async def lifespan(app: FastAPI): print_launch_message() galleries.update_gallery_cache() spawn_fastchat_controller_subprocess() - await db.init() - await create_db_and_tables() + await db.init() # This now runs Alembic migrations internally + # create_db_and_tables() is deprecated - migrations are handled in db.init() print("✅ SEED DATA") # Initialize experiments and cancel any running jobs seed_default_experiments() diff --git a/api/requirements-no-gpu-uv.txt b/api/requirements-no-gpu-uv.txt index 35d740416..2c9123657 100644 --- a/api/requirements-no-gpu-uv.txt +++ b/api/requirements-no-gpu-uv.txt @@ -24,6 +24,8 @@ aiosignal==1.3.2 # via aiohttp aiosqlite==0.20.0 # via -r requirements.in +alembic==1.17.2 + # via -r requirements.in annotated-types==0.7.0 # via pydantic anyio==4.8.0 @@ -240,6 +242,8 @@ magika==0.6.1 # via markitdown makefun==1.16.0 # via fastapi-users +mako==1.3.10 + # via alembic mammoth==1.9.0 # via markitdown markdown==3.7 @@ -255,6 +259,7 @@ markitdown==0.1.1 markupsafe==2.1.5 # via # jinja2 + # mako # werkzeug mcp==1.8.1 # via -r requirements.in @@ -543,6 +548,7 @@ speechrecognition==3.14.2 sqlalchemy==2.0.38 # via # -r requirements.in + # alembic # fastapi-users-db-sqlalchemy sse-starlette==2.3.5 # via mcp @@ -619,6 +625,7 @@ typer==0.15.4 typing-extensions==4.12.2 # via # aiosqlite + # alembic # anyio # azure-ai-documentintelligence # azure-core diff --git a/api/requirements-rocm-uv.txt b/api/requirements-rocm-uv.txt index 0b08f2330..a331d8ec8 100644 --- a/api/requirements-rocm-uv.txt +++ b/api/requirements-rocm-uv.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements-rocm.in -o requirements-rocm-uv.txt --index-strategy unsafe-best-match --python-platform linux +# uv pip compile requirements-rocm.in -o requirements-rocm-uv.txt --index-strategy unsafe-best-match absl-py==2.2.2 # via tensorboard accelerate==1.6.0 @@ -24,6 +24,8 @@ aiosignal==1.3.2 # via aiohttp aiosqlite==0.21.0 # via -r requirements-rocm.in +alembic==1.17.2 + # via -r requirements-rocm.in annotated-types==0.7.0 # via pydantic anyio==4.9.0 @@ -240,6 +242,8 @@ magika==0.6.1 # via markitdown makefun==1.16.0 # via fastapi-users +mako==1.3.10 + # via alembic mammoth==1.9.0 # via markitdown markdown==3.8 @@ -255,6 +259,7 @@ markitdown==0.1.1 markupsafe==2.1.5 # via # jinja2 + # mako # werkzeug mcp==1.9.1 # via -r requirements-rocm.in @@ -546,6 +551,7 @@ speechrecognition==3.14.2 sqlalchemy==2.0.40 # via # -r requirements-rocm.in + # alembic # fastapi-users-db-sqlalchemy sse-starlette==2.3.5 # via mcp @@ -622,6 +628,7 @@ typer==0.16.0 typing-extensions==4.12.2 # via # aiosqlite + # alembic # anyio # azure-ai-documentintelligence # azure-core diff --git a/api/requirements-rocm.in b/api/requirements-rocm.in index a9c53300d..42a626af7 100644 --- a/api/requirements-rocm.in +++ b/api/requirements-rocm.in @@ -39,4 +39,5 @@ controlnet_aux==0.0.10 timm==1.0.15 librosa==0.11.0 soundfile==0.13.1 -fastapi-users[sqlalchemy] \ No newline at end of file +fastapi-users[sqlalchemy] +alembic \ No newline at end of file diff --git a/api/requirements-uv.txt b/api/requirements-uv.txt index 69a6392f6..aebf0b086 100644 --- a/api/requirements-uv.txt +++ b/api/requirements-uv.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements.in -o requirements-uv.txt --index-strategy unsafe-best-match --python-platform linux +# uv pip compile requirements.in -o requirements-uv.txt absl-py==2.1.0 # via tensorboard accelerate==1.3.0 @@ -24,6 +24,8 @@ aiosignal==1.3.2 # via aiohttp aiosqlite==0.20.0 # via -r requirements.in +alembic==1.17.2 + # via -r requirements.in annotated-types==0.7.0 # via pydantic anyio==4.8.0 @@ -240,6 +242,8 @@ magika==0.6.1 # via markitdown makefun==1.16.0 # via fastapi-users +mako==1.3.10 + # via alembic mammoth==1.9.0 # via markitdown markdown==3.7 @@ -255,6 +259,7 @@ markitdown==0.1.1 markupsafe==2.1.5 # via # jinja2 + # mako # werkzeug mcp==1.8.1 # via -r requirements.in @@ -581,6 +586,7 @@ speechrecognition==3.14.2 sqlalchemy==2.0.38 # via # -r requirements.in + # alembic # fastapi-users-db-sqlalchemy sse-starlette==2.3.5 # via mcp @@ -659,6 +665,7 @@ typer==0.15.4 typing-extensions==4.12.2 # via # aiosqlite + # alembic # anyio # azure-ai-documentintelligence # azure-core diff --git a/api/requirements.in b/api/requirements.in index 2404d5ce1..543a6d5f2 100644 --- a/api/requirements.in +++ b/api/requirements.in @@ -40,3 +40,4 @@ timm==1.0.15 librosa==0.11.0 soundfile==0.13.1 fastapi-users[sqlalchemy] +alembic diff --git a/api/test/api/test_experiment_jobs.py b/api/test/api/test_experiment_jobs.py index d20ff921c..748f5157d 100644 --- a/api/test/api/test_experiment_jobs.py +++ b/api/test/api/test_experiment_jobs.py @@ -92,19 +92,6 @@ def test_job_evaluation_images(client): assert resp.status_code in (200, 404) -def test_training_template_endpoints(client): - """Test training template endpoints""" - # Test get training template - resp = client.get("/experiment/alpha/jobs/template/1") - assert resp.status_code in (200, 404) - - # Test update training template with valid config - resp = client.put( - '/experiment/alpha/jobs/template/update?template_id=1&name=test&description=test&type=test&config={"valid": "config"}' - ) - assert resp.status_code in (200, 404, 400) # 400 for invalid config - - def test_job_get_by_id(client): """Test getting job by ID""" resp = client.get("/experiment/alpha/jobs/1") diff --git a/api/test/api/test_train.py b/api/test/api/test_train.py index e08e28888..de923e91b 100644 --- a/api/test/api/test_train.py +++ b/api/test/api/test_train.py @@ -1,15 +1,3 @@ -def test_train_templates(client): - resp = client.get("/train/templates") - assert resp.status_code == 200 - assert isinstance(resp.json(), list) or isinstance(resp.json(), dict) - - def test_train_export_recipe(client): resp = client.get("/train/template/1/export") assert resp.status_code in (200, 404) - - -def test_train_create_template(client): - data = {"name": "test_template", "description": "desc", "type": "test", "config": "{}"} - resp = client.post("/train/template/create", data=data) - assert resp.status_code in (200, 422, 400) diff --git a/api/test/db/test_db.py b/api/test/db/test_db.py index 85831c044..59fec6ed0 100644 --- a/api/test/db/test_db.py +++ b/api/test/db/test_db.py @@ -10,14 +10,8 @@ from transformerlab.db.db import ( # noqa: E402 - get_training_template_by_name, config_get, config_set, - get_training_template, - get_training_templates, - create_training_template, - update_training_template, - delete_training_template, ) from transformerlab.services import experiment_service # noqa: E402 @@ -54,14 +48,6 @@ import pytest # noqa: E402 -@pytest.mark.asyncio -async def test_get_training_template_and_by_name_returns_none_for_missing(): - tmpl = await get_training_template(999999) - assert tmpl is None - tmpl = await get_training_template_by_name("does_not_exist") - assert tmpl is None - - @pytest.mark.asyncio @pytest.mark.skip("skipping workflow tests") async def test_workflows_get_by_id_returns_none_for_missing(): @@ -85,20 +71,6 @@ async def test_config_get_returns_none_for_missing(): pytest_plugins = ("pytest_asyncio",) -@pytest.mark.asyncio -async def test_training_template_crud(): - await create_training_template("tmpl", "desc", "type", "[]", "{}") - templates = await get_training_templates() - assert any(t.get("name") == "tmpl" for t in templates) - tmpl_id = templates[0].get("id") - await update_training_template(tmpl_id, "tmpl2", "desc2", "type2", "[]", "{}") - tmpl = await get_training_template(tmpl_id) - assert tmpl["name"] == "tmpl2" - await delete_training_template(tmpl_id) - tmpl = await get_training_template(tmpl_id) - assert tmpl is None - - pytest_plugins = ("pytest_asyncio",) @@ -354,91 +326,3 @@ async def test_sync_job_functions_trigger_workflows(self, test_experiment): # Check that workflow was triggered again workflow_runs = await workflow_runs_get_from_experiment(test_experiment) assert len(workflow_runs) >= 3 - - -# @pytest.mark.skip(reason="Skipping because I can't get it to work") -# @pytest.mark.asyncio -# async def test_workflow_run_get_running(setup_db): -# """Test the workflow_run_get_running function.""" -# # Create a workflow and workflow_run using db methods -# workflow_id = await db.workflow_create("test_workflow", "{}", "test_experiment") -# await db.workflow_queue(workflow_id) - -# # Sleep for 3 seconds, async: -# await asyncio.sleep(3) - -# # Test the function -# running_workflow = await db.workflow_run_get_running() - -# # Verify results -# assert running_workflow is not None -# assert running_workflow["status"] == "RUNNING" -# assert running_workflow["workflow_name"] == "test_workflow" - - -# @pytest.mark.skip(reason="Skipping because I can't get it to work") -# @pytest.mark.asyncio -# async def test_training_jobs_get_all(setup_db): -# """Test the training_jobs_get_all function.""" -# # Create a training template using db method -# template_id = await db.create_training_template("test_template", "Test description", "fine-tuning", "[]", "{}") - -# # Create a job that references this training template -# job_data = {"template_id": template_id, "description": "Test training job"} -# job_id = await db.job_create("TRAIN", "QUEUED", json.dumps(job_data), "test_experiment") - -# # Test the function -# training_jobs = await db.training_jobs_get_all() - -# # Verify results -# assert len(training_jobs) > 0 -# found_job = False -# for job in training_jobs: -# if job["id"] == job_id: -# found_job = True -# assert job["type"] == "TRAIN" -# assert job.get("status") == "QUEUED" -# assert job["job_data"]["template_id"] == template_id -# assert job["job_data"]["description"] == "Test training job" -# assert "config" in job - -# assert found_job, "The created training job was not found in the results" - - -# @pytest.mark.skip(reason="Skipping test_workflow_run_get_running because I can't get it to work") -# @pytest.mark.asyncio -# async def test_workflow_run_get_queued(setup_db): -# """Test the workflow_run_get_queued function.""" -# # Create a workflow and workflow_run using db methods -# workflow_id = await db.workflow_create("queued_workflow", "{}", "test_experiment") - -# # Test the function -# queued_workflow = await db.workflow_run_get_queued() - -# # Verify results -# assert queued_workflow is not None -# assert queued_workflow["status"] == "QUEUED" -# assert queued_workflow["workflow_name"] == "queued_workflow" - - -# @pytest.mark.skip(reason="Skipping test_workflow_run_get_running because I can't get it to work") -# @pytest.mark.asyncio -# async def test_workflow_run_update_with_new_job(setup_db): -# """Test the workflow_run_update_with_new_job function.""" -# # Create a workflow and workflow_run using db methods -# workflow_id = await db.workflow_create("test_workflow", "{}", "test_experiment") # noqa: F841 - -# # New task and job IDs -# current_task = '["task1"]' -# current_job_id = "[1]" - -# # Test the function -# await db.workflow_run_update_with_new_job(workflow_run_id, current_task, current_job_id) - -# # Verify results -# updated_workflow_run = await db.workflow_run_get_by_id(workflow_run_id) -# assert updated_workflow_run is not None -# assert updated_workflow_run["current_tasks"] == current_task -# assert updated_workflow_run["current_job_ids"] == current_job_id -# assert json.loads(updated_workflow_run["job_ids"]) == [1] -# assert json.loads(updated_workflow_run["node_ids"]) == ["task1"] diff --git a/api/transformerlab/db/db.py b/api/transformerlab/db/db.py index 26ed8781c..bfc32f2c1 100644 --- a/api/transformerlab/db/db.py +++ b/api/transformerlab/db/db.py @@ -1,6 +1,4 @@ -import json - -from sqlalchemy import select, delete, text, update +from sqlalchemy import select from sqlalchemy.dialects.sqlite import insert # Correct import for SQLite upsert # from sqlalchemy import create_engine @@ -12,9 +10,7 @@ from typing import AsyncGenerator -from transformerlab.shared.models import models from transformerlab.shared.models.models import Config -from transformerlab.db.utils import sqlalchemy_to_dict, sqlalchemy_list_to_dict from transformerlab.db.session import async_session @@ -24,113 +20,6 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]: yield session -async def get_training_template(id): - async with async_session() as session: - result = await session.execute(select(models.TrainingTemplate).where(models.TrainingTemplate.id == id)) - template = result.scalar_one_or_none() - if template is None: - return None - # Convert ORM object to dict - return sqlalchemy_to_dict(template) - - -async def get_training_template_by_name(name): - async with async_session() as session: - result = await session.execute(select(models.TrainingTemplate).where(models.TrainingTemplate.name == name)) - template = result.scalar_one_or_none() - if template is None: - return None - # Convert ORM object to dict - return sqlalchemy_to_dict(template) - - -async def get_training_templates(): - async with async_session() as session: - result = await session.execute( - select(models.TrainingTemplate).order_by(models.TrainingTemplate.created_at.desc()) - ) - templates = result.scalars().all() - # Convert ORM objects to dicts if needed - return sqlalchemy_list_to_dict(templates) - - -async def create_training_template(name, description, type, datasets, config): - async with async_session() as session: - template = models.TrainingTemplate( - name=name, - description=description, - type=type, - datasets=datasets, - config=config, - ) - session.add(template) - await session.commit() - return - - -async def update_training_template(id, name, description, type, datasets, config): - async with async_session() as session: - await session.execute( - update(models.TrainingTemplate) - .where(models.TrainingTemplate.id == id) - .values( - name=name, - description=description, - type=type, - datasets=datasets, - config=config, - ) - ) - await session.commit() - return - - -async def delete_training_template(id): - async with async_session() as session: - await session.execute(delete(models.TrainingTemplate).where(models.TrainingTemplate.id == id)) - await session.commit() - return - - -async def training_jobs_get_all(): - async with async_session() as session: - # Select jobs of type "TRAIN" and join with TrainingTemplate using the template_id from job_data JSON - stmt = ( - select( - models.Job, - models.TrainingTemplate.id.label("tt_id"), - models.TrainingTemplate.config, - ) - .join( - models.TrainingTemplate, - text("json_extract(job.job_data, '$.template_id') = training_template.id"), - ) - .where(models.Job.type == "TRAIN") - ) - result = await session.execute(stmt) - rows = result.all() - - data = [] - for job, tt_id, config in rows: - row = sqlalchemy_to_dict(job) - row["tt_id"] = tt_id - # Convert job_data and config from JSON string to Python object - if "job_data" in row and row["job_data"]: - try: - row["job_data"] = json.loads(row["job_data"]) - except Exception: - pass - if config: - try: - row["config"] = json.loads(config) - except Exception: - row["config"] = config - else: - row["config"] = None - data.append(row) - return data - - ############### # Config MODEL ############### diff --git a/api/transformerlab/db/session.py b/api/transformerlab/db/session.py index 4c7a9119c..5ba03153d 100644 --- a/api/transformerlab/db/session.py +++ b/api/transformerlab/db/session.py @@ -1,12 +1,13 @@ import os import shutil import aiosqlite +import subprocess +import sys from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from transformerlab.db.constants import DATABASE_FILE_NAME, DATABASE_URL from lab.dirs import get_workspace_dir -from transformerlab.shared.models import models # --- SQLAlchemy Async Engine --- @@ -24,6 +25,41 @@ ) +async def run_alembic_migrations(): + """ + Run Alembic migrations to create/update database schema. + This replaces the previous create_all() approach. + """ + try: + # Get the directory containing this file (transformerlab/db) + current_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up to api directory where alembic.ini is located + api_dir = os.path.dirname(os.path.dirname(current_dir)) + + # Run alembic upgrade head + result = subprocess.run( + [sys.executable, "-m", "alembic", "upgrade", "head"], + cwd=api_dir, + capture_output=True, + text=True, + check=False + ) + + if result.returncode != 0: + print(f"⚠️ Alembic migration warning: {result.stderr}") + # Don't fail completely - the database might already be up to date + # or there might be a minor issue + if "Target database is not up to date" not in result.stderr: + print(f"Migration output: {result.stdout}") + else: + print("✅ Database migrations applied") + except Exception as e: + print(f"⚠️ Error running Alembic migrations: {e}") + print("Continuing with startup - database may need manual migration") + # Don't raise - allow the app to continue + # The database might already be in the correct state + + async def init(): """ Create the database, tables, and workspace folder if they don't exist. @@ -52,9 +88,9 @@ async def init(): await db.execute("PRAGMA synchronous=normal") await db.execute("PRAGMA busy_timeout = 30000") - # Create the tables if they don't exist - async with async_engine.begin() as conn: - await conn.run_sync(models.Base.metadata.create_all) + # Run Alembic migrations to create/update tables + # This replaces the previous create_all() call + await run_alembic_migrations() # Check if experiment_id column exists in workflow_runs table cursor = await db.execute("PRAGMA table_info(workflow_runs)") @@ -152,9 +188,9 @@ async def migrate_workflows_non_preserving(): # Rename current table as backup await db.execute("ALTER TABLE workflows RENAME TO workflows_backup") - # Create new workflows table using SQLAlchemy schema - async with async_engine.begin() as conn: - await conn.run_sync(models.Base.metadata.create_all) + # Note: Table creation is now handled by Alembic migrations + # If we need to recreate the workflows table, it should be done via a migration + pass await db.commit() print("Successfully created new workflows table with correct schema. Old table saved as workflows_backup.") diff --git a/api/transformerlab/routers/experiment/jobs.py b/api/transformerlab/routers/experiment/jobs.py index 3f6dec4cb..d88df2625 100644 --- a/api/transformerlab/routers/experiment/jobs.py +++ b/api/transformerlab/routers/experiment/jobs.py @@ -4,24 +4,22 @@ import os import csv import pandas as pd -from fastapi import APIRouter, Body, Response, Request +from fastapi import APIRouter, Response, Request from fastapi.responses import StreamingResponse, FileResponse from lab import storage from transformerlab.shared import shared -from typing import Annotated from json import JSONDecodeError from werkzeug.utils import secure_filename from transformerlab.routers.serverinfo import watch_file -from transformerlab.db.db import get_training_template from datetime import datetime import transformerlab.services.job_service as job_service from transformerlab.services.job_service import job_update_status -from lab import dirs, Job +from lab import Job from lab.dirs import get_workspace_dir router = APIRouter(prefix="/jobs", tags=["train"]) @@ -120,54 +118,6 @@ async def get_training_job(job_id: str): return job -@router.get("/{job_id}/output") -async def get_training_job_output(job_id: str, sweeps: bool = False): - # First get the template Id from this job: - job = job_service.job_get(job_id) - if job is None: - return {"checkpoints": []} - job_data = job["job_data"] - - if not isinstance(job_data, dict): - try: - job_data = json.loads(job_data) - except JSONDecodeError: - print(f"Error decoding job_data for job {job_id}. Using empty job_data.") - job_data = {} - - if sweeps: - output_file = job_data.get("sweep_output_file", None) - if output_file is not None and storage.exists(output_file): - with storage.open(output_file, "r") as f: - output = f.read() - return output - - if "template_id" not in job_data: - return {"error": "true"} - - template_id = job_data["template_id"] - # Then get the template: - template = await get_training_template(template_id) - # Then get the plugin name from the template: - if not isinstance(template["config"], dict): - template_config = json.loads(template["config"]) - else: - template_config = template["config"] - if "plugin_name" not in template_config: - return {"error": "true"} - - plugin_name = template_config["plugin_name"] - - # Now we can get the output.txt from the plugin which is stored in - # /workspace/experiments/{experiment_name}/plugins/{plugin_name}/output.txt - output_file = storage.join(dirs.plugin_dir_by_name(plugin_name), "output.txt") - if storage.exists(output_file): - with storage.open(output_file, "r") as f: - output = f.read() - return output - return "" - - @router.get("/{job_id}/tasks_output") async def get_tasks_job_output(job_id: str, sweeps: bool = False): """ @@ -251,30 +201,30 @@ async def get_tasks_job_output(job_id: str, sweeps: bool = False): # Templates -@router.get("/template/{template_id}") -async def get_train_template(template_id: str): - return await get_training_template(template_id) - - -@router.put("/template/update") -async def update_training_template( - template_id: str, - name: str, - description: str, - type: str, - config: Annotated[str, Body(embed=True)], -): - try: - configObject = json.loads(config) - datasets = configObject["dataset_name"] - job_service.update_training_template(template_id, name, description, type, datasets, config) - except JSONDecodeError as e: - print(f"JSON decode error: {e}") - return {"status": "error", "message": "An error occurred while processing the request."} - except Exception as e: - print(f"Unexpected error: {e}") - return {"status": "error", "message": "An internal error has occurred."} - return {"status": "success"} +# @router.get("/template/{template_id}") +# async def get_train_template(template_id: str): +# return await get_training_template(template_id) + + +# @router.put("/template/update") +# async def update_training_template( +# template_id: str, +# name: str, +# description: str, +# type: str, +# config: Annotated[str, Body(embed=True)], +# ): +# try: +# configObject = json.loads(config) +# datasets = configObject["dataset_name"] +# job_service.update_training_template(template_id, name, description, type, datasets, config) +# except JSONDecodeError as e: +# print(f"JSON decode error: {e}") +# return {"status": "error", "message": "An error occurred while processing the request."} +# except Exception as e: +# print(f"Unexpected error: {e}") +# return {"status": "error", "message": "An internal error has occurred."} +# return {"status": "success"} @router.get("/{job_id}/stream_output") diff --git a/api/transformerlab/routers/train.py b/api/transformerlab/routers/train.py index 28eed3acd..835dcff10 100644 --- a/api/transformerlab/routers/train.py +++ b/api/transformerlab/routers/train.py @@ -1,61 +1,15 @@ -import json import subprocess -from typing import Annotated -from fastapi import APIRouter, Body -import transformerlab.db.db as db +from fastapi import APIRouter import transformerlab.services.job_service as job_service from lab import Experiment, storage from werkzeug.utils import secure_filename -# @TODO hook this up to an endpoint so we can cancel a finetune - - -def abort_fine_tune(): - print("Aborting training...") - return "abort" - router = APIRouter(prefix="/train", tags=["train"]) -# @router.post("/finetune_lora") -# def finetune_lora( -# model: str, -# adaptor_name: str, -# text: Annotated[str, Body()], -# background_tasks: BackgroundTasks, -# ): -# background_tasks.add_task(finetune, model, text, adaptor_name) - -# return {"message": "OK"} - - -@router.post("/template/create") -async def create_training_template( - name: str, - description: str, - type: str, - config: Annotated[str, Body(embed=True)], -): - configObject = json.loads(config) - datasets = configObject["dataset_name"] - await db.create_training_template(name, description, type, datasets, config) - return {"message": "OK"} - - -@router.get("/templates") -async def get_training_templates(): - return await db.get_training_templates() - - -@router.get("/template/{template_id}/delete") -async def delete_training_template(template_id: str): - await db.delete_training_template(template_id) - return {"message": "OK"} - - tensorboard_process = None diff --git a/api/transformerlab/shared/models/models.py b/api/transformerlab/shared/models/models.py index 92331c11b..8993a27e9 100644 --- a/api/transformerlab/shared/models/models.py +++ b/api/transformerlab/shared/models/models.py @@ -19,35 +19,6 @@ class Config(Base): value: Mapped[Optional[str]] = mapped_column(String, nullable=True) -# I believe we are not using the following table anymore as the filesystem -# is being used to track plugins -class Plugin(Base): - """Plugin definition model.""" - - __tablename__ = "plugins" - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - name: Mapped[str] = mapped_column(String, unique=True, index=True, nullable=False) - type: Mapped[str] = mapped_column(String, index=True, nullable=False) - - -class TrainingTemplate(Base): - """Training template model.""" - - __tablename__ = "training_template" - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - name: Mapped[str] = mapped_column(String, unique=True, index=True, nullable=False) - description: Mapped[Optional[str]] = mapped_column(String, nullable=True) - type: Mapped[Optional[str]] = mapped_column(String, index=True, nullable=True) - datasets: Mapped[Optional[str]] = mapped_column(String, nullable=True) - config: Mapped[Optional[str]] = mapped_column(String, nullable=True) - created_at: Mapped[DateTime] = mapped_column(DateTime, index=True, server_default=func.now(), nullable=False) - updated_at: Mapped[DateTime] = mapped_column( - DateTime, index=True, server_default=func.now(), onupdate=func.now(), nullable=False - ) - - class Workflow(Base): """Workflow model.""" @@ -85,6 +56,7 @@ class WorkflowRun(Base): DateTime, server_default=func.now(), onupdate=func.now(), nullable=False ) + class Team(Base): """Team model.""" @@ -96,6 +68,7 @@ class Team(Base): class TeamRole(str, enum.Enum): """Enum for user roles within a team.""" + OWNER = "owner" MEMBER = "member" @@ -112,6 +85,7 @@ class UserTeam(Base): class InvitationStatus(str, enum.Enum): """Enum for invitation status.""" + PENDING = "pending" ACCEPTED = "accepted" REJECTED = "rejected" @@ -125,7 +99,9 @@ class TeamInvitation(Base): __tablename__ = "team_invitations" id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: str(uuid.uuid4())) - token: Mapped[str] = mapped_column(String, unique=True, index=True, nullable=False, default=lambda: str(uuid.uuid4())) + token: Mapped[str] = mapped_column( + String, unique=True, index=True, nullable=False, default=lambda: str(uuid.uuid4()) + ) email: Mapped[str] = mapped_column(String, nullable=False, index=True) team_id: Mapped[str] = mapped_column(String, nullable=False, index=True) invited_by_user_id: Mapped[str] = mapped_column(String, nullable=False) @@ -133,4 +109,6 @@ class TeamInvitation(Base): status: Mapped[str] = mapped_column(String, nullable=False, default=InvitationStatus.PENDING.value, index=True) expires_at: Mapped[DateTime] = mapped_column(DateTime, nullable=False) created_at: Mapped[DateTime] = mapped_column(DateTime, server_default=func.now(), nullable=False) - updated_at: Mapped[DateTime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now(), nullable=False) \ No newline at end of file + updated_at: Mapped[DateTime] = mapped_column( + DateTime, server_default=func.now(), onupdate=func.now(), nullable=False + ) diff --git a/api/transformerlab/shared/models/user_model.py b/api/transformerlab/shared/models/user_model.py index ce0bf9496..d1d996d17 100644 --- a/api/transformerlab/shared/models/user_model.py +++ b/api/transformerlab/shared/models/user_model.py @@ -33,9 +33,15 @@ class User(SQLAlchemyBaseUserTableUUID, Base): # 3. Utility to create tables (run this on app startup) +# NOTE: This function is deprecated. Database schema is now managed by Alembic migrations. +# See transformerlab.db.session.run_alembic_migrations() for the migration function. async def create_db_and_tables(): - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) + """ + Deprecated: Database tables are now created via Alembic migrations. + This function is kept for backwards compatibility but does nothing. + """ + # Tables are now created via Alembic migrations in db.session.init() + pass # 4. Database session dependency diff --git a/src/renderer/components/Experiment/Train/ViewOutputModal.tsx b/src/renderer/components/Experiment/Train/ViewOutputModal.tsx deleted file mode 100644 index 68244450d..000000000 --- a/src/renderer/components/Experiment/Train/ViewOutputModal.tsx +++ /dev/null @@ -1,64 +0,0 @@ -import useSWR from 'swr'; - -import { Button, Modal, ModalClose, ModalDialog, Typography } from '@mui/joy'; - -import * as chatAPI from 'renderer/lib/transformerlab-api-sdk'; -import { Editor } from '@monaco-editor/react'; -import { useExperimentInfo } from 'renderer/lib/ExperimentInfoContext'; -import { fetcher } from 'renderer/lib/transformerlab-api-sdk'; - -export default function ViewOutputModal({ jobId, setJobId }) { - const { experimentInfo } = useExperimentInfo(); - const { data, error, isLoading, isValidating, mutate } = useSWR( - jobId == -1 - ? null - : chatAPI.Endpoints.Experiment.GetOutputFromJob(experimentInfo.id, jobId), - fetcher, - { - refreshInterval: 5000, //refresh every 5 seconds - }, - ); - - // The following code prevents a crash if the output file doesn't exist - var dataChecked = ''; - if (data?.status) { - dataChecked = ''; - } else { - dataChecked = data; - } - - return ( - setJobId(-1)}> - - - - Output from job: {jobId} {isValidating && <>Refreshing...} - - - - - - - - ); -}