diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 20d3101e..193f450c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,39 +4,50 @@ on: push: branches: - main - - 'release**' + - "release**" paths-ignore: - - '**.md' + - "**.md" pull_request: paths-ignore: - - '**.md' + - "**.md" workflow_dispatch: schedule: - - cron: '0 6 * * 1' + - cron: "0 6 * * 1" concurrency: group: "${{ github.workflow }}-${{ github.head_ref || github.run_id }}" cancel-in-progress: true jobs: - test_on_matrix: - name: build (${{ matrix.os }}, py${{ matrix.python-version }}) - runs-on: ${{ matrix.os }}-latest + test_on_ubuntu: + name: build (ubuntu, py${{ matrix.python-version }}) + runs-on: ubuntu-latest timeout-minutes: 6 strategy: fail-fast: false matrix: - os: [windows, ubuntu, macos] - python-version: ['3.13'] - include: - - os: ubuntu - python-version: '3.11' - - os: ubuntu - python-version: '3.12' + python-version: ["3.11", "3.12", "3.13"] env: MESA_FRAMES_RUNTIME_TYPECHECKING: "true" + POSTGRES_URI: postgresql://user:password@localhost:5432/testdb + SKIP_PG_TESTS: "false" + + services: + postgres: + image: postgres:15 + ports: + - 5432:5432 + env: + POSTGRES_USER: user + POSTGRES_PASSWORD: password + POSTGRES_DB: testdb + options: >- + --health-cmd="pg_isready" + --health-interval=10s + --health-timeout=5s + --health-retries=5 steps: - uses: actions/checkout@v4 @@ -53,24 +64,80 @@ jobs: - name: Install mesa-frames + dev dependencies run: | - # 1. Install the project itself uv pip install --system . - # 2. Install everything under the "dev" dependency group uv pip install --group dev --system - name: Test with pytest run: pytest --durations=10 --cov=mesa_frames tests/ --cov-report=xml - - if: matrix.os == 'ubuntu' - name: Codecov + - name: Upload coverage to Codecov uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} + test_on_mac_windows: + name: build (${{ matrix.os }}, py3.13) + runs-on: ${{ matrix.os }}-latest + timeout-minutes: 10 + + strategy: + matrix: + include: + - os: windows + skip_pg: true + - os: macos + skip_pg: false + + env: + MESA_FRAMES_RUNTIME_TYPECHECKING: "true" + POSTGRES_URI: postgresql://user:password@localhost:5432/testdb + SKIP_PG_TESTS: ${{ matrix.skip_pg }} + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install uv via GitHub Action + uses: astral-sh/setup-uv@v6 + with: + cache: true + + - name: Install mesa-frames + dev dependencies + run: | + uv pip install --system . + uv pip install --group dev --system + + - name: Install and Start PostgreSQL (macOS) + if: matrix.os == 'macos' + run: | + brew install postgresql@15 + + export PATH="/opt/homebrew/opt/postgresql@15/bin:$PATH" + export PGDATA="/opt/homebrew/var/postgresql@15" + + # Ensure a clean database directory + rm -rf $PGDATA + mkdir -p $PGDATA + + initdb --username=user --auth=trust --encoding=UTF8 $PGDATA + + pg_ctl -D $PGDATA -l logfile start + + sleep 5 + + createdb testdb -U user + shell: bash + + - name: Test with pytest + run: pytest --durations=10 --cov=mesa_frames tests/ --cov-report=xml + build: name: build runs-on: ubuntu-latest - needs: [test_on_matrix] + needs: [test_on_ubuntu, test_on_mac_windows] steps: - - name: All matrix builds and tests passed - run: echo "All matrix jobs completed successfully." \ No newline at end of file + - run: echo "All matrix jobs completed successfully." diff --git a/mesa_frames/abstract/datacollector.py b/mesa_frames/abstract/datacollector.py index cf23831e..9066e429 100644 --- a/mesa_frames/abstract/datacollector.py +++ b/mesa_frames/abstract/datacollector.py @@ -64,7 +64,7 @@ class AbstractDataCollector(ABC): _agent_reporters: dict[str, str | Callable] | None _trigger: Callable[..., bool] _reset_memory = bool - _storage_uri: Literal["memory:", "csv:", "postgresql:"] + _storage: Literal["memory", "csv", "parquet", "S3-csv", "S3-parquet", "postgresql"] _frames: list[pl.DataFrame] def __init__( @@ -74,7 +74,9 @@ def __init__( agent_reporters: dict[str, str | Callable] | None = None, trigger: Callable[[Any], bool] | None = None, reset_memory: bool = True, - storage: Literal["memory:", "csv:", "postgresql:"] = "memory:", + storage: Literal[ + "memory", "csv", "parquet", "S3-csv", "S3-parquet", "postgresql" + ] = "memory", ): """ Initialize a Datacollector. @@ -91,7 +93,7 @@ def __init__( A function(model) -> bool that determines whether to collect data. reset_memory : bool Whether to reset in-memory data after flushing. Default is True. - storage : Literal["memory:", "csv:", "postgresql:"] + storage : Literal["memory", "csv", "parquet", "S3-csv", "S3-parquet", "postgresql" ] Storage backend URI (e.g. 'memory:', 'csv:', 'postgresql:'). """ self._model = model @@ -99,7 +101,7 @@ def __init__( self._agent_reporters = agent_reporters or {} self._trigger = trigger or (lambda model: False) self._reset_memory = reset_memory - self._storage_uri = storage or "memory:" + self._storage = storage or "memory" self._frames = [] def collect(self) -> None: diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py new file mode 100644 index 00000000..942a3be3 --- /dev/null +++ b/mesa_frames/concrete/datacollector.py @@ -0,0 +1,502 @@ +""" +Concrete class for data collection in mesa-frames. + +This module defines a `DataCollector` implementation that gathers and optionally persists +model-level and agent-level data during simulations. It supports multiple storage backends, +including in-memory, CSV, Parquet, S3, and PostgreSQL, using Polars for efficient lazy +data processing. + +Classes: + DataCollector: + A concrete class defining logic for all data collector implementations. + It supports flexible reporting of model and agent attributes, conditional + data collection using a trigger function, and pluggable backends for storage. + +Supported Storage Backends: + - memory : In-memory collection (default) + - csv : Local CSV file output + - parquet : Local Parquet file output + - S3-csv : CSV files stored on Amazon S3 + - S3-parquet : Parquet files stored on Amazon S3 + - postgresql : PostgreSQL database with schema support + +Triggers: + - A `trigger` parameter can be provided to control conditional collection. + This is a callable taking the model as input and returning a boolean. + If true, data is collected during `conditional_collect()`. + +Usage: + The `DataCollector` class is designed to be used within a `ModelDF` instance + to collect model-level and/or agent-level data. + + Example: + -------- + from mesa_frames.concrete.model import ModelDF + from mesa_frames.concrete.datacollector import DataCollector + + class ExampleModel(ModelDF): + def __init__(self, agents: AgentsDF): + super().__init__() + self.agents = agents + self.dc = DataCollector( + model=self, + # other required arguments + ) + + def step(self): + # Option 1: collect immediately + self.dc.collect() + + # Option 2: collect based on condition + self.dc.conditional_collect() + + # Write the collected data to the destination + self.dc.flush() +""" + +import polars as pl +import boto3 +from urllib.parse import urlparse +import tempfile +import psycopg2 +from mesa_frames.abstract.datacollector import AbstractDataCollector +from typing import Any, Literal +from collections.abc import Callable +from mesa_frames import ModelDF +from psycopg2.extensions import connection + + +class DataCollector(AbstractDataCollector): + def __init__( + self, + model: ModelDF, + model_reporters: dict[str, Callable] | None = None, + agent_reporters: dict[str, str | Callable] | None = None, + trigger: Callable[[Any], bool] | None = None, + reset_memory: bool = True, + storage: Literal[ + "memory", "csv", "parquet", "S3-csv", "S3-parquet", "postgresql" + ] = "memory", + storage_uri: str | None = None, + schema: str = "public", + ): + """ + Initialize the DataCollector with configuration options. + + Parameters + ---------- + model : ModelDF + The model object from which data is collected. + model_reporters : dict[str, Callable] | None + Functions to collect data at the model level. + agent_reporters : dict[str, str | Callable] | None + Attributes or functions to collect data at the agent level. + trigger : Callable[[Any], bool] | None + A function(model) -> bool that determines whether to collect data. + reset_memory : bool + Whether to reset in-memory data after flushing. Default is True. + storage : Literal["memory", "csv", "parquet", "S3-csv", "S3-parquet", "postgresql" ] + Storage backend URI (e.g. 'memory:', 'csv:', 'postgresql:'). + storage_uri: str | None + URI or path corresponding to the selected storage backend. + schema: str + Schema name used for PostgreSQL storage. + + """ + super().__init__( + model=model, + model_reporters=model_reporters, + agent_reporters=agent_reporters, + trigger=trigger, + reset_memory=reset_memory, + storage=storage, # literal won't work + ) + self._writers = { + "csv": self._write_csv_local, + "parquet": self._write_parquet_local, + "S3-csv": self._write_csv_s3, + "S3-parquet": self._write_parquet_s3, + "postgresql": self._write_postgres, + } + self._storage_uri = storage_uri + self._schema = schema + + self._validate_inputs() + + def _collect(self): + """ + Collect data from the model and agents for the current step. + + This method checks for the presence of model and agent reporters + and calls the appropriate collection routines for each. + """ + if self._model_reporters: + self._collect_model_reporters() + + if self._agent_reporters: + self._collect_agent_reporters() + + def _collect_model_reporters(self): + """ + Collect model-level data using the model_reporters. + + Creates a LazyFrame containing the step, seed, and values + returned by each model reporter. Appends the LazyFrame to internal storage. + """ + model_data_dict = {} + model_data_dict["step"] = self._model._steps + model_data_dict["seed"] = str(self.seed) + for column_name, reporter in self._model_reporters.items(): + model_data_dict[column_name] = reporter(self._model) + model_lazy_frame = pl.LazyFrame([model_data_dict]) + self._frames.append(("model", str(self._model._steps), model_lazy_frame)) + + def _collect_agent_reporters(self): + """ + Collect agent-level data using the agent_reporters. + + Constructs a LazyFrame with one column per reporter and + includes `step` and `seed` metadata. Appends it to internal storage. + """ + agent_data_dict = {} + for col_name, reporter in self._agent_reporters.items(): + if isinstance(reporter, str): + for k, v in self._model.agents[reporter].items(): + agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v + else: + agent_data_dict[col_name] = reporter(self._model.agents) + agent_lazy_frame = pl.LazyFrame(agent_data_dict) + agent_lazy_frame = agent_lazy_frame.with_columns( + [ + pl.lit(self._model._steps).alias("step"), + pl.lit(str(self.seed)).alias("seed"), + ] + ) + self._frames.append(("agent", str(self._model._steps), agent_lazy_frame)) + + @property + def data(self) -> dict[str, pl.DataFrame]: + """ + Retrieve the collected data as eagerly evaluated Polars DataFrames. + + Returns + ------- + dict[str, pl.DataFrame] + A dictionary with keys "model" and "agent" mapping to concatenated DataFrames of collected data. + """ + model_frames = [ + lf.collect() for kind, step, lf in self._frames if kind == "model" + ] + agent_frames = [ + lf.collect() for kind, step, lf in self._frames if kind == "agent" + ] + return { + "model": pl.concat(model_frames) if model_frames else pl.DataFrame(), + "agent": pl.concat(agent_frames) if agent_frames else pl.DataFrame(), + } + + def _flush(self): + """ + Flush the collected data to the configured external storage backend. + + Uses the appropriate writer function based on the specified storage option. + """ + self._writers[self._storage](self._storage_uri) + + def _write_csv_local(self, uri: str): + """ + Write collected data to local CSV files. + + Parameters + ---------- + uri : str + Local directory path to write files into. + """ + for kind, step, df in self._frames: + df.collect().write_csv(f"{uri}/{kind}_step{step}.csv") + + def _write_parquet_local(self, uri: str): + """ + Write collected data to local Parquet files. + + Parameters + ---------- + uri: str + Local directory path to write files into. + """ + for kind, step, df in self._frames: + df.collect().write_parquet(f"{uri}/{kind}_step{step}.parquet") + + def _write_csv_s3(self, uri: str): + """ + Write collected data to AWS S3 in CSV format. + + Parameters + ---------- + uri: str + S3 URI (e.g., s3://bucket/path) to upload files to. + """ + self._write_s3(uri, format_="csv") + + def _write_parquet_s3(self, uri: str): + """ + Write collected data to AWS S3 in Parquet format. + + Parameters + ---------- + uri: str + S3 URI (e.g., s3://bucket/path) to upload files to. + """ + self._write_s3(uri, format_="parquet") + + def _write_s3(self, uri: str, format_: str): + """ + Upload collected data to S3 in a specified format. + + Parameters + ---------- + uri: str + S3 URI to upload to. + format_: str + Format of the output files ("csv" or "parquet"). + """ + s3 = boto3.client("s3") + parsed = urlparse(uri) + bucket = parsed.netloc + prefix = parsed.path.lstrip("/") + for kind, step, lf in self._frames: + df = lf.collect() + with tempfile.NamedTemporaryFile(suffix=f".{format_}") as tmp: + if format_ == "csv": + df.write_csv(tmp.name) + elif format_ == "parquet": + df.write_parquet(tmp.name) + key = f"{prefix}/{kind}_step{step}.{format_}" + s3.upload_file(tmp.name, bucket, key) + + def _write_postgres(self, uri: str): + """ + Write collected data to a PostgreSQL database. + + Each frame is inserted into the appropriate table (`model_data` or `agent_data`) + using batched insert queries. + + Parameters + ---------- + uri: str + PostgreSQL connection URI in the form postgresql://testuser:testpass@localhost:5432/testdb + """ + conn = self._get_db_connection(uri=uri) + cur = conn.cursor() + for kind, step, lf in self._frames: + df = lf.collect() + table = f"{kind}_data" + cols = df.columns + values = [tuple(row) for row in df.rows()] + placeholders = ", ".join(["%s"] * len(cols)) + columns = ", ".join(cols) + cur.executemany( + f"INSERT INTO {self._schema}.{table} ({columns}) VALUES ({placeholders})", + values, + ) + conn.commit() + cur.close() + conn.close() + + def _get_db_connection(self, uri: str) -> connection: + """ + Uri should be like: postgresql://user:pass@host:port/dbname. + + Parameters + ---------- + uri: str + PostgreSQL connection URI in the form postgresql://testuser:testpass@localhost:5432/testdb + + Returns + ------- + connection + psycopg2 connection + """ + parsed = urlparse(uri) + conn = psycopg2.connect( + dbname=parsed.path[1:], # remove leading slash + user=parsed.username, + password=parsed.password, + host=parsed.hostname, + port=parsed.port, + ) + return conn + + def _validate_inputs(self): + """ + Validate configuration and required schema for non-memory storage backends. + + - Ensures a `storage_uri` is provided if needed. + - For PostgreSQL, validates that required tables and columns exist. + """ + if self._storage != "memory" and self._storage_uri == None: + raise ValueError( + "Please define a storage_uri to if to be stored not in memory" + ) + + if self._storage == "postgresql": + conn = self._get_db_connection(self._storage_uri) + try: + self._validate_postgress_table_exists(conn) + self._validate_postgress_columns_exists(conn) + finally: + conn.close() + + def _validate_postgress_table_exists(self, conn: connection): + """ + Validate that the required PostgreSQL tables exist for storing model and agent data. + + Parameters + ---------- + conn: connection + Open database connection. + """ + if self._model_reporters: + self._validate_reporter_table(conn=conn, table_name="model_data") + if self._agent_reporters: + self._validate_reporter_table(conn=conn, table_name="agent_data") + + def _validate_postgress_columns_exists(self, conn: connection): + """ + Validate that required columns are present in the PostgreSQL tables. + + Parameters + ---------- + conn: connection + Open database connection. + """ + if self._model_reporters: + self._validate_reporter_table_columns( + conn=conn, table_name="model_data", reporter=self._model_reporters + ) + if self._agent_reporters: + self._validate_reporter_table_columns( + conn=conn, table_name="agent_data", reporter=self._agent_reporters + ) + + def _validate_reporter_table(self, conn: connection, table_name: str): + """ + Check if a given table exists in the PostgreSQL schema. + + Parameters + ---------- + conn : connection + Open database connection. + table_name : str + Name of the table to check. + + Raises + ------ + ValueError + If the table does not exist in the schema. + """ + query = f""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_schema = '{self._schema}' AND table_name = '{table_name}' + );""" + result = self._execute_query_with_result(conn, query) + if result == [(False,)]: + raise ValueError( + f"{self._schema}.{table_name} does not exist. To store collected data in DB please create a table with required columns" + ) + + def _validate_reporter_table_columns( + self, conn: connection, table_name: str, reporter: dict[str, Callable | str] + ): + """ + Check if the expected columns are present in a given PostgreSQL table. + + Parameters + ---------- + conn : connection + Open database connection. + table_name :str + Name of the table to validate. + reporter : dict[str, Callable | str] + Dictionary of reporters whose keys are expected as columns. + + Raises + ------ + ValueError + If any expected columns are missing from the table. + """ + expected_columns = set() + for col_name, required_column in reporter.items(): + if isinstance(required_column, str): + for k, v in self._model.agents[required_column].items(): + expected_columns.add( + (col_name + "_" + str(k.__class__.__name__)).lower() + ) + else: + expected_columns.add(col_name.lower()) + + query = f""" + SELECT column_name + FROM information_schema.columns + WHERE table_schema = '{self._schema}' AND table_name = '{table_name}'; + """ + + result = self._execute_query_with_result(conn, query) + if not result: + raise ValueError( + f"Could not retrieve columns for table {self._schema}.{table_name}" + ) + + existing_columns = {row[0] for row in result} + missing_columns = expected_columns - existing_columns + required_columns = { + "step": "Integer", + "seed": "Varchar", + } + + missing_required = { + col: col_type + for col, col_type in required_columns.items() + if col not in existing_columns + } + + if missing_columns or missing_required: + error_parts = [] + + if missing_columns: + error_parts.append(f"Missing columns: {sorted(missing_columns)}") + + if missing_required: + required_list = [ + f"`{col}` column of type ({col_type})" + for col, col_type in missing_required.items() + ] + error_parts.append( + "Missing specific columns: " + ", ".join(required_list) + ) + + raise ValueError( + f"Missing columns in table {self._schema}.{table_name}: " + + "; ".join(error_parts) + ) + + def _execute_query_with_result(self, conn: connection, query: str) -> list[tuple]: + """ + Execute a SQL query and return the fetched results. + + Parameters + ---------- + conn : connection + Open database connection. + query : str + SQL query string. + + Returns + ------- + list[tuple] + Query result rows. + """ + with conn.cursor() as cur: + cur.execute(query) + return cur.fetchall() diff --git a/pyproject.toml b/pyproject.toml index 5da0c84e..40fafbd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,8 @@ dependencies = [ "pyarrow>=20.0.0", # polars._typing added in 1.0.0 "polars>=1.30.0", + "psycopg2-binary==2.9.10", + "boto3>=1.35.91" ] dynamic = ["version"] diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py new file mode 100644 index 00000000..401b09e1 --- /dev/null +++ b/tests/test_datacollector.py @@ -0,0 +1,480 @@ +from mesa_frames.concrete.datacollector import DataCollector +from mesa_frames import ModelDF, AgentSetPolars, AgentsDF +import pytest +import polars as pl +import beartype +import tempfile +import os + + +def custom_trigger(model): + return model._steps % 2 == 0 + + +class ExampleAgentSet1(AgentSetPolars): + def __init__(self, model: ModelDF): + super().__init__(model) + self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) + self["age"] = pl.Series("age", [10, 20, 30, 40]) + + def add_wealth(self, amount: int) -> None: + self["wealth"] += amount + + def step(self) -> None: + self.add_wealth(1) + + +class ExampleAgentSet2(AgentSetPolars): + def __init__(self, model: ModelDF): + super().__init__(model) + self["wealth"] = pl.Series("wealth", [10, 20, 30, 40]) + self["age"] = pl.Series("age", [11, 22, 33, 44]) + + def add_wealth(self, amount: int) -> None: + self["wealth"] += amount + + def step(self) -> None: + self.add_wealth(2) + + +class ExampleAgentSet3(AgentSetPolars): + def __init__(self, model: ModelDF): + super().__init__(model) + self["age"] = pl.Series("age", [1, 2, 3, 4]) + self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) + + def age_agents(self, amount: int) -> None: + self["age"] += amount + + def step(self) -> None: + self.age_agents(1) + + +class ExampleModel(ModelDF): + def __init__(self, agents: AgentsDF): + super().__init__() + self.agents = agents + + def step(self): + self.agents.do("step") + + def run_model(self, n): + for _ in range(n): + self.step() + + def run_model_with_collect(self, n): + for _ in range(n): + self.step() + self.dc.collect() + + def run_model_with_conditional_collect(self, n): + for _ in range(n): + self.step() + self.dc.conditional_collect() + + +@pytest.fixture(scope="session") +def postgres_uri(): + return os.getenv("POSTGRES_URI", "postgresql://user:password@localhost:5432/testdb") + + +@pytest.fixture +def fix1_AgentSetPolars() -> ExampleAgentSet1: + return ExampleAgentSet1(ModelDF()) + + +@pytest.fixture +def fix2_AgentSetPolars() -> ExampleAgentSet2: + return ExampleAgentSet2(ModelDF()) + + +@pytest.fixture +def fix3_AgentSetPolars() -> ExampleAgentSet3: + return ExampleAgentSet3(ModelDF()) + + +@pytest.fixture +def fix_AgentsDF( + fix1_AgentSetPolars: ExampleAgentSet1, + fix2_AgentSetPolars: ExampleAgentSet2, + fix3_AgentSetPolars: ExampleAgentSet3, +) -> AgentsDF: + model = ModelDF() + agents = AgentsDF(model) + agents.add([fix1_AgentSetPolars, fix2_AgentSetPolars, fix3_AgentSetPolars]) + return agents + + +@pytest.fixture +def fix1_model(fix_AgentsDF: AgentsDF) -> ExampleModel: + return ExampleModel(fix_AgentsDF) + + +class TestDataCollector: + def test__init__(self, fix1_model, postgres_uri): + model = fix1_model + with pytest.raises( + beartype.roar.BeartypeCallHintParamViolation, + match="not instance of .*Callable", + ): + model.test_dc = DataCollector( + model=model, model_reporters={"total_agents": "sum"} + ) + with pytest.raises( + ValueError, + match="Please define a storage_uri to if to be stored not in memory", + ): + model.test_dc = DataCollector(model=model, storage="S3-csv") + + with pytest.raises( + ValueError, + match="Please define a storage_uri to if to be stored not in memory", + ): + model.test_dc = DataCollector(model=model, storage="postgresql") + + def test_collect(self, fix1_model): + model = fix1_model + + model.dc = DataCollector( + model=model, + model_reporters={ + "total_agents": lambda model: sum( + len(agentset) for agentset in model.agents._agentsets + ) + }, + agent_reporters={ + "wealth": lambda agents: agents._agentsets[0]["wealth"], + "age": "age", + }, + ) + + model.dc.collect() + collected_data = model.dc.data + + # test collected_model_data + assert collected_data["model"].shape == (1, 3) + assert collected_data["model"].columns == ["step", "seed", "total_agents"] + assert collected_data["model"]["step"].to_list() == [0] + assert collected_data["model"]["total_agents"].to_list() == [12] + with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): + collected_data["model"]["max_wealth"] + + assert collected_data["agent"].shape == (4, 6) + assert list(collected_data["agent"].columns) == [ + "wealth", + "age_ExampleAgentSet1", + "age_ExampleAgentSet2", + "age_ExampleAgentSet3", + "step", + "seed", + ] + assert collected_data["agent"]["wealth"].to_list() == [1, 2, 3, 4] + assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ + 10, + 20, + 30, + 40, + ] + assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ + 11, + 22, + 33, + 44, + ] + assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [1, 2, 3, 4] + assert collected_data["agent"]["step"].to_list() == [0, 0, 0, 0] + with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): + collected_data["agent"]["max_wealth"] + + def test_collect_step(self, fix1_model): + model = fix1_model + model.dc = DataCollector( + model=model, + model_reporters={ + "total_agents": lambda model: sum( + len(agentset) for agentset in model.agents._agentsets + ) + }, + agent_reporters={ + "wealth": lambda agents: agents._agentsets[0]["wealth"], + "age": "age", + }, + ) + model.run_model(5) + + model.dc.collect() + collected_data = model.dc.data + + assert collected_data["model"].shape == (1, 3) + assert collected_data["model"].columns == ["step", "seed", "total_agents"] + assert collected_data["model"]["step"].to_list() == [5] + assert collected_data["model"]["total_agents"].to_list() == [12] + + assert collected_data["agent"].shape == (4, 6) + assert list(collected_data["agent"].columns) == [ + "wealth", + "age_ExampleAgentSet1", + "age_ExampleAgentSet2", + "age_ExampleAgentSet3", + "step", + "seed", + ] + assert collected_data["agent"]["wealth"].to_list() == [6, 7, 8, 9] + assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ + 10, + 20, + 30, + 40, + ] + assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ + 11, + 22, + 33, + 44, + ] + assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [6, 7, 8, 9] + assert collected_data["agent"]["step"].to_list() == [5, 5, 5, 5] + + def test_conditional_collect(self, fix1_model): + model = fix1_model + model.dc = DataCollector( + model=model, + trigger=custom_trigger, + model_reporters={ + "total_agents": lambda model: sum( + len(agentset) for agentset in model.agents._agentsets + ) + }, + agent_reporters={ + "wealth": lambda agents: agents._agentsets[0]["wealth"], + "age": "age", + }, + ) + + model.run_model_with_conditional_collect(5) + collected_data = model.dc.data + + assert collected_data["model"].shape == (2, 3) + assert collected_data["model"].columns == ["step", "seed", "total_agents"] + assert collected_data["model"]["step"].to_list() == [2, 4] + assert collected_data["model"]["total_agents"].to_list() == [12, 12] + + assert collected_data["agent"].shape == (8, 6) + assert list(collected_data["agent"].columns) == [ + "wealth", + "age_ExampleAgentSet1", + "age_ExampleAgentSet2", + "age_ExampleAgentSet3", + "step", + "seed", + ] + assert collected_data["agent"]["wealth"].to_list() == [3, 4, 5, 6, 5, 6, 7, 8] + assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ + 10, + 20, + 30, + 40, + 10, + 20, + 30, + 40, + ] + assert collected_data["agent"]["age_ExampleAgentSet2"].to_list() == [ + 11, + 22, + 33, + 44, + 11, + 22, + 33, + 44, + ] + assert collected_data["agent"]["age_ExampleAgentSet3"].to_list() == [ + 3, + 4, + 5, + 6, + 5, + 6, + 7, + 8, + ] + assert collected_data["agent"]["step"].to_list() == [2, 2, 2, 2, 4, 4, 4, 4] + + def test_flush_local_csv(self, fix1_model): + with tempfile.TemporaryDirectory() as tmpdir: + model = fix1_model + model.dc = DataCollector( + model=model, + trigger=custom_trigger, + model_reporters={ + "total_agents": lambda model: sum( + len(agentset) for agentset in model.agents._agentsets + ) + }, + agent_reporters={ + "wealth": lambda agents: agents._agentsets[0]["wealth"], + "age": "age", + }, + storage="csv", + storage_uri=tmpdir, + ) + + model.run_model_with_conditional_collect(4) + model.dc.flush() + + # check deletion after flush + collected_data = model.dc.data + assert collected_data["model"].shape == (0, 0) + assert collected_data["agent"].shape == (0, 0) + + created_files = os.listdir(tmpdir) + assert len(created_files) == 4, ( + f"Expected 4 files, found {len(created_files)}: {created_files}" + ) + + model_df = pl.read_csv( + os.path.join(tmpdir, "model_step2.csv"), + schema_overrides={"seed": pl.Utf8}, + ) + assert model_df.columns == ["step", "seed", "total_agents"] + assert model_df["step"].to_list() == [2] + assert model_df["total_agents"].to_list() == [12] + + agent_df = pl.read_csv( + os.path.join(tmpdir, "agent_step2.csv"), + schema_overrides={"seed": pl.Utf8}, + ) + assert agent_df.columns == [ + "wealth", + "age_ExampleAgentSet1", + "age_ExampleAgentSet2", + "age_ExampleAgentSet3", + "step", + "seed", + ] + assert agent_df["step"].to_list() == [2, 2, 2, 2] + assert agent_df["wealth"].to_list() == [3, 4, 5, 6] + assert agent_df["age_ExampleAgentSet1"].to_list() == [10, 20, 30, 40] + assert agent_df["age_ExampleAgentSet2"].to_list() == [11, 22, 33, 44] + assert agent_df["age_ExampleAgentSet3"].to_list() == [ + 3, + 4, + 5, + 6, + ] + + agent_df = pl.read_csv( + os.path.join(tmpdir, "agent_step4.csv"), + schema_overrides={"seed": pl.Utf8}, + ) + assert agent_df["step"].to_list() == [4, 4, 4, 4] + assert agent_df["wealth"].to_list() == [5, 6, 7, 8] + + def test_flush_local_parquet(self, fix1_model): + with tempfile.TemporaryDirectory() as tmpdir: + model = fix1_model + model.dc = DataCollector( + model=model, + trigger=custom_trigger, + model_reporters={ + "total_agents": lambda model: sum( + len(agentset) for agentset in model.agents._agentsets + ) + }, + agent_reporters={ + "wealth": lambda agents: agents._agentsets[0]["wealth"] + }, + storage="parquet", + storage_uri=tmpdir, + ) + + model.dc.collect() + model.dc.flush() + + created_files = os.listdir(tmpdir) + assert len(created_files) == 2, ( + f"Expected 2 files, found {len(created_files)}: {created_files}" + ) + + model_df = pl.read_parquet(os.path.join(tmpdir, "model_step0.parquet")) + assert model_df["step"].to_list() == [0] + assert model_df["total_agents"].to_list() == [12] + + agent_df = pl.read_parquet(os.path.join(tmpdir, "agent_step0.parquet")) + assert agent_df["step"].to_list() == [0, 0, 0, 0] + assert agent_df["wealth"].to_list() == [1, 2, 3, 4] + + @pytest.mark.skipif( + os.getenv("SKIP_PG_TESTS") == "true", + reason="PostgreSQL tests are skipped on Windows runners", + ) + def test_postgress(self, fix1_model, postgres_uri): + model = fix1_model + + # Connect directly and validate data + import psycopg2 + + conn = psycopg2.connect(postgres_uri) + cur = conn.cursor() + + cur.execute(""" + CREATE TABLE public.model_data ( + step INTEGER, + seed VARCHAR, + total_agents INTEGER + ) + """) + + cur.execute(""" + CREATE TABLE public.agent_data ( + step INTEGER, + seed VARCHAR, + age_ExampleAgentSet1 INTEGER, + age_ExampleAgentSet2 INTEGER, + age_ExampleAgentSet3 INTEGER, + wealth INTEGER + ) + """) + conn.commit() + + model.dc = DataCollector( + model=model, + trigger=custom_trigger, + model_reporters={ + "total_agents": lambda model: sum( + len(agentset) for agentset in model.agents._agentsets + ) + }, + agent_reporters={ + "wealth": lambda agents: agents._agentsets[0]["wealth"], + "age": "age", + }, + storage="postgresql", + schema="public", + storage_uri=postgres_uri, + ) + + model.run_model_with_conditional_collect(4) + model.dc.flush() + + # Connect directly and validate data + + # Check model data + cur.execute("SELECT step, total_agents FROM model_data ORDER BY step") + model_rows = cur.fetchall() + assert model_rows == [(2, 12), (4, 12)] + + cur.execute( + "SELECT step, wealth,age_ExampleAgentSet1, age_ExampleAgentSet2, age_ExampleAgentSet3 FROM agent_data WHERE step=2 ORDER BY wealth" + ) + agent_rows = cur.fetchall() + assert agent_rows == [ + (2, 3, 10, 11, 3), + (2, 4, 20, 22, 4), + (2, 5, 30, 33, 5), + (2, 6, 40, 44, 6), + ] + + cur.close() + conn.close()