|
97 | 97 |
|
98 | 98 | from opentelemetry.sdk.trace import Span |
99 | 99 | from pydantic import NonNegativeInt |
100 | | - from sqlalchemy.orm import Query, Session |
| 100 | + from sqlalchemy.engine import ScalarResult |
| 101 | + from sqlalchemy.orm import Session |
101 | 102 | from sqlalchemy.sql.elements import Case, ColumnElement |
102 | 103 |
|
103 | 104 | from airflow.models.dag_version import DagVersion |
@@ -348,7 +349,7 @@ def __init__( |
348 | 349 | self.conf = conf or {} |
349 | 350 | if state is not None: |
350 | 351 | self.state = state |
351 | | - if queued_at is NOTSET: |
| 352 | + if isinstance(queued_at, ArgNotSet): |
352 | 353 | self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED else None |
353 | 354 | elif queued_at is not None: |
354 | 355 | self.queued_at = queued_at |
@@ -572,11 +573,11 @@ def active_runs_of_dags( |
572 | 573 | ) |
573 | 574 | if exclude_backfill: |
574 | 575 | query = query.where(cls.run_type != DagRunType.BACKFILL_JOB) |
575 | | - return dict(session.execute(query).all()) |
| 576 | + return {dag_id: count for dag_id, count in session.execute(query)} |
576 | 577 |
|
577 | 578 | @classmethod |
578 | 579 | @retry_db_transaction |
579 | | - def get_running_dag_runs_to_examine(cls, session: Session) -> Query: |
| 580 | + def get_running_dag_runs_to_examine(cls, session: Session) -> ScalarResult[DagRun]: |
580 | 581 | """ |
581 | 582 | Return the next DagRuns that the scheduler should attempt to schedule. |
582 | 583 |
|
@@ -615,7 +616,7 @@ def get_running_dag_runs_to_examine(cls, session: Session) -> Query: |
615 | 616 |
|
616 | 617 | @classmethod |
617 | 618 | @retry_db_transaction |
618 | | - def get_queued_dag_runs_to_set_running(cls, session: Session) -> Query: |
| 619 | + def get_queued_dag_runs_to_set_running(cls, session: Session) -> ScalarResult[DagRun]: |
619 | 620 | """ |
620 | 621 | Return the next queued DagRuns that the scheduler should attempt to schedule. |
621 | 622 |
|
|
0 commit comments