Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 82 additions & 51 deletions db.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,60 @@
"""A medium-faithful port of https://github.com/weinberg/SQLToy to Python"""

from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple


Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do Row = Dict[str, Any] and RowPredicate = Callable[[Row], bool] to make types easier to read

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry which line are you referencing/

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding new type aliases so we can use those throughout the new type annotations. This will make nearly every annotation much shorter

class Table:
def __init__(self, name: str, rows: tuple[dict] = ()):
def __init__(self, name: str, rows: Iterable[Dict[str, Any]] = ()):
self.name = name
self.rows = tuple(rows)
self._colnames = ()
self.rows: Tuple[Dict[str, Any], ...] = tuple(rows)
self._colnames: Tuple[str, ...] = ()

def set_colnames(self, colnames):
def set_colnames(self, colnames: Iterable[str]) -> None:
self._colnames = tuple(sorted(colnames))

def colnames(self):
def colnames(self) -> Tuple[str, ...]:
if self._colnames:
return self._colnames
if not self.rows:
raise ValueError("Need either rows or manually specified column names")
return tuple(sorted(self.rows[0].keys()))

def filter(self, pred):
def filter(self, pred: Callable[[Dict[str, Any]], bool]) -> "Table":
return Table(self.name, [row for row in self.rows if pred(row)])

def __repr__(self):
def __repr__(self) -> str:
if not self.name:
return f"Table({list(self.rows)!r})"
return f"Table({self.name!r}, {list(self.rows)!r})"


class Database:
def __init__(self):
self.tables = {}
def __init__(self) -> None:
self.tables: Dict[str, Table] = {}

def CREATE_TABLE(self, name, colnames=()):
def CREATE_TABLE(self, name: str, colnames: Iterable[str] = ()) -> Table:
table = Table(name)
if colnames:
table.set_colnames(colnames)
self.tables[name] = table
return table

def DROP_TABLE(self, name):
def DROP_TABLE(self, name: str) -> None:
del self.tables[name]

def FROM(self, first_table, *rest):
def FROM(self, first_table: str, *rest: str) -> Table:
match rest:
case ():
return self.tables[first_table]
case _:
return self.CROSS_JOIN(self.tables[first_table], self.FROM(*rest))

def SELECT(self, table, columns, aliases=None):
def SELECT(
self,
table: Table,
columns: Iterable[str],
aliases: Optional[Dict[str, str]] = None,
) -> Table:
if aliases is None:
aliases = {}
return Table(
Expand All @@ -58,19 +65,25 @@ def SELECT(self, table, columns, aliases=None):
],
)

def WHERE(self, table, pred):
def WHERE(self, table: Table, pred: Callable[[Dict[str, Any]], bool]) -> Table:
return table.filter(pred)

def INSERT_INTO(self, table_name, rows):
def INSERT_INTO(self, table_name: str, rows: Iterable[Dict[str, Any]]) -> None:
table = self.tables[table_name]
table.rows = (*table.rows, *rows)

def UPDATE(self, table, set, pred=lambda _: True):
def UPDATE(
self,
table: Table,
set_values: Dict[str, Any],
pred: Callable[[Dict[str, Any]], bool] = lambda _: True,
) -> Table:
return Table(
table.name, [{**row, **set} if pred(row) else row for row in table.rows]
table.name,
[{**row, **set_values} if pred(row) else row for row in table.rows],
)
Comment on lines +80 to 86
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/set_values/row_updates/g maybe


def CROSS_JOIN(self, a, b):
def CROSS_JOIN(self, a: Table, b: Table) -> Table:
rows = []
a_prefix = f"{a.name}." if a.name else ""
b_prefix = f"{b.name}." if b.name else ""
Expand All @@ -84,12 +97,16 @@ def CROSS_JOIN(self, a, b):
)
return Table("", rows)

def INNER_JOIN(self, a, b, pred):
def INNER_JOIN(
self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool]
) -> Table:
return self.CROSS_JOIN(a, b).filter(pred)

JOIN = INNER_JOIN

def LEFT_JOIN(self, a, b, pred):
def LEFT_JOIN(
self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool]
) -> Table:
rows = []
empty_b_row = {f"{b.name}.{k}": None for k in b.colnames()}
for a_row in a.rows:
Expand All @@ -107,24 +124,26 @@ def LEFT_JOIN(self, a, b, pred):
rows.append({**mangled_a_row, **empty_b_row})
return Table("", rows)

def RIGHT_JOIN(self, a, b, pred):
def RIGHT_JOIN(
self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool]
) -> Table:
return self.LEFT_JOIN(b, a, pred)

def LIMIT(self, table, limit):
def LIMIT(self, table: Table, limit: int) -> Table:
return Table(table.name, table.rows[:limit])

def ORDER_BY(self, table, rel):
def ORDER_BY(self, table: Table, rel: Callable[[Dict[str, Any]], Any]) -> Table:
# Differs from JS version by passing the whole row to the comparator
return Table(table.name, sorted(table.rows, key=rel))

def HAVING(self, table, pred):
def HAVING(self, table: Table, pred: Callable[[Dict[str, Any]], bool]) -> Table:
return table.filter(pred)

def OFFSET(self, table, offset):
def OFFSET(self, table: Table, offset: int) -> Table:
return Table(table.name, table.rows[offset:])

def DISTINCT(self, table, columns):
seen = set()
def DISTINCT(self, table: Table, columns: Iterable[str]) -> Table:
seen: Set[Tuple[Tuple[str, Any], ...]] = set()
rows = []
for row in table.rows:
view = tuple((col, row[col]) for col in columns)
Expand All @@ -133,22 +152,28 @@ def DISTINCT(self, table, columns):
rows.append(dict(view))
return Table(table.name, rows)

def GROUP_BY(self, table, groupBys):
groupRows = {}
def GROUP_BY(self, table: Table, groupBys: Iterable[str]) -> Table:
groupRows: Dict[Tuple[Any, ...], list[Dict[str, Any]]] = {}
for row in table.rows:
key = tuple(row[col] for col in groupBys)
if key not in groupRows:
groupRows[key] = []
groupRows[key].append(row.copy())
resultRows = []
for group in groupRows.values():
resultRow = {"_groupRows": group}
resultRow: Dict[str, Any] = {"_groupRows": group}
for col in groupBys:
resultRow[col] = group[0][col]
resultRows.append(resultRow)
return Table(table.name, resultRows)

def _aggregate(self, table, col, agg_name, agg):
def _aggregate(
self,
table: Table,
col: str,
agg_name: str,
agg: Callable[[Iterable[Dict[str, Any]]], Any],
) -> Table:
grouped = table.rows and "_groupRows" in table.rows[0]
col_name = f"{agg_name}({col})"
if not grouped:
Expand All @@ -164,32 +189,36 @@ def _aggregate(self, table, col, agg_name, agg):
rows.append(new_row)
return Table(table.name, rows)

def COUNT(self, table, col):
return self._aggregate(table, col, "COUNT", len)
def COUNT(self, table: Table, col: str) -> Table:
return self._aggregate(table, col, "COUNT", lambda rows: len(list(rows)))

def MAX(self, table, col):
return self._aggregate(table, col, "MAX", lambda rows: max(row[col] for row in rows))
def MAX(self, table: Table, col: str) -> Table:
return self._aggregate(
table, col, "MAX", lambda rows: max(row[col] for row in rows)
)

def SUM(self, table, col):
return self._aggregate(table, col, "SUM", lambda rows: sum(row[col] for row in rows))
def SUM(self, table: Table, col: str) -> Table:
return self._aggregate(
table, col, "SUM", lambda rows: sum(row[col] for row in rows)
)

def __repr__(self):
def __repr__(self) -> str:
return f"Database({list(self.tables.keys())!r})"


def query(
db,
select=(),
select_as=None,
distinct=None,
from_=None,
join=(),
where=(),
group_by=(),
having=None,
order_by=None,
offset=None,
limit=None,
db: Database,
select: Iterable[str] = (),
select_as: Optional[Dict[str, str]] = None,
distinct: Optional[Iterable[str]] = None,
from_: Optional[Iterable[str]] = None,
join: Iterable[Tuple[str, Callable[[Dict[str, Any]], bool]]] = (),
where: Iterable[Callable[[Dict[str, Any]], bool]] = (),
group_by: Iterable[str] = (),
having: Optional[Callable[[Dict[str, Any]], bool]] = None,
order_by: Optional[Callable[[Dict[str, Any]], Any]] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
) -> Table:
if from_ is None:
raise ValueError("Need a FROM clause")
Expand All @@ -211,10 +240,12 @@ def query(
result = db.OFFSET(result, offset)
if limit:
result = db.LIMIT(result, limit)
if distinct:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go in a separate commit and needs separate tests. Additionally, in the SQL order of execution, I think it goes between select and order by

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(test is slightly less important but it should go in a separate commit than type checking for sure)

result = db.DISTINCT(result, distinct)
return result


def csv(table):
def csv(table: Table) -> None:
print(",".join(table.colnames()))
for row in table.rows:
print(",".join(str(val) for val in row.values()))