|
56 | 56 | from sqlalchemy.sql.expression import false, select |
57 | 57 | from sqlalchemy.sql.functions import coalesce |
58 | 58 | from sqlalchemy_utils import UUIDType |
| 59 | +from sqlalchemy.engine import ScalarResult |
59 | 60 |
|
60 | 61 | from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext |
61 | 62 | from airflow.configuration import conf as airflow_conf |
@@ -351,7 +352,7 @@ def __init__( |
351 | 352 | if queued_at is NOTSET: |
352 | 353 | self.queued_at = pendulum.now(tz="UTC") if state == DagRunState.QUEUED else None |
353 | 354 | elif queued_at is not None: |
354 | | - self.queued_at = queued_at |
| 355 | + self.queued_at = cast(datetime, queued_at) |
355 | 356 | if run_type is not None: |
356 | 357 | self.run_type = run_type |
357 | 358 | self.creating_job_id = creating_job_id |
@@ -572,11 +573,12 @@ def active_runs_of_dags( |
572 | 573 | ) |
573 | 574 | if exclude_backfill: |
574 | 575 | query = query.where(cls.run_type != DagRunType.BACKFILL_JOB) |
575 | | - return dict(session.execute(query).all()) |
| 576 | + result = session.execute(query).all() |
| 577 | + return {dag_id:count for dag_id, count in result} |
576 | 578 |
|
577 | 579 | @classmethod |
578 | 580 | @retry_db_transaction |
579 | | - def get_running_dag_runs_to_examine(cls, session: Session) -> Query: |
| 581 | + def get_running_dag_runs_to_examine(cls, session: Session) -> ScalarResult[DagRun]: |
580 | 582 | """ |
581 | 583 | Return the next DagRuns that the scheduler should attempt to schedule. |
582 | 584 |
|
@@ -615,7 +617,7 @@ def get_running_dag_runs_to_examine(cls, session: Session) -> Query: |
615 | 617 |
|
616 | 618 | @classmethod |
617 | 619 | @retry_db_transaction |
618 | | - def get_queued_dag_runs_to_set_running(cls, session: Session) -> Query: |
| 620 | + def get_queued_dag_runs_to_set_running(cls, session: Session) -> ScalarResult[DagRun]: |
619 | 621 | """ |
620 | 622 | Return the next queued DagRuns that the scheduler should attempt to schedule. |
621 | 623 |
|
@@ -732,11 +734,13 @@ def find( |
732 | 734 | qry = qry.where(cls.dag_id.in_(dag_ids)) |
733 | 735 |
|
734 | 736 | if is_container(run_id): |
735 | | - qry = qry.where(cls.run_id.in_(run_id)) |
| 737 | + run_ids = cast(Iterable[str], run_id) |
| 738 | + qry = qry.where(cls.run_id.in_(run_ids)) |
736 | 739 | elif run_id is not None: |
737 | 740 | qry = qry.where(cls.run_id == run_id) |
738 | 741 | if is_container(logical_date): |
739 | | - qry = qry.where(cls.logical_date.in_(logical_date)) |
| 742 | + logical_dates = cast(Iterable[datetime], logical_date) |
| 743 | + qry = qry.where(cls.logical_date.in_(logical_dates)) |
740 | 744 | elif logical_date is not None: |
741 | 745 | qry = qry.where(cls.logical_date == logical_date) |
742 | 746 | if logical_start_date and logical_end_date: |
|
0 commit comments