diff --git a/README.md b/README.md new file mode 100644 index 0000000..751da81 --- /dev/null +++ b/README.md @@ -0,0 +1,121 @@ +# Python SQL-like Database + +This project is a Python-based, in-memory database that provides a simplified, SQL-like interface for data manipulation. It is a medium-faithful port of the original [SQLToy](https://github.com/weinberg/SQLToy) project. + +## Core Concepts + +The two primary components of this library are the `Database` and `Table` classes. + +- **`Database`**: Manages a collection of tables and provides methods for performing SQL-like operations. +- **`Table`**: Represents a collection of rows, where each row is a dictionary of column-value pairs. + +## Basic Operations + +### Creating a Table + +You can create a new table within a database using the `CREATE_TABLE` method. + +```python +from db import Database + +db = Database() +db.CREATE_TABLE("users", colnames=["id", "name"]) +``` + +### Inserting Data + +To add data to a table, use the `INSERT_INTO` method. + +```python +db.INSERT_INTO("users", [ + {"id": 1, "name": "Alice"}, + {"id": 2, "name": "Bob"} +]) +``` + +### Selecting Data + +The `SELECT` method allows you to retrieve specific columns from a table. + +```python +users_table = db.tables["users"] +names = db.SELECT(users_table, ["name"]) +# names.rows will be ({'name': 'Alice'}, {'name': 'Bob'}) +``` + +## Advanced Queries + +This library supports a variety of more complex query operations. + +### Joining Tables + +You can perform `INNER JOIN`, `LEFT JOIN`, and `RIGHT JOIN` operations. + +```python +# Create another table +db.CREATE_TABLE("posts", colnames=["user_id", "title"]) +db.INSERT_INTO("posts", [ + {"user_id": 1, "title": "First Post"}, + {"user_id": 2, "title": "Second Post"} +]) + +# Perform an inner join +user_posts = db.JOIN( + db.tables["users"], + db.tables["posts"], + lambda row: row["users.id"] == row["posts.user_id"] +) +``` + +### Grouping and Aggregation + +The library supports `GROUP_BY` and aggregate functions like `COUNT`, `MAX`, and `SUM`. + +```python +from db import Table + +friends = Table( + "friends", + [ + {"id": 1, "city": "Denver", "state": "Colorado"}, + {"id": 2, "city": "Houston", "state": "Texas"}, + {"id": 3, "city": "Colorado Springs", "state": "Colorado"}, + ], +) + +# Group by state +grouped_by_state = db.GROUP_BY(friends, ["state"]) + +# Count cities in each state +state_counts = db.COUNT(grouped_by_state, "city") +# state_counts.rows will be ({'state': 'Colorado', 'COUNT(city)': 2}, {'state': 'Texas', 'COUNT(city)': 1}) +``` + +### The `query` Function + +For more complex, multi-step queries, you can use the `query` function, which chains together multiple operations in a declarative way. + +```python +from db import query + +# A complex query example +result = query( + db, + select=["employee.name", "department.title"], + from_=["employee"], + join=[ + [ + "department", + lambda row: row["employee.department_id"] == row["department.id"], + ], + ], + where=[lambda row: row["employee.salary"] > 150], +) +``` + +## Running Tests + +To ensure everything is working correctly, you can run the included test suite from your terminal: + +```bash +python -m unittest db_tests.py diff --git a/db.py b/db.py index cddf6b8..10b1c60 100644 --- a/db.py +++ b/db.py @@ -1,53 +1,62 @@ """A medium-faithful port of https://github.com/weinberg/SQLToy to Python""" +from __future__ import annotations + +from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple + class Table: - def __init__(self, name: str, rows: tuple[dict] = ()): + def __init__(self, name: str, rows: Iterable[Dict[str, Any]] = ()): self.name = name - self.rows = tuple(rows) - self._colnames = () + self.rows: Tuple[Dict[str, Any], ...] = tuple(rows) + self._colnames: Tuple[str, ...] = () - def set_colnames(self, colnames): + def set_colnames(self, colnames: Iterable[str]) -> None: self._colnames = tuple(sorted(colnames)) - def colnames(self): + def colnames(self) -> Tuple[str, ...]: if self._colnames: return self._colnames if not self.rows: raise ValueError("Need either rows or manually specified column names") return tuple(sorted(self.rows[0].keys())) - def filter(self, pred): + def filter(self, pred: Callable[[Dict[str, Any]], bool]) -> Table: return Table(self.name, [row for row in self.rows if pred(row)]) - def __repr__(self): + def __repr__(self) -> str: if not self.name: return f"Table({list(self.rows)!r})" return f"Table({self.name!r}, {list(self.rows)!r})" class Database: - def __init__(self): - self.tables = {} + def __init__(self) -> None: + self.tables: Dict[str, Table] = {} - def CREATE_TABLE(self, name, colnames=()): + def CREATE_TABLE(self, name: str, colnames: Iterable[str] = ()) -> Table: table = Table(name) if colnames: table.set_colnames(colnames) self.tables[name] = table return table - def DROP_TABLE(self, name): + def DROP_TABLE(self, name: str) -> None: del self.tables[name] - def FROM(self, first_table, *rest): + def FROM(self, first_table: str, *rest: str) -> Table: match rest: case (): return self.tables[first_table] case _: return self.CROSS_JOIN(self.tables[first_table], self.FROM(*rest)) - def SELECT(self, table, columns, aliases=None): + def SELECT( + self, + table: Table, + columns: Iterable[str], + aliases: Optional[Dict[str, str]] = None, + ) -> Table: if aliases is None: aliases = {} return Table( @@ -58,19 +67,25 @@ def SELECT(self, table, columns, aliases=None): ], ) - def WHERE(self, table, pred): + def WHERE(self, table: Table, pred: Callable[[Dict[str, Any]], bool]) -> Table: return table.filter(pred) - def INSERT_INTO(self, table_name, rows): + def INSERT_INTO(self, table_name: str, rows: Iterable[Dict[str, Any]]) -> None: table = self.tables[table_name] table.rows = (*table.rows, *rows) - def UPDATE(self, table, set, pred=lambda _: True): + def UPDATE( + self, + table: Table, + set_values: Dict[str, Any], + pred: Callable[[Dict[str, Any]], bool] = lambda _: True, + ) -> Table: return Table( - table.name, [{**row, **set} if pred(row) else row for row in table.rows] + table.name, + [{**row, **set_values} if pred(row) else row for row in table.rows], ) - def CROSS_JOIN(self, a, b): + def CROSS_JOIN(self, a: Table, b: Table) -> Table: rows = [] a_prefix = f"{a.name}." if a.name else "" b_prefix = f"{b.name}." if b.name else "" @@ -84,12 +99,16 @@ def CROSS_JOIN(self, a, b): ) return Table("", rows) - def INNER_JOIN(self, a, b, pred): + def INNER_JOIN( + self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool] + ) -> Table: return self.CROSS_JOIN(a, b).filter(pred) JOIN = INNER_JOIN - def LEFT_JOIN(self, a, b, pred): + def LEFT_JOIN( + self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool] + ) -> Table: rows = [] empty_b_row = {f"{b.name}.{k}": None for k in b.colnames()} for a_row in a.rows: @@ -107,24 +126,26 @@ def LEFT_JOIN(self, a, b, pred): rows.append({**mangled_a_row, **empty_b_row}) return Table("", rows) - def RIGHT_JOIN(self, a, b, pred): + def RIGHT_JOIN( + self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool] + ) -> Table: return self.LEFT_JOIN(b, a, pred) - def LIMIT(self, table, limit): + def LIMIT(self, table: Table, limit: int) -> Table: return Table(table.name, table.rows[:limit]) - def ORDER_BY(self, table, rel): + def ORDER_BY(self, table: Table, rel: Callable[[Dict[str, Any]], Any]) -> Table: # Differs from JS version by passing the whole row to the comparator return Table(table.name, sorted(table.rows, key=rel)) - def HAVING(self, table, pred): + def HAVING(self, table: Table, pred: Callable[[Dict[str, Any]], bool]) -> Table: return table.filter(pred) - def OFFSET(self, table, offset): + def OFFSET(self, table: Table, offset: int) -> Table: return Table(table.name, table.rows[offset:]) - def DISTINCT(self, table, columns): - seen = set() + def DISTINCT(self, table: Table, columns: Iterable[str]) -> Table: + seen: Set[Tuple[Tuple[str, Any], ...]] = set() rows = [] for row in table.rows: view = tuple((col, row[col]) for col in columns) @@ -133,8 +154,8 @@ def DISTINCT(self, table, columns): rows.append(dict(view)) return Table(table.name, rows) - def GROUP_BY(self, table, groupBys): - groupRows = {} + def GROUP_BY(self, table: Table, groupBys: Iterable[str]) -> Table: + groupRows: Dict[Tuple[Any, ...], list[Dict[str, Any]]] = {} for row in table.rows: key = tuple(row[col] for col in groupBys) if key not in groupRows: @@ -142,13 +163,19 @@ def GROUP_BY(self, table, groupBys): groupRows[key].append(row.copy()) resultRows = [] for group in groupRows.values(): - resultRow = {"_groupRows": group} + resultRow: Dict[str, Any] = {"_groupRows": group} for col in groupBys: resultRow[col] = group[0][col] resultRows.append(resultRow) return Table(table.name, resultRows) - def _aggregate(self, table, col, agg_name, agg): + def _aggregate( + self, + table: Table, + col: str, + agg_name: str, + agg: Callable[[Iterable[Dict[str, Any]]], Any], + ) -> Table: grouped = table.rows and "_groupRows" in table.rows[0] col_name = f"{agg_name}({col})" if not grouped: @@ -164,32 +191,36 @@ def _aggregate(self, table, col, agg_name, agg): rows.append(new_row) return Table(table.name, rows) - def COUNT(self, table, col): - return self._aggregate(table, col, "COUNT", len) + def COUNT(self, table: Table, col: str) -> Table: + return self._aggregate(table, col, "COUNT", lambda rows: len(list(rows))) - def MAX(self, table, col): - return self._aggregate(table, col, "MAX", lambda rows: max(row[col] for row in rows)) + def MAX(self, table: Table, col: str) -> Table: + return self._aggregate( + table, col, "MAX", lambda rows: max(row[col] for row in rows) + ) - def SUM(self, table, col): - return self._aggregate(table, col, "SUM", lambda rows: sum(row[col] for row in rows)) + def SUM(self, table: Table, col: str) -> Table: + return self._aggregate( + table, col, "SUM", lambda rows: sum(row[col] for row in rows) + ) - def __repr__(self): + def __repr__(self) -> str: return f"Database({list(self.tables.keys())!r})" def query( - db, - select=(), - select_as=None, - distinct=None, - from_=None, - join=(), - where=(), - group_by=(), - having=None, - order_by=None, - offset=None, - limit=None, + db: Database, + select: Iterable[str] = (), + select_as: Optional[Dict[str, str]] = None, + distinct: Optional[Iterable[str]] = None, + from_: Optional[Iterable[str]] = None, + join: Iterable[Tuple[str, Callable[[Dict[str, Any]], bool]]] = (), + where: Iterable[Callable[[Dict[str, Any]], bool]] = (), + group_by: Iterable[str] = (), + having: Optional[Callable[[Dict[str, Any]], bool]] = None, + order_by: Optional[Callable[[Dict[str, Any]], Any]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, ) -> Table: if from_ is None: raise ValueError("Need a FROM clause") @@ -211,10 +242,12 @@ def query( result = db.OFFSET(result, offset) if limit: result = db.LIMIT(result, limit) + if distinct: + result = db.DISTINCT(result, distinct) return result -def csv(table): +def csv(table: Table) -> None: print(",".join(table.colnames())) for row in table.rows: print(",".join(str(val) for val in row.values()))