diff --git a/airflow-core/src/airflow/models/asset.py b/airflow-core/src/airflow/models/asset.py index 4d9bd8ae1cf8d..a28addd99819f 100644 --- a/airflow-core/src/airflow/models/asset.py +++ b/airflow-core/src/airflow/models/asset.py @@ -35,12 +35,12 @@ text, ) from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy.orm import relationship from airflow._shared.timezones import timezone from airflow.models.base import Base, StringID from airflow.settings import json -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: from collections.abc import Iterable @@ -140,7 +140,7 @@ def remove_references_to_deleted_dags(session: Session): class AssetWatcherModel(Base): """A table to store asset watchers.""" - name: Mapped[str] = mapped_column( + name = Column( String(length=1500).with_variant( String( length=1500, @@ -152,8 +152,8 @@ class AssetWatcherModel(Base): ), nullable=False, ) - asset_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - trigger_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) + asset_id = Column(Integer, primary_key=True, nullable=False) + trigger_id = Column(Integer, primary_key=True, nullable=False) asset = relationship("AssetModel", back_populates="watchers") trigger = relationship("Trigger", back_populates="asset_watchers") @@ -187,8 +187,8 @@ class AssetAliasModel(Base): :param uri: a string that uniquely identifies the asset alias """ - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - name: Mapped[str] = mapped_column( + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column( String(length=1500).with_variant( String( length=1500, @@ -200,7 +200,7 @@ class AssetAliasModel(Base): ), nullable=False, ) - group: Mapped[str] = mapped_column( + group = Column( String(length=1500).with_variant( String( length=1500, @@ -263,8 +263,8 @@ class AssetModel(Base): :param extra: JSON field for arbitrary extra info """ - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - name: Mapped[str] = mapped_column( + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column( String(length=1500).with_variant( String( length=1500, @@ -276,7 +276,7 @@ class AssetModel(Base): ), nullable=False, ) - uri: Mapped[str] = mapped_column( + uri = Column( String(length=1500).with_variant( String( length=1500, @@ -288,7 +288,7 @@ class AssetModel(Base): ), nullable=False, ) - group: Mapped[str] = mapped_column( + group = Column( String(length=1500).with_variant( String( length=1500, @@ -301,12 +301,10 @@ class AssetModel(Base): default=str, nullable=False, ) - extra: Mapped[dict] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False - ) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) active = relationship("AssetActive", uselist=False, viewonly=True, back_populates="asset") @@ -376,7 +374,7 @@ class AssetActive(Base): *name and URI are each unique* within active assets. """ - name: Mapped[str] = mapped_column( + name = Column( String(length=1500).with_variant( String( length=1500, @@ -388,7 +386,7 @@ class AssetActive(Base): ), nullable=False, ) - uri: Mapped[str] = mapped_column( + uri = Column( String(length=1500).with_variant( String( length=1500, @@ -424,7 +422,7 @@ def for_asset(cls, asset: AssetModel) -> AssetActive: class DagScheduleAssetNameReference(Base): """Reference from a DAG to an asset name reference of which it is a consumer.""" - name: Mapped[str] = mapped_column( + name = Column( String(length=1500).with_variant( String( length=1500, @@ -437,8 +435,8 @@ class DagScheduleAssetNameReference(Base): primary_key=True, nullable=False, ) - dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + dag_id = Column(StringID(), primary_key=True, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) dag = relationship("DagModel", back_populates="schedule_asset_name_references") @@ -470,7 +468,7 @@ def __repr__(self): class DagScheduleAssetUriReference(Base): """Reference from a DAG to an asset URI reference of which it is a consumer.""" - uri: Mapped[str] = mapped_column( + uri = Column( String(length=1500).with_variant( String( length=1500, @@ -483,8 +481,8 @@ class DagScheduleAssetUriReference(Base): primary_key=True, nullable=False, ) - dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + dag_id = Column(StringID(), primary_key=True, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) dag = relationship("DagModel", back_populates="schedule_asset_uri_references") @@ -516,12 +514,10 @@ def __repr__(self): class DagScheduleAssetAliasReference(Base): """References from a DAG to an asset alias of which it is a consumer.""" - alias_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False - ) + alias_id = Column(Integer, primary_key=True, nullable=False) + dag_id = Column(StringID(), primary_key=True, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) asset_alias = relationship("AssetAliasModel", back_populates="scheduled_dags") dag = relationship("DagModel", back_populates="schedule_asset_alias_references") @@ -560,12 +556,10 @@ def __repr__(self): class DagScheduleAssetReference(Base): """References from a DAG to an asset of which it is a consumer.""" - asset_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False - ) + asset_id = Column(Integer, primary_key=True, nullable=False) + dag_id = Column(StringID(), primary_key=True, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) asset = relationship("AssetModel", back_populates="scheduled_dags") dag = relationship("DagModel", back_populates="schedule_asset_references") @@ -613,13 +607,11 @@ def __repr__(self): class TaskOutletAssetReference(Base): """References from a task to an asset that it updates / produces.""" - asset_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) - task_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False - ) + asset_id = Column(Integer, primary_key=True, nullable=False) + dag_id = Column(StringID(), primary_key=True, nullable=False) + task_id = Column(StringID(), primary_key=True, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) asset = relationship("AssetModel", back_populates="producing_tasks") @@ -664,13 +656,11 @@ def __repr__(self): class TaskInletAssetReference(Base): """References from a task to an asset that it references as an inlet.""" - asset_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) - task_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False - ) + asset_id = Column(Integer, primary_key=True, nullable=False) + dag_id = Column(StringID(), primary_key=True, nullable=False) + task_id = Column(StringID(), primary_key=True, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) asset = relationship("AssetModel", back_populates="consuming_tasks") @@ -710,9 +700,9 @@ def __repr__(self): class AssetDagRunQueue(Base): """Model for storing asset events that need processing.""" - asset_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - target_dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True, nullable=False) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + asset_id = Column(Integer, primary_key=True, nullable=False) + target_dag_id = Column(StringID(), primary_key=True, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) asset = relationship("AssetModel", viewonly=True) dag_model = relationship("DagModel", viewonly=True) @@ -775,14 +765,14 @@ class AssetEvent(Base): if the foreign key object is. """ - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - asset_id: Mapped[int] = mapped_column(Integer, nullable=False) - extra: Mapped[dict] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) - source_task_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) - source_dag_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) - source_run_id: Mapped[str | None] = mapped_column(StringID(), nullable=True) - source_map_index: Mapped[int | None] = mapped_column(Integer, nullable=True, server_default=text("-1")) - timestamp: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + id = Column(Integer, primary_key=True, autoincrement=True) + asset_id = Column(Integer, nullable=False) + extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + source_task_id = Column(StringID(), nullable=True) + source_dag_id = Column(StringID(), nullable=True) + source_run_id = Column(StringID(), nullable=True) + source_map_index = Column(Integer, nullable=True, server_default=text("-1")) + timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) __tablename__ = "asset_event" __table_args__ = ( diff --git a/airflow-core/src/airflow/models/backfill.py b/airflow-core/src/airflow/models/backfill.py index a9d0f2b3e344d..7840dfaf3ec5c 100644 --- a/airflow-core/src/airflow/models/backfill.py +++ b/airflow-core/src/airflow/models/backfill.py @@ -30,6 +30,7 @@ from sqlalchemy import ( Boolean, + Column, ForeignKeyConstraint, Integer, String, @@ -39,7 +40,7 @@ select, ) from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Mapped, relationship, validates +from sqlalchemy.orm import relationship, validates from sqlalchemy_jsonfield import JSONField from airflow._shared.timezones import timezone @@ -47,7 +48,7 @@ from airflow.models.base import Base, StringID from airflow.settings import json from airflow.utils.session import create_session -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, nulls_first, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, with_row_locks from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -115,27 +116,23 @@ class Backfill(Base): __tablename__ = "backfill" - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) - from_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) - to_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) - dag_run_conf: Mapped[JSONField] = mapped_column(JSONField(json=json), nullable=False, default={}) - is_paused: Mapped[bool] = mapped_column(Boolean, default=False) + id = Column(Integer, primary_key=True, autoincrement=True) + dag_id = Column(StringID(), nullable=False) + from_date = Column(UtcDateTime, nullable=False) + to_date = Column(UtcDateTime, nullable=False) + dag_run_conf = Column(JSONField(json=json), nullable=False, default={}) + is_paused = Column(Boolean, default=False) """ Controls whether new dag runs will be created for this backfill. Does not pause existing dag runs. """ - reprocess_behavior: Mapped[str] = mapped_column( - StringID(), nullable=False, default=ReprocessBehavior.NONE - ) - max_active_runs: Mapped[int] = mapped_column(Integer, default=10, nullable=False) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) - completed_at: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False - ) - triggering_user_name: Mapped[str | None] = mapped_column( + reprocess_behavior = Column(StringID(), nullable=False, default=ReprocessBehavior.NONE) + max_active_runs = Column(Integer, default=10, nullable=False) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + completed_at = Column(UtcDateTime, nullable=True) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) + triggering_user_name = Column( String(512), nullable=True, ) # The user that triggered the Backfill, if applicable @@ -169,12 +166,12 @@ class BackfillDagRun(Base): """Mapping table between backfill run and dag run.""" __tablename__ = "backfill_dag_run" - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - backfill_id: Mapped[int] = mapped_column(Integer, nullable=False) - dag_run_id: Mapped[int | None] = mapped_column(Integer, nullable=True) - exception_reason: Mapped[str] = mapped_column(StringID()) - logical_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) - sort_ordinal: Mapped[int] = mapped_column(Integer, nullable=False) + id = Column(Integer, primary_key=True, autoincrement=True) + backfill_id = Column(Integer, nullable=False) + dag_run_id = Column(Integer, nullable=True) + exception_reason = Column(StringID()) + logical_date = Column(UtcDateTime, nullable=False) + sort_ordinal = Column(Integer, nullable=False) backfill = relationship("Backfill", back_populates="backfill_dag_run_associations") dag_run = relationship("DagRun") diff --git a/airflow-core/src/airflow/models/base.py b/airflow-core/src/airflow/models/base.py index cc9853330625a..0548e08a8f605 100644 --- a/airflow-core/src/airflow/models/base.py +++ b/airflow-core/src/airflow/models/base.py @@ -19,11 +19,11 @@ from typing import TYPE_CHECKING, Any -from sqlalchemy import Integer, MetaData, String, text -from sqlalchemy.orm import Mapped, registry +from sqlalchemy import Column, Integer, MetaData, String, text +from sqlalchemy.orm import registry from airflow.configuration import conf -from airflow.utils.sqlalchemy import is_sqlalchemy_v1, mapped_column +from airflow.utils.sqlalchemy import is_sqlalchemy_v1 SQL_ALCHEMY_SCHEMA = conf.get("database", "SQL_ALCHEMY_SCHEMA") @@ -94,7 +94,7 @@ class TaskInstanceDependencies(Base): __abstract__ = True - task_id: Mapped[str] = mapped_column(StringID(), nullable=False) - dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) - run_id: Mapped[str] = mapped_column(StringID(), nullable=False) - map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1")) + task_id = Column(StringID(), nullable=False) + dag_id = Column(StringID(), nullable=False) + run_id = Column(StringID(), nullable=False) + map_index = Column(Integer, nullable=False, server_default=text("-1")) diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index ca2230203e4b2..875d8d0b5711b 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -27,8 +27,8 @@ from typing import Any from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit -from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, select -from sqlalchemy.orm import Mapped, declared_attr, reconstructor, synonym +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text, select +from sqlalchemy.orm import declared_attr, reconstructor, synonym from sqlalchemy_utils import UUIDType from airflow._shared.secrets_masker import mask_secret @@ -42,7 +42,6 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import mapped_column log = logging.getLogger(__name__) # sanitize the `conn_id` pattern by allowing alphanumeric characters plus @@ -127,21 +126,19 @@ class Connection(Base, LoggingMixin): __tablename__ = "connection" - id: Mapped[int] = mapped_column(Integer(), primary_key=True) - conn_id: Mapped[str] = mapped_column(String(ID_LEN), unique=True, nullable=False) - conn_type: Mapped[str] = mapped_column(String(500), nullable=False) - description: Mapped[str] = mapped_column( - Text().with_variant(Text(5000), "mysql").with_variant(String(5000), "sqlite") - ) - host: Mapped[str] = mapped_column(String(500)) - schema: Mapped[str] = mapped_column(String(500)) - login: Mapped[str] = mapped_column(Text()) - _password: Mapped[str] = mapped_column("password", Text()) - port: Mapped[int] = mapped_column(Integer()) - is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) - is_extra_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) - team_id: Mapped[str | None] = mapped_column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) - _extra: Mapped[str] = mapped_column("extra", Text()) + id = Column(Integer(), primary_key=True) + conn_id = Column(String(ID_LEN), unique=True, nullable=False) + conn_type = Column(String(500), nullable=False) + description = Column(Text().with_variant(Text(5000), "mysql").with_variant(String(5000), "sqlite")) + host = Column(String(500)) + schema = Column(String(500)) + login = Column(Text()) + _password = Column("password", Text()) + port = Column(Integer()) + is_encrypted = Column(Boolean, unique=False, default=False) + is_extra_encrypted = Column(Boolean, unique=False, default=False) + team_id = Column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) + _extra = Column("extra", Text()) def __init__( self, diff --git a/airflow-core/src/airflow/models/dag.py b/airflow-core/src/airflow/models/dag.py index b789f3f437759..844c28b724c0f 100644 --- a/airflow-core/src/airflow/models/dag.py +++ b/airflow-core/src/airflow/models/dag.py @@ -27,6 +27,7 @@ from dateutil.relativedelta import relativedelta from sqlalchemy import ( Boolean, + Column, Float, ForeignKey, Index, @@ -41,7 +42,7 @@ ) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import Mapped, backref, load_only, relationship +from sqlalchemy.orm import backref, load_only, relationship from sqlalchemy.sql import expression from airflow import settings @@ -62,7 +63,7 @@ from airflow.timetables.simple import AssetTriggeredTimetable, NullTimetable, OnceTimetable from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks from airflow.utils.state import DagRunState from airflow.utils.types import DagRunType @@ -266,8 +267,8 @@ class DagTag(Base): """A tag name per dag, to allow quick filtering in the DAG view.""" __tablename__ = "dag_tag" - name: Mapped[str] = mapped_column(String(TAG_MAX_LEN), primary_key=True) - dag_id: Mapped[str] = mapped_column( + name = Column(String(TAG_MAX_LEN), primary_key=True) + dag_id = Column( StringID(), ForeignKey("dag.dag_id", name="dag_tag_dag_id_fkey", ondelete="CASCADE"), primary_key=True, @@ -287,14 +288,14 @@ class DagOwnerAttributes(Base): """ __tablename__ = "dag_owner_attributes" - dag_id: Mapped[str] = mapped_column( + dag_id = Column( StringID(), ForeignKey("dag.dag_id", name="dag.dag_id", ondelete="CASCADE"), nullable=False, primary_key=True, ) - owner: Mapped[str] = mapped_column(String(500), primary_key=True, nullable=False) - link: Mapped[str] = mapped_column(String(500), nullable=False) + owner = Column(String(500), primary_key=True, nullable=False) + link = Column(String(500), nullable=False) def __repr__(self): return f"" @@ -314,49 +315,43 @@ class DagModel(Base): """ These items are stored in the database for state related information. """ - dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True) + dag_id = Column(StringID(), primary_key=True) # A DAG can be paused from the UI / DB # Set this default value of is_paused based on a configuration value! is_paused_at_creation = airflow_conf.getboolean("core", "dags_are_paused_at_creation") - is_paused: Mapped[bool] = mapped_column(Boolean, default=is_paused_at_creation) + is_paused = Column(Boolean, default=is_paused_at_creation) # Whether that DAG was seen on the last DagBag load - is_stale: Mapped[bool] = mapped_column(Boolean, default=True) + is_stale = Column(Boolean, default=True) # Last time the scheduler started - last_parsed_time: Mapped[UtcDateTime] = mapped_column(UtcDateTime) + last_parsed_time = Column(UtcDateTime) # How long it took to parse this file - last_parse_duration: Mapped[float] = mapped_column(Float) + last_parse_duration = Column(Float) # Time when the DAG last received a refresh signal # (e.g. the DAG's "refresh" button was clicked in the web UI) - last_expired: Mapped[UtcDateTime] = mapped_column(UtcDateTime) + last_expired = Column(UtcDateTime) # The location of the file containing the DAG object # Note: Do not depend on fileloc pointing to a file; in the case of a # packaged DAG, it will point to the subpath of the DAG within the # associated zip. - fileloc: Mapped[str] = mapped_column(String(2000)) - relative_fileloc: Mapped[str] = mapped_column(String(2000)) - bundle_name: Mapped[str] = mapped_column(StringID(), ForeignKey("dag_bundle.name"), nullable=False) + fileloc = Column(String(2000)) + relative_fileloc = Column(String(2000)) + bundle_name = Column(StringID(), ForeignKey("dag_bundle.name"), nullable=False) # The version of the bundle the last time the DAG was processed - bundle_version: Mapped[str | None] = mapped_column(String(200), nullable=True) + bundle_version = Column(String(200), nullable=True) # String representing the owners - owners: Mapped[str] = mapped_column(String(2000)) + owners = Column(String(2000)) # Display name of the dag - _dag_display_property_value: Mapped[str | None] = mapped_column( - "dag_display_name", String(2000), nullable=True - ) + _dag_display_property_value = Column("dag_display_name", String(2000), nullable=True) # Description of the dag - description: Mapped[str] = mapped_column(Text) + description = Column(Text) # Timetable summary - timetable_summary: Mapped[str | None] = mapped_column(Text, nullable=True) + timetable_summary = Column(Text, nullable=True) # Timetable description - timetable_description: Mapped[str | None] = mapped_column(String(1000), nullable=True) + timetable_description = Column(String(1000), nullable=True) # Asset expression based on asset triggers - asset_expression: Mapped[dict[str, Any] | None] = mapped_column( - sqlalchemy_jsonfield.JSONField(json=json), nullable=True - ) + asset_expression = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) # DAG deadline information - _deadline: Mapped[dict[str, Any] | None] = mapped_column( - "deadline", sqlalchemy_jsonfield.JSONField(json=json), nullable=True - ) + _deadline = Column("deadline", sqlalchemy_jsonfield.JSONField(json=json), nullable=True) # Tags for view filter tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag")) # Dag owner links for DAGs view @@ -364,24 +359,22 @@ class DagModel(Base): "DagOwnerAttributes", cascade="all, delete, delete-orphan", backref=backref("dag") ) - max_active_tasks: Mapped[int] = mapped_column(Integer, nullable=False) - max_active_runs: Mapped[int | None] = mapped_column( - Integer, nullable=True - ) # todo: should not be nullable if we have a default - max_consecutive_failed_dag_runs: Mapped[int] = mapped_column(Integer, nullable=False) + max_active_tasks = Column(Integer, nullable=False) + max_active_runs = Column(Integer, nullable=True) # todo: should not be nullable if we have a default + max_consecutive_failed_dag_runs = Column(Integer, nullable=False) - has_task_concurrency_limits: Mapped[bool] = mapped_column(Boolean, nullable=False) - has_import_errors: Mapped[bool] = mapped_column(Boolean(), default=False, server_default="0") + has_task_concurrency_limits = Column(Boolean, nullable=False) + has_import_errors = Column(Boolean(), default=False, server_default="0") # The logical date of the next dag run. - next_dagrun: Mapped[UtcDateTime] = mapped_column(UtcDateTime) + next_dagrun = Column(UtcDateTime) # Must be either both NULL or both datetime. - next_dagrun_data_interval_start: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - next_dagrun_data_interval_end: Mapped[UtcDateTime] = mapped_column(UtcDateTime) + next_dagrun_data_interval_start = Column(UtcDateTime) + next_dagrun_data_interval_end = Column(UtcDateTime) # Earliest time at which this ``next_dagrun`` can be created. - next_dagrun_create_after: Mapped[UtcDateTime] = mapped_column(UtcDateTime) + next_dagrun_create_after = Column(UtcDateTime) __table_args__ = (Index("idx_next_dagrun_create_after", next_dagrun_create_after, unique=False),) diff --git a/airflow-core/src/airflow/models/dag_favorite.py b/airflow-core/src/airflow/models/dag_favorite.py index 700cf8575bf00..5dfb742fdaf80 100644 --- a/airflow-core/src/airflow/models/dag_favorite.py +++ b/airflow-core/src/airflow/models/dag_favorite.py @@ -17,11 +17,9 @@ # under the License. from __future__ import annotations -from sqlalchemy import ForeignKey -from sqlalchemy.orm import Mapped # noqa: TC002 +from sqlalchemy import Column, ForeignKey from airflow.models.base import Base, StringID -from airflow.utils.sqlalchemy import mapped_column class DagFavorite(Base): @@ -29,7 +27,5 @@ class DagFavorite(Base): __tablename__ = "dag_favorite" - user_id: Mapped[str] = mapped_column(StringID(), primary_key=True) - dag_id: Mapped[str] = mapped_column( - StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"), primary_key=True - ) + user_id = Column(StringID(), primary_key=True) + dag_id = Column(StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"), primary_key=True) diff --git a/airflow-core/src/airflow/models/dag_version.py b/airflow-core/src/airflow/models/dag_version.py index bfb210a3d5157..a51f3cb301867 100644 --- a/airflow-core/src/airflow/models/dag_version.py +++ b/airflow-core/src/airflow/models/dag_version.py @@ -21,20 +21,19 @@ from typing import TYPE_CHECKING import uuid6 -from sqlalchemy import ForeignKey, Integer, UniqueConstraint, select -from sqlalchemy.orm import Mapped, joinedload, relationship +from sqlalchemy import Column, ForeignKey, Integer, UniqueConstraint, select +from sqlalchemy.orm import joinedload, relationship from sqlalchemy_utils import UUIDType from airflow._shared.timezones import timezone from airflow.models.base import Base, StringID from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks if TYPE_CHECKING: from sqlalchemy.orm import Session from sqlalchemy.sql import Select - log = logging.getLogger(__name__) @@ -42,14 +41,12 @@ class DagVersion(Base): """Model to track the versions of DAGs in the database.""" __tablename__ = "dag_version" - id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) - version_number: Mapped[int] = mapped_column(Integer, nullable=False, default=1) - dag_id: Mapped[str] = mapped_column( - StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"), nullable=False - ) + id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) + version_number = Column(Integer, nullable=False, default=1) + dag_id = Column(StringID(), ForeignKey("dag.dag_id", ondelete="CASCADE"), nullable=False) dag_model = relationship("DagModel", back_populates="dag_versions") - bundle_name: Mapped[str | None] = mapped_column(StringID(), nullable=True) - bundle_version: Mapped[str] = mapped_column(StringID()) + bundle_name = Column(StringID(), nullable=True) + bundle_version = Column(StringID()) dag_code = relationship( "DagCode", back_populates="dag_version", @@ -65,10 +62,8 @@ class DagVersion(Base): cascade_backrefs=False, ) task_instances = relationship("TaskInstance", back_populates="dag_version") - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) - last_updated: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow - ) + created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow) + last_updated = Column(UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow) __table_args__ = ( UniqueConstraint("dag_id", "version_number", name="dag_id_v_name_v_number_unique_constraint"), diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index a7d76bd1f5307..928e8e406d4ec 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -20,13 +20,12 @@ import hashlib from typing import TYPE_CHECKING, Any -from sqlalchemy import String, inspect, select -from sqlalchemy.orm import Mapped, joinedload +from sqlalchemy import Column, String, inspect, select +from sqlalchemy.orm import joinedload from sqlalchemy.orm.attributes import NO_VALUE from airflow.models.base import Base, StringID from airflow.models.dag_version import DagVersion -from airflow.utils.sqlalchemy import mapped_column if TYPE_CHECKING: from collections.abc import Generator @@ -116,16 +115,14 @@ class DagPriorityParsingRequest(Base): # Adding a unique constraint to fileloc results in the creation of an index and we have a limitation # on the size of the string we can use in the index for MySQL DB. We also have to keep the fileloc # size consistent with other tables. This is a workaround to enforce the unique constraint. - id: Mapped[str] = mapped_column( - String(32), primary_key=True, default=generate_md5_hash, onupdate=generate_md5_hash - ) + id = Column(String(32), primary_key=True, default=generate_md5_hash, onupdate=generate_md5_hash) - bundle_name: Mapped[str] = mapped_column(StringID(), nullable=False) + bundle_name = Column(StringID(), nullable=False) # The location of the file containing the DAG object # Note: Do not depend on fileloc pointing to a file; in the case of a # packaged DAG, it will point to the subpath of the DAG within the # associated zip. - relative_fileloc: Mapped[str] = mapped_column(String(2000), nullable=False) + relative_fileloc = Column(String(2000), nullable=False) def __init__(self, bundle_name: str, relative_fileloc: str) -> None: super().__init__() diff --git a/airflow-core/src/airflow/models/dagbundle.py b/airflow-core/src/airflow/models/dagbundle.py index 95302ecfa479c..91e3f97c402b3 100644 --- a/airflow-core/src/airflow/models/dagbundle.py +++ b/airflow-core/src/airflow/models/dagbundle.py @@ -16,14 +16,14 @@ # under the License. from __future__ import annotations -from sqlalchemy import Boolean, String -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy import Boolean, Column, String +from sqlalchemy.orm import relationship from sqlalchemy_utils import JSONType from airflow.models.base import Base, StringID from airflow.models.team import dag_bundle_team_association_table from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime class DagBundleModel(Base, LoggingMixin): @@ -42,12 +42,12 @@ class DagBundleModel(Base, LoggingMixin): """ __tablename__ = "dag_bundle" - name: Mapped[str] = mapped_column(StringID(length=250), primary_key=True, nullable=False) - active: Mapped[bool] = mapped_column(Boolean, default=True) - version: Mapped[str | None] = mapped_column(String(200), nullable=True) - last_refreshed: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) - signed_url_template: Mapped[str | None] = mapped_column(String(200), nullable=True) - template_params: Mapped[dict | None] = mapped_column(JSONType, nullable=True) + name = Column(StringID(length=250), primary_key=True, nullable=False) + active = Column(Boolean, default=True) + version = Column(String(200), nullable=True) + last_refreshed = Column(UtcDateTime, nullable=True) + signed_url_template = Column(String(200), nullable=True) + template_params = Column(JSONType, nullable=True) teams = relationship("Team", secondary=dag_bundle_team_association_table, back_populates="dag_bundles") def __init__(self, *, name: str, version: str | None = None): diff --git a/airflow-core/src/airflow/models/dagcode.py b/airflow-core/src/airflow/models/dagcode.py index 5be3c5395f274..a68885e1dfdf2 100644 --- a/airflow-core/src/airflow/models/dagcode.py +++ b/airflow-core/src/airflow/models/dagcode.py @@ -20,9 +20,9 @@ from typing import TYPE_CHECKING import uuid6 -from sqlalchemy import ForeignKey, String, Text, select +from sqlalchemy import Column, ForeignKey, String, Text, select from sqlalchemy.dialects.mysql import MEDIUMTEXT -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy.orm import relationship from sqlalchemy.sql.expression import literal from sqlalchemy_utils import UUIDType @@ -33,7 +33,7 @@ from airflow.utils.file import open_maybe_zipped from airflow.utils.hashlib_wrapper import md5 from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -54,17 +54,15 @@ class DagCode(Base): """ __tablename__ = "dag_code" - id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) - dag_id: Mapped[str] = mapped_column(String(ID_LEN), nullable=False) - fileloc: Mapped[str] = mapped_column(String(2000), nullable=False) + id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) + dag_id = Column(String(ID_LEN), nullable=False) + fileloc = Column(String(2000), nullable=False) # The max length of fileloc exceeds the limit of indexing. - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) - last_updated: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow - ) - source_code: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False) - source_code_hash: Mapped[str] = mapped_column(String(32), nullable=False) - dag_version_id: Mapped[str] = mapped_column( + created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow) + last_updated = Column(UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow) + source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False) + source_code_hash = Column(String(32), nullable=False) + dag_version_id = Column( UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"), nullable=False, unique=True ) dag_version = relationship("DagVersion", back_populates="dag_code", uselist=False) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 0c8ff67d2312b..b5d0a715b7171 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -29,6 +29,7 @@ from natsort import natsorted from sqlalchemy import ( JSON, + Column, Enum, ForeignKey, ForeignKeyConstraint, @@ -51,7 +52,7 @@ from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import Mapped, declared_attr, joinedload, relationship, synonym, validates +from sqlalchemy.orm import declared_attr, joinedload, relationship, synonym, validates from sqlalchemy.sql.expression import false, select from sqlalchemy.sql.functions import coalesce from sqlalchemy_utils import UUIDType @@ -79,7 +80,7 @@ from airflow.utils.retries import retry_db_transaction from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.span_status import SpanStatus -from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, mapped_column, nulls_first, with_row_locks +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, nulls_first, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.strings import get_random_string from airflow.utils.thread_safe_dict import ThreadSafeDict @@ -147,61 +148,57 @@ class DagRun(Base, LoggingMixin): __tablename__ = "dag_run" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) - queued_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - logical_date: Mapped[datetime | None] = mapped_column(UtcDateTime, nullable=True) - start_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - end_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - _state: Mapped[str] = mapped_column("state", String(50), default=DagRunState.QUEUED) - run_id: Mapped[str] = mapped_column(StringID(), nullable=False) - creating_job_id: Mapped[int] = mapped_column(Integer) - run_type: Mapped[str] = mapped_column(String(50), nullable=False) - triggered_by: Mapped[DagRunTriggeredByType] = mapped_column( + id = Column(Integer, primary_key=True) + dag_id = Column(StringID(), nullable=False) + queued_at = Column(UtcDateTime) + logical_date = Column(UtcDateTime, nullable=True) + start_date = Column(UtcDateTime) + end_date = Column(UtcDateTime) + _state = Column("state", String(50), default=DagRunState.QUEUED) + run_id = Column(StringID(), nullable=False) + creating_job_id = Column(Integer) + run_type = Column(String(50), nullable=False) + triggered_by = Column( Enum(DagRunTriggeredByType, native_enum=False, length=50) ) # Airflow component that triggered the run. - triggering_user_name: Mapped[str | None] = mapped_column( + triggering_user_name = Column( String(512), nullable=True, ) # The user that triggered the DagRun, if applicable - conf: Mapped[dict[str, Any]] = mapped_column(JSON().with_variant(postgresql.JSONB, "postgresql")) + conf = Column(JSON().with_variant(postgresql.JSONB, "postgresql")) # These two must be either both NULL or both datetime. - data_interval_start: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - data_interval_end: Mapped[UtcDateTime] = mapped_column(UtcDateTime) + data_interval_start = Column(UtcDateTime) + data_interval_end = Column(UtcDateTime) # Earliest time when this DagRun can start running. - run_after: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=_default_run_after, nullable=False) + run_after = Column(UtcDateTime, default=_default_run_after, nullable=False) # When a scheduler last attempted to schedule TIs for this DagRun - last_scheduling_decision: Mapped[UtcDateTime] = mapped_column(UtcDateTime) + last_scheduling_decision = Column(UtcDateTime) # Foreign key to LogTemplate. DagRun rows created prior to this column's # existence have this set to NULL. Later rows automatically populate this on # insert to point to the latest LogTemplate entry. - log_template_id: Mapped[int] = mapped_column( + log_template_id = Column( Integer, ForeignKey("log_template.id", name="task_instance_log_template_id_fkey", ondelete="NO ACTION"), default=select(func.max(LogTemplate.__table__.c.id)), ) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow - ) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) # Keeps track of the number of times the dagrun had been cleared. # This number is incremented only when the DagRun is re-Queued, # when the DagRun is cleared. - clear_number: Mapped[int] = mapped_column(Integer, default=0, nullable=False, server_default="0") - backfill_id: Mapped[int | None] = mapped_column(Integer, ForeignKey("backfill.id"), nullable=True) + clear_number = Column(Integer, default=0, nullable=False, server_default="0") + backfill_id = Column(Integer, ForeignKey("backfill.id"), nullable=True) """ The backfill this DagRun is currently associated with. It's possible this could change if e.g. the dag run is cleared to be rerun, or perhaps re-backfilled. """ - bundle_version: Mapped[str] = mapped_column(StringID()) + bundle_version = Column(StringID()) - scheduled_by_job_id: Mapped[int] = mapped_column(Integer) + scheduled_by_job_id = Column(Integer) # Span context carrier, used for context propagation. - context_carrier: Mapped[dict[str, Any]] = mapped_column(MutableDict.as_mutable(ExtendedJSON)) - span_status: Mapped[str] = mapped_column( - String(250), server_default=SpanStatus.NOT_STARTED, nullable=False - ) - created_dag_version_id: Mapped[str | None] = mapped_column( + context_carrier = Column(MutableDict.as_mutable(ExtendedJSON)) + span_status = Column(String(250), server_default=SpanStatus.NOT_STARTED, nullable=False) + created_dag_version_id = Column( UUIDType(binary=False), ForeignKey("dag_version.id", name="created_dag_version_id_fkey", ondelete="set null"), nullable=True, @@ -2113,13 +2110,11 @@ class DagRunNote(Base): __tablename__ = "dag_run_note" - user_id: Mapped[str | None] = mapped_column(String(128), nullable=True) - dag_run_id: Mapped[int] = mapped_column(Integer, primary_key=True, nullable=False) - content: Mapped[str] = mapped_column(String(1000).with_variant(Text(1000), "mysql")) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False - ) + user_id = Column(String(128), nullable=True) + dag_run_id = Column(Integer, primary_key=True, nullable=False) + content = Column(String(1000).with_variant(Text(1000), "mysql")) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) dag_run = relationship("DagRun", back_populates="dag_run_note") diff --git a/airflow-core/src/airflow/models/dagwarning.py b/airflow-core/src/airflow/models/dagwarning.py index cec81ce8abeaf..f11b5037a9090 100644 --- a/airflow-core/src/airflow/models/dagwarning.py +++ b/airflow-core/src/airflow/models/dagwarning.py @@ -20,15 +20,15 @@ from enum import Enum from typing import TYPE_CHECKING -from sqlalchemy import ForeignKeyConstraint, Index, String, Text, delete, select, true -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy import Column, ForeignKeyConstraint, Index, String, Text, delete, select, true +from sqlalchemy.orm import relationship from airflow._shared.timezones import timezone from airflow.models.base import Base, StringID from airflow.models.dag import DagModel from airflow.utils.retries import retry_db_transaction from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -43,10 +43,10 @@ class DagWarning(Base): when parsing DAG and displayed on the Webserver in a flash message. """ - dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True) - warning_type: Mapped[str] = mapped_column(String(50), primary_key=True) - message: Mapped[str] = mapped_column(Text, nullable=False) - timestamp: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) + dag_id = Column(StringID(), primary_key=True) + warning_type = Column(String(50), primary_key=True) + message = Column(Text, nullable=False) + timestamp = Column(UtcDateTime, nullable=False, default=timezone.utcnow) dag_model = relationship("DagModel", viewonly=True, lazy="selectin") diff --git a/airflow-core/src/airflow/models/db_callback_request.py b/airflow-core/src/airflow/models/db_callback_request.py index 0dc2a287ca805..f1009ca1babe4 100644 --- a/airflow-core/src/airflow/models/db_callback_request.py +++ b/airflow-core/src/airflow/models/db_callback_request.py @@ -20,12 +20,11 @@ from importlib import import_module from typing import TYPE_CHECKING -from sqlalchemy import Integer, String -from sqlalchemy.orm import Mapped # noqa: TC002 +from sqlalchemy import Column, Integer, String from airflow._shared.timezones import timezone from airflow.models.base import Base -from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime if TYPE_CHECKING: from airflow.callbacks.callback_requests import CallbackRequest @@ -36,11 +35,11 @@ class DbCallbackRequest(Base): __tablename__ = "callback_request" - id: Mapped[int] = mapped_column(Integer(), nullable=False, primary_key=True) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) - priority_weight: Mapped[int] = mapped_column(Integer(), nullable=False) - callback_data: Mapped[dict] = mapped_column(ExtendedJSON, nullable=False) - callback_type: Mapped[str] = mapped_column(String(20), nullable=False) + id = Column(Integer(), nullable=False, primary_key=True) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + priority_weight = Column(Integer(), nullable=False) + callback_data = Column(ExtendedJSON, nullable=False) + callback_type = Column(String(20), nullable=False) def __init__(self, priority_weight: int, callback: CallbackRequest): self.created_at = timezone.utcnow() diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index f91b3228bc1d2..21e49a36ed1e7 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -26,9 +26,9 @@ import sqlalchemy_jsonfield import uuid6 -from sqlalchemy import ForeignKey, Index, Integer, String, and_, func, select, text +from sqlalchemy import Column, ForeignKey, Index, Integer, String, and_, func, select, text from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy.orm import relationship from sqlalchemy_utils import UUIDType from airflow._shared.timezones import timezone @@ -40,7 +40,7 @@ from airflow.triggers.deadline import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY, DeadlineCallbackTrigger from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -97,24 +97,22 @@ class Deadline(Base): __tablename__ = "deadline" - id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) + id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) # If the Deadline Alert is for a DAG, store the DAG run ID from the dag_run. - dagrun_id: Mapped[int] = mapped_column(Integer, ForeignKey("dag_run.id", ondelete="CASCADE")) + dagrun_id = Column(Integer, ForeignKey("dag_run.id", ondelete="CASCADE")) # The time after which the Deadline has passed and the callback should be triggered. - deadline_time: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) + deadline_time = Column(UtcDateTime, nullable=False) # The (serialized) callback to be called when the Deadline has passed. - _callback: Mapped[dict] = mapped_column( - "callback", sqlalchemy_jsonfield.JSONField(json=json), nullable=False - ) + _callback = Column("callback", sqlalchemy_jsonfield.JSONField(json=json), nullable=False) # The state of the deadline callback - callback_state: Mapped[str] = mapped_column(String(20)) + callback_state = Column(String(20)) dagrun = relationship("DagRun", back_populates="deadlines") # The Trigger where the callback is running - trigger_id: Mapped[int | None] = mapped_column(Integer, ForeignKey("trigger.id"), nullable=True) + trigger_id = Column(Integer, ForeignKey("trigger.id"), nullable=True) trigger = relationship("Trigger", back_populates="deadline") __table_args__ = (Index("deadline_callback_state_time_idx", callback_state, deadline_time, unique=False),) @@ -147,7 +145,7 @@ def _determine_resource() -> tuple[str, str]: ) @classmethod - def prune_deadlines(cls, *, session: Session, conditions: dict[Mapped, Any]) -> int: + def prune_deadlines(cls, *, session: Session, conditions: dict[Column, Any]) -> int: """ Remove deadlines from the table which match the provided conditions and return the number removed. @@ -481,7 +479,7 @@ def deserialize_reference(cls, reference_data: dict): @provide_session -def _fetch_from_db(model_reference: Mapped, session=None, **conditions) -> datetime: +def _fetch_from_db(model_reference: Column, session=None, **conditions) -> datetime: """ Fetch a datetime value from the database using the provided model reference and filtering conditions. diff --git a/airflow-core/src/airflow/models/errors.py b/airflow-core/src/airflow/models/errors.py index 0dbd8d5de6b76..6670df1dfaf62 100644 --- a/airflow-core/src/airflow/models/errors.py +++ b/airflow-core/src/airflow/models/errors.py @@ -17,23 +17,22 @@ # under the License. from __future__ import annotations -from sqlalchemy import Integer, String, Text -from sqlalchemy.orm import Mapped # noqa: TC002 +from sqlalchemy import Column, Integer, String, Text from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.models.base import Base, StringID -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime class ParseImportError(Base): """Stores all Import Errors which are recorded when parsing DAGs and displayed on the Webserver.""" __tablename__ = "import_error" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - timestamp: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - filename: Mapped[str] = mapped_column(String(1024)) - bundle_name: Mapped[str] = mapped_column(StringID()) - stacktrace: Mapped[str] = mapped_column(Text) + id = Column(Integer, primary_key=True) + timestamp = Column(UtcDateTime) + filename = Column(String(1024)) + bundle_name = Column(StringID()) + stacktrace = Column(Text) def full_file_path(self) -> str: """Return the full file path of the dag.""" diff --git a/airflow-core/src/airflow/models/hitl.py b/airflow-core/src/airflow/models/hitl.py index da31d3be52204..b6bbb2bc402b0 100644 --- a/airflow-core/src/airflow/models/hitl.py +++ b/airflow-core/src/airflow/models/hitl.py @@ -19,17 +19,17 @@ from typing import TYPE_CHECKING, Any, TypedDict import sqlalchemy_jsonfield -from sqlalchemy import Boolean, ForeignKeyConstraint, String, Text, func, literal +from sqlalchemy import Boolean, Column, ForeignKeyConstraint, String, Text, func, literal from sqlalchemy.dialects import postgresql from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy.orm import relationship from sqlalchemy.sql.functions import FunctionElement from airflow._shared.timezones import timezone from airflow.models.base import Base from airflow.settings import json -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: from sqlalchemy.sql import ColumnElement @@ -84,37 +84,31 @@ class HITLDetail(Base): """Human-in-the-loop request and corresponding response.""" __tablename__ = "hitl_detail" - ti_id: Mapped[str] = mapped_column( + ti_id = Column( String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), primary_key=True, nullable=False, ) # User Request Detail - options: Mapped[dict] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) - subject: Mapped[str] = mapped_column(Text, nullable=False) - body: Mapped[str | None] = mapped_column(Text, nullable=True) - defaults: Mapped[dict | None] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) - multiple: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) - params: Mapped[dict] = mapped_column( - sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={} - ) - assignees: Mapped[dict | None] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + options = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) + subject = Column(Text, nullable=False) + body = Column(Text, nullable=True) + defaults = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) + multiple = Column(Boolean, unique=False, default=False) + params = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) + assignees = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) # Response Content Detail - responded_at: Mapped[UtcDateTime | None] = mapped_column(UtcDateTime, nullable=True) - responded_by: Mapped[dict | None] = mapped_column( - sqlalchemy_jsonfield.JSONField(json=json), nullable=True - ) - chosen_options: Mapped[dict | None] = mapped_column( + responded_at = Column(UtcDateTime, nullable=True) + responded_by = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) + chosen_options = Column( sqlalchemy_jsonfield.JSONField(json=json), nullable=True, default=None, ) - params_input: Mapped[dict] = mapped_column( - sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={} - ) + params_input = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}) task_instance = relationship( "TaskInstance", lazy="joined", diff --git a/airflow-core/src/airflow/models/log.py b/airflow-core/src/airflow/models/log.py index fee291ce97689..645c3ee6c7783 100644 --- a/airflow-core/src/airflow/models/log.py +++ b/airflow-core/src/airflow/models/log.py @@ -19,12 +19,12 @@ from typing import TYPE_CHECKING -from sqlalchemy import Index, Integer, String, Text -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy import Column, Index, Integer, String, Text +from sqlalchemy.orm import relationship from airflow._shared.timezones import timezone from airflow.models.base import Base, StringID -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance @@ -36,18 +36,18 @@ class Log(Base): __tablename__ = "log" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - dttm: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - dag_id: Mapped[str] = mapped_column(StringID()) - task_id: Mapped[str] = mapped_column(StringID()) - map_index: Mapped[int] = mapped_column(Integer) - event: Mapped[str] = mapped_column(String(60)) - logical_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - run_id: Mapped[str] = mapped_column(StringID()) - owner: Mapped[str] = mapped_column(String(500)) - owner_display_name: Mapped[str] = mapped_column(String(500)) - extra: Mapped[str] = mapped_column(Text) - try_number: Mapped[int] = mapped_column(Integer) + id = Column(Integer, primary_key=True) + dttm = Column(UtcDateTime) + dag_id = Column(StringID()) + task_id = Column(StringID()) + map_index = Column(Integer) + event = Column(String(60)) + logical_date = Column(UtcDateTime) + run_id = Column(StringID()) + owner = Column(String(500)) + owner_display_name = Column(String(500)) + extra = Column(Text) + try_number = Column(Integer) dag_model = relationship( "DagModel", diff --git a/airflow-core/src/airflow/models/pool.py b/airflow-core/src/airflow/models/pool.py index 058a14f7883b6..6f9f159c51857 100644 --- a/airflow-core/src/airflow/models/pool.py +++ b/airflow-core/src/airflow/models/pool.py @@ -19,8 +19,7 @@ from typing import TYPE_CHECKING, Any, TypedDict -from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, func, select -from sqlalchemy.orm import Mapped # noqa: TC002 +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text, func, select from sqlalchemy_utils import UUIDType from airflow.exceptions import AirflowException, PoolNotFound @@ -29,7 +28,7 @@ from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.utils.db import exists_query from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import mapped_column, with_row_locks +from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: @@ -52,13 +51,13 @@ class Pool(Base): __tablename__ = "slot_pool" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - pool: Mapped[str] = mapped_column(String(256), unique=True) + id = Column(Integer, primary_key=True) + pool = Column(String(256), unique=True) # -1 for infinite - slots: Mapped[int] = mapped_column(Integer, default=0) - description: Mapped[str] = mapped_column(Text) - include_deferred: Mapped[bool] = mapped_column(Boolean, nullable=False) - team_id: Mapped[str | None] = mapped_column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) + slots = Column(Integer, default=0) + description = Column(Text) + include_deferred = Column(Boolean, nullable=False) + team_id = Column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) DEFAULT_POOL_NAME = "default_pool" diff --git a/airflow-core/src/airflow/models/renderedtifields.py b/airflow-core/src/airflow/models/renderedtifields.py index 1f55d946cb0d2..b19b7f1f6edf9 100644 --- a/airflow-core/src/airflow/models/renderedtifields.py +++ b/airflow-core/src/airflow/models/renderedtifields.py @@ -24,6 +24,7 @@ import sqlalchemy_jsonfield from sqlalchemy import ( + Column, ForeignKeyConstraint, Integer, PrimaryKeyConstraint, @@ -33,7 +34,7 @@ text, ) from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy.orm import relationship from airflow.configuration import conf from airflow.models.base import StringID, TaskInstanceDependencies @@ -41,7 +42,6 @@ from airflow.settings import json from airflow.utils.retries import retry_db_transaction from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import mapped_column if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -69,14 +69,12 @@ class RenderedTaskInstanceFields(TaskInstanceDependencies): __tablename__ = "rendered_task_instance_fields" - dag_id: Mapped[str] = mapped_column(StringID(), primary_key=True) - task_id: Mapped[str] = mapped_column(StringID(), primary_key=True) - run_id: Mapped[str] = mapped_column(StringID(), primary_key=True) - map_index: Mapped[int] = mapped_column(Integer, primary_key=True, server_default=text("-1")) - rendered_fields: Mapped[dict] = mapped_column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) - k8s_pod_yaml: Mapped[dict | None] = mapped_column( - sqlalchemy_jsonfield.JSONField(json=json), nullable=True - ) + dag_id = Column(StringID(), primary_key=True) + task_id = Column(StringID(), primary_key=True) + run_id = Column(StringID(), primary_key=True) + map_index = Column(Integer, primary_key=True, server_default=text("-1")) + rendered_fields = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False) + k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True) __table_args__ = ( PrimaryKeyConstraint( diff --git a/airflow-core/src/airflow/models/serialized_dag.py b/airflow-core/src/airflow/models/serialized_dag.py index c0bf5be286009..32a44baaffccf 100644 --- a/airflow-core/src/airflow/models/serialized_dag.py +++ b/airflow-core/src/airflow/models/serialized_dag.py @@ -27,9 +27,9 @@ import sqlalchemy_jsonfield import uuid6 -from sqlalchemy import ForeignKey, LargeBinary, String, select, tuple_ +from sqlalchemy import Column, ForeignKey, LargeBinary, String, select, tuple_ from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import Mapped, backref, foreign, relationship +from sqlalchemy.orm import backref, foreign, relationship from sqlalchemy.sql.expression import func, literal from sqlalchemy_utils import UUIDType @@ -49,7 +49,7 @@ from airflow.settings import COMPRESS_SERIALIZED_DAGS, json from airflow.utils.hashlib_wrapper import md5 from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -280,17 +280,15 @@ class SerializedDagModel(Base): """ __tablename__ = "serialized_dag" - id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) - dag_id: Mapped[str] = mapped_column(String(ID_LEN), nullable=False) - _data: Mapped[dict | None] = mapped_column( + id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) + dag_id = Column(String(ID_LEN), nullable=False) + _data = Column( "data", sqlalchemy_jsonfield.JSONField(json=json).with_variant(JSONB, "postgresql"), nullable=True ) - _data_compressed: Mapped[bytes | None] = mapped_column("data_compressed", LargeBinary, nullable=True) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) - last_updated: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow - ) - dag_hash: Mapped[str] = mapped_column(String(32), nullable=False) + _data_compressed = Column("data_compressed", LargeBinary, nullable=True) + created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow) + last_updated = Column(UtcDateTime, nullable=False, default=timezone.utcnow, onupdate=timezone.utcnow) + dag_hash = Column(String(32), nullable=False) dag_runs = relationship( DagRun, @@ -306,7 +304,7 @@ class SerializedDagModel(Base): innerjoin=True, backref=backref("serialized_dag", uselist=False, innerjoin=True), ) - dag_version_id: Mapped[str] = mapped_column( + dag_version_id = Column( UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="CASCADE"), nullable=False, diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 90423914da586..e9806696e243b 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -25,7 +25,7 @@ import uuid from collections import defaultdict from collections.abc import Collection, Iterable -from datetime import datetime, timedelta +from datetime import timedelta from functools import cache from typing import TYPE_CHECKING, Any from urllib.parse import quote @@ -35,6 +35,7 @@ import lazy_object_proxy import uuid6 from sqlalchemy import ( + Column, Float, ForeignKey, ForeignKeyConstraint, @@ -61,7 +62,7 @@ from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import Mapped, lazyload, reconstructor, relationship +from sqlalchemy.orm import lazyload, reconstructor, relationship from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value from sqlalchemy_utils import UUIDType @@ -92,7 +93,7 @@ from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.span_status import SpanStatus -from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, UtcDateTime from airflow.utils.state import DagRunState, State, TaskInstanceState TR = TaskReschedule @@ -373,65 +374,59 @@ class TaskInstance(Base, LoggingMixin): """ __tablename__ = "task_instance" - id: Mapped[str] = mapped_column( + id = Column( String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), primary_key=True, default=uuid7, nullable=False, ) - task_id: Mapped[str] = mapped_column(StringID(), nullable=False) - dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) - run_id: Mapped[str] = mapped_column(StringID(), nullable=False) - map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1")) - - start_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - end_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - duration: Mapped[float] = mapped_column(Float) - state: Mapped[str] = mapped_column(String(20)) - try_number: Mapped[int] = mapped_column(Integer, default=0) - max_tries: Mapped[int] = mapped_column(Integer, server_default=text("-1")) - hostname: Mapped[str] = mapped_column(String(1000)) - unixname: Mapped[str] = mapped_column(String(1000)) - pool: Mapped[str] = mapped_column(String(256), nullable=False) - pool_slots: Mapped[int] = mapped_column(Integer, default=1, nullable=False) - queue: Mapped[str] = mapped_column(String(256)) - priority_weight: Mapped[int] = mapped_column(Integer) - operator: Mapped[str] = mapped_column(String(1000)) - custom_operator_name: Mapped[str] = mapped_column(String(1000)) - queued_dttm: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - scheduled_dttm: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - queued_by_job_id: Mapped[int] = mapped_column(Integer) - - last_heartbeat_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - pid: Mapped[int] = mapped_column(Integer) - executor: Mapped[str] = mapped_column(String(1000)) - executor_config: Mapped[dict] = mapped_column(ExecutorConfigType(pickler=dill)) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow - ) - _rendered_map_index: Mapped[str] = mapped_column("rendered_map_index", String(250)) - context_carrier: Mapped[dict] = mapped_column(MutableDict.as_mutable(ExtendedJSON)) - span_status: Mapped[str] = mapped_column( - String(250), server_default=SpanStatus.NOT_STARTED, nullable=False - ) - - external_executor_id: Mapped[str] = mapped_column(StringID()) + task_id = Column(StringID(), nullable=False) + dag_id = Column(StringID(), nullable=False) + run_id = Column(StringID(), nullable=False) + map_index = Column(Integer, nullable=False, server_default=text("-1")) + + start_date = Column(UtcDateTime) + end_date = Column(UtcDateTime) + duration = Column(Float) + state = Column(String(20)) + try_number = Column(Integer, default=0) + max_tries = Column(Integer, server_default=text("-1")) + hostname = Column(String(1000)) + unixname = Column(String(1000)) + pool = Column(String(256), nullable=False) + pool_slots = Column(Integer, default=1, nullable=False) + queue = Column(String(256)) + priority_weight = Column(Integer) + operator = Column(String(1000)) + custom_operator_name = Column(String(1000)) + queued_dttm = Column(UtcDateTime) + scheduled_dttm = Column(UtcDateTime) + queued_by_job_id = Column(Integer) + + last_heartbeat_at = Column(UtcDateTime) + pid = Column(Integer) + executor = Column(String(1000)) + executor_config = Column(ExecutorConfigType(pickler=dill)) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) + _rendered_map_index = Column("rendered_map_index", String(250)) + context_carrier = Column(MutableDict.as_mutable(ExtendedJSON)) + span_status = Column(String(250), server_default=SpanStatus.NOT_STARTED, nullable=False) + + external_executor_id = Column(StringID()) # The trigger to resume on if we are in state DEFERRED - trigger_id: Mapped[int] = mapped_column(Integer) + trigger_id = Column(Integer) # Optional timeout utcdatetime for the trigger (past this, we'll fail) - trigger_timeout: Mapped[UtcDateTime] = mapped_column(UtcDateTime) + trigger_timeout = Column(UtcDateTime) # The method to call next, and any extra arguments to pass to it. # Usually used when resuming from DEFERRED. - next_method: Mapped[str] = mapped_column(String(1000)) - next_kwargs: Mapped[dict] = mapped_column(MutableDict.as_mutable(ExtendedJSON)) + next_method = Column(String(1000)) + next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON)) - _task_display_property_value: Mapped[str | None] = mapped_column( - "task_display_name", String(2000), nullable=True - ) - dag_version_id: Mapped[str] = mapped_column( + _task_display_property_value = Column("task_display_name", String(2000), nullable=True) + dag_version_id = Column( UUIDType(binary=False), ForeignKey("dag_version.id", ondelete="RESTRICT"), ) @@ -2207,17 +2202,15 @@ class TaskInstanceNote(Base): """For storage of arbitrary notes concerning the task instance.""" __tablename__ = "task_instance_note" - ti_id: Mapped[str] = mapped_column( + ti_id = Column( String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), primary_key=True, nullable=False, ) - user_id: Mapped[str | None] = mapped_column(String(128), nullable=True) - content: Mapped[str] = mapped_column(String(1000).with_variant(Text(1000), "mysql")) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False - ) + user_id = Column(String(128), nullable=True) + content = Column(String(1000).with_variant(Text(1000), "mysql")) + created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) task_instance = relationship("TaskInstance", back_populates="task_instance_note", uselist=False) diff --git a/airflow-core/src/airflow/models/taskinstancehistory.py b/airflow-core/src/airflow/models/taskinstancehistory.py index 701abef564a57..b5932b28c58a3 100644 --- a/airflow-core/src/airflow/models/taskinstancehistory.py +++ b/airflow-core/src/airflow/models/taskinstancehistory.py @@ -21,6 +21,7 @@ import dill from sqlalchemy import ( + Column, DateTime, Float, ForeignKeyConstraint, @@ -34,7 +35,7 @@ ) from sqlalchemy.dialects import postgresql from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy.orm import relationship from sqlalchemy_utils import UUIDType from airflow._shared.timezones import timezone @@ -45,7 +46,6 @@ ExecutorConfigType, ExtendedJSON, UtcDateTime, - mapped_column, ) from airflow.utils.state import State, TaskInstanceState @@ -64,52 +64,48 @@ class TaskInstanceHistory(Base): """ __tablename__ = "task_instance_history" - task_instance_id: Mapped[str] = mapped_column( + task_instance_id = Column( String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), nullable=False, primary_key=True, ) - task_id: Mapped[str] = mapped_column(StringID(), nullable=False) - dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) - run_id: Mapped[str] = mapped_column(StringID(), nullable=False) - map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1")) - try_number: Mapped[int] = mapped_column(Integer, nullable=False) - start_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - end_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - duration: Mapped[float] = mapped_column(Float) - state: Mapped[str] = mapped_column(String(20)) - max_tries: Mapped[int] = mapped_column(Integer, server_default=text("-1")) - hostname: Mapped[str] = mapped_column(String(1000)) - unixname: Mapped[str] = mapped_column(String(1000)) - pool: Mapped[str] = mapped_column(String(256), nullable=False) - pool_slots: Mapped[int] = mapped_column(Integer, default=1, nullable=False) - queue: Mapped[str] = mapped_column(String(256)) - priority_weight: Mapped[int] = mapped_column(Integer) - operator: Mapped[str] = mapped_column(String(1000)) - custom_operator_name: Mapped[str] = mapped_column(String(1000)) - queued_dttm: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - scheduled_dttm: Mapped[UtcDateTime] = mapped_column(UtcDateTime) - queued_by_job_id: Mapped[int] = mapped_column(Integer) - pid: Mapped[int] = mapped_column(Integer) - executor: Mapped[str] = mapped_column(String(1000)) - executor_config: Mapped[dict] = mapped_column(ExecutorConfigType(pickler=dill)) - updated_at: Mapped[UtcDateTime] = mapped_column( - UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow - ) - rendered_map_index: Mapped[str] = mapped_column(String(250)) - context_carrier: Mapped[dict] = mapped_column(MutableDict.as_mutable(ExtendedJSON)) - span_status: Mapped[str] = mapped_column( - String(250), server_default=SpanStatus.NOT_STARTED, nullable=False - ) - - external_executor_id: Mapped[str] = mapped_column(StringID()) - trigger_id: Mapped[int] = mapped_column(Integer) - trigger_timeout: Mapped[DateTime] = mapped_column(DateTime) - next_method: Mapped[str] = mapped_column(String(1000)) - next_kwargs: Mapped[dict] = mapped_column(MutableDict.as_mutable(ExtendedJSON)) - - task_display_name: Mapped[str] = mapped_column(String(2000), nullable=True) - dag_version_id: Mapped[str] = mapped_column(UUIDType(binary=False)) + task_id = Column(StringID(), nullable=False) + dag_id = Column(StringID(), nullable=False) + run_id = Column(StringID(), nullable=False) + map_index = Column(Integer, nullable=False, server_default=text("-1")) + try_number = Column(Integer, nullable=False) + start_date = Column(UtcDateTime) + end_date = Column(UtcDateTime) + duration = Column(Float) + state = Column(String(20)) + max_tries = Column(Integer, server_default=text("-1")) + hostname = Column(String(1000)) + unixname = Column(String(1000)) + pool = Column(String(256), nullable=False) + pool_slots = Column(Integer, default=1, nullable=False) + queue = Column(String(256)) + priority_weight = Column(Integer) + operator = Column(String(1000)) + custom_operator_name = Column(String(1000)) + queued_dttm = Column(UtcDateTime) + scheduled_dttm = Column(UtcDateTime) + queued_by_job_id = Column(Integer) + pid = Column(Integer) + executor = Column(String(1000)) + executor_config = Column(ExecutorConfigType(pickler=dill)) + updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow) + rendered_map_index = Column(String(250)) + context_carrier = Column(MutableDict.as_mutable(ExtendedJSON)) + span_status = Column(String(250), server_default=SpanStatus.NOT_STARTED, nullable=False) + + external_executor_id = Column(StringID()) + trigger_id = Column(Integer) + trigger_timeout = Column(DateTime) + next_method = Column(String(1000)) + next_kwargs = Column(MutableDict.as_mutable(ExtendedJSON)) + + task_display_name = Column(String(2000), nullable=True) + dag_version_id = Column(UUIDType(binary=False)) dag_version = relationship( "DagVersion", diff --git a/airflow-core/src/airflow/models/tasklog.py b/airflow-core/src/airflow/models/tasklog.py index 51ec10db1808b..d9a5c57c30ac5 100644 --- a/airflow-core/src/airflow/models/tasklog.py +++ b/airflow-core/src/airflow/models/tasklog.py @@ -17,12 +17,11 @@ # under the License. from __future__ import annotations -from sqlalchemy import Integer, Text -from sqlalchemy.orm import Mapped # noqa: TC002 +from sqlalchemy import Column, Integer, Text from airflow._shared.timezones import timezone from airflow.models.base import Base -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime class LogTemplate(Base): @@ -35,10 +34,10 @@ class LogTemplate(Base): __tablename__ = "log_template" - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - filename: Mapped[str] = mapped_column(Text, nullable=False) - elasticsearch_id: Mapped[str] = mapped_column(Text, nullable=False) - created_at: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False, default=timezone.utcnow) + id = Column(Integer, primary_key=True, autoincrement=True) + filename = Column(Text, nullable=False) + elasticsearch_id = Column(Text, nullable=False) + created_at = Column(UtcDateTime, nullable=False, default=timezone.utcnow) def __repr__(self) -> str: attrs = ", ".join(f"{k}={getattr(self, k)}" for k in ("filename", "elasticsearch_id")) diff --git a/airflow-core/src/airflow/models/taskmap.py b/airflow-core/src/airflow/models/taskmap.py index 8d6500747b727..edd0b21b114ea 100644 --- a/airflow-core/src/airflow/models/taskmap.py +++ b/airflow-core/src/airflow/models/taskmap.py @@ -24,13 +24,12 @@ from collections.abc import Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any -from sqlalchemy import CheckConstraint, ForeignKeyConstraint, Integer, String, func, or_, select -from sqlalchemy.orm import Mapped # noqa: TC002 +from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String, func, or_, select from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies from airflow.models.dag_version import DagVersion from airflow.utils.db import exists_query -from airflow.utils.sqlalchemy import ExtendedJSON, mapped_column, with_row_locks +from airflow.utils.sqlalchemy import ExtendedJSON, with_row_locks from airflow.utils.state import State, TaskInstanceState if TYPE_CHECKING: @@ -64,13 +63,13 @@ class TaskMap(TaskInstanceDependencies): __tablename__ = "task_map" # Link to upstream TaskInstance creating this dynamic mapping information. - dag_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) - task_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) - run_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) - map_index: Mapped[int] = mapped_column(Integer, primary_key=True) + dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) + task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) + run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) + map_index = Column(Integer, primary_key=True) - length: Mapped[int] = mapped_column(Integer, nullable=False) - keys: Mapped[list | None] = mapped_column(ExtendedJSON, nullable=True) + length = Column(Integer, nullable=False) + keys = Column(ExtendedJSON, nullable=True) __table_args__ = ( CheckConstraint(length >= 0, name="task_map_length_not_negative"), diff --git a/airflow-core/src/airflow/models/taskreschedule.py b/airflow-core/src/airflow/models/taskreschedule.py index 2dc6bcb18b257..e07d750ebdef7 100644 --- a/airflow-core/src/airflow/models/taskreschedule.py +++ b/airflow-core/src/airflow/models/taskreschedule.py @@ -19,11 +19,11 @@ from __future__ import annotations -import datetime import uuid from typing import TYPE_CHECKING from sqlalchemy import ( + Column, ForeignKey, Integer, String, @@ -32,10 +32,10 @@ select, ) from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy.orm import relationship from airflow.models.base import Base -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: import datetime @@ -49,16 +49,16 @@ class TaskReschedule(Base): """TaskReschedule tracks rescheduled task instances.""" __tablename__ = "task_reschedule" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - ti_id: Mapped[str] = mapped_column( + id = Column(Integer, primary_key=True) + ti_id = Column( String(36).with_variant(postgresql.UUID(as_uuid=False), "postgresql"), ForeignKey("task_instance.id", ondelete="CASCADE", name="task_reschedule_ti_fkey"), nullable=False, ) - start_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) - end_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) - duration: Mapped[int] = mapped_column(Integer, nullable=False) - reschedule_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) + start_date = Column(UtcDateTime, nullable=False) + end_date = Column(UtcDateTime, nullable=False) + duration = Column(Integer, nullable=False) + reschedule_date = Column(UtcDateTime, nullable=False) task_instance = relationship( "TaskInstance", primaryjoin="TaskReschedule.ti_id == foreign(TaskInstance.id)", uselist=False diff --git a/airflow-core/src/airflow/models/team.py b/airflow-core/src/airflow/models/team.py index 551b426abcace..3ef31b434095a 100644 --- a/airflow-core/src/airflow/models/team.py +++ b/airflow-core/src/airflow/models/team.py @@ -21,12 +21,11 @@ import uuid6 from sqlalchemy import Column, ForeignKey, Index, String, Table, select -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy.orm import relationship from sqlalchemy_utils import UUIDType from airflow.models.base import Base, StringID from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import mapped_column if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -55,8 +54,8 @@ class Team(Base): __tablename__ = "team" - id: Mapped[str] = mapped_column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) - name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False) + id = Column(UUIDType(binary=False), primary_key=True, default=uuid6.uuid7) + name = Column(String(50), unique=True, nullable=False) dag_bundles = relationship( "DagBundleModel", secondary=dag_bundle_team_association_table, back_populates="teams" ) diff --git a/airflow-core/src/airflow/models/trigger.py b/airflow-core/src/airflow/models/trigger.py index 94da2acc6aa84..8de31b312d985 100644 --- a/airflow-core/src/airflow/models/trigger.py +++ b/airflow-core/src/airflow/models/trigger.py @@ -24,9 +24,9 @@ from traceback import format_exception from typing import TYPE_CHECKING, Any -from sqlalchemy import Integer, String, Text, delete, func, or_, select, update +from sqlalchemy import Column, Integer, String, Text, delete, func, or_, select, update from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import Mapped, Session, relationship, selectinload +from sqlalchemy.orm import Session, relationship, selectinload from sqlalchemy.sql.functions import coalesce from airflow._shared.timezones import timezone @@ -37,7 +37,7 @@ from airflow.triggers.base import BaseTaskEndEvent from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: @@ -90,11 +90,11 @@ class Trigger(Base): __tablename__ = "trigger" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - classpath: Mapped[str] = mapped_column(String(1000), nullable=False) - encrypted_kwargs: Mapped[str] = mapped_column("kwargs", Text, nullable=False) - created_date: Mapped[UtcDateTime] = mapped_column(UtcDateTime, nullable=False) - triggerer_id: Mapped[int | None] = mapped_column(Integer, nullable=True) + id = Column(Integer, primary_key=True) + classpath = Column(String(1000), nullable=False) + encrypted_kwargs = Column("kwargs", Text, nullable=False) + created_date = Column(UtcDateTime, nullable=False) + triggerer_id = Column(Integer, nullable=True) triggerer_job = relationship( "Job", diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index ff58370e87707..2c50da9add721 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -24,9 +24,9 @@ import warnings from typing import TYPE_CHECKING, Any -from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, delete, select +from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text, delete, select from sqlalchemy.dialects.mysql import MEDIUMTEXT -from sqlalchemy.orm import Mapped, declared_attr, reconstructor, synonym +from sqlalchemy.orm import declared_attr, reconstructor, synonym from sqlalchemy_utils import UUIDType from airflow._shared.secrets_masker import mask_secret @@ -38,7 +38,6 @@ from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.sqlalchemy import mapped_column if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -52,12 +51,12 @@ class Variable(Base, LoggingMixin): __tablename__ = "variable" __NO_DEFAULT_SENTINEL = object() - id: Mapped[int] = mapped_column(Integer, primary_key=True) - key: Mapped[str] = mapped_column(String(ID_LEN), unique=True) - _val: Mapped[str] = mapped_column("val", Text().with_variant(MEDIUMTEXT, "mysql")) - description: Mapped[str] = mapped_column(Text) - is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) - team_id: Mapped[str | None] = mapped_column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) + id = Column(Integer, primary_key=True) + key = Column(String(ID_LEN), unique=True) + _val = Column("val", Text().with_variant(MEDIUMTEXT, "mysql")) + description = Column(Text) + is_encrypted = Column(Boolean, unique=False, default=False) + team_id = Column(UUIDType(binary=False), ForeignKey("team.id"), nullable=True) def __init__(self, key=None, val=None, description=None, team_id=None): super().__init__() diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index 8141c8e2f7af6..709d6bf69030c 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -24,6 +24,7 @@ from sqlalchemy import ( JSON, + Column, ForeignKeyConstraint, Index, Integer, @@ -36,7 +37,7 @@ ) from sqlalchemy.dialects import postgresql from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy.orm import Mapped, relationship +from sqlalchemy.orm import relationship from airflow._shared.timezones import timezone from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies @@ -44,7 +45,7 @@ from airflow.utils.helpers import is_container from airflow.utils.json import XComDecoder, XComEncoder from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, mapped_column +from airflow.utils.sqlalchemy import UtcDateTime log = logging.getLogger(__name__) @@ -62,19 +63,17 @@ class XComModel(TaskInstanceDependencies): __tablename__ = "xcom" - dag_run_id: Mapped[int] = mapped_column(Integer(), nullable=False, primary_key=True) - task_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True) - map_index: Mapped[int] = mapped_column( - Integer, primary_key=True, nullable=False, server_default=text("-1") - ) - key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) + dag_run_id = Column(Integer(), nullable=False, primary_key=True) + task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True) + map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1")) + key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True) # Denormalized for easier lookup. - dag_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - run_id: Mapped[str] = mapped_column(String(ID_LEN, **COLLATION_ARGS), nullable=False) + dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) + run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False) - value: Mapped[Any] = mapped_column(JSON().with_variant(postgresql.JSONB, "postgresql")) - timestamp: Mapped[UtcDateTime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + value = Column(JSON().with_variant(postgresql.JSONB, "postgresql")) + timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False) __table_args__ = ( # Ideally we should create a unique index over (key, dag_id, task_id, run_id), diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py b/airflow-core/src/airflow/utils/sqlalchemy.py index 98bb68843e812..072f767335c6f 100644 --- a/airflow-core/src/airflow/utils/sqlalchemy.py +++ b/airflow-core/src/airflow/utils/sqlalchemy.py @@ -48,15 +48,6 @@ log = logging.getLogger(__name__) -try: - from sqlalchemy.orm import mapped_column -except ImportError: - # fallback for SQLAlchemy < 2.0 - def mapped_column(*args, **kwargs): - from sqlalchemy import Column - - return Column(*args, **kwargs) - class UtcDateTime(TypeDecorator): """