Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 4 additions & 85 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@
from transformerlab.shared import galleries
from transformerlab.shared.constants import WORKSPACE_DIR
from lab import dirs as lab_dirs
from lab.dataset import Dataset as dataset_service
from transformerlab.shared import dirs
from transformerlab.db.filesystem_migrations import migrate_datasets_table_to_filesystem, migrate_tasks_table_to_filesystem

from dotenv import load_dotenv

Expand Down Expand Up @@ -102,6 +102,7 @@ async def lifespan(app: FastAPI):
# run the migrations
asyncio.create_task(migrate())
asyncio.create_task(migrate_datasets_table_to_filesystem())
asyncio.create_task(migrate_tasks_table_to_filesystem())
asyncio.create_task(run_over_and_over())
print("FastAPI LIFESPAN: 🏁 🏁 🏁 Begin API Server 🏁 🏁 🏁", flush=True)
yield
Expand All @@ -114,94 +115,12 @@ async def lifespan(app: FastAPI):

# the migrate function only runs the conversion function if no tasks are already present
async def migrate():
if len(await tasks.tasks_get_all()) == 0:
from transformerlab.services.tasks_service import tasks_service
if len(await tasks_service.tasks_get_all()) == 0:
for exp in await experiment.experiments_get_all():
await tasks.convert_all_to_tasks(exp["id"])


async def migrate_datasets_table_to_filesystem():
"""
One-time migration: copy rows from the legacy dataset DB table into the filesystem
registry via transformerlab-sdk, then drop the table.
Safe to run multiple times; it will no-op if table is missing or empty.
"""
try:
# Late import to avoid hard dependency during tests without DB
from transformerlab.db.session import async_session
from sqlalchemy import text as sqlalchemy_text

# Read existing rows
rows = []
try:
# First check if the table exists
async with async_session() as session:
result = await session.execute(
sqlalchemy_text("SELECT name FROM sqlite_master WHERE type='table' AND name='dataset'")
)
exists = result.fetchone() is not None
if not exists:
return
# Migrated db.dataset.get_datasets() to run here as we are deleting that code
rows = []
async with async_session() as session:
result = await session.execute(sqlalchemy_text("SELECT * FROM dataset"))
datasets = result.mappings().all()
dict_rows = [dict(dataset) for dataset in datasets]
for row in dict_rows:
if "json_data" in row and row["json_data"]:
if isinstance(row["json_data"], str):
row["json_data"] = json.loads(row["json_data"])
rows.append(row)
except Exception as e:
print(f"Failed to read datasets for migration: {e}")
rows = []

migrated = 0
for row in rows:
dataset_id = str(row.get("dataset_id")) if row.get("dataset_id") is not None else None
if not dataset_id:
continue
location = row.get("location", "local")
description = row.get("description", "")
size = int(row.get("size", -1)) if row.get("size") is not None else -1
json_data = row.get("json_data", {})
if isinstance(json_data, str):
try:
json_data = json.loads(json_data)
except Exception:
json_data = {}

try:
try:
ds = dataset_service.get(dataset_id)
except FileNotFoundError:
ds = dataset_service.create(dataset_id)
ds.set_metadata(
location=location,
description=description,
size=size,
json_data=json_data,
)
migrated += 1
except Exception:
# Best-effort migration; continue
continue

# Drop the legacy table if present
try:
async with async_session() as session:
await session.execute(sqlalchemy_text("ALTER TABLE dataset RENAME TO zzz_archived_dataset"))
await session.commit()
except Exception:
pass

if migrated:
print(f"Datasets migration completed: {migrated} entries migrated to filesystem store.")
except Exception as e:
# Do not block startup on migration issues
print(f"Datasets migration skipped due to error: {e}")


async def run_over_and_over():
"""Every three seconds, check for new jobs to run."""
while True:
Expand Down
2 changes: 1 addition & 1 deletion requirements-no-gpu-uv.txt
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ tqdm==4.66.5
# peft
# sentence-transformers
# transformers
transformerlab==0.0.9
transformerlab==0.0.12
# via -r requirements.in
transformerlab-inference==0.2.49
# via -r requirements.in
Expand Down
2 changes: 1 addition & 1 deletion requirements-rocm-uv.txt
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ tqdm==4.66.5
# peft
# sentence-transformers
# transformers
transformerlab==0.0.9
transformerlab==0.0.12
# via -r requirements-rocm.in
transformerlab-inference==0.2.49
# via -r requirements-rocm.in
Expand Down
2 changes: 1 addition & 1 deletion requirements-rocm.in
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ hf_xet
macmon-python
mcp[cli]
transformerlab-inference==0.2.49
transformerlab==0.0.9
transformerlab==0.0.12
diffusers==0.33.1
pyrsmi
controlnet_aux==0.0.10
Expand Down
2 changes: 1 addition & 1 deletion requirements-uv.txt
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ tqdm==4.66.5
# peft
# sentence-transformers
# transformers
transformerlab==0.0.9
transformerlab==0.0.12
# via -r requirements.in
transformerlab-inference==0.2.49
# via -r requirements.in
Expand Down
2 changes: 1 addition & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ markitdown[all]
hf_xet
macmon-python
transformerlab-inference==0.2.49
transformerlab==0.0.9
transformerlab==0.0.12
diffusers==0.33.1
nvidia-ml-py
mcp[cli]
Expand Down
21 changes: 11 additions & 10 deletions test/api/test_experiment_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import transformerlab.db.db as db
import transformerlab.db.workflows as db_workflows
from transformerlab.shared.constants import WORKSPACE_DIR
from transformerlab.services.tasks_service import tasks_service


async def test_export_experiment(client):
Expand All @@ -21,13 +22,13 @@ async def test_export_experiment(client):
"batch_size": "4",
"learning_rate": "0.0001",
}
await db.add_task(
await tasks_service.add_task(
name="test_train_task",
Type="TRAIN",
inputs=json.dumps({"model_name": "test-model", "dataset_name": "test-dataset"}),
config=json.dumps(train_config),
task_type="TRAIN",
inputs={"model_name": "test-model", "dataset_name": "test-dataset"},
config=train_config,
plugin="test_trainer",
outputs="{}",
outputs={},
experiment_id=experiment_id,
)

Expand All @@ -40,13 +41,13 @@ async def test_export_experiment(client):
"script_parameters": {"tasks": ["mmlu"], "limit": 0.5},
"eval_dataset": "test-eval-dataset",
}
await db.add_task(
await tasks_service.add_task(
name="test_eval_task",
Type="EVAL",
inputs=json.dumps({"model_name": "test-model-2", "dataset_name": "test-eval-dataset"}),
config=json.dumps(eval_config),
task_type="EVAL",
inputs={"model_name": "test-model-2", "dataset_name": "test-eval-dataset"},
config=eval_config,
plugin="test_evaluator",
outputs=json.dumps({"eval_results": "{}"}),
outputs={"eval_results": "{}"},
experiment_id=experiment_id,
)

Expand Down
39 changes: 0 additions & 39 deletions test/db/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,6 @@
save_plugin,
config_get,
config_set,
add_task,
update_task,
tasks_get_all,
tasks_get_by_type,
tasks_get_by_type_in_experiment,
delete_task,
tasks_delete_all,
tasks_get_by_id,
get_training_template,
get_training_templates,
create_training_template,
Expand Down Expand Up @@ -149,12 +141,6 @@ async def test_job_cancel_in_progress_jobs_sets_cancelled():
assert job.get("status") == "CANCELLED"


@pytest.mark.asyncio
async def test_tasks_get_by_id_returns_none_for_missing():
task = await tasks_get_by_id(999999)
assert task is None


@pytest.mark.asyncio
async def test_get_training_template_and_by_name_returns_none_for_missing():
tmpl = await get_training_template(999999)
Expand Down Expand Up @@ -269,31 +255,6 @@ async def test_experiment_update_and_update_config_and_save_prompt_template(test
assert exp_config.get("prompt_template") == test_prompt


@pytest.mark.asyncio
async def test_task_crud(test_experiment):
await add_task("task1", "TYPE", "{}", "{}", "plugin", "{}", test_experiment)
tasks = await tasks_get_all()
assert any(t["name"] == "task1" for t in tasks)
task = tasks[0]
await update_task(task["id"], {"inputs": "[]", "config": "{}", "outputs": "[]", "name": "task1_updated"})
updated = await tasks_get_by_id(task["id"])
assert updated["name"] == "task1_updated"
await delete_task(task["id"])
deleted = await tasks_get_by_id(task["id"])
assert deleted is None


@pytest.mark.asyncio
async def test_tasks_get_by_type_and_in_experiment(test_experiment):
await add_task("task2", "TYPE2", "{}", "{}", "plugin", "{}", test_experiment)
by_type = await tasks_get_by_type("TYPE2")
assert any(t["name"] == "task2" for t in by_type)
by_type_exp = await tasks_get_by_type_in_experiment("TYPE2", test_experiment)
assert any(t["name"] == "task2" for t in by_type_exp)
await tasks_delete_all()
all_tasks = await tasks_get_all()
assert len(all_tasks) == 0


@pytest.mark.asyncio
async def test_training_template_crud():
Expand Down
Loading
Loading