|
97 | 97 |
|
98 | 98 | from opentelemetry.sdk.trace import Span |
99 | 99 | from pydantic import NonNegativeInt |
100 | | - from sqlalchemy.orm import Query, Session |
| 100 | + from sqlalchemy.engine import ScalarResult |
| 101 | + from sqlalchemy.orm import Session |
101 | 102 | from sqlalchemy.sql.elements import Case, ColumnElement |
102 | 103 |
|
103 | 104 | from airflow.models.dag_version import DagVersion |
@@ -351,7 +352,7 @@ def __init__( |
351 | 352 | if queued_at is NOTSET: |
352 | 353 | self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED else None |
353 | 354 | elif queued_at is not None: |
354 | | - self.queued_at = queued_at |
| 355 | + self.queued_at = cast("datetime", queued_at) |
355 | 356 | if run_type is not None: |
356 | 357 | self.run_type = run_type |
357 | 358 | self.creating_job_id = creating_job_id |
@@ -572,11 +573,11 @@ def active_runs_of_dags( |
572 | 573 | ) |
573 | 574 | if exclude_backfill: |
574 | 575 | query = query.where(cls.run_type != DagRunType.BACKFILL_JOB) |
575 | | - return dict(session.execute(query).all()) |
| 576 | + return {dag_id: count for dag_id, count in session.execute(query)} |
576 | 577 |
|
577 | 578 | @classmethod |
578 | 579 | @retry_db_transaction |
579 | | - def get_running_dag_runs_to_examine(cls, session: Session) -> Query: |
| 580 | + def get_running_dag_runs_to_examine(cls, session: Session) -> ScalarResult[DagRun]: |
580 | 581 | """ |
581 | 582 | Return the next DagRuns that the scheduler should attempt to schedule. |
582 | 583 |
|
@@ -615,7 +616,7 @@ def get_running_dag_runs_to_examine(cls, session: Session) -> Query: |
615 | 616 |
|
616 | 617 | @classmethod |
617 | 618 | @retry_db_transaction |
618 | | - def get_queued_dag_runs_to_set_running(cls, session: Session) -> Query: |
| 619 | + def get_queued_dag_runs_to_set_running(cls, session: Session) -> ScalarResult[DagRun]: |
619 | 620 | """ |
620 | 621 | Return the next queued DagRuns that the scheduler should attempt to schedule. |
621 | 622 |
|
|
0 commit comments