-
Notifications
You must be signed in to change notification settings - Fork 2
Add Type Annotations #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,53 +1,60 @@ | ||
| """A medium-faithful port of https://github.com/weinberg/SQLToy to Python""" | ||
|
|
||
| from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple | ||
|
|
||
|
|
||
| class Table: | ||
| def __init__(self, name: str, rows: tuple[dict] = ()): | ||
| def __init__(self, name: str, rows: Iterable[Dict[str, Any]] = ()): | ||
| self.name = name | ||
| self.rows = tuple(rows) | ||
| self._colnames = () | ||
| self.rows: Tuple[Dict[str, Any], ...] = tuple(rows) | ||
| self._colnames: Tuple[str, ...] = () | ||
|
|
||
| def set_colnames(self, colnames): | ||
| def set_colnames(self, colnames: Iterable[str]) -> None: | ||
| self._colnames = tuple(sorted(colnames)) | ||
|
|
||
| def colnames(self): | ||
| def colnames(self) -> Tuple[str, ...]: | ||
| if self._colnames: | ||
| return self._colnames | ||
| if not self.rows: | ||
| raise ValueError("Need either rows or manually specified column names") | ||
| return tuple(sorted(self.rows[0].keys())) | ||
|
|
||
| def filter(self, pred): | ||
| def filter(self, pred: Callable[[Dict[str, Any]], bool]) -> "Table": | ||
rosmur marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return Table(self.name, [row for row in self.rows if pred(row)]) | ||
|
|
||
| def __repr__(self): | ||
| def __repr__(self) -> str: | ||
| if not self.name: | ||
| return f"Table({list(self.rows)!r})" | ||
| return f"Table({self.name!r}, {list(self.rows)!r})" | ||
|
|
||
|
|
||
| class Database: | ||
| def __init__(self): | ||
| self.tables = {} | ||
| def __init__(self) -> None: | ||
| self.tables: Dict[str, Table] = {} | ||
|
|
||
| def CREATE_TABLE(self, name, colnames=()): | ||
| def CREATE_TABLE(self, name: str, colnames: Iterable[str] = ()) -> Table: | ||
| table = Table(name) | ||
| if colnames: | ||
| table.set_colnames(colnames) | ||
| self.tables[name] = table | ||
| return table | ||
|
|
||
| def DROP_TABLE(self, name): | ||
| def DROP_TABLE(self, name: str) -> None: | ||
| del self.tables[name] | ||
|
|
||
| def FROM(self, first_table, *rest): | ||
| def FROM(self, first_table: str, *rest: str) -> Table: | ||
| match rest: | ||
| case (): | ||
| return self.tables[first_table] | ||
| case _: | ||
| return self.CROSS_JOIN(self.tables[first_table], self.FROM(*rest)) | ||
|
|
||
| def SELECT(self, table, columns, aliases=None): | ||
| def SELECT( | ||
| self, | ||
| table: Table, | ||
| columns: Iterable[str], | ||
| aliases: Optional[Dict[str, str]] = None, | ||
| ) -> Table: | ||
| if aliases is None: | ||
| aliases = {} | ||
| return Table( | ||
|
|
@@ -58,19 +65,25 @@ def SELECT(self, table, columns, aliases=None): | |
| ], | ||
| ) | ||
|
|
||
| def WHERE(self, table, pred): | ||
| def WHERE(self, table: Table, pred: Callable[[Dict[str, Any]], bool]) -> Table: | ||
| return table.filter(pred) | ||
|
|
||
| def INSERT_INTO(self, table_name, rows): | ||
| def INSERT_INTO(self, table_name: str, rows: Iterable[Dict[str, Any]]) -> None: | ||
| table = self.tables[table_name] | ||
| table.rows = (*table.rows, *rows) | ||
|
|
||
| def UPDATE(self, table, set, pred=lambda _: True): | ||
| def UPDATE( | ||
| self, | ||
| table: Table, | ||
| set_values: Dict[str, Any], | ||
| pred: Callable[[Dict[str, Any]], bool] = lambda _: True, | ||
| ) -> Table: | ||
| return Table( | ||
| table.name, [{**row, **set} if pred(row) else row for row in table.rows] | ||
| table.name, | ||
| [{**row, **set_values} if pred(row) else row for row in table.rows], | ||
| ) | ||
|
Comment on lines
+80
to
86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/set_values/row_updates/g maybe |
||
|
|
||
| def CROSS_JOIN(self, a, b): | ||
| def CROSS_JOIN(self, a: Table, b: Table) -> Table: | ||
| rows = [] | ||
| a_prefix = f"{a.name}." if a.name else "" | ||
| b_prefix = f"{b.name}." if b.name else "" | ||
|
|
@@ -84,12 +97,16 @@ def CROSS_JOIN(self, a, b): | |
| ) | ||
| return Table("", rows) | ||
|
|
||
| def INNER_JOIN(self, a, b, pred): | ||
| def INNER_JOIN( | ||
| self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool] | ||
| ) -> Table: | ||
| return self.CROSS_JOIN(a, b).filter(pred) | ||
|
|
||
| JOIN = INNER_JOIN | ||
|
|
||
| def LEFT_JOIN(self, a, b, pred): | ||
| def LEFT_JOIN( | ||
| self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool] | ||
| ) -> Table: | ||
| rows = [] | ||
| empty_b_row = {f"{b.name}.{k}": None for k in b.colnames()} | ||
| for a_row in a.rows: | ||
|
|
@@ -107,24 +124,26 @@ def LEFT_JOIN(self, a, b, pred): | |
| rows.append({**mangled_a_row, **empty_b_row}) | ||
| return Table("", rows) | ||
|
|
||
| def RIGHT_JOIN(self, a, b, pred): | ||
| def RIGHT_JOIN( | ||
| self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool] | ||
| ) -> Table: | ||
| return self.LEFT_JOIN(b, a, pred) | ||
|
|
||
| def LIMIT(self, table, limit): | ||
| def LIMIT(self, table: Table, limit: int) -> Table: | ||
| return Table(table.name, table.rows[:limit]) | ||
|
|
||
| def ORDER_BY(self, table, rel): | ||
| def ORDER_BY(self, table: Table, rel: Callable[[Dict[str, Any]], Any]) -> Table: | ||
| # Differs from JS version by passing the whole row to the comparator | ||
| return Table(table.name, sorted(table.rows, key=rel)) | ||
|
|
||
| def HAVING(self, table, pred): | ||
| def HAVING(self, table: Table, pred: Callable[[Dict[str, Any]], bool]) -> Table: | ||
| return table.filter(pred) | ||
|
|
||
| def OFFSET(self, table, offset): | ||
| def OFFSET(self, table: Table, offset: int) -> Table: | ||
| return Table(table.name, table.rows[offset:]) | ||
|
|
||
| def DISTINCT(self, table, columns): | ||
| seen = set() | ||
| def DISTINCT(self, table: Table, columns: Iterable[str]) -> Table: | ||
| seen: Set[Tuple[Tuple[str, Any], ...]] = set() | ||
| rows = [] | ||
| for row in table.rows: | ||
| view = tuple((col, row[col]) for col in columns) | ||
|
|
@@ -133,22 +152,28 @@ def DISTINCT(self, table, columns): | |
| rows.append(dict(view)) | ||
| return Table(table.name, rows) | ||
|
|
||
| def GROUP_BY(self, table, groupBys): | ||
| groupRows = {} | ||
| def GROUP_BY(self, table: Table, groupBys: Iterable[str]) -> Table: | ||
| groupRows: Dict[Tuple[Any, ...], list[Dict[str, Any]]] = {} | ||
| for row in table.rows: | ||
| key = tuple(row[col] for col in groupBys) | ||
| if key not in groupRows: | ||
| groupRows[key] = [] | ||
| groupRows[key].append(row.copy()) | ||
| resultRows = [] | ||
| for group in groupRows.values(): | ||
| resultRow = {"_groupRows": group} | ||
| resultRow: Dict[str, Any] = {"_groupRows": group} | ||
| for col in groupBys: | ||
| resultRow[col] = group[0][col] | ||
| resultRows.append(resultRow) | ||
| return Table(table.name, resultRows) | ||
|
|
||
| def _aggregate(self, table, col, agg_name, agg): | ||
| def _aggregate( | ||
| self, | ||
| table: Table, | ||
| col: str, | ||
| agg_name: str, | ||
| agg: Callable[[Iterable[Dict[str, Any]]], Any], | ||
| ) -> Table: | ||
| grouped = table.rows and "_groupRows" in table.rows[0] | ||
| col_name = f"{agg_name}({col})" | ||
| if not grouped: | ||
|
|
@@ -164,32 +189,36 @@ def _aggregate(self, table, col, agg_name, agg): | |
| rows.append(new_row) | ||
| return Table(table.name, rows) | ||
|
|
||
| def COUNT(self, table, col): | ||
| return self._aggregate(table, col, "COUNT", len) | ||
| def COUNT(self, table: Table, col: str) -> Table: | ||
| return self._aggregate(table, col, "COUNT", lambda rows: len(list(rows))) | ||
|
|
||
| def MAX(self, table, col): | ||
| return self._aggregate(table, col, "MAX", lambda rows: max(row[col] for row in rows)) | ||
| def MAX(self, table: Table, col: str) -> Table: | ||
| return self._aggregate( | ||
| table, col, "MAX", lambda rows: max(row[col] for row in rows) | ||
| ) | ||
|
|
||
| def SUM(self, table, col): | ||
| return self._aggregate(table, col, "SUM", lambda rows: sum(row[col] for row in rows)) | ||
| def SUM(self, table: Table, col: str) -> Table: | ||
| return self._aggregate( | ||
| table, col, "SUM", lambda rows: sum(row[col] for row in rows) | ||
| ) | ||
|
|
||
| def __repr__(self): | ||
| def __repr__(self) -> str: | ||
| return f"Database({list(self.tables.keys())!r})" | ||
|
|
||
|
|
||
| def query( | ||
| db, | ||
| select=(), | ||
| select_as=None, | ||
| distinct=None, | ||
| from_=None, | ||
| join=(), | ||
| where=(), | ||
| group_by=(), | ||
| having=None, | ||
| order_by=None, | ||
| offset=None, | ||
| limit=None, | ||
| db: Database, | ||
| select: Iterable[str] = (), | ||
| select_as: Optional[Dict[str, str]] = None, | ||
| distinct: Optional[Iterable[str]] = None, | ||
| from_: Optional[Iterable[str]] = None, | ||
| join: Iterable[Tuple[str, Callable[[Dict[str, Any]], bool]]] = (), | ||
| where: Iterable[Callable[[Dict[str, Any]], bool]] = (), | ||
| group_by: Iterable[str] = (), | ||
| having: Optional[Callable[[Dict[str, Any]], bool]] = None, | ||
| order_by: Optional[Callable[[Dict[str, Any]], Any]] = None, | ||
| offset: Optional[int] = None, | ||
| limit: Optional[int] = None, | ||
| ) -> Table: | ||
| if from_ is None: | ||
| raise ValueError("Need a FROM clause") | ||
|
|
@@ -211,10 +240,12 @@ def query( | |
| result = db.OFFSET(result, offset) | ||
| if limit: | ||
| result = db.LIMIT(result, limit) | ||
| if distinct: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should go in a separate commit and needs separate tests. Additionally, in the SQL order of execution, I think it goes between select and order by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (test is slightly less important but it should go in a separate commit than type checking for sure) |
||
| result = db.DISTINCT(result, distinct) | ||
| return result | ||
|
|
||
|
|
||
| def csv(table): | ||
| def csv(table: Table) -> None: | ||
| print(",".join(table.colnames())) | ||
| for row in table.rows: | ||
| print(",".join(str(val) for val in row.values())) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do
Row = Dict[str, Any]andRowPredicate = Callable[[Row], bool]to make types easier to readThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry which line are you referencing/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding new type aliases so we can use those throughout the new type annotations. This will make nearly every annotation much shorter