Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Python SQL-like Database

This project is a Python-based, in-memory database that provides a simplified, SQL-like interface for data manipulation. It is a medium-faithful port of the original [SQLToy](https://github.com/weinberg/SQLToy) project.

## Core Concepts

The two primary components of this library are the `Database` and `Table` classes.

- **`Database`**: Manages a collection of tables and provides methods for performing SQL-like operations.
- **`Table`**: Represents a collection of rows, where each row is a dictionary of column-value pairs.

## Basic Operations

### Creating a Table

You can create a new table within a database using the `CREATE_TABLE` method.

```python
from db import Database

db = Database()
db.CREATE_TABLE("users", colnames=["id", "name"])
```

### Inserting Data

To add data to a table, use the `INSERT_INTO` method.

```python
db.INSERT_INTO("users", [
{"id": 1, "name": "Alice"},
{"id": 2, "name": "Bob"}
])
```

### Selecting Data

The `SELECT` method allows you to retrieve specific columns from a table.

```python
users_table = db.tables["users"]
names = db.SELECT(users_table, ["name"])
# names.rows will be ({'name': 'Alice'}, {'name': 'Bob'})
```

## Advanced Queries

This library supports a variety of more complex query operations.

### Joining Tables

You can perform `INNER JOIN`, `LEFT JOIN`, and `RIGHT JOIN` operations.

```python
# Create another table
db.CREATE_TABLE("posts", colnames=["user_id", "title"])
db.INSERT_INTO("posts", [
{"user_id": 1, "title": "First Post"},
{"user_id": 2, "title": "Second Post"}
])

# Perform an inner join
user_posts = db.JOIN(
db.tables["users"],
db.tables["posts"],
lambda row: row["users.id"] == row["posts.user_id"]
)
```

### Grouping and Aggregation

The library supports `GROUP_BY` and aggregate functions like `COUNT`, `MAX`, and `SUM`.

```python
from db import Table

friends = Table(
"friends",
[
{"id": 1, "city": "Denver", "state": "Colorado"},
{"id": 2, "city": "Houston", "state": "Texas"},
{"id": 3, "city": "Colorado Springs", "state": "Colorado"},
],
)

# Group by state
grouped_by_state = db.GROUP_BY(friends, ["state"])

# Count cities in each state
state_counts = db.COUNT(grouped_by_state, "city")
# state_counts.rows will be ({'state': 'Colorado', 'COUNT(city)': 2}, {'state': 'Texas', 'COUNT(city)': 1})
```

### The `query` Function

For more complex, multi-step queries, you can use the `query` function, which chains together multiple operations in a declarative way.

```python
from db import query

# A complex query example
result = query(
db,
select=["employee.name", "department.title"],
from_=["employee"],
join=[
[
"department",
lambda row: row["employee.department_id"] == row["department.id"],
],
],
where=[lambda row: row["employee.salary"] > 150],
)
```

## Running Tests

To ensure everything is working correctly, you can run the included test suite from your terminal:

```bash
python -m unittest db_tests.py
133 changes: 82 additions & 51 deletions db.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,60 @@
"""A medium-faithful port of https://github.com/weinberg/SQLToy to Python"""

from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple


Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do Row = Dict[str, Any] and RowPredicate = Callable[[Row], bool] to make types easier to read

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry which line are you referencing/

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding new type aliases so we can use those throughout the new type annotations. This will make nearly every annotation much shorter

class Table:
def __init__(self, name: str, rows: tuple[dict] = ()):
def __init__(self, name: str, rows: Iterable[Dict[str, Any]] = ()):
self.name = name
self.rows = tuple(rows)
self._colnames = ()
self.rows: Tuple[Dict[str, Any], ...] = tuple(rows)
self._colnames: Tuple[str, ...] = ()

def set_colnames(self, colnames):
def set_colnames(self, colnames: Iterable[str]) -> None:
self._colnames = tuple(sorted(colnames))

def colnames(self):
def colnames(self) -> Tuple[str, ...]:
if self._colnames:
return self._colnames
if not self.rows:
raise ValueError("Need either rows or manually specified column names")
return tuple(sorted(self.rows[0].keys()))

def filter(self, pred):
def filter(self, pred: Callable[[Dict[str, Any]], bool]) -> "Table":
return Table(self.name, [row for row in self.rows if pred(row)])

def __repr__(self):
def __repr__(self) -> str:
if not self.name:
return f"Table({list(self.rows)!r})"
return f"Table({self.name!r}, {list(self.rows)!r})"


class Database:
def __init__(self):
self.tables = {}
def __init__(self) -> None:
self.tables: Dict[str, Table] = {}

def CREATE_TABLE(self, name, colnames=()):
def CREATE_TABLE(self, name: str, colnames: Iterable[str] = ()) -> Table:
table = Table(name)
if colnames:
table.set_colnames(colnames)
self.tables[name] = table
return table

def DROP_TABLE(self, name):
def DROP_TABLE(self, name: str) -> None:
del self.tables[name]

def FROM(self, first_table, *rest):
def FROM(self, first_table: str, *rest: str) -> Table:
match rest:
case ():
return self.tables[first_table]
case _:
return self.CROSS_JOIN(self.tables[first_table], self.FROM(*rest))

def SELECT(self, table, columns, aliases=None):
def SELECT(
self,
table: Table,
columns: Iterable[str],
aliases: Optional[Dict[str, str]] = None,
) -> Table:
if aliases is None:
aliases = {}
return Table(
Expand All @@ -58,19 +65,25 @@ def SELECT(self, table, columns, aliases=None):
],
)

def WHERE(self, table, pred):
def WHERE(self, table: Table, pred: Callable[[Dict[str, Any]], bool]) -> Table:
return table.filter(pred)

def INSERT_INTO(self, table_name, rows):
def INSERT_INTO(self, table_name: str, rows: Iterable[Dict[str, Any]]) -> None:
table = self.tables[table_name]
table.rows = (*table.rows, *rows)

def UPDATE(self, table, set, pred=lambda _: True):
def UPDATE(
self,
table: Table,
set_values: Dict[str, Any],
pred: Callable[[Dict[str, Any]], bool] = lambda _: True,
) -> Table:
return Table(
table.name, [{**row, **set} if pred(row) else row for row in table.rows]
table.name,
[{**row, **set_values} if pred(row) else row for row in table.rows],
)
Comment on lines +80 to 86
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/set_values/row_updates/g maybe


def CROSS_JOIN(self, a, b):
def CROSS_JOIN(self, a: Table, b: Table) -> Table:
rows = []
a_prefix = f"{a.name}." if a.name else ""
b_prefix = f"{b.name}." if b.name else ""
Expand All @@ -84,12 +97,16 @@ def CROSS_JOIN(self, a, b):
)
return Table("", rows)

def INNER_JOIN(self, a, b, pred):
def INNER_JOIN(
self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool]
) -> Table:
return self.CROSS_JOIN(a, b).filter(pred)

JOIN = INNER_JOIN

def LEFT_JOIN(self, a, b, pred):
def LEFT_JOIN(
self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool]
) -> Table:
rows = []
empty_b_row = {f"{b.name}.{k}": None for k in b.colnames()}
for a_row in a.rows:
Expand All @@ -107,24 +124,26 @@ def LEFT_JOIN(self, a, b, pred):
rows.append({**mangled_a_row, **empty_b_row})
return Table("", rows)

def RIGHT_JOIN(self, a, b, pred):
def RIGHT_JOIN(
self, a: Table, b: Table, pred: Callable[[Dict[str, Any]], bool]
) -> Table:
return self.LEFT_JOIN(b, a, pred)

def LIMIT(self, table, limit):
def LIMIT(self, table: Table, limit: int) -> Table:
return Table(table.name, table.rows[:limit])

def ORDER_BY(self, table, rel):
def ORDER_BY(self, table: Table, rel: Callable[[Dict[str, Any]], Any]) -> Table:
# Differs from JS version by passing the whole row to the comparator
return Table(table.name, sorted(table.rows, key=rel))

def HAVING(self, table, pred):
def HAVING(self, table: Table, pred: Callable[[Dict[str, Any]], bool]) -> Table:
return table.filter(pred)

def OFFSET(self, table, offset):
def OFFSET(self, table: Table, offset: int) -> Table:
return Table(table.name, table.rows[offset:])

def DISTINCT(self, table, columns):
seen = set()
def DISTINCT(self, table: Table, columns: Iterable[str]) -> Table:
seen: Set[Tuple[Tuple[str, Any], ...]] = set()
rows = []
for row in table.rows:
view = tuple((col, row[col]) for col in columns)
Expand All @@ -133,22 +152,28 @@ def DISTINCT(self, table, columns):
rows.append(dict(view))
return Table(table.name, rows)

def GROUP_BY(self, table, groupBys):
groupRows = {}
def GROUP_BY(self, table: Table, groupBys: Iterable[str]) -> Table:
groupRows: Dict[Tuple[Any, ...], list[Dict[str, Any]]] = {}
for row in table.rows:
key = tuple(row[col] for col in groupBys)
if key not in groupRows:
groupRows[key] = []
groupRows[key].append(row.copy())
resultRows = []
for group in groupRows.values():
resultRow = {"_groupRows": group}
resultRow: Dict[str, Any] = {"_groupRows": group}
for col in groupBys:
resultRow[col] = group[0][col]
resultRows.append(resultRow)
return Table(table.name, resultRows)

def _aggregate(self, table, col, agg_name, agg):
def _aggregate(
self,
table: Table,
col: str,
agg_name: str,
agg: Callable[[Iterable[Dict[str, Any]]], Any],
) -> Table:
grouped = table.rows and "_groupRows" in table.rows[0]
col_name = f"{agg_name}({col})"
if not grouped:
Expand All @@ -164,32 +189,36 @@ def _aggregate(self, table, col, agg_name, agg):
rows.append(new_row)
return Table(table.name, rows)

def COUNT(self, table, col):
return self._aggregate(table, col, "COUNT", len)
def COUNT(self, table: Table, col: str) -> Table:
return self._aggregate(table, col, "COUNT", lambda rows: len(list(rows)))

def MAX(self, table, col):
return self._aggregate(table, col, "MAX", lambda rows: max(row[col] for row in rows))
def MAX(self, table: Table, col: str) -> Table:
return self._aggregate(
table, col, "MAX", lambda rows: max(row[col] for row in rows)
)

def SUM(self, table, col):
return self._aggregate(table, col, "SUM", lambda rows: sum(row[col] for row in rows))
def SUM(self, table: Table, col: str) -> Table:
return self._aggregate(
table, col, "SUM", lambda rows: sum(row[col] for row in rows)
)

def __repr__(self):
def __repr__(self) -> str:
return f"Database({list(self.tables.keys())!r})"


def query(
db,
select=(),
select_as=None,
distinct=None,
from_=None,
join=(),
where=(),
group_by=(),
having=None,
order_by=None,
offset=None,
limit=None,
db: Database,
select: Iterable[str] = (),
select_as: Optional[Dict[str, str]] = None,
distinct: Optional[Iterable[str]] = None,
from_: Optional[Iterable[str]] = None,
join: Iterable[Tuple[str, Callable[[Dict[str, Any]], bool]]] = (),
where: Iterable[Callable[[Dict[str, Any]], bool]] = (),
group_by: Iterable[str] = (),
having: Optional[Callable[[Dict[str, Any]], bool]] = None,
order_by: Optional[Callable[[Dict[str, Any]], Any]] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
) -> Table:
if from_ is None:
raise ValueError("Need a FROM clause")
Expand All @@ -211,10 +240,12 @@ def query(
result = db.OFFSET(result, offset)
if limit:
result = db.LIMIT(result, limit)
if distinct:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go in a separate commit and needs separate tests. Additionally, in the SQL order of execution, I think it goes between select and order by

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(test is slightly less important but it should go in a separate commit than type checking for sure)

result = db.DISTINCT(result, distinct)
return result


def csv(table):
def csv(table: Table) -> None:
print(",".join(table.colnames()))
for row in table.rows:
print(",".join(str(val) for val in row.values()))