diff --git a/docs/general/user-guide/1_classes.md b/docs/general/user-guide/1_classes.md
index f85c062d..9ac446de 100644
--- a/docs/general/user-guide/1_classes.md
+++ b/docs/general/user-guide/1_classes.md
@@ -27,9 +27,9 @@ You can access the underlying DataFrame where agents are stored with `self.df`.
## Model 🏗️
-To add your AgentSet to your Model, you should also add it to the sets with `+=` or `add`.
+To add your AgentSet to your Model, use the registry `self.sets` with `+=` or `add`.
-NOTE: Model.sets are stored in a class which is entirely similar to AgentSet called AgentSetRegistry. The API of the two are the same. If you try accessing AgentSetRegistry.df, you will get a dictionary of `[AgentSet, DataFrame]`.
+Note: All agent sets live inside `AgentSetRegistry` (available as `model.sets`). Access sets through the registry, and access DataFrames from the set itself. For example: `self.sets["Preys"].df`.
Example:
@@ -43,7 +43,8 @@ class EcosystemModel(Model):
def step(self):
self.sets.do("move")
self.sets.do("hunt")
- self.prey.do("reproduce")
+ # Access specific sets via the registry
+ self.sets["Preys"].do("reproduce")
```
## Space: Grid 🌐
@@ -76,18 +77,23 @@ Example:
class ExampleModel(Model):
def __init__(self):
super().__init__()
- self.sets = MoneyAgent(self)
+ # Add the set to the registry
+ self.sets.add(MoneyAgents(100, self))
+ # Configure reporters: use the registry to locate sets; get df from the set
self.datacollector = DataCollector(
model=self,
- model_reporters={"total_wealth": lambda m: lambda m: list(m.sets.df.values())[0]["wealth"].sum()},
+ model_reporters={
+ "total_wealth": lambda m: m.sets["MoneyAgents"].df["wealth"].sum(),
+ },
agent_reporters={"wealth": "wealth"},
storage="csv",
storage_uri="./data",
- trigger=lambda m: m.schedule.steps % 2 == 0
+ trigger=lambda m: m.steps % 2 == 0,
)
def step(self):
- self.sets.step()
+ # Step all sets via the registry
+ self.sets.do("step")
self.datacollector.conditional_collect()
self.datacollector.flush()
```
diff --git a/docs/general/user-guide/2_introductory-tutorial.ipynb b/docs/general/user-guide/2_introductory-tutorial.ipynb
index ec1165da..11391f9d 100644
--- a/docs/general/user-guide/2_introductory-tutorial.ipynb
+++ b/docs/general/user-guide/2_introductory-tutorial.ipynb
@@ -74,7 +74,9 @@
" self.sets += agents_cls(N, self)\n",
" self.datacollector = DataCollector(\n",
" model=self,\n",
- " model_reporters={\"total_wealth\": lambda m: m.agents[\"wealth\"].sum()},\n",
+ " model_reporters={\n",
+ " \"total_wealth\": lambda m: m.sets[\"MoneyAgents\"].df[\"wealth\"].sum()\n",
+ " },\n",
" agent_reporters={\"wealth\": \"wealth\"},\n",
" storage=\"csv\",\n",
" storage_uri=\"./data\",\n",
diff --git a/docs/general/user-guide/4_datacollector.ipynb b/docs/general/user-guide/4_datacollector.ipynb
index 3fa16b49..0809caa2 100644
--- a/docs/general/user-guide/4_datacollector.ipynb
+++ b/docs/general/user-guide/4_datacollector.ipynb
@@ -26,7 +26,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 6,
"id": "9a63283cbaf04dbcab1f6479b197f3a8",
"metadata": {
"editable": true
@@ -48,7 +48,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 7,
"id": "72eea5119410473aa328ad9291626812",
"metadata": {
"editable": true
@@ -63,11 +63,11 @@
" │ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
" │ i64 ┆ str ┆ i64 ┆ f64 ┆ i64 │\n",
" ╞══════╪═════════════════════════════════╪═══════╪══════════════╪══════════╡\n",
- " │ 2 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n",
- " │ 4 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n",
- " │ 6 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n",
- " │ 8 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n",
- " │ 10 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n",
+ " │ 2 ┆ 332212815818606584686857770936… ┆ 0 ┆ 1000.0 ┆ 1000 │\n",
+ " │ 4 ┆ 332212815818606584686857770936… ┆ 0 ┆ 1000.0 ┆ 1000 │\n",
+ " │ 6 ┆ 332212815818606584686857770936… ┆ 0 ┆ 1000.0 ┆ 1000 │\n",
+ " │ 8 ┆ 332212815818606584686857770936… ┆ 0 ┆ 1000.0 ┆ 1000 │\n",
+ " │ 10 ┆ 332212815818606584686857770936… ┆ 0 ┆ 1000.0 ┆ 1000 │\n",
" └──────┴─────────────────────────────────┴───────┴──────────────┴──────────┘,\n",
" 'agent': shape: (5_000, 4)\n",
" ┌────────────────────┬──────┬─────────────────────────────────┬───────┐\n",
@@ -75,21 +75,21 @@
" │ --- ┆ --- ┆ --- ┆ --- │\n",
" │ f64 ┆ i32 ┆ str ┆ i32 │\n",
" ╞════════════════════╪══════╪═════════════════════════════════╪═══════╡\n",
- " │ 0.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n",
- " │ 3.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n",
- " │ 1.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n",
- " │ 3.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n",
- " │ 6.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n",
+ " │ 3.0 ┆ 2 ┆ 332212815818606584686857770936… ┆ 0 │\n",
+ " │ 0.0 ┆ 2 ┆ 332212815818606584686857770936… ┆ 0 │\n",
+ " │ 2.0 ┆ 2 ┆ 332212815818606584686857770936… ┆ 0 │\n",
+ " │ 1.0 ┆ 2 ┆ 332212815818606584686857770936… ┆ 0 │\n",
+ " │ 0.0 ┆ 2 ┆ 332212815818606584686857770936… ┆ 0 │\n",
" │ … ┆ … ┆ … ┆ … │\n",
- " │ 4.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n",
- " │ 1.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n",
- " │ 0.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n",
- " │ 0.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n",
- " │ 0.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n",
+ " │ 0.0 ┆ 10 ┆ 332212815818606584686857770936… ┆ 0 │\n",
+ " │ 0.0 ┆ 10 ┆ 332212815818606584686857770936… ┆ 0 │\n",
+ " │ 0.0 ┆ 10 ┆ 332212815818606584686857770936… ┆ 0 │\n",
+ " │ 0.0 ┆ 10 ┆ 332212815818606584686857770936… ┆ 0 │\n",
+ " │ 0.0 ┆ 10 ┆ 332212815818606584686857770936… ┆ 0 │\n",
" └────────────────────┴──────┴─────────────────────────────────┴───────┘}"
]
},
- "execution_count": 19,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -120,8 +120,8 @@
" self.dc = DataCollector(\n",
" model=self,\n",
" model_reporters={\n",
- " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n",
- " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n",
+ " \"total_wealth\": lambda m: m.sets[\"MoneyAgents\"].df[\"wealth\"].sum(),\n",
+ " \"n_agents\": lambda m: len(m.sets[\"MoneyAgents\"]),\n",
" },\n",
" agent_reporters={\n",
" \"wealth\": \"wealth\", # pull existing column\n",
@@ -175,7 +175,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": null,
"id": "5f14f38c",
"metadata": {},
"outputs": [
@@ -185,7 +185,7 @@
"[]"
]
},
- "execution_count": 20,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -198,8 +198,8 @@
"model_csv.dc = DataCollector(\n",
" model=model_csv,\n",
" model_reporters={\n",
- " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n",
- " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n",
+ " \"total_wealth\": lambda m: m.sets[\"MoneyAgents\"].df[\"wealth\"].sum(),\n",
+ " \"n_agents\": lambda m: len(m.sets[\"MoneyAgents\"]),\n",
" },\n",
" agent_reporters={\n",
" \"wealth\": \"wealth\",\n",
@@ -226,7 +226,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"id": "8763a12b2bbd4a93a75aff182afb95dc",
"metadata": {
"editable": true
@@ -238,7 +238,7 @@
"[]"
]
},
- "execution_count": 21,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@@ -249,8 +249,8 @@
"model_parq.dc = DataCollector(\n",
" model=model_parq,\n",
" model_reporters={\n",
- " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n",
- " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n",
+ " \"total_wealth\": lambda m: m.sets[\"MoneyAgents\"].df[\"wealth\"].sum(),\n",
+ " \"n_agents\": lambda m: len(m.sets[\"MoneyAgents\"]),\n",
" },\n",
" agent_reporters={\n",
" \"wealth\": \"wealth\",\n",
@@ -279,7 +279,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": null,
"id": "7cdc8c89c7104fffa095e18ddfef8986",
"metadata": {
"editable": true
@@ -290,8 +290,8 @@
"model_s3.dc = DataCollector(\n",
" model=model_s3,\n",
" model_reporters={\n",
- " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n",
- " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n",
+ " \"total_wealth\": lambda m: m.sets[\"MoneyAgents\"].df[\"wealth\"].sum(),\n",
+ " \"n_agents\": lambda m: len(m.sets[\"MoneyAgents\"]),\n",
" },\n",
" agent_reporters={\n",
" \"wealth\": \"wealth\",\n",
@@ -319,7 +319,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 11,
"id": "938c804e27f84196a10c8828c723f798",
"metadata": {
"editable": true
@@ -381,7 +381,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 12,
"id": "59bbdb311c014d738909a11f9e486628",
"metadata": {
"editable": true
@@ -410,7 +410,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 13,
"id": "8a65eabff63a45729fe45fb5ade58bdc",
"metadata": {
"editable": true
@@ -426,7 +426,7 @@
" white-space: pre-wrap;\n",
"}\n",
"\n",
- "shape: (5, 5)
step | seed | batch | total_wealth | n_agents |
---|
i64 | str | i64 | f64 | i64 |
2 | "732054881101029867447298951813… | 0 | 100.0 | 100 |
4 | "732054881101029867447298951813… | 0 | 100.0 | 100 |
6 | "732054881101029867447298951813… | 0 | 100.0 | 100 |
8 | "732054881101029867447298951813… | 0 | 100.0 | 100 |
10 | "732054881101029867447298951813… | 0 | 100.0 | 100 |
"
+ "shape: (5, 5)step | seed | batch | total_wealth | n_agents |
---|
i64 | str | i64 | f64 | i64 |
2 | "540832786058427425452319829502… | 0 | 100.0 | 100 |
4 | "540832786058427425452319829502… | 0 | 100.0 | 100 |
6 | "540832786058427425452319829502… | 0 | 100.0 | 100 |
8 | "540832786058427425452319829502… | 0 | 100.0 | 100 |
10 | "540832786058427425452319829502… | 0 | 100.0 | 100 |
"
],
"text/plain": [
"shape: (5, 5)\n",
@@ -435,15 +435,15 @@
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ i64 ┆ str ┆ i64 ┆ f64 ┆ i64 │\n",
"╞══════╪═════════════════════════════════╪═══════╪══════════════╪══════════╡\n",
- "│ 2 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n",
- "│ 4 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n",
- "│ 6 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n",
- "│ 8 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n",
- "│ 10 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n",
+ "│ 2 ┆ 540832786058427425452319829502… ┆ 0 ┆ 100.0 ┆ 100 │\n",
+ "│ 4 ┆ 540832786058427425452319829502… ┆ 0 ┆ 100.0 ┆ 100 │\n",
+ "│ 6 ┆ 540832786058427425452319829502… ┆ 0 ┆ 100.0 ┆ 100 │\n",
+ "│ 8 ┆ 540832786058427425452319829502… ┆ 0 ┆ 100.0 ┆ 100 │\n",
+ "│ 10 ┆ 540832786058427425452319829502… ┆ 0 ┆ 100.0 ┆ 100 │\n",
"└──────┴─────────────────────────────────┴───────┴──────────────┴──────────┘"
]
},
- "execution_count": 25,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
diff --git a/examples/sugarscape_ig/ss_polars/agents.py b/examples/sugarscape_ig/ss_polars/agents.py
index b0ecbe90..32ca91f5 100644
--- a/examples/sugarscape_ig/ss_polars/agents.py
+++ b/examples/sugarscape_ig/ss_polars/agents.py
@@ -35,12 +35,15 @@ def __init__(
self.add(agents)
def eat(self):
+ # Only consider cells currently occupied by agents of this set
cells = self.space.cells.filter(pl.col("agent_id").is_not_null())
- self[cells["agent_id"], "sugar"] = (
- self[cells["agent_id"], "sugar"]
- + cells["sugar"]
- - self[cells["agent_id"], "metabolism"]
- )
+ mask_in_set = cells["agent_id"].is_in(self.index)
+ if mask_in_set.any():
+ cells = cells.filter(mask_in_set)
+ ids = cells["agent_id"]
+ self[ids, "sugar"] = (
+ self[ids, "sugar"] + cells["sugar"] - self[ids, "metabolism"]
+ )
def step(self):
self.shuffle().do("move").do("eat")
diff --git a/examples/sugarscape_ig/ss_polars/model.py b/examples/sugarscape_ig/ss_polars/model.py
index 56a3a83b..36b2718e 100644
--- a/examples/sugarscape_ig/ss_polars/model.py
+++ b/examples/sugarscape_ig/ss_polars/model.py
@@ -33,7 +33,10 @@ def __init__(
sugar=sugar_grid.flatten(), max_sugar=sugar_grid.flatten()
)
self.space.set_cells(sugar_grid)
- self.sets += agent_type(self, n_agents, initial_sugar, metabolism, vision)
+ # Create and register the main agent set; keep its name for later lookups
+ main_set = agent_type(self, n_agents, initial_sugar, metabolism, vision)
+ self.sets += main_set
+ self._main_set_name = main_set.name
if initial_positions is not None:
self.space.place_agents(self.sets, initial_positions)
else:
@@ -41,7 +44,8 @@ def __init__(
def run_model(self, steps: int) -> list[int]:
for _ in range(steps):
- if len(list(self.sets.df.values())[0]) == 0:
+ # Stop if the main agent set is empty
+ if len(self.sets[self._main_set_name]) == 0: # type: ignore[index]
return
empty_cells = self.space.empty_cells
full_cells = self.space.full_cells
diff --git a/mesa_frames/abstract/agentset.py b/mesa_frames/abstract/agentset.py
index a7da9097..ae5db2db 100644
--- a/mesa_frames/abstract/agentset.py
+++ b/mesa_frames/abstract/agentset.py
@@ -20,10 +20,12 @@
from abc import abstractmethod
from collections.abc import Collection, Iterable, Iterator
+from contextlib import suppress
from typing import Any, Literal, Self, overload
-from mesa_frames.abstract.agentsetregistry import AbstractAgentSetRegistry
-from mesa_frames.abstract.mixin import DataFrameMixin
+from numpy.random import Generator
+
+from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin
from mesa_frames.types_ import (
AgentMask,
BoolSeries,
@@ -35,7 +37,7 @@
)
-class AbstractAgentSet(AbstractAgentSetRegistry, DataFrameMixin):
+class AbstractAgentSet(CopyMixin, DataFrameMixin):
"""The AbstractAgentSet class is a container for agents of the same type.
Parameters
@@ -44,6 +46,7 @@ class AbstractAgentSet(AbstractAgentSetRegistry, DataFrameMixin):
The model that the agent set belongs to.
"""
+ _copy_only_reference: list[str] = ["_model"]
_df: DataFrame # The agents in the AbstractAgentSet
_mask: AgentMask # The underlying mask used for the active agents in the AbstractAgentSet.
_model: (
@@ -75,7 +78,31 @@ def add(
Returns
-------
Self
- A new AbstractAgentSetRegistry with the added agents.
+ A new AbstractAgentSet with the added agents.
+ """
+ ...
+
+ @overload
+ @abstractmethod
+ def contains(self, agents: int) -> bool: ...
+
+ @overload
+ @abstractmethod
+ def contains(self, agents: IdsLike) -> BoolSeries: ...
+
+ @abstractmethod
+ def contains(self, agents: IdsLike) -> bool | BoolSeries:
+ """Check if agents with the specified IDs are in the AgentSet.
+
+ Parameters
+ ----------
+ agents : IdsLike
+ The ID(s) to check for.
+
+ Returns
+ -------
+ bool | BoolSeries
+ True if the agent is in the AgentSet, False otherwise.
"""
...
@@ -94,65 +121,85 @@ def discard(self, agents: IdsLike | AgentMask, inplace: bool = True) -> Self:
Self
The updated AbstractAgentSet.
"""
- return super().discard(agents, inplace)
+ with suppress(KeyError, ValueError):
+ return self.remove(agents, inplace=inplace)
+ return self._get_obj(inplace)
+
+ @abstractmethod
+ def remove(self, agents: IdsLike | AgentMask, inplace: bool = True) -> Self:
+ """Remove agents from this AbstractAgentSet.
+
+ Parameters
+ ----------
+ agents : IdsLike | AgentMask
+ The agents or mask to remove.
+ inplace : bool, optional
+ Whether to remove in place, by default True.
+
+ Returns
+ -------
+ Self
+ The updated agent set.
+ """
+ ...
@overload
+ @abstractmethod
def do(
self,
method_name: str,
- *args,
+ *args: Any,
mask: AgentMask | None = None,
return_results: Literal[False] = False,
inplace: bool = True,
- **kwargs,
+ **kwargs: Any,
) -> Self: ...
@overload
+ @abstractmethod
def do(
self,
method_name: str,
- *args,
+ *args: Any,
mask: AgentMask | None = None,
return_results: Literal[True],
inplace: bool = True,
- **kwargs,
+ **kwargs: Any,
) -> Any: ...
+ @abstractmethod
def do(
self,
method_name: str,
- *args,
+ *args: Any,
mask: AgentMask | None = None,
return_results: bool = False,
inplace: bool = True,
- **kwargs,
+ **kwargs: Any,
) -> Self | Any:
- masked_df = self._get_masked_df(mask)
- # If the mask is empty, we can use the object as is
- if len(masked_df) == len(self._df):
- obj = self._get_obj(inplace)
- method = getattr(obj, method_name)
- result = method(*args, **kwargs)
- else: # If the mask is not empty, we need to create a new masked AbstractAgentSet and concatenate the AbstractAgentSets at the end
- obj = self._get_obj(inplace=False)
- obj._df = masked_df
- original_masked_index = obj._get_obj_copy(obj.index)
- method = getattr(obj, method_name)
- result = method(*args, **kwargs)
- obj._concatenate_agentsets(
- [self],
- duplicates_allowed=True,
- keep_first_only=True,
- original_masked_index=original_masked_index,
- )
- if inplace:
- for key, value in obj.__dict__.items():
- setattr(self, key, value)
- obj = self
- if return_results:
- return result
- else:
- return obj
+ """Invoke a method on the AgentSet.
+
+ Parameters
+ ----------
+ method_name : str
+ The name of the method to invoke.
+ *args : Any
+ Positional arguments to pass to the method
+ mask : AgentMask | None, optional
+ The subset of agents on which to apply the method
+ return_results : bool, optional
+ Whether to return the result of the method, by default False
+ inplace : bool, optional
+ Whether the operation should be done inplace, by default False
+ **kwargs : Any
+ Keyword arguments to pass to the method
+
+ Returns
+ -------
+ Self | Any
+ The updated AgentSet or the result of the method.
+ """
+ ...
@abstractmethod
@overload
@@ -182,20 +229,6 @@ def step(self) -> None:
"""Run a single step of the AbstractAgentSet. This method should be overridden by subclasses."""
...
- def remove(self, agents: IdsLike | AgentMask, inplace: bool = True) -> Self:
- if isinstance(agents, str) and agents == "active":
- agents = self.active_agents
- if agents is None or (isinstance(agents, Iterable) and len(agents) == 0):
- return self._get_obj(inplace)
- agents = self._df_index(self._get_masked_df(agents), "unique_id")
- sets = self.model.sets.remove(agents, inplace=inplace)
- # TODO: Refactor AgentSetRegistry to return dict[str, AbstractAgentSet] instead of dict[AbstractAgentSet, DataFrame]
- # And assign a name to AbstractAgentSet? This has to be replaced by a nicer API of AgentSetRegistry
- for agentset in sets.df.keys():
- if isinstance(agentset, self.__class__):
- return agentset
- return self
-
@abstractmethod
def _concatenate_agentsets(
self,
@@ -285,9 +318,9 @@ def __add__(self, other: DataFrame | DataFrameInput) -> Self:
Returns
-------
Self
- A new AbstractAgentSetRegistry with the added agents.
+ A new AbstractAgentSet with the added agents.
"""
- return super().__add__(other)
+ return self.add(other, inplace=False)
def __iadd__(self, other: DataFrame | DataFrameInput) -> Self:
"""
@@ -305,9 +338,17 @@ def __iadd__(self, other: DataFrame | DataFrameInput) -> Self:
Returns
-------
Self
- The updated AbstractAgentSetRegistry.
+ The updated AbstractAgentSet.
"""
- return super().__iadd__(other)
+ return self.add(other, inplace=True)
+
+ def __isub__(self, other: IdsLike | AgentMask | DataFrame) -> Self:
+ """Remove agents via -= operator."""
+ return self.discard(other, inplace=True)
+
+ def __sub__(self, other: IdsLike | AgentMask | DataFrame) -> Self:
+ """Return a new set with agents removed via - operator."""
+ return self.discard(other, inplace=False)
@abstractmethod
def __getattr__(self, name: str) -> Any:
@@ -336,9 +377,20 @@ def __getitem__(
| tuple[AgentMask, Collection[str]]
),
) -> Series | DataFrame:
- attr = super().__getitem__(key)
- assert isinstance(attr, (Series, DataFrame, Index))
- return attr
+ # Mirror registry/old container behavior: delegate to get()
+ if isinstance(key, tuple):
+ return self.get(mask=key[0], attr_names=key[1])
+ else:
+ if isinstance(key, str) or (
+ isinstance(key, Collection) and all(isinstance(k, str) for k in key)
+ ):
+ return self.get(attr_names=key)
+ else:
+ return self.get(mask=key)
+
+ def __contains__(self, agents: int) -> bool:
+ """Membership test for an agent id in this set."""
+ return bool(self.contains(agents))
def __len__(self) -> int:
return len(self._df)
@@ -376,6 +428,7 @@ def active_agents(self) -> DataFrame: ...
def inactive_agents(self) -> DataFrame: ...
@property
+ @abstractmethod
def index(self) -> Index: ...
@property
@@ -391,3 +444,101 @@ def pos(self) -> DataFrame:
pos, self.index, new_index_cols="unique_id", original_index_cols="agent_id"
)
return pos
+
+ @property
+ def name(self) -> str:
+ """The name of the agent set.
+
+ Returns
+ -------
+ str
+ The name of the agent set
+ """
+ return self._name
+
+ @property
+ def model(self) -> mesa_frames.concrete.model.Model:
+ return self._model
+
+ @property
+ def random(self) -> Generator:
+ return self.model.random
+
+ @property
+ def space(self) -> mesa_frames.abstract.space.Space | None:
+ return self.model.space
+
+ @abstractmethod
+ def rename(self, new_name: str, inplace: bool = True) -> Self:
+ """Rename this AgentSet.
+
+ Concrete subclasses must implement the mechanics for coordinating with
+ any containing registry and managing ``inplace`` semantics. The method
+ should update the set's name (or return a renamed copy when
+ ``inplace=False``) while preserving registry invariants.
+
+ Parameters
+ ----------
+ new_name : str
+ Desired new name for this AgentSet.
+ inplace : bool, optional
+ Whether to perform the rename in place. If False, a renamed copy is
+ returned, by default True.
+
+ Returns
+ -------
+ Self
+ The updated AgentSet (or a renamed copy when ``inplace=False``).
+ """
+ ...
+
+ @abstractmethod
+ def set(
+ self,
+ attr_names: str | Collection[str] | dict[str, Any] | None = None,
+ values: Any | None = None,
+ mask: AgentMask | None = None,
+ inplace: bool = True,
+ ) -> Self:
+ """Update agent attributes, optionally on a masked subset.
+
+ Parameters
+ ----------
+ attr_names : str | Collection[str] | dict[str, Any] | None, optional
+ Attribute(s) to assign. When ``None``, concrete implementations may
+ derive targets from ``values``.
+ values : Any | None, optional
+ Replacement value(s) aligned with ``attr_names``.
+ mask : AgentMask | None, optional
+ Subset selector limiting which agents are updated.
+ inplace : bool, optional
+ Whether to mutate in place or return an updated copy, by default True.
+
+ Returns
+ -------
+ Self
+ The updated AgentSet (or a modified copy when ``inplace=False``).
+ """
+ ...
+
+ def __setitem__(
+ self,
+ key: str
+ | Collection[str]
+ | AgentMask
+ | tuple[AgentMask, str | Collection[str]],
+ values: Any,
+ ) -> None:
+ """Set values using [] syntax, delegating to set()."""
+ if isinstance(key, tuple):
+ self.set(mask=key[0], attr_names=key[1], values=values)
+ else:
+ if isinstance(key, str) or (
+ isinstance(key, Collection) and all(isinstance(k, str) for k in key)
+ ):
+ try:
+ self.set(attr_names=key, values=values)
+ except KeyError: # key may actually be a mask
+ self.set(attr_names=None, mask=key, values=values)
+ else:
+ self.set(attr_names=None, mask=key, values=values)
diff --git a/mesa_frames/abstract/agentsetregistry.py b/mesa_frames/abstract/agentsetregistry.py
index c8fa6c60..abb0ef69 100644
--- a/mesa_frames/abstract/agentsetregistry.py
+++ b/mesa_frames/abstract/agentsetregistry.py
@@ -43,7 +43,7 @@ def __init__(self, model):
from __future__ import annotations # PEP 563: postponed evaluation of type annotations
from abc import abstractmethod
-from collections.abc import Callable, Collection, Iterator, Sequence
+from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from contextlib import suppress
from typing import Any, Literal, Self, overload
@@ -51,12 +51,12 @@ def __init__(self, model):
from mesa_frames.abstract.mixin import CopyMixin
from mesa_frames.types_ import (
- AgentMask,
+ AbstractAgentSetSelector as AgentSetSelector,
+)
+from mesa_frames.types_ import (
BoolSeries,
- DataFrame,
- DataFrameInput,
- IdsLike,
Index,
+ KeyBy,
Series,
)
@@ -74,20 +74,17 @@ def __init__(self) -> None: ...
def discard(
self,
- agents: IdsLike
- | AgentMask
- | mesa_frames.abstract.agentset.AbstractAgentSet
- | Collection[mesa_frames.abstract.agentset.AbstractAgentSet],
+ sets: AgentSetSelector,
inplace: bool = True,
) -> Self:
- """Remove agents from the AbstractAgentSetRegistry. Does not raise an error if the agent is not found.
+ """Remove AgentSets selected by ``sets``. Ignores missing.
Parameters
----------
- agents : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
- The agents to remove
+ sets : AgentSetSelector
+ Which AgentSets to remove (instance, type, name, or collection thereof).
inplace : bool
- Whether to remove the agent in place. Defaults to True.
+ Whether to remove in place. Defaults to True.
Returns
-------
@@ -95,26 +92,70 @@ def discard(
The updated AbstractAgentSetRegistry.
"""
with suppress(KeyError, ValueError):
- return self.remove(agents, inplace=inplace)
+ return self.remove(sets, inplace=inplace)
return self._get_obj(inplace)
+ @abstractmethod
+ def rename(
+ self,
+ target: (
+ mesa_frames.abstract.agentset.AbstractAgentSet
+ | str
+ | dict[mesa_frames.abstract.agentset.AbstractAgentSet | str, str]
+ | list[tuple[mesa_frames.abstract.agentset.AbstractAgentSet | str, str]]
+ ),
+ new_name: str | None = None,
+ *,
+ on_conflict: Literal["canonicalize", "raise"] = "canonicalize",
+ mode: Literal["atomic", "best_effort"] = "atomic",
+ inplace: bool = True,
+ ) -> Self:
+ """Rename AgentSets in this registry, handling conflicts.
+
+ Parameters
+ ----------
+ target : mesa_frames.abstract.agentset.AbstractAgentSet | str | dict[mesa_frames.abstract.agentset.AbstractAgentSet | str, str] | list[tuple[mesa_frames.abstract.agentset.AbstractAgentSet | str, str]]
+ Single target (instance or existing name) with ``new_name`` provided,
+ or a mapping/sequence of (target, new_name) pairs for batch rename.
+ new_name : str | None
+ New name for single-target rename.
+ on_conflict : Literal["canonicalize", "raise"]
+ When a desired name collides, either canonicalize by appending a
+ numeric suffix (default) or raise ``ValueError``.
+ mode : Literal["atomic", "best_effort"]
+ In "atomic" mode, validate all renames before applying any. In
+ "best_effort" mode, apply what can be applied and skip failures.
+
+ Returns
+ -------
+ Self
+ Updated registry (or a renamed copy when ``inplace=False``).
+
+ Parameters
+ ----------
+ inplace : bool, optional
+ Whether to perform the rename in place. If False, a renamed copy is
+ returned, by default True.
+ """
+ ...
+
@abstractmethod
def add(
self,
- agents: DataFrame
- | DataFrameInput
- | mesa_frames.abstract.agentset.AbstractAgentSet
- | Collection[mesa_frames.abstract.agentset.AbstractAgentSet],
+ sets: (
+ mesa_frames.abstract.agentset.AbstractAgentSet
+ | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
+ ),
inplace: bool = True,
) -> Self:
- """Add agents to the AbstractAgentSetRegistry.
+ """Add AgentSets to the AbstractAgentSetRegistry.
Parameters
----------
- agents : DataFrame | DataFrameInput | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
- The agents to add.
+ sets : mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
+ The AgentSet(s) to add.
inplace : bool
- Whether to add the agents in place. Defaults to True.
+ Whether to add in place. Defaults to True.
Returns
-------
@@ -125,29 +166,40 @@ def add(
@overload
@abstractmethod
- def contains(self, agents: int) -> bool: ...
+ def contains(
+ self,
+ sets: (
+ mesa_frames.abstract.agentset.AbstractAgentSet
+ | type[mesa_frames.abstract.agentset.AbstractAgentSet]
+ | str
+ ),
+ ) -> bool: ...
@overload
@abstractmethod
def contains(
- self, agents: mesa_frames.abstract.agentset.AbstractAgentSet | IdsLike
+ self,
+ sets: Collection[
+ mesa_frames.abstract.agentset.AbstractAgentSet
+ | type[mesa_frames.abstract.agentset.AbstractAgentSet]
+ | str
+ ],
) -> BoolSeries: ...
@abstractmethod
- def contains(
- self, agents: mesa_frames.abstract.agentset.AbstractAgentSet | IdsLike
- ) -> bool | BoolSeries:
- """Check if agents with the specified IDs are in the AbstractAgentSetRegistry.
+ def contains(self, sets: AgentSetSelector) -> bool | BoolSeries:
+ """Check if selected AgentSets are present in the registry.
Parameters
----------
- agents : mesa_frames.abstract.agentset.AbstractAgentSet | IdsLike
- The ID(s) to check for.
+ sets : AgentSetSelector
+ An AgentSet instance, class/type, name string, or a collection of
+ those. For collections, returns a BoolSeries aligned with input order.
Returns
-------
bool | BoolSeries
- True if the agent is in the AbstractAgentSetRegistry, False otherwise.
+ Boolean for single selector values; BoolSeries for collections.
"""
@overload
@@ -156,9 +208,10 @@ def do(
self,
method_name: str,
*args: Any,
- mask: AgentMask | None = None,
+ sets: AgentSetSelector | None = None,
return_results: Literal[False] = False,
inplace: bool = True,
+ key_by: KeyBy = "name",
**kwargs: Any,
) -> Self: ...
@@ -168,22 +221,35 @@ def do(
self,
method_name: str,
*args: Any,
- mask: AgentMask | None = None,
+ sets: AgentSetSelector,
return_results: Literal[True],
inplace: bool = True,
+ key_by: KeyBy = "name",
**kwargs: Any,
- ) -> Any | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Any]: ...
+ ) -> (
+ Any
+ | dict[str, Any]
+ | dict[int, Any]
+ | dict[type[mesa_frames.abstract.agentset.AbstractAgentSet], Any]
+ ): ...
@abstractmethod
def do(
self,
method_name: str,
*args: Any,
- mask: AgentMask | None = None,
+ sets: AgentSetSelector = None,
return_results: bool = False,
inplace: bool = True,
+ key_by: KeyBy = "name",
**kwargs: Any,
- ) -> Self | Any | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Any]:
+ ) -> (
+ Self
+ | Any
+ | dict[str, Any]
+ | dict[int, Any]
+ | dict[type[mesa_frames.abstract.agentset.AbstractAgentSet], Any]
+ ):
"""Invoke a method on the AbstractAgentSetRegistry.
Parameters
@@ -192,71 +258,88 @@ def do(
The name of the method to invoke.
*args : Any
Positional arguments to pass to the method
- mask : AgentMask | None, optional
- The subset of agents on which to apply the method
+ sets : AgentSetSelector, optional
+ Which AgentSets to target (instance, type, name, or collection thereof). Defaults to all.
return_results : bool, optional
- Whether to return the result of the method, by default False
+ Whether to return per-set results as a dictionary, by default False.
inplace : bool, optional
- Whether the operation should be done inplace, by default False
+ Whether the operation should be done inplace, by default True
+ key_by : KeyBy, optional
+ Key domain for the returned mapping when ``return_results`` is True.
+ - "name" (default) → keys are set names (str)
+ - "index" → keys are positional indices (int)
+ - "type" → keys are concrete set classes (type)
**kwargs : Any
Keyword arguments to pass to the method
Returns
-------
- Self | Any | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Any]
- The updated AbstractAgentSetRegistry or the result of the method.
+ Self | Any | dict[str, Any] | dict[int, Any] | dict[type[mesa_frames.abstract.agentset.AbstractAgentSet], Any]
+ The updated registry, or the method result(s). When ``return_results``
+ is True, returns a dictionary keyed per ``key_by``.
"""
...
- @abstractmethod
@overload
- def get(self, attr_names: str) -> Series | dict[str, Series]: ...
-
@abstractmethod
+ def get(
+ self, key: int, default: None = ...
+ ) -> mesa_frames.abstract.agentset.AbstractAgentSet | None: ...
+
@overload
+ @abstractmethod
def get(
- self, attr_names: Collection[str] | None = None
- ) -> DataFrame | dict[str, DataFrame]: ...
+ self, key: str, default: None = ...
+ ) -> mesa_frames.abstract.agentset.AbstractAgentSet | None: ...
+ @overload
@abstractmethod
def get(
self,
- attr_names: str | Collection[str] | None = None,
- mask: AgentMask | None = None,
- ) -> Series | dict[str, Series] | DataFrame | dict[str, DataFrame]:
- """Retrieve the value of a specified attribute for each agent in the AbstractAgentSetRegistry.
+ key: type[mesa_frames.abstract.agentset.AbstractAgentSet],
+ default: None = ...,
+ ) -> list[mesa_frames.abstract.agentset.AbstractAgentSet]: ...
- Parameters
- ----------
- attr_names : str | Collection[str] | None, optional
- The attributes to retrieve. If None, all attributes are retrieved. Defaults to None.
- mask : AgentMask | None, optional
- The AgentMask of agents to retrieve the attribute for. If None, attributes of all agents are returned. Defaults to None.
+ @overload
+ @abstractmethod
+ def get(
+ self,
+ key: int | str | type[mesa_frames.abstract.agentset.AbstractAgentSet],
+ default: mesa_frames.abstract.agentset.AbstractAgentSet
+ | list[mesa_frames.abstract.agentset.AbstractAgentSet]
+ | None,
+ ) -> (
+ mesa_frames.abstract.agentset.AbstractAgentSet
+ | list[mesa_frames.abstract.agentset.AbstractAgentSet]
+ | None
+ ): ...
- Returns
- -------
- Series | dict[str, Series] | DataFrame | dict[str, DataFrame]
- The attribute values.
- """
- ...
+ @abstractmethod
+ def get(
+ self,
+ key: int | str | type[mesa_frames.abstract.agentset.AbstractAgentSet],
+ default: mesa_frames.abstract.agentset.AbstractAgentSet
+ | list[mesa_frames.abstract.agentset.AbstractAgentSet]
+ | None = None,
+ ) -> (
+ mesa_frames.abstract.agentset.AbstractAgentSet
+ | list[mesa_frames.abstract.agentset.AbstractAgentSet]
+ | None
+ ):
+ """Safe lookup for AgentSet(s) by index, name, or type."""
@abstractmethod
def remove(
self,
- agents: (
- IdsLike
- | AgentMask
- | mesa_frames.abstract.agentset.AbstractAgentSet
- | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
- ),
+ sets: AgentSetSelector,
inplace: bool = True,
) -> Self:
- """Remove the agents from the AbstractAgentSetRegistry.
+ """Remove AgentSets from the AbstractAgentSetRegistry.
Parameters
----------
- agents : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
- The agents to remove.
+ sets : AgentSetSelector
+ Which AgentSets to remove (instance, type, name, or collection thereof).
inplace : bool, optional
Whether to remove the agent in place.
@@ -267,96 +350,46 @@ def remove(
"""
...
- @abstractmethod
- def select(
- self,
- mask: AgentMask | None = None,
- filter_func: Callable[[Self], AgentMask] | None = None,
- n: int | None = None,
- negate: bool = False,
- inplace: bool = True,
- ) -> Self:
- """Select agents in the AbstractAgentSetRegistry based on the given criteria.
-
- Parameters
- ----------
- mask : AgentMask | None, optional
- The AgentMask of agents to be selected, by default None
- filter_func : Callable[[Self], AgentMask] | None, optional
- A function which takes as input the AbstractAgentSetRegistry and returns a AgentMask, by default None
- n : int | None, optional
- The maximum number of agents to be selected, by default None
- negate : bool, optional
- If the selection should be negated, by default False
- inplace : bool, optional
- If the operation should be performed on the same object, by default True
-
- Returns
- -------
- Self
- A new or updated AbstractAgentSetRegistry.
- """
- ...
-
- @abstractmethod
- @overload
- def set(
- self,
- attr_names: dict[str, Any],
- values: None,
- mask: AgentMask | None = None,
- inplace: bool = True,
- ) -> Self: ...
+ # select() intentionally removed from the abstract API.
@abstractmethod
- @overload
- def set(
+ def replace(
self,
- attr_names: str | Collection[str],
- values: Any,
- mask: AgentMask | None = None,
- inplace: bool = True,
- ) -> Self: ...
-
- @abstractmethod
- def set(
- self,
- attr_names: DataFrameInput | str | Collection[str],
- values: Any | None = None,
- mask: AgentMask | None = None,
+ mapping: (
+ dict[int | str, mesa_frames.abstract.agentset.AbstractAgentSet]
+ | list[tuple[int | str, mesa_frames.abstract.agentset.AbstractAgentSet]]
+ ),
+ *,
inplace: bool = True,
+ atomic: bool = True,
) -> Self:
- """Set the value of a specified attribute or attributes for each agent in the mask in AbstractAgentSetRegistry.
+ """Batch assign/replace AgentSets by index or name.
Parameters
----------
- attr_names : DataFrameInput | str | Collection[str]
- The key can be:
- - A string: sets the specified column of the agents in the AbstractAgentSetRegistry.
- - A collection of strings: sets the specified columns of the agents in the AbstractAgentSetRegistry.
- - A dictionary: keys should be attributes and values should be the values to set. Value should be None.
- values : Any | None
- The value to set the attribute to. If None, attr_names must be a dictionary.
- mask : AgentMask | None
- The AgentMask of agents to set the attribute for.
- inplace : bool
- Whether to set the attribute in place.
+ mapping : dict[int | str, mesa_frames.abstract.agentset.AbstractAgentSet] | list[tuple[int | str, mesa_frames.abstract.agentset.AbstractAgentSet]]
+ Keys are indices or names to assign; values are AgentSets bound to the same model.
+ inplace : bool, optional
+ Whether to apply on this registry or return a copy, by default True.
+ atomic : bool, optional
+ When True, validates all keys and name invariants before applying any
+ change; either all assignments succeed or none are applied.
Returns
-------
Self
- The updated agent set.
+ Updated registry.
"""
...
@abstractmethod
def shuffle(self, inplace: bool = False) -> Self:
- """Shuffles the order of agents in the AbstractAgentSetRegistry.
+ """Shuffle the order of AgentSets in the registry.
Parameters
----------
inplace : bool
- Whether to shuffle the agents in place.
+ Whether to shuffle in place.
Returns
-------
@@ -373,7 +406,7 @@ def sort(
**kwargs,
) -> Self:
"""
- Sorts the agents in the agent set based on the given criteria.
+ Sort the AgentSets in the registry based on the given criteria.
Parameters
----------
@@ -394,145 +427,75 @@ def sort(
def __add__(
self,
- other: DataFrame
- | DataFrameInput
- | mesa_frames.abstract.agentset.AbstractAgentSet
+ other: mesa_frames.abstract.agentset.AbstractAgentSet
| Collection[mesa_frames.abstract.agentset.AbstractAgentSet],
) -> Self:
- """Add agents to a new AbstractAgentSetRegistry through the + operator.
-
- Parameters
- ----------
- other : DataFrame | DataFrameInput | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
- The agents to add.
-
- Returns
- -------
- Self
- A new AbstractAgentSetRegistry with the added agents.
- """
- return self.add(agents=other, inplace=False)
+ """Add AgentSets to a new AbstractAgentSetRegistry through the + operator."""
+ return self.add(sets=other, inplace=False)
def __contains__(
- self, agents: int | mesa_frames.abstract.agentset.AbstractAgentSet
+ self, sets: mesa_frames.abstract.agentset.AbstractAgentSet
) -> bool:
- """Check if an agent is in the AbstractAgentSetRegistry.
-
- Parameters
- ----------
- agents : int | mesa_frames.abstract.agentset.AbstractAgentSet
- The ID(s) or AbstractAgentSet to check for.
+ """Check if an AgentSet is in the AbstractAgentSetRegistry."""
+ return bool(self.contains(sets=sets))
- Returns
- -------
- bool
- True if the agent is in the AbstractAgentSetRegistry, False otherwise.
- """
- return self.contains(agents=agents)
+ @overload
+ def __getitem__(
+ self, key: int
+ ) -> mesa_frames.abstract.agentset.AbstractAgentSet: ...
@overload
def __getitem__(
- self, key: str | tuple[AgentMask, str]
- ) -> Series | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Series]: ...
+ self, key: str
+ ) -> mesa_frames.abstract.agentset.AbstractAgentSet: ...
@overload
def __getitem__(
- self,
- key: AgentMask | Collection[str] | tuple[AgentMask, Collection[str]],
- ) -> (
- DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]
- ): ...
+ self, key: type[mesa_frames.abstract.agentset.AbstractAgentSet]
+ ) -> list[mesa_frames.abstract.agentset.AbstractAgentSet]: ...
def __getitem__(
- self,
- key: (
- str
- | Collection[str]
- | AgentMask
- | tuple[AgentMask, str]
- | tuple[AgentMask, Collection[str]]
- | tuple[
- dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str
- ]
- | tuple[
- dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask],
- Collection[str],
- ]
- ),
+ self, key: int | str | type[mesa_frames.abstract.agentset.AbstractAgentSet]
) -> (
- Series
- | DataFrame
- | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Series]
- | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]
+ mesa_frames.abstract.agentset.AbstractAgentSet
+ | list[mesa_frames.abstract.agentset.AbstractAgentSet]
):
- """Implement the [] operator for the AbstractAgentSetRegistry.
-
- The key can be:
- - An attribute or collection of attributes (eg. AbstractAgentSetRegistry["str"], AbstractAgentSetRegistry[["str1", "str2"]]): returns the specified column(s) of the agents in the AbstractAgentSetRegistry.
- - An AgentMask (eg. AbstractAgentSetRegistry[AgentMask]): returns the agents in the AbstractAgentSetRegistry that satisfy the AgentMask.
- - A tuple (eg. AbstractAgentSetRegistry[AgentMask, "str"]): returns the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask.
- - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, "str"]): returns the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary.
- - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, Collection[str]]): returns the specified columns of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary.
-
- Parameters
- ----------
- key : str | Collection[str] | AgentMask | tuple[AgentMask, str] | tuple[AgentMask, Collection[str]] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], Collection[str]]
- The key to retrieve.
-
- Returns
- -------
- Series | DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Series] | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]
- The attribute values.
- """
- # TODO: fix types
- if isinstance(key, tuple):
- return self.get(mask=key[0], attr_names=key[1])
- else:
- if isinstance(key, str) or (
- isinstance(key, Collection) and all(isinstance(k, str) for k in key)
- ):
- return self.get(attr_names=key)
- else:
- return self.get(mask=key)
+ """Retrieve AgentSet(s) by index, name, or type."""
def __iadd__(
self,
other: (
- DataFrame
- | DataFrameInput
- | mesa_frames.abstract.agentset.AbstractAgentSet
+ mesa_frames.abstract.agentset.AbstractAgentSet
| Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
),
) -> Self:
- """Add agents to the AbstractAgentSetRegistry through the += operator.
+ """Add AgentSets to the registry through the += operator.
Parameters
----------
- other : DataFrame | DataFrameInput | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
- The agents to add.
+ other : mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
+ The AgentSets to add.
Returns
-------
Self
The updated AbstractAgentSetRegistry.
"""
- return self.add(agents=other, inplace=True)
+ return self.add(sets=other, inplace=True)
def __isub__(
self,
other: (
- IdsLike
- | AgentMask
- | mesa_frames.abstract.agentset.AbstractAgentSet
+ mesa_frames.abstract.agentset.AbstractAgentSet
| Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
),
) -> Self:
- """Remove agents from the AbstractAgentSetRegistry through the -= operator.
+ """Remove AgentSets from the registry through the -= operator.
Parameters
----------
- other : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
- The agents to remove.
+ other : mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
+ The AgentSets to remove.
Returns
-------
@@ -544,144 +507,130 @@ def __isub__(
def __sub__(
self,
other: (
- IdsLike
- | AgentMask
- | mesa_frames.abstract.agentset.AbstractAgentSet
+ mesa_frames.abstract.agentset.AbstractAgentSet
| Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
),
) -> Self:
- """Remove agents from a new AbstractAgentSetRegistry through the - operator.
+ """Remove AgentSets from a new registry through the - operator.
Parameters
----------
- other : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
- The agents to remove.
+ other : mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet]
+ The AgentSets to remove.
Returns
-------
Self
- A new AbstractAgentSetRegistry with the removed agents.
+ A new AbstractAgentSetRegistry with the removed AgentSets.
"""
return self.discard(other, inplace=False)
def __setitem__(
self,
- key: (
- str
- | Collection[str]
- | AgentMask
- | tuple[AgentMask, str | Collection[str]]
- | tuple[
- dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str
- ]
- | tuple[
- dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask],
- Collection[str],
- ]
- ),
- values: Any,
+ key: int | str,
+ value: mesa_frames.abstract.agentset.AbstractAgentSet,
) -> None:
- """Implement the [] operator for setting values in the AbstractAgentSetRegistry.
-
- The key can be:
- - A string (eg. AbstractAgentSetRegistry["str"]): sets the specified column of the agents in the AbstractAgentSetRegistry.
- - A list of strings(eg. AbstractAgentSetRegistry[["str1", "str2"]]): sets the specified columns of the agents in the AbstractAgentSetRegistry.
- - A tuple (eg. AbstractAgentSetRegistry[AgentMask, "str"]): sets the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask.
- - A AgentMask (eg. AbstractAgentSetRegistry[AgentMask]): sets the attributes of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask.
- - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, "str"]): sets the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary.
- - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, Collection[str]]): sets the specified columns of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary.
-
- Parameters
- ----------
- key : str | Collection[str] | AgentMask | tuple[AgentMask, str | Collection[str]] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], Collection[str]]
- The key to set.
- values : Any
- The values to set for the specified key.
- """
- # TODO: fix types as in __getitem__
- if isinstance(key, tuple):
- self.set(mask=key[0], attr_names=key[1], values=values)
- else:
- if isinstance(key, str) or (
- isinstance(key, Collection) and all(isinstance(k, str) for k in key)
- ):
- try:
- self.set(attr_names=key, values=values)
- except KeyError: # key=AgentMask
- self.set(attr_names=None, mask=key, values=values)
- else:
- self.set(attr_names=None, mask=key, values=values)
+ """Assign/replace a single AgentSet at an index or name.
+
+ Mirrors the invariants of ``replace`` for single-key assignment:
+ - Names remain unique across the registry
+ - ``value.model is self.model``
+ - For name keys, the key is authoritative for the assigned set's name
+ - For index keys, collisions on a different entry's name must raise
+ """
+ if value.model is not self.model:
+ raise TypeError("Assigned AgentSet must belong to the same model")
+ if isinstance(key, int):
+ # Delegate to replace() so subclasses centralize invariant handling.
+ self.replace({key: value}, inplace=True, atomic=True)
+ return
+ if isinstance(key, str):
+ for existing in self:
+ if existing.name == key:
+ self.replace({key: value}, inplace=True, atomic=True)
+ return
+ try:
+ value.rename(key, inplace=True)
+ except Exception:
+ if hasattr(value, "_name"):
+ value._name = key # type: ignore[attr-defined]
+ self.add(value, inplace=True)
+ return
+ raise TypeError("Key must be int index or str name")
@abstractmethod
def __getattr__(self, name: str) -> Any | dict[str, Any]:
- """Fallback for retrieving attributes of the AbstractAgentSetRegistry. Retrieve an attribute of the underlying DataFrame(s).
-
- Parameters
- ----------
- name : str
- The name of the attribute to retrieve.
-
- Returns
- -------
- Any | dict[str, Any]
- The attribute value
- """
+ """Fallback for retrieving attributes of the AgentSetRegistry."""
@abstractmethod
- def __iter__(self) -> Iterator[dict[str, Any]]:
- """Iterate over the agents in the AbstractAgentSetRegistry.
-
- Returns
- -------
- Iterator[dict[str, Any]]
- An iterator over the agents.
- """
+ def __iter__(self) -> Iterator[mesa_frames.abstract.agentset.AbstractAgentSet]:
+ """Iterate over AgentSets in the registry."""
...
@abstractmethod
def __len__(self) -> int:
- """Get the number of agents in the AbstractAgentSetRegistry.
-
- Returns
- -------
- int
- The number of agents in the AbstractAgentSetRegistry.
- """
+ """Get the number of AgentSets in the registry."""
...
@abstractmethod
def __repr__(self) -> str:
- """Get a string representation of the DataFrame in the AbstractAgentSetRegistry.
-
- Returns
- -------
- str
- A string representation of the DataFrame in the AbstractAgentSetRegistry.
- """
+ """Get a string representation of the AgentSets in the registry."""
pass
@abstractmethod
def __reversed__(self) -> Iterator:
- """Iterate over the agents in the AbstractAgentSetRegistry in reverse order.
-
- Returns
- -------
- Iterator
- An iterator over the agents in reverse order.
- """
+ """Iterate over AgentSets in reverse order."""
...
@abstractmethod
def __str__(self) -> str:
- """Get a string representation of the agents in the AbstractAgentSetRegistry.
-
- Returns
- -------
- str
- A string representation of the agents in the AbstractAgentSetRegistry.
- """
+ """Get a string representation of the AgentSets in the registry."""
...
+ def keys(
+ self, *, key_by: KeyBy = "name"
+ ) -> Iterable[str | int | type[mesa_frames.abstract.agentset.AbstractAgentSet]]:
+ """Iterate keys for contained AgentSets (by name|index|type)."""
+ if key_by == "index":
+ yield from range(len(self))
+ return
+ if key_by == "type":
+ for agentset in self:
+ yield type(agentset)
+ return
+ if key_by != "name":
+ raise ValueError("key_by must be 'name'|'index'|'type'")
+ for agentset in self:
+ if agentset.name is not None:
+ yield agentset.name
+
+ def items(
+ self, *, key_by: KeyBy = "name"
+ ) -> Iterable[
+ tuple[
+ str | int | type[mesa_frames.abstract.agentset.AbstractAgentSet],
+ mesa_frames.abstract.agentset.AbstractAgentSet,
+ ]
+ ]:
+ """Iterate (key, AgentSet) pairs for contained sets."""
+ if key_by == "index":
+ for idx, agentset in enumerate(self):
+ yield idx, agentset
+ return
+ if key_by == "type":
+ for agentset in self:
+ yield type(agentset), agentset
+ return
+ if key_by != "name":
+ raise ValueError("key_by must be 'name'|'index'|'type'")
+ for agentset in self:
+ if agentset.name is not None:
+ yield agentset.name, agentset
+
+ def values(self) -> Iterable[mesa_frames.abstract.agentset.AbstractAgentSet]:
+ """Iterate contained AgentSets (values view)."""
+ yield from self
+
@property
def model(self) -> mesa_frames.concrete.model.Model:
"""The model that the AbstractAgentSetRegistry belongs to.
@@ -714,85 +663,12 @@ def space(self) -> mesa_frames.abstract.space.Space | None:
@property
@abstractmethod
- def df(self) -> DataFrame | dict[str, DataFrame]:
- """The agents in the AbstractAgentSetRegistry.
-
- Returns
- -------
- DataFrame | dict[str, DataFrame]
- """
-
- @df.setter
- @abstractmethod
- def df(
- self, agents: DataFrame | list[mesa_frames.abstract.agentset.AbstractAgentSet]
- ) -> None:
- """Set the agents in the AbstractAgentSetRegistry.
-
- Parameters
- ----------
- agents : DataFrame | list[mesa_frames.abstract.agentset.AbstractAgentSet]
- """
-
- @property
- @abstractmethod
- def active_agents(self) -> DataFrame | dict[str, DataFrame]:
- """The active agents in the AbstractAgentSetRegistry.
-
- Returns
- -------
- DataFrame | dict[str, DataFrame]
- """
-
- @active_agents.setter
- @abstractmethod
- def active_agents(
- self,
- mask: AgentMask,
- ) -> None:
- """Set the active agents in the AbstractAgentSetRegistry.
-
- Parameters
- ----------
- mask : AgentMask
- The mask to apply.
- """
- self.select(mask=mask, inplace=True)
-
- @property
- @abstractmethod
- def inactive_agents(
- self,
- ) -> DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]:
- """The inactive agents in the AbstractAgentSetRegistry.
+ def ids(self) -> Series:
+ """Public view of all agent unique_id values across contained sets.
Returns
-------
- DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]
- """
-
- @property
- @abstractmethod
- def index(
- self,
- ) -> Index | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Index]:
- """The ids in the AbstractAgentSetRegistry.
-
- Returns
- -------
- Index | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Index]
- """
- ...
-
- @property
- @abstractmethod
- def pos(
- self,
- ) -> DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]:
- """The position of the agents in the AbstractAgentSetRegistry.
-
- Returns
- -------
- DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]
+ Series
+ Concatenated unique_id Series for all AgentSets.
"""
...
diff --git a/mesa_frames/abstract/datacollector.py b/mesa_frames/abstract/datacollector.py
index edbfb11f..6505408f 100644
--- a/mesa_frames/abstract/datacollector.py
+++ b/mesa_frames/abstract/datacollector.py
@@ -91,7 +91,11 @@ def __init__(
model_reporters : dict[str, Callable] | None
Functions to collect data at the model level.
agent_reporters : dict[str, str | Callable] | None
- Attributes or functions to collect data at the agent level.
+ Agent-level reporters. Values may be:
+ - str or list[str]: pull existing columns from each set; columns are suffixed per-set.
+ - Callable[[AbstractAgentSetRegistry], Series | DataFrame | dict[str, Series|DataFrame]]: registry-level, runs once per step.
+ - Callable[[mesa_frames.abstract.agentset.AbstractAgentSet], Series | DataFrame]: set-level, runs once per set.
+ Note: model-level callables are not supported for agent reporters.
trigger : Callable[[Any], bool] | None
A function(model) -> bool that determines whether to collect data.
reset_memory : bool
diff --git a/mesa_frames/abstract/space.py b/mesa_frames/abstract/space.py
index 74df16e8..39abe6bd 100644
--- a/mesa_frames/abstract/space.py
+++ b/mesa_frames/abstract/space.py
@@ -7,14 +7,14 @@
performance and scalability.
Classes:
- SpaceDF(CopyMixin, DataFrameMixin):
+ Space(CopyMixin, DataFrameMixin):
An abstract base class that defines the common interface for all space
classes in mesa-frames. It combines fast copying functionality with
DataFrame operations.
- AbstractDiscreteSpace(SpaceDF):
+ AbstractDiscreteSpace(Space):
An abstract base class for discrete space implementations, such as grids
- and networks. It extends SpaceDF with methods specific to discrete spaces.
+ and networks. It extends Space with methods specific to discrete spaces.
AbstractGrid(AbstractDiscreteSpace):
An abstract base class for grid-based spaces. It inherits from
@@ -52,7 +52,7 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type):
from abc import abstractmethod
from collections.abc import Callable, Collection, Sequence, Sized
from itertools import product
-from typing import Any, Literal, Self
+from typing import Any, Literal, Self, cast
from warnings import warn
import numpy as np
@@ -64,7 +64,6 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type):
AbstractAgentSetRegistry,
)
from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin
-from mesa_frames.concrete.agentsetregistry import AgentSetRegistry
from mesa_frames.types_ import (
ArrayLike,
BoolSeries,
@@ -98,7 +97,7 @@ class Space(CopyMixin, DataFrameMixin):
] # The column names of the positions in the _agents dataframe (eg. ['dim_0', 'dim_1', ...] in Grids, ['node_id', 'edge_id'] in Networks)
def __init__(self, model: mesa_frames.concrete.model.Model) -> None:
- """Create a new SpaceDF.
+ """Create a new Space.
Parameters
----------
@@ -109,7 +108,9 @@ def __init__(self, model: mesa_frames.concrete.model.Model) -> None:
def move_agents(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
pos: SpaceCoordinate | SpaceCoordinates,
inplace: bool = True,
@@ -120,7 +121,7 @@ def move_agents(
Parameters
----------
- agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry]
+ agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry]
The agents to move
pos : SpaceCoordinate | SpaceCoordinates
The coordinates for each agents. The length of the coordinates must match the number of agents.
@@ -145,7 +146,9 @@ def move_agents(
def place_agents(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
pos: SpaceCoordinate | SpaceCoordinates,
inplace: bool = True,
@@ -154,7 +157,7 @@ def place_agents(
Parameters
----------
- agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry]
+ agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry]
The agents to place in the space
pos : SpaceCoordinate | SpaceCoordinates
The coordinates for each agents. The length of the coordinates must match the number of agents.
@@ -198,10 +201,14 @@ def random_agents(
def swap_agents(
self,
agents0: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
agents1: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
inplace: bool = True,
) -> Self:
@@ -211,9 +218,9 @@ def swap_agents(
Parameters
----------
- agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry]
+ agents0 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry]
The first set of agents to swap
- agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry]
+ agents1 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry]
The second set of agents to swap
inplace : bool, optional
Whether to perform the operation inplace, by default True
@@ -222,26 +229,27 @@ def swap_agents(
-------
Self
"""
- agents0 = self._get_ids_srs(agents0)
- agents1 = self._get_ids_srs(agents1)
+ # Normalize inputs to Series of ids for validation and operations
+ ids0 = self._get_ids_srs(agents0)
+ ids1 = self._get_ids_srs(agents1)
if __debug__:
- if len(agents0) != len(agents1):
+ if len(ids0) != len(ids1):
raise ValueError("The two sets of agents must have the same length")
- if not self._df_contains(self._agents, "agent_id", agents0).all():
+ if not self._df_contains(self._agents, "agent_id", ids0).all():
raise ValueError("Some agents in agents0 are not in the space")
- if not self._df_contains(self._agents, "agent_id", agents1).all():
+ if not self._df_contains(self._agents, "agent_id", ids1).all():
raise ValueError("Some agents in agents1 are not in the space")
- if self._srs_contains(agents0, agents1).any():
+ if self._srs_contains(ids0, ids1).any():
raise ValueError("Some agents are present in both agents0 and agents1")
obj = self._get_obj(inplace)
agents0_df = obj._df_get_masked_df(
- obj._agents, index_cols="agent_id", mask=agents0
+ obj._agents, index_cols="agent_id", mask=ids0
)
agents1_df = obj._df_get_masked_df(
- obj._agents, index_cols="agent_id", mask=agents1
+ obj._agents, index_cols="agent_id", mask=ids1
)
- agents0_df = obj._df_set_index(agents0_df, "agent_id", agents1)
- agents1_df = obj._df_set_index(agents1_df, "agent_id", agents0)
+ agents0_df = obj._df_set_index(agents0_df, "agent_id", ids1)
+ agents1_df = obj._df_set_index(agents1_df, "agent_id", ids0)
obj._agents = obj._df_combine_first(
agents0_df, obj._agents, index_cols="agent_id"
)
@@ -257,11 +265,15 @@ def get_directions(
pos0: SpaceCoordinate | SpaceCoordinates | None = None,
pos1: SpaceCoordinate | SpaceCoordinates | None = None,
agents0: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None = None,
agents1: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None = None,
normalize: bool = False,
@@ -278,9 +290,9 @@ def get_directions(
The starting positions
pos1 : SpaceCoordinate | SpaceCoordinates | None, optional
The ending positions
- agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional
+ agents0 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional
The starting agents
- agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional
+ agents1 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional
The ending agents
normalize : bool, optional
Whether to normalize the vectors to unit norm. By default False
@@ -298,11 +310,15 @@ def get_distances(
pos0: SpaceCoordinate | SpaceCoordinates | None = None,
pos1: SpaceCoordinate | SpaceCoordinates | None = None,
agents0: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None = None,
agents1: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None = None,
) -> DataFrame:
@@ -318,9 +334,9 @@ def get_distances(
The starting positions
pos1 : SpaceCoordinate | SpaceCoordinates | None, optional
The ending positions
- agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional
+ agents0 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional
The starting agents
- agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional
+ agents1 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional
The ending agents
Returns
@@ -336,7 +352,9 @@ def get_neighbors(
radius: int | float | Sequence[int] | Sequence[float] | ArrayLike,
pos: SpaceCoordinate | SpaceCoordinates | None = None,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None = None,
include_center: bool = False,
@@ -351,7 +369,7 @@ def get_neighbors(
The radius(es) of the neighborhood
pos : SpaceCoordinate | SpaceCoordinates | None, optional
The coordinates of the cell to get the neighborhood from, by default None
- agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional
+ agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional
The id of the agents to get the neighborhood from, by default None
include_center : bool, optional
If the center cells or agents should be included in the result, by default False
@@ -373,7 +391,9 @@ def get_neighbors(
def move_to_empty(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
inplace: bool = True,
) -> Self:
@@ -381,7 +401,7 @@ def move_to_empty(
Parameters
----------
- agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry]
+ agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry]
The agents to move to empty cells/positions
inplace : bool, optional
Whether to perform the operation inplace, by default True
@@ -396,7 +416,9 @@ def move_to_empty(
def place_to_empty(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
inplace: bool = True,
) -> Self:
@@ -404,7 +426,7 @@ def place_to_empty(
Parameters
----------
- agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry]
+ agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry]
The agents to place in empty cells/positions
inplace : bool, optional
Whether to perform the operation inplace, by default True
@@ -438,7 +460,9 @@ def random_pos(
def remove_agents(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
inplace: bool = True,
) -> Self:
@@ -448,7 +472,7 @@ def remove_agents(
Parameters
----------
- agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry]
+ agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry]
The agents to remove from the space
inplace : bool, optional
Whether to perform the operation inplace, by default True
@@ -467,7 +491,9 @@ def remove_agents(
def _get_ids_srs(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
) -> Series:
if isinstance(agents, Sized) and len(agents) == 0:
@@ -478,10 +504,11 @@ def _get_ids_srs(
name="agent_id",
dtype="uint64",
)
- elif isinstance(agents, AgentSetRegistry):
- return self._srs_constructor(agents._ids, name="agent_id", dtype="uint64")
+ elif isinstance(agents, AbstractAgentSetRegistry):
+ return self._srs_constructor(agents.ids, name="agent_id", dtype="uint64")
elif isinstance(agents, Collection) and (
- isinstance(agents[0], AbstractAgentSetRegistry)
+ isinstance(agents[0], AbstractAgentSet)
+ or isinstance(agents[0], AbstractAgentSetRegistry)
):
ids = []
for a in agents:
@@ -493,9 +520,9 @@ def _get_ids_srs(
dtype="uint64",
)
)
- elif isinstance(a, AgentSetRegistry):
+ elif isinstance(a, AbstractAgentSetRegistry):
ids.append(
- self._srs_constructor(a._ids, name="agent_id", dtype="uint64")
+ self._srs_constructor(a.ids, name="agent_id", dtype="uint64")
)
return self._df_concat(ids, ignore_index=True)
elif isinstance(agents, int):
@@ -657,7 +684,9 @@ def move_to_empty(
self,
agents: IdsLike
| AbstractAgentSetRegistry
- | Collection[AbstractAgentSetRegistry],
+ | Collection[AbstractAgentSetRegistry]
+ | AbstractAgentSet
+ | Collection[AbstractAgentSet],
inplace: bool = True,
) -> Self:
obj = self._get_obj(inplace)
@@ -668,7 +697,9 @@ def move_to_empty(
def move_to_available(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
inplace: bool = True,
) -> Self:
@@ -676,7 +707,7 @@ def move_to_available(
Parameters
----------
- agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry]
+ agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry]
The agents to move to available cells/positions
inplace : bool, optional
Whether to perform the operation inplace, by default True
@@ -686,6 +717,7 @@ def move_to_available(
Self
"""
obj = self._get_obj(inplace)
+
return obj._place_or_move_agents_to_cells(
agents, cell_type="available", is_move=True
)
@@ -693,11 +725,14 @@ def move_to_available(
def place_to_empty(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
inplace: bool = True,
) -> Self:
obj = self._get_obj(inplace)
+
return obj._place_or_move_agents_to_cells(
agents, cell_type="empty", is_move=False
)
@@ -705,7 +740,9 @@ def place_to_empty(
def place_to_available(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
inplace: bool = True,
) -> Self:
@@ -823,7 +860,9 @@ def get_neighborhood(
radius: int | float | Sequence[int] | Sequence[float] | ArrayLike,
pos: DiscreteCoordinate | DiscreteCoordinates | None = None,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry] = None,
include_center: bool = False,
) -> DataFrame:
@@ -837,7 +876,7 @@ def get_neighborhood(
The radius(es) of the neighborhoods
pos : DiscreteCoordinate | DiscreteCoordinates | None, optional
The coordinates of the cell(s) to get the neighborhood from
- agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry], optional
+ agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], optional
The agent(s) to get the neighborhood from
include_center : bool, optional
If the cell in the center of the neighborhood should be included in the result, by default False
@@ -933,7 +972,9 @@ def _check_cells(
def _place_or_move_agents_to_cells(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
cell_type: Literal["any", "empty", "available"],
is_move: bool,
@@ -942,8 +983,8 @@ def _place_or_move_agents_to_cells(
agents = self._get_ids_srs(agents)
if __debug__:
- # Check ids presence in model
- b_contained = self.model.sets.contains(agents)
+ # Check ids presence in model using public API
+ b_contained = agents.is_in(self.model.sets.ids)
if (isinstance(b_contained, Series) and not b_contained.all()) or (
isinstance(b_contained, bool) and not b_contained
):
@@ -994,7 +1035,7 @@ def _sample_cells(
self,
n: int | None,
with_replacement: bool,
- condition: Callable[[DiscreteSpaceCapacity], BoolSeries],
+ condition: Callable[[DiscreteSpaceCapacity], BoolSeries | np.ndarray],
respect_capacity: bool = True,
) -> DataFrame:
"""Sample cells from the grid according to a condition on the capacity.
@@ -1005,7 +1046,7 @@ def _sample_cells(
The number of cells to sample. If None, samples the maximum available.
with_replacement : bool
If the sampling should be with replacement
- condition : Callable[[DiscreteSpaceCapacity], BoolSeries]
+ condition : Callable[[DiscreteSpaceCapacity], BoolSeries | np.ndarray]
The condition to apply on the capacity
respect_capacity : bool, optional
If the capacity should be respected in the sampling.
@@ -1259,11 +1300,15 @@ def get_directions(
pos0: GridCoordinate | GridCoordinates | None = None,
pos1: GridCoordinate | GridCoordinates | None = None,
agents0: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None = None,
agents1: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None = None,
normalize: bool = False,
@@ -1278,11 +1323,15 @@ def get_distances(
pos0: GridCoordinate | GridCoordinates | None = None,
pos1: GridCoordinate | GridCoordinates | None = None,
agents0: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None = None,
agents1: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None = None,
) -> DataFrame:
@@ -1311,7 +1360,7 @@ def get_neighbors(
def get_neighborhood(
self,
radius: int | Sequence[int] | ArrayLike,
- pos: GridCoordinate | GridCoordinates | None = None,
+ pos: DiscreteCoordinate | DiscreteCoordinates | None = None,
agents: IdsLike
| AbstractAgentSetRegistry
| Collection[AbstractAgentSetRegistry]
@@ -1549,7 +1598,9 @@ def out_of_bounds(self, pos: GridCoordinate | GridCoordinates) -> DataFrame:
def remove_agents(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
inplace: bool = True,
) -> Self:
@@ -1558,8 +1609,8 @@ def remove_agents(
agents = obj._get_ids_srs(agents)
if __debug__:
- # Check ids presence in model
- b_contained = obj.model.sets.contains(agents)
+ # Check ids presence in model via public ids
+ b_contained = agents.is_in(obj.model.sets.ids)
if (isinstance(b_contained, Series) and not b_contained.all()) or (
isinstance(b_contained, bool) and not b_contained
):
@@ -1594,11 +1645,15 @@ def _calculate_differences(
pos0: GridCoordinate | GridCoordinates | None,
pos1: GridCoordinate | GridCoordinates | None,
agents0: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None,
agents1: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None,
) -> DataFrame:
@@ -1610,9 +1665,9 @@ def _calculate_differences(
The starting positions
pos1 : GridCoordinate | GridCoordinates | None
The ending positions
- agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None
+ agents0 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None
The starting agents
- agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None
+ agents1 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None
The ending agents
Returns
@@ -1694,7 +1749,9 @@ def _get_df_coords(
self,
pos: GridCoordinate | GridCoordinates | None = None,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry]
| None = None,
check_bounds: bool = True,
@@ -1705,7 +1762,7 @@ def _get_df_coords(
----------
pos : GridCoordinate | GridCoordinates | None, optional
The positions to get the DataFrame from, by default None
- agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional
+ agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional
The agents to get the DataFrame from, by default None
check_bounds: bool, optional
If the positions should be checked for out-of-bounds in non-toroidal grids, by default True
@@ -1735,7 +1792,7 @@ def _get_df_coords(
if agents is not None:
agents = self._get_ids_srs(agents)
# Check ids presence in model
- b_contained = self.model.sets.contains(agents)
+ b_contained = agents.is_in(self.model.sets.ids)
if (isinstance(b_contained, Series) and not b_contained.all()) or (
isinstance(b_contained, bool) and not b_contained
):
@@ -1796,7 +1853,9 @@ def _get_df_coords(
def _place_or_move_agents(
self,
agents: IdsLike
+ | AbstractAgentSet
| AbstractAgentSetRegistry
+ | Collection[AbstractAgentSet]
| Collection[AbstractAgentSetRegistry],
pos: GridCoordinate | GridCoordinates,
is_move: bool,
@@ -1812,8 +1871,8 @@ def _place_or_move_agents(
if self._df_contains(self._agents, "agent_id", agents).any():
warn("Some agents are already present in the grid", RuntimeWarning)
- # Check if agents are present in the model
- b_contained = self.model.sets.contains(agents)
+ # Check if agents are present in the model using the public ids
+ b_contained = agents.is_in(self.model.sets.ids)
if (isinstance(b_contained, Series) and not b_contained.all()) or (
isinstance(b_contained, bool) and not b_contained
):
diff --git a/mesa_frames/concrete/agentset.py b/mesa_frames/concrete/agentset.py
index 5c64aef6..2a9b1a55 100644
--- a/mesa_frames/concrete/agentset.py
+++ b/mesa_frames/concrete/agentset.py
@@ -67,7 +67,7 @@ def step(self):
from mesa_frames.abstract.agentset import AbstractAgentSet
from mesa_frames.concrete.mixin import PolarsMixin
-from mesa_frames.types_ import AgentPolarsMask, IntoExpr, PolarsIdsLike
+from mesa_frames.types_ import AgentMask, AgentPolarsMask, IntoExpr, PolarsIdsLike
from mesa_frames.utils import copydoc
@@ -82,19 +82,76 @@ class AgentSet(AbstractAgentSet, PolarsMixin):
_copy_only_reference: list[str] = ["_model", "_mask"]
_mask: pl.Expr | pl.Series
- def __init__(self, model: mesa_frames.concrete.model.Model) -> None:
+ def __init__(
+ self, model: mesa_frames.concrete.model.Model, name: str | None = None
+ ) -> None:
"""Initialize a new AgentSet.
Parameters
----------
model : "mesa_frames.concrete.model.Model"
The model that the agent set belongs to.
+ name : str | None, optional
+ Name for this agent set. If None, class name is used.
+ Will be converted to snake_case if in camelCase.
"""
+ # Model reference
self._model = model
+ # Set proposed name (no uniqueness guarantees here)
+ self._name = name if name is not None else self.__class__.__name__
# No definition of schema with unique_id, as it becomes hard to add new agents
self._df = pl.DataFrame()
self._mask = pl.repeat(True, len(self._df), dtype=pl.Boolean, eager=True)
+ def rename(self, new_name: str, inplace: bool = True) -> Self:
+ """Rename this agent set. If attached to AgentSetRegistry, delegate for uniqueness enforcement.
+
+ Parameters
+ ----------
+ new_name : str
+ Desired new name.
+
+ inplace : bool, optional
+ Whether to perform the rename in place. If False, a renamed copy is
+ returned, by default True.
+
+ Returns
+ -------
+ Self
+ The updated AgentSet (or a renamed copy when ``inplace=False``).
+
+ Raises
+ ------
+ ValueError
+ If name conflicts occur and delegate encounters errors.
+ """
+ # Respect inplace semantics consistently with other mutators
+ obj = self._get_obj(inplace)
+
+ # Always delegate to the container's accessor if available through the model's sets
+ # Check if we have a model and can find the AgentSetRegistry that contains this set
+ try:
+ if self in self.model.sets:
+ # Save index to locate the copy on non-inplace path
+ try:
+ idx = list(self.model.sets).index(self) # type: ignore[arg-type]
+ except Exception:
+ idx = None
+ reg = self.model.sets.rename(self, new_name, inplace=inplace)
+ if inplace:
+ return self
+ if idx is not None:
+ return reg[idx]
+ return reg.get(new_name) # type: ignore[return-value]
+ except Exception:
+ # Fall back to local rename if delegation fails
+ obj._name = new_name
+ return obj
+
+ # Set name locally if no container found
+ obj._name = new_name
+ return obj
+
def add(
self,
agents: pl.DataFrame | Sequence[Any] | dict[str, Any],
@@ -181,6 +238,64 @@ def contains(
else:
return agents in self._df["unique_id"]
+ @overload
+ def do(
+ self,
+ method_name: str,
+ *args,
+ mask: AgentMask | None = None,
+ return_results: Literal[False] = False,
+ inplace: bool = True,
+ **kwargs,
+ ) -> Self: ...
+
+ @overload
+ def do(
+ self,
+ method_name: str,
+ *args,
+ mask: AgentMask | None = None,
+ return_results: Literal[True],
+ inplace: bool = True,
+ **kwargs,
+ ) -> Any: ...
+
+ def do(
+ self,
+ method_name: str,
+ *args,
+ mask: AgentMask | None = None,
+ return_results: bool = False,
+ inplace: bool = True,
+ **kwargs,
+ ) -> Self | Any:
+ masked_df = self._get_masked_df(mask)
+ # If the mask is empty, we can use the object as is
+ if len(masked_df) == len(self._df):
+ obj = self._get_obj(inplace)
+ method = getattr(obj, method_name)
+ result = method(*args, **kwargs)
+ else: # If the mask is not empty, we need to create a new masked AbstractAgentSet and concatenate the AbstractAgentSets at the end
+ obj = self._get_obj(inplace=False)
+ obj._df = masked_df
+ original_masked_index = obj._get_obj_copy(obj.index)
+ method = getattr(obj, method_name)
+ result = method(*args, **kwargs)
+ obj._concatenate_agentsets(
+ [self],
+ duplicates_allowed=True,
+ keep_first_only=True,
+ original_masked_index=original_masked_index,
+ )
+ if inplace:
+ for key, value in obj.__dict__.items():
+ setattr(self, key, value)
+ obj = self
+ if return_results:
+ return result
+ else:
+ return obj
+
def get(
self,
attr_names: IntoExpr | Iterable[IntoExpr] | None,
@@ -198,6 +313,20 @@ def get(
return masked_df[masked_df.columns[0]]
return masked_df
+ def remove(self, agents: PolarsIdsLike | AgentMask, inplace: bool = True) -> Self:
+ if isinstance(agents, str) and agents == "active":
+ agents = self.active_agents
+ if agents is None or (isinstance(agents, Iterable) and len(agents) == 0):
+ return self._get_obj(inplace)
+ obj = self._get_obj(inplace)
+ # Normalize to Series of unique_ids
+ ids = obj._df_index(obj._get_masked_df(agents), "unique_id")
+ # Validate presence
+ if not ids.is_in(obj._df["unique_id"]).all():
+ raise KeyError("Some 'unique_id' of mask are not present in this AgentSet.")
+ # Remove by ids
+ return obj._discard(ids)
+
def set(
self,
attr_names: str | Collection[str] | dict[str, Any] | None = None,
@@ -463,7 +592,9 @@ def _update_mask(
else:
self._mask = self._df["unique_id"].is_in(original_active_indices)
- def __getattr__(self, key: str) -> pl.Series:
+ def __getattr__(self, key: str) -> Any:
+ if key == "name":
+ return self.name
super().__getattr__(key)
return self._df[key]
@@ -546,3 +677,13 @@ def index(self) -> pl.Series:
@property
def pos(self) -> pl.DataFrame:
return super().pos
+
+ @property
+ def name(self) -> str:
+ """Return the name of the AgentSet."""
+ return self._name
+
+ @name.setter
+ def name(self, value: str) -> None:
+ """Set the name of the AgentSet."""
+ self._name = value
diff --git a/mesa_frames/concrete/agentsetregistry.py b/mesa_frames/concrete/agentsetregistry.py
index b9ed1563..4b486ba2 100644
--- a/mesa_frames/concrete/agentsetregistry.py
+++ b/mesa_frames/concrete/agentsetregistry.py
@@ -46,26 +46,17 @@ def step(self):
from __future__ import annotations # For forward references
-from collections import defaultdict
-from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
-from typing import Any, Literal, Self, cast, overload
-
-import numpy as np
+from collections.abc import Collection, Iterable, Iterator, Sequence
+from typing import Any, Literal, Self, overload, cast
+from collections.abc import Sized
+from itertools import chain
import polars as pl
from mesa_frames.abstract.agentsetregistry import (
AbstractAgentSetRegistry,
)
from mesa_frames.concrete.agentset import AgentSet
-from mesa_frames.types_ import (
- AgentMask,
- AgnosticAgentMask,
- BoolSeries,
- DataFrame,
- IdsLike,
- Index,
- Series,
-)
+from mesa_frames.types_ import BoolSeries, KeyBy, AgentSetSelector
class AgentSetRegistry(AbstractAgentSetRegistry):
@@ -88,34 +79,31 @@ def __init__(self, model: mesa_frames.concrete.model.Model) -> None:
def add(
self,
- agents: AgentSet | Iterable[AgentSet],
+ sets: AgentSet | Iterable[AgentSet],
inplace: bool = True,
) -> Self:
- """Add an AgentSet to the AgentSetRegistry.
-
- Parameters
- ----------
- agents : AgentSet | Iterable[AgentSet]
- The AgentSets to add.
- inplace : bool, optional
- Whether to add the AgentSets in place. Defaults to True.
-
- Returns
- -------
- Self
- The updated AgentSetRegistry.
-
- Raises
- ------
- ValueError
- If any AgentSets are already present or if IDs are not unique.
- """
obj = self._get_obj(inplace)
- other_list = obj._return_agentsets_list(agents)
+ other_list = obj._return_agentsets_list(sets)
if obj._check_agentsets_presence(other_list).any():
raise ValueError(
"Some agentsets are already present in the AgentSetRegistry."
)
+ # Ensure unique names across existing and to-be-added sets
+ existing_names = {s.name for s in obj._agentsets}
+ for agentset in other_list:
+ base_name = agentset.name or agentset.__class__.__name__
+ name = base_name
+ if name in existing_names:
+ counter = 1
+ candidate = f"{base_name}_{counter}"
+ while candidate in existing_names:
+ counter += 1
+ candidate = f"{base_name}_{counter}"
+ name = candidate
+ # Assign back if changed or was None
+ if name != (agentset.name or base_name):
+ agentset.name = name
+ existing_names.add(name)
new_ids = pl.concat(
[obj._ids] + [pl.Series(agentset["unique_id"]) for agentset in other_list]
)
@@ -125,215 +113,427 @@ def add(
obj._ids = new_ids
return obj
+ def rename(
+ self,
+ target: (
+ AgentSet
+ | str
+ | dict[AgentSet | str, str]
+ | list[tuple[AgentSet | str, str]]
+ ),
+ new_name: str | None = None,
+ *,
+ on_conflict: Literal["canonicalize", "raise"] = "canonicalize",
+ mode: Literal["atomic", "best_effort"] = "atomic",
+ inplace: bool = True,
+ ) -> Self:
+ """Rename AgentSets with conflict handling.
+
+ Supports single-target ``(set | old_name, new_name)`` and batch rename via
+ dict or list of pairs. Names remain unique across the registry.
+ """
+
+ # Normalize to list of (index_in_self, desired_name) using the original registry
+ def _resolve_one(x: AgentSet | str) -> int:
+ if isinstance(x, AgentSet):
+ for i, s in enumerate(self._agentsets):
+ if s is x:
+ return i
+ raise KeyError("AgentSet not found in registry")
+ # name lookup on original registry
+ for i, s in enumerate(self._agentsets):
+ if s.name == x:
+ return i
+ raise KeyError(f"Agent set '{x}' not found")
+
+ if isinstance(target, (AgentSet, str)):
+ if new_name is None:
+ raise TypeError("new_name must be provided for single rename")
+ pairs_idx: list[tuple[int, str]] = [(_resolve_one(target), new_name)]
+ single = True
+ elif isinstance(target, dict):
+ pairs_idx = [(_resolve_one(k), v) for k, v in target.items()]
+ single = False
+ else:
+ pairs_idx = [(_resolve_one(k), v) for k, v in target]
+ single = False
+
+ # Choose object to mutate
+ obj = self._get_obj(inplace)
+ # Translate indices to object AgentSets in the selected registry object
+ target_sets = [obj._agentsets[i] for i, _ in pairs_idx]
+
+ # Build the set of names that remain fixed (exclude targets' current names)
+ targets_set = set(target_sets)
+ fixed_names: set[str] = {
+ s.name
+ for s in obj._agentsets
+ if s.name is not None and s not in targets_set
+ } # type: ignore[comparison-overlap]
+
+ # Plan final names
+ final: list[tuple[AgentSet, str]] = []
+ used = set(fixed_names)
+
+ def _canonicalize(base: str) -> str:
+ if base not in used:
+ used.add(base)
+ return base
+ counter = 1
+ cand = f"{base}_{counter}"
+ while cand in used:
+ counter += 1
+ cand = f"{base}_{counter}"
+ used.add(cand)
+ return cand
+
+ errors: list[Exception] = []
+ for aset, (_idx, desired) in zip(target_sets, pairs_idx):
+ if on_conflict == "canonicalize":
+ final_name = _canonicalize(desired)
+ final.append((aset, final_name))
+ else: # on_conflict == 'raise'
+ if desired in used:
+ err = ValueError(
+ f"Duplicate agent set name disallowed: '{desired}'"
+ )
+ if mode == "atomic":
+ errors.append(err)
+ else:
+ # best_effort: skip this rename
+ continue
+ else:
+ used.add(desired)
+ final.append((aset, desired))
+
+ if errors and mode == "atomic":
+ # Surface first meaningful error
+ raise errors[0]
+
+ # Apply renames
+ for aset, newn in final:
+ # Set the private name directly to avoid external uniqueness hooks
+ if hasattr(aset, "_name"):
+ aset._name = newn # type: ignore[attr-defined]
+
+ return obj
+
+ def replace(
+ self,
+ mapping: (dict[int | str, AgentSet] | list[tuple[int | str, AgentSet]]),
+ *,
+ inplace: bool = True,
+ atomic: bool = True,
+ ) -> Self:
+ # Normalize to list of (key, value)
+ items: list[tuple[int | str, AgentSet]]
+ if isinstance(mapping, dict):
+ items = list(mapping.items())
+ else:
+ items = list(mapping)
+
+ obj = self._get_obj(inplace)
+
+ # Helpers (build name->idx map only if needed)
+ has_str_keys = any(isinstance(k, str) for k, _ in items)
+ if has_str_keys:
+ name_to_idx = {
+ s.name: i for i, s in enumerate(obj._agentsets) if s.name is not None
+ }
+
+ def _find_index_by_name(name: str) -> int:
+ try:
+ return name_to_idx[name]
+ except KeyError:
+ raise KeyError(f"Agent set '{name}' not found")
+ else:
+
+ def _find_index_by_name(name: str) -> int:
+ for i, s in enumerate(obj._agentsets):
+ if s.name == name:
+ return i
+
+ raise KeyError(f"Agent set '{name}' not found")
+
+ if atomic:
+ n = len(obj._agentsets)
+ # Map existing object identity -> index (for aliasing checks)
+ id_to_idx = {id(s): i for i, s in enumerate(obj._agentsets)}
+
+ for k, v in items:
+ if not isinstance(v, AgentSet):
+ raise TypeError("Values must be AgentSet instances")
+ if v.model is not obj.model:
+ raise TypeError(
+ "All AgentSets must belong to the same model as the registry"
+ )
+
+ v_idx_existing = id_to_idx.get(id(v))
+
+ if isinstance(k, int):
+ if not (0 <= k < n):
+ raise IndexError(
+ f"Index {k} out of range for AgentSetRegistry of size {n}"
+ )
+
+ # Prevent aliasing: the same object cannot appear in two positions
+ if v_idx_existing is not None and v_idx_existing != k:
+ raise ValueError(
+ f"This AgentSet instance already exists at index {v_idx_existing}; cannot also place it at {k}."
+ )
+
+ # Preserve name uniqueness when assigning by index
+ vname = v.name
+ if vname is not None:
+ try:
+ other_idx = _find_index_by_name(vname)
+ if other_idx != k:
+ raise ValueError(
+ f"Duplicate agent set name disallowed: '{vname}' already at index {other_idx}"
+ )
+ except KeyError:
+ # name not present elsewhere -> OK
+ pass
+
+ elif isinstance(k, str):
+ # Locate the slot by name; replacing that slot preserves uniqueness
+ idx = _find_index_by_name(k)
+
+ # Prevent aliasing: if the same object already exists at a different slot, forbid
+ if v_idx_existing is not None and v_idx_existing != idx:
+ raise ValueError(
+ f"This AgentSet instance already exists at index {v_idx_existing}; cannot also place it at {idx}."
+ )
+
+ else:
+ raise TypeError("Keys must be int indices or str names")
+
+ # Apply
+ target = obj if inplace else obj.copy(deep=False)
+ if not inplace:
+ target._agentsets = list(obj._agentsets)
+
+ for k, v in items:
+ if isinstance(k, int):
+ target._agentsets[k] = v # keep v.name as-is (validated above)
+ else:
+ idx = _find_index_by_name(k)
+ # Force the authoritative name without triggering external uniqueness checks
+ if hasattr(v, "_name"):
+ v._name = k # type: ignore[attr-defined]
+ target._agentsets[idx] = v
+
+ # Recompute ids cache
+ target._recompute_ids()
+
+ return target
+
@overload
- def contains(self, agents: int | AgentSet) -> bool: ...
+ def contains(self, sets: AgentSet | type[AgentSet] | str) -> bool: ...
@overload
- def contains(self, agents: IdsLike | Iterable[AgentSet]) -> pl.Series: ...
+ def contains(
+ self,
+ sets: Iterable[AgentSet] | Iterable[type[AgentSet]] | Iterable[str],
+ ) -> pl.Series: ...
def contains(
- self, agents: IdsLike | AgentSet | Iterable[AgentSet]
+ self,
+ sets: AgentSet
+ | type[AgentSet]
+ | str
+ | Iterable[AgentSet]
+ | Iterable[type[AgentSet]]
+ | Iterable[str],
) -> bool | pl.Series:
- if isinstance(agents, int):
- return agents in self._ids
- elif isinstance(agents, AgentSet):
- return self._check_agentsets_presence([agents]).any()
- elif isinstance(agents, Iterable):
- if len(agents) == 0:
- return True
- elif isinstance(next(iter(agents)), AgentSet):
- agents = cast(Iterable[AgentSet], agents)
- return self._check_agentsets_presence(list(agents))
- else: # IdsLike
- agents = cast(IdsLike, agents)
-
- return pl.Series(agents, dtype=pl.UInt64).is_in(self._ids)
+ # Single value fast paths
+ if isinstance(sets, AgentSet):
+ return self._check_agentsets_presence([sets]).any()
+ if isinstance(sets, type) and issubclass(sets, AgentSet):
+ return any(isinstance(s, sets) for s in self._agentsets)
+ if isinstance(sets, str):
+ return any(s.name == sets for s in self._agentsets)
+
+ # Iterable paths without materializing unnecessarily
+
+ if isinstance(sets, Sized) and len(sets) == 0: # type: ignore[arg-type]
+ return True
+ it = iter(sets) # type: ignore[arg-type]
+ try:
+ first = next(it)
+ except StopIteration:
+ return True
+
+ if isinstance(first, AgentSet):
+ lst = [first, *it]
+ return self._check_agentsets_presence(lst)
+
+ if isinstance(first, type) and issubclass(first, AgentSet):
+ present_types = {type(s) for s in self._agentsets}
+
+ def has_type(t: type[AgentSet]) -> bool:
+ return any(issubclass(pt, t) for pt in present_types)
+
+ return pl.Series(
+ (has_type(t) for t in chain([first], it)), dtype=pl.Boolean
+ )
+
+ if isinstance(first, str):
+ names = {s.name for s in self._agentsets if s.name is not None}
+ return pl.Series((x in names for x in chain([first], it)), dtype=pl.Boolean)
+
+ raise TypeError("Unsupported type for contains()")
@overload
def do(
self,
method_name: str,
- *args,
- mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None,
+ *args: Any,
+ sets: AgentSetSelector | None = None,
return_results: Literal[False] = False,
inplace: bool = True,
- **kwargs,
+ key_by: KeyBy = "name",
+ **kwargs: Any,
) -> Self: ...
@overload
def do(
self,
method_name: str,
- *args,
- mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None,
+ *args: Any,
+ sets: AgentSetSelector,
return_results: Literal[True],
inplace: bool = True,
- **kwargs,
- ) -> dict[AgentSet, Any]: ...
+ key_by: KeyBy = "name",
+ **kwargs: Any,
+ ) -> dict[str, Any] | dict[int, Any] | dict[type[AgentSet], Any]: ...
def do(
self,
method_name: str,
- *args,
- mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None,
+ *args: Any,
+ sets: AgentSetSelector = None,
return_results: bool = False,
inplace: bool = True,
- **kwargs,
+ key_by: KeyBy = "name",
+ **kwargs: Any,
) -> Self | Any:
obj = self._get_obj(inplace)
- agentsets_masks = obj._get_bool_masks(mask)
+ target_sets = obj._resolve_selector(sets)
+
+ if not target_sets:
+ return {} if return_results else obj
+
+ index_lookup = {id(s): idx for idx, s in enumerate(obj._agentsets)}
+
if return_results:
- return {
- agentset: agentset.do(
+
+ def make_key(agentset: AgentSet) -> Any:
+ if key_by == "name":
+ return agentset.name
+ if key_by == "index":
+ try:
+ return index_lookup[id(agentset)]
+ except KeyError as exc: # pragma: no cover - defensive
+ raise ValueError(
+ "AgentSet not found in registry; cannot key by index."
+ ) from exc
+ if key_by == "type":
+ return type(agentset)
+ return agentset # backward-compatible: key by object
+
+ results: dict[Any, Any] = {}
+ for agentset in target_sets:
+ key = make_key(agentset)
+ if key_by == "type" and key in results:
+ raise ValueError(
+ "Multiple agent sets of the same type were selected; "
+ "use key_by='name' or key_by='index' instead."
+ )
+ results[key] = agentset.do(
method_name,
*args,
- mask=mask,
- return_results=return_results,
- **kwargs,
+ return_results=True,
inplace=inplace,
- )
- for agentset, mask in agentsets_masks.items()
- }
- else:
- obj._agentsets = [
- agentset.do(
- method_name,
- *args,
- mask=mask,
- return_results=return_results,
**kwargs,
- inplace=inplace,
)
- for agentset, mask in agentsets_masks.items()
- ]
- return obj
+ return results
+
+ updates: list[tuple[int, AgentSet]] = []
+ for agentset in target_sets:
+ try:
+ registry_index = index_lookup[id(agentset)]
+ except KeyError as exc: # pragma: no cover - defensive
+ raise ValueError(
+ "AgentSet not found in registry; cannot apply operation."
+ ) from exc
+ updated = agentset.do(
+ method_name,
+ *args,
+ return_results=False,
+ inplace=inplace,
+ **kwargs,
+ )
+ updates.append((registry_index, updated))
- def get(
- self,
- attr_names: str | Collection[str] | None = None,
- mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None,
- ) -> dict[AgentSet, Series] | dict[AgentSet, DataFrame]:
- agentsets_masks = self._get_bool_masks(mask)
- result = {}
-
- # Convert attr_names to list for consistent checking
- if attr_names is None:
- # None means get all data - no column filtering needed
- required_columns = []
- elif isinstance(attr_names, str):
- required_columns = [attr_names]
- else:
- required_columns = list(attr_names)
+ for registry_index, updated in updates:
+ obj._agentsets[registry_index] = updated
+ obj._recompute_ids()
+ return obj
- for agentset, mask in agentsets_masks.items():
- # Fast column existence check - no data processing, just property access
- agentset_columns = agentset.df.columns
+ @overload
+ def get(self, key: int, default: None = ...) -> AgentSet | None: ...
- # Check if all required columns exist in this agent set
- if not required_columns or all(
- col in agentset_columns for col in required_columns
- ):
- result[agentset] = agentset.get(attr_names, mask)
+ @overload
+ def get(self, key: str, default: None = ...) -> AgentSet | None: ...
- return result
+ @overload
+ def get(self, key: type[AgentSet], default: None = ...) -> list[AgentSet]: ...
- def remove(
+ @overload
+ def get(
self,
- agents: AgentSet | Iterable[AgentSet] | IdsLike,
- inplace: bool = True,
- ) -> Self:
- obj = self._get_obj(inplace)
- if agents is None or (isinstance(agents, Iterable) and len(agents) == 0):
- return obj
- if isinstance(agents, AgentSet):
- agents = [agents]
- if isinstance(agents, Iterable) and isinstance(next(iter(agents)), AgentSet):
- # We have to get the index of the original AgentSet because the copy made AgentSets with different hash
- ids = [self._agentsets.index(agentset) for agentset in iter(agents)]
- ids.sort(reverse=True)
- removed_ids = pl.Series(dtype=pl.UInt64)
- for id in ids:
- removed_ids = pl.concat(
- [
- removed_ids,
- pl.Series(obj._agentsets[id]["unique_id"], dtype=pl.UInt64),
- ]
- )
- obj._agentsets.pop(id)
-
- else: # IDsLike
- if isinstance(agents, (int, np.uint64)):
- agents = [agents]
- elif isinstance(agents, DataFrame):
- agents = agents["unique_id"]
- removed_ids = pl.Series(agents, dtype=pl.UInt64)
- deleted = 0
-
- for agentset in obj._agentsets:
- initial_len = len(agentset)
- agentset._discard(removed_ids)
- deleted += initial_len - len(agentset)
- if deleted == len(removed_ids):
- break
- if deleted < len(removed_ids): # TODO: fix type hint
- raise KeyError(
- "There exist some IDs which are not present in any agentset"
- )
- try:
- obj.space.remove_agents(removed_ids, inplace=True)
- except ValueError:
- pass
- obj._ids = obj._ids.filter(obj._ids.is_in(removed_ids).not_())
- return obj
+ key: int | str | type[AgentSet],
+ default: AgentSet | list[AgentSet] | None,
+ ) -> AgentSet | list[AgentSet] | None: ...
- def select(
+ def get(
self,
- mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None,
- filter_func: Callable[[AgentSet], AgentMask] | None = None,
- n: int | None = None,
- inplace: bool = True,
- negate: bool = False,
- ) -> Self:
- obj = self._get_obj(inplace)
- agentsets_masks = obj._get_bool_masks(mask)
- if n is not None:
- n = n // len(agentsets_masks)
- obj._agentsets = [
- agentset.select(
- mask=mask, filter_func=filter_func, n=n, negate=negate, inplace=inplace
- )
- for agentset, mask in agentsets_masks.items()
- ]
- return obj
+ key: int | str | type[AgentSet],
+ default: AgentSet | list[AgentSet] | None = None,
+ ) -> AgentSet | list[AgentSet] | None:
+ try:
+ if isinstance(key, int):
+ return self._agentsets[key]
+ if isinstance(key, str):
+ for s in self._agentsets:
+ if s.name == key:
+ return s
+ return default
+ if isinstance(key, type) and issubclass(key, AgentSet):
+ return [s for s in self._agentsets if isinstance(s, key)]
+ except (IndexError, KeyError, TypeError):
+ return default
+ return default
- def set(
+ def remove(
self,
- attr_names: str | dict[AgentSet, Any] | Collection[str],
- values: Any | None = None,
- mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None,
+ sets: AgentSetSelector,
inplace: bool = True,
) -> Self:
obj = self._get_obj(inplace)
- agentsets_masks = obj._get_bool_masks(mask)
- if isinstance(attr_names, dict):
- for agentset, values in attr_names.items():
- if not inplace:
- # We have to get the index of the original AgentSet because the copy made AgentSets with different hash
- id = self._agentsets.index(agentset)
- agentset = obj._agentsets[id]
- agentset.set(
- attr_names=values, mask=agentsets_masks[agentset], inplace=True
- )
- else:
- obj._agentsets = [
- agentset.set(
- attr_names=attr_names, values=values, mask=mask, inplace=True
- )
- for agentset, mask in agentsets_masks.items()
- ]
+ # Normalize to a list of AgentSet instances using _resolve_selector
+ selected = obj._resolve_selector(sets) # type: ignore[arg-type]
+ # Remove in reverse positional order
+ indices = [i for i, s in enumerate(obj._agentsets) if s in selected]
+ indices.sort(reverse=True)
+ for idx in indices:
+ obj._agentsets.pop(idx)
+ # Recompute ids cache
+ obj._recompute_ids()
return obj
- def shuffle(self, inplace: bool = True) -> Self:
+ def shuffle(self, inplace: bool = False) -> Self:
obj = self._get_obj(inplace)
obj._agentsets = [agentset.shuffle(inplace=True) for agentset in obj._agentsets]
return obj
@@ -343,7 +543,7 @@ def sort(
by: str | Sequence[str],
ascending: bool | Sequence[bool] = True,
inplace: bool = True,
- **kwargs,
+ **kwargs: Any,
) -> Self:
obj = self._get_obj(inplace)
obj._agentsets = [
@@ -352,23 +552,6 @@ def sort(
]
return obj
- def step(self, inplace: bool = True) -> Self:
- """Advance the state of the agents in the AgentSetRegistry by one step.
-
- Parameters
- ----------
- inplace : bool, optional
- Whether to update the AgentSetRegistry in place, by default True
-
- Returns
- -------
- Self
- """
- obj = self._get_obj(inplace)
- for agentset in obj._agentsets:
- agentset.step()
- return obj
-
def _check_ids_presence(self, other: list[AgentSet]) -> pl.DataFrame:
"""Check if the IDs of the agents to be added are unique.
@@ -425,17 +608,54 @@ def _check_agentsets_presence(self, other: list[AgentSet]) -> pl.Series:
[agentset in other_set for agentset in self._agentsets], dtype=pl.Boolean
)
- def _get_bool_masks(
- self,
- mask: (AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask]) = None,
- ) -> dict[AgentSet, BoolSeries]:
- return_dictionary = {}
- if not isinstance(mask, dict):
- # No need to convert numpy integers - let polars handle them directly
- mask = {agentset: mask for agentset in self._agentsets}
- for agentset, mask_value in mask.items():
- return_dictionary[agentset] = agentset._get_bool_mask(mask_value)
- return return_dictionary
+ def _recompute_ids(self) -> None:
+ """Rebuild the registry-level `unique_id` cache from current AgentSets.
+
+ Ensures `self._ids` stays a `pl.UInt64` Series and empty when no sets.
+ """
+ if self._agentsets:
+ cols = [pl.Series(s["unique_id"]) for s in self._agentsets]
+ self._ids = (
+ pl.concat(cols)
+ if cols
+ else pl.Series(name="unique_id", dtype=pl.UInt64)
+ )
+ else:
+ self._ids = pl.Series(name="unique_id", dtype=pl.UInt64)
+
+ def _resolve_selector(self, selector: AgentSetSelector = None) -> list[AgentSet]:
+ """Resolve a selector (instance/type/name or collection) to a list of AgentSets."""
+ if selector is None:
+ return list(self._agentsets)
+ # Single instance
+ if isinstance(selector, AgentSet):
+ return [selector] if selector in self._agentsets else []
+ # Single type
+ if isinstance(selector, type) and issubclass(selector, AgentSet):
+ return [s for s in self._agentsets if isinstance(s, selector)]
+ # Single name
+ if isinstance(selector, str):
+ return [s for s in self._agentsets if s.name == selector]
+ # Collection of mixed selectors
+ selected: list[AgentSet] = []
+ for item in selector: # type: ignore[assignment]
+ if isinstance(item, AgentSet):
+ if item in self._agentsets:
+ selected.append(item)
+ elif isinstance(item, type) and issubclass(item, AgentSet):
+ selected.extend([s for s in self._agentsets if isinstance(s, item)])
+ elif isinstance(item, str):
+ selected.extend([s for s in self._agentsets if s.name == item])
+ else:
+ raise TypeError("Unsupported selector element type")
+ # Deduplicate while preserving order
+ seen = set()
+ result = []
+ for s in selected:
+ if s not in seen:
+ seen.add(s)
+ result.append(s)
+ return result
def _return_agentsets_list(
self, agentsets: AgentSet | Iterable[AgentSet]
@@ -452,192 +672,67 @@ def _return_agentsets_list(
"""
return [agentsets] if isinstance(agentsets, AgentSet) else list(agentsets)
- def __add__(self, other: AgentSet | Iterable[AgentSet]) -> Self:
- """Add AgentSets to a new AgentSetRegistry through the + operator.
-
- Parameters
- ----------
- other : AgentSet | Iterable[AgentSet]
- The AgentSets to add.
-
- Returns
- -------
- Self
- A new AgentSetRegistry with the added AgentSets.
- """
- return super().__add__(other)
-
- def __getattr__(self, name: str) -> dict[AgentSet, Any]:
+ def _generate_name(self, base_name: str) -> str:
+ """Generate a unique name for an agent set."""
+ existing_names = [
+ agentset.name for agentset in self._agentsets if agentset.name is not None
+ ]
+ if base_name not in existing_names:
+ return base_name
+ counter = 1
+ candidate = f"{base_name}_{counter}"
+ while candidate in existing_names:
+ counter += 1
+ candidate = f"{base_name}_{counter}"
+ return candidate
+
+ def __getattr__(self, name: str) -> Any | dict[str, Any]:
# Avoids infinite recursion of private attributes
- if __debug__: # Only execute in non-optimized mode
- if name.startswith("_"):
- raise AttributeError(
- f"'{self.__class__.__name__}' object has no attribute '{name}'"
- )
- return {agentset: getattr(agentset, name) for agentset in self._agentsets}
-
- @overload
- def __getitem__(
- self, key: str | tuple[dict[AgentSet, AgentMask], str]
- ) -> dict[AgentSet, Series | pl.Expr]: ...
-
- @overload
- def __getitem__(
- self,
- key: (
- Collection[str]
- | AgnosticAgentMask
- | IdsLike
- | tuple[dict[AgentSet, AgentMask], Collection[str]]
- ),
- ) -> dict[AgentSet, DataFrame]: ...
-
- def __getitem__(
- self,
- key: (
- str
- | Collection[str]
- | AgnosticAgentMask
- | IdsLike
- | tuple[dict[AgentSet, AgentMask], str]
- | tuple[dict[AgentSet, AgentMask], Collection[str]]
- ),
- ) -> dict[AgentSet, Series | pl.Expr] | dict[AgentSet, DataFrame]:
- return super().__getitem__(key)
-
- def __iadd__(self, agents: AgentSet | Iterable[AgentSet]) -> Self:
- """Add AgentSets to the AgentSetRegistry through the += operator.
-
- Parameters
- ----------
- agents : AgentSet | Iterable[AgentSet]
- The AgentSets to add.
-
- Returns
- -------
- Self
- The updated AgentSetRegistry.
- """
- return super().__iadd__(agents)
-
- def __iter__(self) -> Iterator[dict[str, Any]]:
- return (agent for agentset in self._agentsets for agent in iter(agentset))
-
- def __isub__(self, agents: AgentSet | Iterable[AgentSet] | IdsLike) -> Self:
- """Remove AgentSets from the AgentSetRegistry through the -= operator.
+ if name.startswith("_"):
+ raise AttributeError(
+ f"'{self.__class__.__name__}' object has no attribute '{name}'"
+ )
+ # Delegate attribute access to sets; map results by set name
+ return {cast(str, s.name): getattr(s, name) for s in self._agentsets}
- Parameters
- ----------
- agents : AgentSet | Iterable[AgentSet] | IdsLike
- The AgentSets or agent IDs to remove.
-
- Returns
- -------
- Self
- The updated AgentSetRegistry.
- """
- return super().__isub__(agents)
+ def __iter__(self) -> Iterator[AgentSet]:
+ return iter(self._agentsets)
def __len__(self) -> int:
- return sum(len(agentset._df) for agentset in self._agentsets)
+ return len(self._agentsets)
def __repr__(self) -> str:
return "\n".join([repr(agentset) for agentset in self._agentsets])
- def __reversed__(self) -> Iterator:
- return (
- agent
- for agentset in self._agentsets
- for agent in reversed(agentset._backend)
- )
-
- def __setitem__(
- self,
- key: (
- str
- | Collection[str]
- | AgnosticAgentMask
- | IdsLike
- | tuple[dict[AgentSet, AgentMask], str]
- | tuple[dict[AgentSet, AgentMask], Collection[str]]
- ),
- values: Any,
- ) -> None:
- super().__setitem__(key, values)
+ def __reversed__(self) -> Iterator[AgentSet]:
+ return reversed(self._agentsets)
def __str__(self) -> str:
return "\n".join([str(agentset) for agentset in self._agentsets])
- def __sub__(self, agents: AgentSet | Iterable[AgentSet] | IdsLike) -> Self:
- """Remove AgentSets from a new AgentSetRegistry through the - operator.
-
- Parameters
- ----------
- agents : AgentSet | Iterable[AgentSet] | IdsLike
- The AgentSets or agent IDs to remove. Supports NumPy integer types.
-
- Returns
- -------
- Self
- A new AgentSetRegistry with the removed AgentSets.
- """
- return super().__sub__(agents)
-
- @property
- def df(self) -> dict[AgentSet, DataFrame]:
- return {agentset: agentset.df for agentset in self._agentsets}
-
- @df.setter
- def df(self, other: Iterable[AgentSet]) -> None:
- """Set the agents in the AgentSetRegistry.
-
- Parameters
- ----------
- other : Iterable[AgentSet]
- The AgentSets to set.
- """
- self._agentsets = list(other)
-
@property
- def active_agents(self) -> dict[AgentSet, DataFrame]:
- return {agentset: agentset.active_agents for agentset in self._agentsets}
-
- @active_agents.setter
- def active_agents(
- self, agents: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask]
- ) -> None:
- self.select(agents, inplace=True)
-
- @property
- def agentsets_by_type(self) -> dict[type[AgentSet], Self]:
- """Get the agent sets in the AgentSetRegistry grouped by type.
-
- Returns
- -------
- dict[type[AgentSet], Self]
- A dictionary mapping agent set types to the corresponding AgentSetRegistry.
- """
-
- def copy_without_agentsets() -> Self:
- return self.copy(deep=False, skip=["_agentsets"])
+ def ids(self) -> pl.Series:
+ """Public view of all agent unique_id values across contained sets."""
+ return self._ids
- dictionary = defaultdict(copy_without_agentsets)
-
- for agentset in self._agentsets:
- agents_df = dictionary[agentset.__class__]
- agents_df._agentsets = []
- agents_df._agentsets = agents_df._agentsets + [agentset]
- dictionary[agentset.__class__] = agents_df
- return dictionary
-
- @property
- def inactive_agents(self) -> dict[AgentSet, DataFrame]:
- return {agentset: agentset.inactive_agents for agentset in self._agentsets}
+ @overload
+ def __getitem__(self, key: int) -> AgentSet: ...
- @property
- def index(self) -> dict[AgentSet, Index]:
- return {agentset: agentset.index for agentset in self._agentsets}
+ @overload
+ def __getitem__(self, key: str) -> AgentSet: ...
- @property
- def pos(self) -> dict[AgentSet, DataFrame]:
- return {agentset: agentset.pos for agentset in self._agentsets}
+ @overload
+ def __getitem__(self, key: type[AgentSet]) -> list[AgentSet]: ...
+
+ def __getitem__(self, key: int | str | type[AgentSet]) -> AgentSet | list[AgentSet]:
+ """Retrieve AgentSet(s) by index, name, or type."""
+ if isinstance(key, int):
+ return self._agentsets[key]
+ if isinstance(key, str):
+ for s in self._agentsets:
+ if s.name == key:
+ return s
+ raise KeyError(f"Agent set '{key}' not found")
+ if isinstance(key, type) and issubclass(key, AgentSet):
+ return [s for s in self._agentsets if isinstance(s, key)]
+ raise TypeError("Key must be int, str (name), or AgentSet type")
diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py
index 2b50c76d..f7db338e 100644
--- a/mesa_frames/concrete/datacollector.py
+++ b/mesa_frames/concrete/datacollector.py
@@ -177,13 +177,94 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
Constructs a LazyFrame with one column per reporter and
includes `step` and `seed` metadata. Appends it to internal storage.
"""
- agent_data_dict = {}
+
+ def _is_str_collection(x: Any) -> bool:
+ try:
+ from collections.abc import Collection
+
+ if isinstance(x, str):
+ return False
+ return isinstance(x, Collection) and all(isinstance(i, str) for i in x)
+ except Exception:
+ return False
+
+ agent_data_dict: dict[str, pl.Series] = {}
+
for col_name, reporter in self._agent_reporters.items():
- if isinstance(reporter, str):
- for k, v in self._model.sets[reporter].items():
- agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v
- else:
- agent_data_dict[col_name] = reporter(self._model)
+ # 1) String or collection[str]: shorthand to fetch columns
+ if isinstance(reporter, str) or _is_str_collection(reporter):
+ # If a single string, fetch that attribute from each set
+ if isinstance(reporter, str):
+ values_by_set = getattr(self._model.sets, reporter)
+ for set_name, series in values_by_set.items():
+ agent_data_dict[f"{col_name}_{set_name}"] = series
+ else:
+ # Collection of strings: pull multiple columns from each set via set.get([...])
+ for set_name, aset in self._model.sets.items(): # type: ignore[attr-defined]
+ df = aset.get(list(reporter)) # DataFrame of requested attrs
+ if isinstance(df, pl.Series):
+ # Defensive, though get(list) should yield DataFrame
+ agent_data_dict[f"{col_name}_{df.name}_{set_name}"] = df
+ else:
+ for subcol in df.columns:
+ agent_data_dict[f"{col_name}_{subcol}_{set_name}"] = df[
+ subcol
+ ]
+ continue
+
+ # 2) Callables: prefer registry-level; then set-level
+ if callable(reporter):
+ called = False
+ # Try registry-level callable: reporter(AgentSetRegistry)
+ try:
+ reg_result = reporter(self._model.sets)
+ # Accept Series | DataFrame | dict[str, Series|DataFrame]
+ if isinstance(reg_result, pl.Series):
+ agent_data_dict[col_name] = reg_result
+ called = True
+ elif isinstance(reg_result, pl.DataFrame):
+ for subcol in reg_result.columns:
+ agent_data_dict[f"{col_name}_{subcol}"] = reg_result[subcol]
+ called = True
+ elif isinstance(reg_result, dict):
+ for key, val in reg_result.items():
+ if isinstance(val, pl.Series):
+ agent_data_dict[f"{col_name}_{key}"] = val
+ elif isinstance(val, pl.DataFrame):
+ for subcol in val.columns:
+ agent_data_dict[f"{col_name}_{key}_{subcol}"] = val[
+ subcol
+ ]
+ else:
+ raise TypeError(
+ "Registry-level reporter dict values must be Series or DataFrame"
+ )
+ called = True
+ except TypeError:
+ called = False
+
+ if not called:
+ # Fallback: set-level callable, run once per set and suffix by set name
+ for set_name, aset in self._model.sets.items(): # type: ignore[attr-defined]
+ set_result = reporter(aset)
+ if isinstance(set_result, pl.Series):
+ agent_data_dict[f"{col_name}_{set_name}"] = set_result
+ elif isinstance(set_result, pl.DataFrame):
+ for subcol in set_result.columns:
+ agent_data_dict[f"{col_name}_{subcol}_{set_name}"] = (
+ set_result[subcol]
+ )
+ else:
+ raise TypeError(
+ "Set-level reporter must return polars Series or DataFrame"
+ )
+ continue
+
+ # Unknown type
+ raise TypeError(
+ "agent_reporters values must be str, collection[str], or callable"
+ )
+
agent_lazy_frame = pl.LazyFrame(agent_data_dict)
agent_lazy_frame = agent_lazy_frame.with_columns(
[
@@ -441,7 +522,10 @@ def _validate_reporter_table(self, conn: connection, table_name: str):
)
def _validate_reporter_table_columns(
- self, conn: connection, table_name: str, reporter: dict[str, Callable | str]
+ self,
+ conn: connection,
+ table_name: str,
+ reporter: dict[str, Callable | str],
):
"""
Check if the expected columns are present in a given PostgreSQL table.
@@ -460,15 +544,35 @@ def _validate_reporter_table_columns(
ValueError
If any expected columns are missing from the table.
"""
- expected_columns = set()
- for col_name, required_column in reporter.items():
- if isinstance(required_column, str):
- for k, v in self._model.sets[required_column].items():
- expected_columns.add(
- (col_name + "_" + str(k.__class__.__name__)).lower()
- )
- else:
- expected_columns.add(col_name.lower())
+
+ def _is_str_collection(x: Any) -> bool:
+ try:
+ from collections.abc import Collection
+
+ if isinstance(x, str):
+ return False
+ return isinstance(x, Collection) and all(isinstance(i, str) for i in x)
+ except Exception:
+ return False
+
+ expected_columns: set[str] = set()
+ for col_name, req in reporter.items():
+ # Strings → one column per set with suffix
+ if isinstance(req, str):
+ for set_name, _ in self._model.sets.items(): # type: ignore[attr-defined]
+ expected_columns.add(f"{col_name}_{set_name}".lower())
+ continue
+
+ # Collection[str] → one column per attribute per set
+ if _is_str_collection(req):
+ for set_name, _ in self._model.sets.items(): # type: ignore[attr-defined]
+ for subcol in req: # type: ignore[assignment]
+ expected_columns.add(f"{col_name}_{subcol}_{set_name}".lower())
+ continue
+
+ # Callable: conservative default → require 'col_name' to exist
+ # We cannot know the dynamic column explosion without running model code safely here.
+ expected_columns.add(col_name.lower())
query = f"""
SELECT column_name
diff --git a/mesa_frames/concrete/model.py b/mesa_frames/concrete/model.py
index a10ce240..b91db207 100644
--- a/mesa_frames/concrete/model.py
+++ b/mesa_frames/concrete/model.py
@@ -64,7 +64,7 @@ class Model:
running: bool
_seed: int | Sequence[int]
_sets: AgentSetRegistry # Where the agent sets are stored
- _space: Space | None # This will be a MultiSpaceDF object
+ _space: Space | None # This will be a Space object
def __init__(self, seed: int | Sequence[int] | None = None) -> None:
"""Create a new model.
@@ -99,24 +99,6 @@ def steps(self) -> int:
"""Get the current step count."""
return self._steps
- def get_sets_of_type(self, agent_type: type) -> AgentSet:
- """Retrieve the AgentSet of a specified type.
-
- Parameters
- ----------
- agent_type : type
- The type of AgentSet to retrieve.
-
- Returns
- -------
- AgentSet
- The AgentSet of the specified type.
- """
- for agentset in self._sets._agentsets:
- if isinstance(agentset, agent_type):
- return agentset
- raise ValueError(f"No agent sets of type {agent_type} found in the model.")
-
def reset_randomizer(self, seed: int | Sequence[int] | None) -> None:
"""Reset the model random number generator.
@@ -144,7 +126,8 @@ def step(self) -> None:
The default method calls the step() method of all agents. Overload as needed.
"""
- self.sets.step()
+ # Invoke step on all contained AgentSets via the public registry API
+ self.sets.do("step")
@property
def steps(self) -> int:
@@ -187,17 +170,6 @@ def sets(self, sets: AgentSetRegistry) -> None:
self._sets = sets
- @property
- def set_types(self) -> list[type]:
- """Get a list of different agent set types present in the model.
-
- Returns
- -------
- list[type]
- A list of the different agent set types present in the model.
- """
- return [agent.__class__ for agent in self._sets._agentsets]
-
@property
def space(self) -> Space:
"""Get the space object associated with the model.
diff --git a/mesa_frames/types_.py b/mesa_frames/types_.py
index 05ab1b3f..5873e034 100644
--- a/mesa_frames/types_.py
+++ b/mesa_frames/types_.py
@@ -1,15 +1,17 @@
"""Type aliases for the mesa_frames package."""
from __future__ import annotations
-from collections.abc import Collection, Sequence
-from datetime import date, datetime, time, timedelta
-from typing import Literal, Annotated, Union, Any
-from collections.abc import Mapping
-from beartype.vale import IsEqual
+
import math
+from collections.abc import Collection, Mapping, Sequence
+from datetime import date, datetime, time, timedelta
+from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
+
+import numpy as np
import polars as pl
+from beartype.vale import IsEqual
from numpy import ndarray
-import numpy as np
+
# import geopolars as gpl # TODO: Uncomment when geopolars is available
###----- Optional Types -----###
@@ -83,6 +85,46 @@
ArrayLike = ndarray | Series | Sequence
Infinity = Annotated[float, IsEqual[math.inf]] # Only accepts math.inf
+# Common option types
+KeyBy = Literal["name", "index", "type"]
+
+# Selectors for choosing AgentSets at the registry level
+# Abstract (for abstract layer APIs)
+if TYPE_CHECKING:
+ from mesa_frames.abstract.agentset import AbstractAgentSet as _AAS
+
+ AbstractAgentSetSelector = (
+ _AAS | type[_AAS] | str | Collection[_AAS | type[_AAS] | str] | None
+ )
+else:
+ AbstractAgentSetSelector = Any # runtime fallback to avoid import cycles
+
+# Concrete (for concrete layer APIs)
+if TYPE_CHECKING:
+ from mesa_frames.concrete.agentset import AgentSet as _CAS
+
+ AgentSetSelector = (
+ _CAS | type[_CAS] | str | Collection[_CAS | type[_CAS] | str] | None
+ )
+else:
+ AgentSetSelector = Any # runtime fallback to avoid import cycles
+
+__all__ = [
+ # common
+ "DataFrame",
+ "Series",
+ "Index",
+ "BoolSeries",
+ "Mask",
+ "AgentMask",
+ "IdsLike",
+ "ArrayLike",
+ "KeyBy",
+ # selectors
+ "AbstractAgentSetSelector",
+ "AgentSetSelector",
+]
+
###----- Time ------###
TimeT = float | int
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..fd84a7ac
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,11 @@
+"""Conftest for tests.
+
+Ensure beartype runtime checking is enabled before importing the package.
+
+This module sets MESA_FRAMES_RUNTIME_TYPECHECKING=1 at import time so tests that
+assert beartype failures at import or construct time behave deterministically.
+"""
+
+import os
+
+os.environ.setdefault("MESA_FRAMES_RUNTIME_TYPECHECKING", "1")
diff --git a/tests/test_agents.py b/tests/test_agents.py
deleted file mode 100644
index f43d94f6..00000000
--- a/tests/test_agents.py
+++ /dev/null
@@ -1,1036 +0,0 @@
-from copy import copy, deepcopy
-
-import polars as pl
-import pytest
-
-from mesa_frames import AgentSetRegistry, Model
-from mesa_frames import AgentSet
-from mesa_frames.types_ import AgentMask
-from tests.test_agentset import (
- ExampleAgentSet,
- ExampleAgentSetNoWealth,
- fix1_AgentSet_no_wealth,
- fix1_AgentSet,
- fix2_AgentSet,
- fix3_AgentSet,
-)
-
-
-@pytest.fixture
-def fix_AgentSetRegistry(
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
-) -> AgentSetRegistry:
- model = Model()
- agents = AgentSetRegistry(model)
- agents.add([fix1_AgentSet, fix2_AgentSet])
- return agents
-
-
-class Test_AgentSetRegistry:
- def test___init__(self):
- model = Model()
- agents = AgentSetRegistry(model)
- assert agents.model == model
- assert isinstance(agents._agentsets, list)
- assert len(agents._agentsets) == 0
- assert isinstance(agents._ids, pl.Series)
- assert agents._ids.is_empty()
- assert agents._ids.name == "unique_id"
-
- def test_add(
- self,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- ):
- model = Model()
- agents = AgentSetRegistry(model)
- agentset_polars1 = fix1_AgentSet
- agentset_polars2 = fix2_AgentSet
-
- # Test with a single AgentSet
- result = agents.add(agentset_polars1, inplace=False)
- assert result._agentsets[0] is agentset_polars1
- assert result._ids.to_list() == agentset_polars1._df["unique_id"].to_list()
-
- # Test with a list of AgentSets
- result = agents.add([agentset_polars1, agentset_polars2], inplace=True)
- assert result._agentsets[0] is agentset_polars1
- assert result._agentsets[1] is agentset_polars2
- assert (
- result._ids.to_list()
- == agentset_polars1._df["unique_id"].to_list()
- + agentset_polars2._df["unique_id"].to_list()
- )
-
- # Test if adding the same AgentSet raises ValueError
- with pytest.raises(ValueError):
- agents.add(agentset_polars1, inplace=False)
-
- def test_contains(
- self,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- fix3_AgentSet: ExampleAgentSet,
- fix_AgentSetRegistry: AgentSetRegistry,
- ):
- agents = fix_AgentSetRegistry
- agentset_polars1 = agents._agentsets[0]
-
- # Test with an AgentSet
- assert agents.contains(agentset_polars1)
- assert agents.contains(fix1_AgentSet)
- assert agents.contains(fix2_AgentSet)
-
- # Test with an AgentSet not present
- assert not agents.contains(fix3_AgentSet)
-
- # Test with an iterable of AgentSets
- assert agents.contains([agentset_polars1, fix3_AgentSet]).to_list() == [
- True,
- False,
- ]
-
- # Test with single id
- assert agents.contains(agentset_polars1["unique_id"][0])
-
- # Test with a list of ids
- assert agents.contains([agentset_polars1["unique_id"][0], 0]).to_list() == [
- True,
- False,
- ]
-
- def test_copy(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
- agents.test_list = [[1, 2, 3]]
-
- # Test with deep=False
- agents2 = agents.copy(deep=False)
- agents2.test_list[0].append(4)
- assert agents.test_list[0][-1] == agents2.test_list[0][-1]
- assert agents.model == agents2.model
- assert agents._agentsets[0] == agents2._agentsets[0]
- assert (agents._ids == agents2._ids).all()
-
- # Test with deep=True
- agents2 = fix_AgentSetRegistry.copy(deep=True)
- agents2.test_list[0].append(4)
- assert agents.test_list[-1] != agents2.test_list[-1]
- assert agents.model == agents2.model
- assert agents._agentsets[0] != agents2._agentsets[0]
- assert (agents._ids == agents2._ids).all()
-
- def test_discard(
- self, fix_AgentSetRegistry: AgentSetRegistry, fix2_AgentSet: ExampleAgentSet
- ):
- agents = fix_AgentSetRegistry
- # Test with a single AgentSet
- agentset_polars2 = agents._agentsets[1]
- result = agents.discard(agents._agentsets[0], inplace=False)
- assert isinstance(result._agentsets[0], ExampleAgentSet)
- assert len(result._agentsets) == 1
-
- # Test with a list of AgentSets
- result = agents.discard(agents._agentsets.copy(), inplace=False)
- assert len(result._agentsets) == 0
-
- # Test with IDs
- ids = [
- agents._agentsets[0]._df["unique_id"][0],
- agents._agentsets[1]._df["unique_id"][0],
- ]
- agentset_polars1 = agents._agentsets[0]
- agentset_polars2 = agents._agentsets[1]
- result = agents.discard(ids, inplace=False)
- assert (
- result._agentsets[0]["unique_id"][0]
- == agentset_polars1._df.select("unique_id").row(1)[0]
- )
- assert (
- result._agentsets[1].df["unique_id"][0]
- == agentset_polars2._df["unique_id"][1]
- )
-
- # Test if removing an AgentSet not present raises ValueError
- result = agents.discard(fix2_AgentSet, inplace=False)
-
- # Test if removing an ID not present raises KeyError
- assert 0 not in agents._ids
- result = agents.discard(0, inplace=False)
-
- def test_do(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
-
- expected_result_0 = agents._agentsets[0].df["wealth"]
- expected_result_0 += 1
-
- expected_result_1 = agents._agentsets[1].df["wealth"]
- expected_result_1 += 1
-
- # Test with no return_results, no mask, inplace
- agents.do("add_wealth", 1)
- assert (
- agents._agentsets[0].df["wealth"].to_list() == expected_result_0.to_list()
- )
- assert (
- agents._agentsets[1].df["wealth"].to_list() == expected_result_1.to_list()
- )
-
- # Test with return_results=True, no mask, inplace
- expected_result_0 = agents._agentsets[0].df["wealth"]
- expected_result_0 += 1
-
- expected_result_1 = agents._agentsets[1].df["wealth"]
- expected_result_1 += 1
- assert agents.do("add_wealth", 1, return_results=True) == {
- agents._agentsets[0]: None,
- agents._agentsets[1]: None,
- }
- assert (
- agents._agentsets[0].df["wealth"].to_list() == expected_result_0.to_list()
- )
- assert (
- agents._agentsets[1].df["wealth"].to_list() == expected_result_1.to_list()
- )
-
- # Test with a mask, inplace
- mask0 = agents._agentsets[0].df["wealth"] > 10 # No agent should be selected
- mask1 = agents._agentsets[1].df["wealth"] > 10 # All agents should be selected
- mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1}
-
- expected_result_0 = agents._agentsets[0].df["wealth"]
- expected_result_1 = agents._agentsets[1].df["wealth"]
- expected_result_1 += 1
-
- agents.do("add_wealth", 1, mask=mask_dictionary)
- assert (
- agents._agentsets[0].df["wealth"].to_list() == expected_result_0.to_list()
- )
- assert (
- agents._agentsets[1].df["wealth"].to_list() == expected_result_1.to_list()
- )
-
- def test_get(
- self,
- fix_AgentSetRegistry: AgentSetRegistry,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- fix1_AgentSet_no_wealth: ExampleAgentSetNoWealth,
- ):
- agents = fix_AgentSetRegistry
-
- # Test with a single attribute
- assert (
- agents.get("wealth")[fix1_AgentSet].to_list()
- == fix1_AgentSet._df["wealth"].to_list()
- )
- assert (
- agents.get("wealth")[fix2_AgentSet].to_list()
- == fix2_AgentSet._df["wealth"].to_list()
- )
-
- # Test with a list of attributes
- result = agents.get(["wealth", "age"])
- assert result[fix1_AgentSet].columns == ["wealth", "age"]
- assert (
- result[fix1_AgentSet]["wealth"].to_list()
- == fix1_AgentSet._df["wealth"].to_list()
- )
- assert (
- result[fix1_AgentSet]["age"].to_list() == fix1_AgentSet._df["age"].to_list()
- )
-
- assert result[fix2_AgentSet].columns == ["wealth", "age"]
- assert (
- result[fix2_AgentSet]["wealth"].to_list()
- == fix2_AgentSet._df["wealth"].to_list()
- )
- assert (
- result[fix2_AgentSet]["age"].to_list() == fix2_AgentSet._df["age"].to_list()
- )
-
- # Test with a single attribute and a mask
- mask0 = fix1_AgentSet._df["wealth"] > fix1_AgentSet._df["wealth"][0]
- mask1 = fix2_AgentSet._df["wealth"] > fix2_AgentSet._df["wealth"][0]
- mask_dictionary = {fix1_AgentSet: mask0, fix2_AgentSet: mask1}
- result = agents.get("wealth", mask=mask_dictionary)
- assert (
- result[fix1_AgentSet].to_list() == fix1_AgentSet._df["wealth"].to_list()[1:]
- )
- assert (
- result[fix2_AgentSet].to_list() == fix2_AgentSet._df["wealth"].to_list()[1:]
- )
-
- # Test heterogeneous agent sets (different columns)
- # This tests the fix for the bug where agents_df["column"] would raise
- # ColumnNotFoundError when some agent sets didn't have that column.
-
- # Create a new AgentSetRegistry with heterogeneous agent sets
- model = Model()
- hetero_agents = AgentSetRegistry(model)
- hetero_agents.add([fix1_AgentSet, fix1_AgentSet_no_wealth])
-
- # Test 1: Access column that exists in only one agent set
- result_wealth = hetero_agents.get("wealth")
- assert len(result_wealth) == 1, (
- "Should only return agent sets that have 'wealth'"
- )
- assert fix1_AgentSet in result_wealth, (
- "Should include the agent set with wealth"
- )
- assert fix1_AgentSet_no_wealth not in result_wealth, (
- "Should not include agent set without wealth"
- )
- assert result_wealth[fix1_AgentSet].to_list() == [1, 2, 3, 4]
-
- # Test 2: Access column that exists in all agent sets
- result_age = hetero_agents.get("age")
- assert len(result_age) == 2, "Should return both agent sets that have 'age'"
- assert fix1_AgentSet in result_age
- assert fix1_AgentSet_no_wealth in result_age
- assert result_age[fix1_AgentSet].to_list() == [10, 20, 30, 40]
- assert result_age[fix1_AgentSet_no_wealth].to_list() == [1, 2, 3, 4]
-
- # Test 3: Access column that exists in no agent sets
- result_nonexistent = hetero_agents.get("nonexistent_column")
- assert len(result_nonexistent) == 0, (
- "Should return empty dict for non-existent column"
- )
-
- # Test 4: Access multiple columns (mixed availability)
- result_multi = hetero_agents.get(["wealth", "age"])
- assert len(result_multi) == 1, (
- "Should only include agent sets that have ALL requested columns"
- )
- assert fix1_AgentSet in result_multi
- assert fix1_AgentSet_no_wealth not in result_multi
- assert result_multi[fix1_AgentSet].columns == ["wealth", "age"]
-
- # Test 5: Access multiple columns where some exist in different sets
- result_mixed = hetero_agents.get(["age", "income"])
- assert len(result_mixed) == 1, (
- "Should only include agent set that has both 'age' and 'income'"
- )
- assert fix1_AgentSet_no_wealth in result_mixed
- assert fix1_AgentSet not in result_mixed
-
- # Test 6: Test via __getitem__ syntax (the original bug report case)
- wealth_via_getitem = hetero_agents["wealth"]
- assert len(wealth_via_getitem) == 1
- assert fix1_AgentSet in wealth_via_getitem
- assert wealth_via_getitem[fix1_AgentSet].to_list() == [1, 2, 3, 4]
-
- # Test 7: Test get(None) - should return all columns for all agent sets
- result_none = hetero_agents.get(None)
- assert len(result_none) == 2, (
- "Should return both agent sets when attr_names=None"
- )
- assert fix1_AgentSet in result_none
- assert fix1_AgentSet_no_wealth in result_none
-
- # Verify each agent set returns all its columns (excluding unique_id)
- wealth_set_result = result_none[fix1_AgentSet]
- assert isinstance(wealth_set_result, pl.DataFrame), (
- "Should return DataFrame when attr_names=None"
- )
- expected_wealth_cols = {"wealth", "age"} # unique_id should be excluded
- assert set(wealth_set_result.columns) == expected_wealth_cols
-
- no_wealth_set_result = result_none[fix1_AgentSet_no_wealth]
- assert isinstance(no_wealth_set_result, pl.DataFrame), (
- "Should return DataFrame when attr_names=None"
- )
- expected_no_wealth_cols = {"income", "age"} # unique_id should be excluded
- assert set(no_wealth_set_result.columns) == expected_no_wealth_cols
-
- def test_remove(
- self,
- fix_AgentSetRegistry: AgentSetRegistry,
- fix3_AgentSet: ExampleAgentSet,
- ):
- agents = fix_AgentSetRegistry
-
- # Test with a single AgentSet
- agentset_polars = agents._agentsets[1]
- result = agents.remove(agents._agentsets[0], inplace=False)
- assert isinstance(result._agentsets[0], ExampleAgentSet)
- assert len(result._agentsets) == 1
-
- # Test with a list of AgentSets
- result = agents.remove(agents._agentsets.copy(), inplace=False)
- assert len(result._agentsets) == 0
-
- # Test with IDs
- ids = [
- agents._agentsets[0]._df["unique_id"][0],
- agents._agentsets[1]._df["unique_id"][0],
- ]
- agentset_polars1 = agents._agentsets[0]
- agentset_polars2 = agents._agentsets[1]
- result = agents.remove(ids, inplace=False)
- assert (
- result._agentsets[0]["unique_id"][0]
- == agentset_polars1._df.select("unique_id").row(1)[0]
- )
- assert (
- result._agentsets[1].df["unique_id"][0]
- == agentset_polars2._df["unique_id"][1]
- )
-
- # Test if removing an AgentSet not present raises ValueError
- with pytest.raises(ValueError):
- result = agents.remove(fix3_AgentSet, inplace=False)
-
- # Test if removing an ID not present raises KeyError
- assert 0 not in agents._ids
- with pytest.raises(KeyError):
- result = agents.remove(0, inplace=False)
-
- def test_select(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
-
- # Test with default arguments. Should select all agents
- selected = agents.select(inplace=False)
- active_agents_dict = selected.active_agents
- agents_dict = selected.df
- assert active_agents_dict.keys() == agents_dict.keys()
- # Using assert to compare all DataFrames in the dictionaries
-
- assert (
- list(active_agents_dict.values())[0].rows()
- == list(agents_dict.values())[0].rows()
- )
-
- assert all(
- series.all()
- for series in (
- list(active_agents_dict.values())[1] == list(agents_dict.values())[1]
- )
- )
-
- # Test with a mask
- mask0 = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean)
- mask1 = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean)
- mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1}
- selected = agents.select(mask_dictionary, inplace=False)
- assert (
- selected.active_agents[selected._agentsets[0]]["wealth"].to_list()[0]
- == agents._agentsets[0]["wealth"].to_list()[0]
- )
- assert (
- selected.active_agents[selected._agentsets[0]]["wealth"].to_list()[-1]
- == agents._agentsets[0]["wealth"].to_list()[-1]
- )
-
- assert (
- selected.active_agents[selected._agentsets[1]]["wealth"].to_list()[0]
- == agents._agentsets[1]["wealth"].to_list()[0]
- )
- assert (
- selected.active_agents[selected._agentsets[1]]["wealth"].to_list()[-1]
- == agents._agentsets[1]["wealth"].to_list()[-1]
- )
-
- # Test with filter_func
-
- def filter_func(agentset: AgentSet) -> pl.Series:
- return agentset.df["wealth"] > agentset.df["wealth"].to_list()[0]
-
- selected = agents.select(filter_func=filter_func, inplace=False)
- assert (
- selected.active_agents[selected._agentsets[0]]["wealth"].to_list()
- == agents._agentsets[0]["wealth"].to_list()[1:]
- )
- assert (
- selected.active_agents[selected._agentsets[1]]["wealth"].to_list()
- == agents._agentsets[1]["wealth"].to_list()[1:]
- )
-
- # Test with n
- selected = agents.select(n=3, inplace=False)
- assert sum(len(df) for df in selected.active_agents.values()) in [2, 3]
-
- # Test with n, filter_func and mask
- selected = agents.select(
- mask_dictionary, filter_func=filter_func, n=2, inplace=False
- )
- assert any(
- el in selected.active_agents[selected._agentsets[0]]["wealth"].to_list()
- for el in agents.active_agents[agents._agentsets[0]]["wealth"].to_list()[
- 2:4
- ]
- )
-
- assert any(
- el in selected.active_agents[selected._agentsets[1]]["wealth"].to_list()
- for el in agents.active_agents[agents._agentsets[1]]["wealth"].to_list()[
- 2:4
- ]
- )
-
- def test_set(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
-
- # Test with a single attribute
- result = agents.set("wealth", 0, inplace=False)
- assert result._agentsets[0].df["wealth"].to_list() == [0] * len(
- agents._agentsets[0]
- )
- assert result._agentsets[1].df["wealth"].to_list() == [0] * len(
- agents._agentsets[1]
- )
-
- # Test with a list of attributes
- agents.set(["wealth", "age"], 1, inplace=True)
- assert agents._agentsets[0].df["wealth"].to_list() == [1] * len(
- agents._agentsets[0]
- )
- assert agents._agentsets[0].df["age"].to_list() == [1] * len(
- agents._agentsets[0]
- )
-
- # Test with a single attribute and a mask
- mask0 = pl.Series(
- "mask", [True] + [False] * (len(agents._agentsets[0]) - 1), dtype=pl.Boolean
- )
- mask1 = pl.Series(
- "mask", [True] + [False] * (len(agents._agentsets[1]) - 1), dtype=pl.Boolean
- )
- mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1}
- result = agents.set("wealth", 0, mask=mask_dictionary, inplace=False)
- assert result._agentsets[0].df["wealth"].to_list() == [0] + [1] * (
- len(agents._agentsets[0]) - 1
- )
- assert result._agentsets[1].df["wealth"].to_list() == [0] + [1] * (
- len(agents._agentsets[1]) - 1
- )
-
- # Test with a dictionary
- agents.set(
- {agents._agentsets[0]: {"wealth": 0}, agents._agentsets[1]: {"wealth": 1}},
- inplace=True,
- )
- assert agents._agentsets[0].df["wealth"].to_list() == [0] * len(
- agents._agentsets[0]
- )
- assert agents._agentsets[1].df["wealth"].to_list() == [1] * len(
- agents._agentsets[1]
- )
-
- def test_shuffle(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
- for _ in range(100):
- original_order_0 = agents._agentsets[0].df["unique_id"].to_list()
- original_order_1 = agents._agentsets[1].df["unique_id"].to_list()
- agents.shuffle(inplace=True)
- if (
- original_order_0 != agents._agentsets[0].df["unique_id"].to_list()
- and original_order_1 != agents._agentsets[1].df["unique_id"].to_list()
- ):
- return
- assert False
-
- def test_sort(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
- agents.sort("wealth", ascending=False, inplace=True)
- assert pl.Series(agents._agentsets[0].df["wealth"]).is_sorted(descending=True)
- assert pl.Series(agents._agentsets[1].df["wealth"]).is_sorted(descending=True)
-
- def test_step(
- self,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- fix_AgentSetRegistry: AgentSetRegistry,
- ):
- previous_wealth_0 = fix1_AgentSet._df["wealth"].clone()
- previous_wealth_1 = fix2_AgentSet._df["wealth"].clone()
-
- agents = fix_AgentSetRegistry
- agents.step()
-
- assert (
- agents._agentsets[0].df["wealth"].to_list()
- == (previous_wealth_0 + 1).to_list()
- )
- assert (
- agents._agentsets[1].df["wealth"].to_list()
- == (previous_wealth_1 + 1).to_list()
- )
-
- def test__check_ids_presence(
- self,
- fix_AgentSetRegistry: AgentSetRegistry,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- ):
- agents = fix_AgentSetRegistry.remove(fix2_AgentSet, inplace=False)
- agents_different_index = deepcopy(fix2_AgentSet)
- result = agents._check_ids_presence([fix1_AgentSet])
- assert result.filter(pl.col("unique_id").is_in(fix1_AgentSet._df["unique_id"]))[
- "present"
- ].all()
-
- assert not result.filter(
- pl.col("unique_id").is_in(agents_different_index._df["unique_id"])
- )["present"].any()
-
- def test__check_agentsets_presence(
- self,
- fix_AgentSetRegistry: AgentSetRegistry,
- fix1_AgentSet: ExampleAgentSet,
- fix3_AgentSet: ExampleAgentSet,
- ):
- agents = fix_AgentSetRegistry
- result = agents._check_agentsets_presence([fix1_AgentSet, fix3_AgentSet])
- assert result[0]
- assert not result[1]
-
- def test__get_bool_masks(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
- # Test with mask = None
- result = agents._get_bool_masks(mask=None)
- truth_value = True
- for i, mask in enumerate(result.values()):
- if isinstance(mask, pl.Expr):
- mask = agents._agentsets[i]._df.select(mask).to_series()
- truth_value &= mask.all()
- assert truth_value
-
- # Test with mask = "all"
- result = agents._get_bool_masks(mask="all")
- truth_value = True
- for i, mask in enumerate(result.values()):
- if isinstance(mask, pl.Expr):
- mask = agents._agentsets[i]._df.select(mask).to_series()
- truth_value &= mask.all()
- assert truth_value
-
- # Test with mask = "active"
- mask0 = (
- agents._agentsets[0].df["wealth"]
- > agents._agentsets[0].df["wealth"].to_list()[0]
- )
- mask1 = agents._agentsets[1].df["wealth"] > agents._agentsets[1].df["wealth"][0]
- mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1}
- agents.select(mask=mask_dictionary)
- result = agents._get_bool_masks(mask="active")
- assert result[agents._agentsets[0]].to_list() == mask0.to_list()
- assert result[agents._agentsets[1]].to_list() == mask1.to_list()
-
- # Test with mask = IdsLike
- result = agents._get_bool_masks(
- mask=[
- agents._agentsets[0]["unique_id"][0],
- agents._agentsets[1].df["unique_id"][0],
- ]
- )
- assert result[agents._agentsets[0]].to_list() == [True] + [False] * (
- len(agents._agentsets[0]) - 1
- )
- assert result[agents._agentsets[1]].to_list() == [True] + [False] * (
- len(agents._agentsets[1]) - 1
- )
-
- # Test with mask = dict[AgentSet, AgentMask]
- result = agents._get_bool_masks(mask=mask_dictionary)
- assert result[agents._agentsets[0]].to_list() == mask0.to_list()
- assert result[agents._agentsets[1]].to_list() == mask1.to_list()
-
- def test__get_obj(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
- assert agents._get_obj(inplace=True) is agents
- assert agents._get_obj(inplace=False) is not agents
-
- def test__return_agentsets_list(
- self,
- fix_AgentSetRegistry: AgentSetRegistry,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- ):
- agents = fix_AgentSetRegistry
- result = agents._return_agentsets_list(fix1_AgentSet)
- assert result == [fix1_AgentSet]
- result = agents._return_agentsets_list([fix1_AgentSet, fix2_AgentSet])
- assert result == [fix1_AgentSet, fix2_AgentSet]
-
- def test___add__(
- self,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- ):
- model = Model()
- agents = AgentSetRegistry(model)
- agentset_polars1 = fix1_AgentSet
- agentset_polars2 = fix2_AgentSet
-
- # Test with a single AgentSet
- result = agents + agentset_polars1
- assert result._agentsets[0] is agentset_polars1
- assert result._ids.to_list() == agentset_polars1._df["unique_id"].to_list()
-
- # Test with a single AgentSet same as above
- result = agents + agentset_polars2
- assert result._agentsets[0] is agentset_polars2
- assert result._ids.to_list() == agentset_polars2._df["unique_id"].to_list()
-
- # Test with a list of AgentSets
- result = agents + [agentset_polars1, agentset_polars2]
- assert result._agentsets[0] is agentset_polars1
- assert result._agentsets[1] is agentset_polars2
- assert (
- result._ids.to_list()
- == agentset_polars1._df["unique_id"].to_list()
- + agentset_polars2._df["unique_id"].to_list()
- )
-
- # Test if adding the same AgentSet raises ValueError
- with pytest.raises(ValueError):
- result + agentset_polars1
-
- def test___contains__(
- self, fix_AgentSetRegistry: AgentSetRegistry, fix3_AgentSet: ExampleAgentSet
- ):
- # Test with a single value
- agents = fix_AgentSetRegistry
- agentset_polars1 = agents._agentsets[0]
-
- # Test with an AgentSet
- assert agentset_polars1 in agents
- # Test with an AgentSet not present
- assert fix3_AgentSet not in agents
-
- # Test with single id present
- assert agentset_polars1["unique_id"][0] in agents
-
- # Test with single id not present
- assert 0 not in agents
-
- def test___copy__(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
- agents.test_list = [[1, 2, 3]]
-
- # Test with deep=False
- agents2 = copy(agents)
- agents2.test_list[0].append(4)
- assert agents.test_list[0][-1] == agents2.test_list[0][-1]
- assert agents.model == agents2.model
- assert agents._agentsets[0] == agents2._agentsets[0]
- assert (agents._ids == agents2._ids).all()
-
- def test___deepcopy__(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
- agents.test_list = [[1, 2, 3]]
-
- agents2 = deepcopy(agents)
- agents2.test_list[0].append(4)
- assert agents.test_list[-1] != agents2.test_list[-1]
- assert agents.model == agents2.model
- assert agents._agentsets[0] != agents2._agentsets[0]
- assert (agents._ids == agents2._ids).all()
-
- def test___getattr__(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
- assert isinstance(agents.model, Model)
- result = agents.wealth
- assert (
- result[agents._agentsets[0]].to_list()
- == agents._agentsets[0].df["wealth"].to_list()
- )
- assert (
- result[agents._agentsets[1]].to_list()
- == agents._agentsets[1].df["wealth"].to_list()
- )
-
- def test___getitem__(
- self,
- fix_AgentSetRegistry: AgentSetRegistry,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- ):
- agents = fix_AgentSetRegistry
-
- # Test with a single attribute
- assert (
- agents["wealth"][fix1_AgentSet].to_list()
- == fix1_AgentSet._df["wealth"].to_list()
- )
- assert (
- agents["wealth"][fix2_AgentSet].to_list()
- == fix2_AgentSet._df["wealth"].to_list()
- )
-
- # Test with a list of attributes
- result = agents[["wealth", "age"]]
- assert result[fix1_AgentSet].columns == ["wealth", "age"]
- assert (
- result[fix1_AgentSet]["wealth"].to_list()
- == fix1_AgentSet._df["wealth"].to_list()
- )
- assert (
- result[fix1_AgentSet]["age"].to_list() == fix1_AgentSet._df["age"].to_list()
- )
- assert result[fix2_AgentSet].columns == ["wealth", "age"]
- assert (
- result[fix2_AgentSet]["wealth"].to_list()
- == fix2_AgentSet._df["wealth"].to_list()
- )
- assert (
- result[fix2_AgentSet]["age"].to_list() == fix2_AgentSet._df["age"].to_list()
- )
-
- # Test with a single attribute and a mask
- mask0 = fix1_AgentSet._df["wealth"] > fix1_AgentSet._df["wealth"][0]
- mask1 = fix2_AgentSet._df["wealth"] > fix2_AgentSet._df["wealth"][0]
- mask_dictionary: dict[AgentSet, AgentMask] = {
- fix1_AgentSet: mask0,
- fix2_AgentSet: mask1,
- }
- result = agents[mask_dictionary, "wealth"]
- assert (
- result[fix1_AgentSet].to_list() == fix1_AgentSet.df["wealth"].to_list()[1:]
- )
- assert (
- result[fix2_AgentSet].to_list() == fix2_AgentSet.df["wealth"].to_list()[1:]
- )
-
- def test___iadd__(
- self,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- ):
- model = Model()
- agents = AgentSetRegistry(model)
- agentset_polars1 = fix1_AgentSet
- agentset_polars = fix2_AgentSet
-
- # Test with a single AgentSet
- agents_copy = deepcopy(agents)
- agents_copy += agentset_polars
- assert agents_copy._agentsets[0] is agentset_polars
- assert agents_copy._ids.to_list() == agentset_polars._df["unique_id"].to_list()
-
- # Test with a list of AgentSets
- agents_copy = deepcopy(agents)
- agents_copy += [agentset_polars1, agentset_polars]
- assert agents_copy._agentsets[0] is agentset_polars1
- assert agents_copy._agentsets[1] is agentset_polars
- assert (
- agents_copy._ids.to_list()
- == agentset_polars1._df["unique_id"].to_list()
- + agentset_polars._df["unique_id"].to_list()
- )
-
- # Test if adding the same AgentSet raises ValueError
- with pytest.raises(ValueError):
- agents_copy += agentset_polars1
-
- def test___iter__(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
- len_agentset0 = len(agents._agentsets[0])
- len_agentset1 = len(agents._agentsets[1])
- for i, agent in enumerate(agents):
- assert isinstance(agent, dict)
- if i < len_agentset0:
- assert agent["unique_id"] == agents._agentsets[0].df["unique_id"][i]
- else:
- assert (
- agent["unique_id"]
- == agents._agentsets[1].df["unique_id"][i - len_agentset0]
- )
- assert i == len_agentset0 + len_agentset1 - 1
-
- def test___isub__(
- self,
- fix_AgentSetRegistry: AgentSetRegistry,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- ):
- # Test with an AgentSet and a DataFrame
- agents = fix_AgentSetRegistry
- agents -= fix1_AgentSet
- assert agents._agentsets[0] == fix2_AgentSet
- assert len(agents._agentsets) == 1
-
- def test___len__(
- self,
- fix_AgentSetRegistry: AgentSetRegistry,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- ):
- assert len(fix_AgentSetRegistry) == len(fix1_AgentSet) + len(fix2_AgentSet)
-
- def test___repr__(self, fix_AgentSetRegistry: AgentSetRegistry):
- repr(fix_AgentSetRegistry)
-
- def test___reversed__(self, fix2_AgentSet: AgentSetRegistry):
- agents = fix2_AgentSet
- reversed_wealth = []
- for agent in reversed(list(agents)):
- reversed_wealth.append(agent["wealth"])
- assert reversed_wealth == list(reversed(agents["wealth"]))
-
- def test___setitem__(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
-
- # Test with a single attribute
- agents["wealth"] = 0
- assert agents._agentsets[0].df["wealth"].to_list() == [0] * len(
- agents._agentsets[0]
- )
- assert agents._agentsets[1].df["wealth"].to_list() == [0] * len(
- agents._agentsets[1]
- )
-
- # Test with a list of attributes
- agents[["wealth", "age"]] = 1
- assert agents._agentsets[0].df["wealth"].to_list() == [1] * len(
- agents._agentsets[0]
- )
- assert agents._agentsets[0].df["age"].to_list() == [1] * len(
- agents._agentsets[0]
- )
-
- # Test with a single attribute and a mask
- mask0 = pl.Series(
- "mask", [True] + [False] * (len(agents._agentsets[0]) - 1), dtype=pl.Boolean
- )
- mask1 = pl.Series(
- "mask", [True] + [False] * (len(agents._agentsets[1]) - 1), dtype=pl.Boolean
- )
- mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1}
- agents[mask_dictionary, "wealth"] = 0
- assert agents._agentsets[0].df["wealth"].to_list() == [0] + [1] * (
- len(agents._agentsets[0]) - 1
- )
- assert agents._agentsets[1].df["wealth"].to_list() == [0] + [1] * (
- len(agents._agentsets[1]) - 1
- )
-
- def test___str__(self, fix_AgentSetRegistry: AgentSetRegistry):
- str(fix_AgentSetRegistry)
-
- def test___sub__(
- self,
- fix_AgentSetRegistry: AgentSetRegistry,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- ):
- # Test with an AgentSet and a DataFrame
- result = fix_AgentSetRegistry - fix1_AgentSet
- assert isinstance(result._agentsets[0], ExampleAgentSet)
- assert len(result._agentsets) == 1
-
- def test_agents(
- self,
- fix_AgentSetRegistry: AgentSetRegistry,
- fix1_AgentSet: ExampleAgentSet,
- fix2_AgentSet: ExampleAgentSet,
- ):
- assert isinstance(fix_AgentSetRegistry.df, dict)
- assert len(fix_AgentSetRegistry.df) == 2
- assert fix_AgentSetRegistry.df[fix1_AgentSet] is fix1_AgentSet._df
- assert fix_AgentSetRegistry.df[fix2_AgentSet] is fix2_AgentSet._df
-
- # Test agents.setter
- fix_AgentSetRegistry.df = [fix1_AgentSet, fix2_AgentSet]
- assert fix_AgentSetRegistry._agentsets[0] == fix1_AgentSet
- assert fix_AgentSetRegistry._agentsets[1] == fix2_AgentSet
-
- def test_active_agents(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
-
- # Test with select
- mask0 = (
- agents._agentsets[0].df["wealth"]
- > agents._agentsets[0].df["wealth"].to_list()[0]
- )
- mask1 = (
- agents._agentsets[1].df["wealth"]
- > agents._agentsets[1].df["wealth"].to_list()[0]
- )
- mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1}
-
- agents1 = agents.select(mask=mask_dictionary, inplace=False)
-
- result = agents1.active_agents
- assert isinstance(result, dict)
- assert isinstance(result[agents1._agentsets[0]], pl.DataFrame)
- assert isinstance(result[agents1._agentsets[1]], pl.DataFrame)
-
- assert all(
- series.all()
- for series in (
- result[agents1._agentsets[0]] == agents1._agentsets[0]._df.filter(mask0)
- )
- )
-
- assert all(
- series.all()
- for series in (
- result[agents1._agentsets[1]] == agents1._agentsets[1]._df.filter(mask1)
- )
- )
-
- # Test with active_agents.setter
- agents1.active_agents = mask_dictionary
- result = agents1.active_agents
- assert isinstance(result, dict)
- assert isinstance(result[agents1._agentsets[0]], pl.DataFrame)
- assert isinstance(result[agents1._agentsets[1]], pl.DataFrame)
- assert all(
- series.all()
- for series in (
- result[agents1._agentsets[0]] == agents1._agentsets[0]._df.filter(mask0)
- )
- )
- assert all(
- series.all()
- for series in (
- result[agents1._agentsets[1]] == agents1._agentsets[1]._df.filter(mask1)
- )
- )
-
- def test_agentsets_by_type(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
-
- result = agents.agentsets_by_type
- assert isinstance(result, dict)
- assert isinstance(result[ExampleAgentSet], AgentSetRegistry)
-
- assert (
- result[ExampleAgentSet]._agentsets[0].df.rows()
- == agents._agentsets[1].df.rows()
- )
-
- def test_inactive_agents(self, fix_AgentSetRegistry: AgentSetRegistry):
- agents = fix_AgentSetRegistry
-
- # Test with select
- mask0 = (
- agents._agentsets[0].df["wealth"]
- > agents._agentsets[0].df["wealth"].to_list()[0]
- )
- mask1 = (
- agents._agentsets[1].df["wealth"]
- > agents._agentsets[1].df["wealth"].to_list()[0]
- )
- mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1}
- agents1 = agents.select(mask=mask_dictionary, inplace=False)
- result = agents1.inactive_agents
- assert isinstance(result, dict)
- assert isinstance(result[agents1._agentsets[0]], pl.DataFrame)
- assert isinstance(result[agents1._agentsets[1]], pl.DataFrame)
- assert all(
- series.all()
- for series in (
- result[agents1._agentsets[0]]
- == agents1._agentsets[0].select(mask0, negate=True).active_agents
- )
- )
- assert all(
- series.all()
- for series in (
- result[agents1._agentsets[1]]
- == agents1._agentsets[1].select(mask1, negate=True).active_agents
- )
- )
diff --git a/tests/test_agentset.py b/tests/test_agentset.py
index d475a4fc..c8459a80 100644
--- a/tests/test_agentset.py
+++ b/tests/test_agentset.py
@@ -260,6 +260,36 @@ def test_select(self, fix1_AgentSet: ExampleAgentSet):
selected.active_agents["wealth"].to_list() == agents.df["wealth"].to_list()
)
+ def test_rename(self, fix1_AgentSet: ExampleAgentSet) -> None:
+ agents = fix1_AgentSet
+ reg = agents.model.sets
+ # Inplace rename returns self and updates registry
+ old_name = agents.name
+ result = agents.rename("alpha", inplace=True)
+ assert result is agents
+ assert agents.name == "alpha"
+ assert reg.get("alpha") is agents
+ assert reg.get(old_name) is None
+
+ # Add a second set and claim the same name via registry first
+ other = ExampleAgentSet(agents.model)
+ other["wealth"] = other.starting_wealth
+ other["age"] = [1, 2, 3, 4]
+ reg.add(other)
+ reg.rename(other, "omega")
+ # Now rename the first to an existing name; should canonicalize to omega_1
+ agents.rename("omega", inplace=True)
+ assert agents.name != "omega"
+ assert agents.name.startswith("omega_")
+ assert reg.get(agents.name) is agents
+
+ # Non-inplace: returns a renamed copy of the set
+ copy_set = agents.rename("beta", inplace=False)
+ assert copy_set is not agents
+ assert copy_set.name in ("beta", "beta_1")
+ # Original remains unchanged
+ assert agents.name not in ("beta", "beta_1")
+
# Test with a pl.Series[bool]
mask = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean)
selected = agents.select(mask, inplace=False)
diff --git a/tests/test_agentsetregistry.py b/tests/test_agentsetregistry.py
new file mode 100644
index 00000000..07422fb6
--- /dev/null
+++ b/tests/test_agentsetregistry.py
@@ -0,0 +1,436 @@
+import polars as pl
+import pytest
+import beartype.roar as bear_roar
+
+from mesa_frames import AgentSet, AgentSetRegistry, Model
+
+
+class ExampleAgentSetA(AgentSet):
+ def __init__(self, model: Model):
+ super().__init__(model)
+ self["wealth"] = pl.Series("wealth", [1, 2, 3, 4])
+ self["age"] = pl.Series("age", [10, 20, 30, 40])
+
+ def add_wealth(self, amount: int) -> None:
+ self["wealth"] += amount
+
+ def step(self) -> None:
+ self.add_wealth(1)
+
+ def count(self) -> int:
+ return len(self)
+
+
+class ExampleAgentSetB(AgentSet):
+ def __init__(self, model: Model):
+ super().__init__(model)
+ self["wealth"] = pl.Series("wealth", [10, 20, 30, 40])
+ self["age"] = pl.Series("age", [11, 22, 33, 44])
+
+ def add_wealth(self, amount: int) -> None:
+ self["wealth"] += amount
+
+ def step(self) -> None:
+ self.add_wealth(2)
+
+ def count(self) -> int:
+ return len(self)
+
+
+@pytest.fixture
+def fix_model() -> Model:
+ return Model()
+
+
+@pytest.fixture
+def fix_set_a(fix_model: Model) -> ExampleAgentSetA:
+ return ExampleAgentSetA(fix_model)
+
+
+@pytest.fixture
+def fix_set_b(fix_model: Model) -> ExampleAgentSetB:
+ return ExampleAgentSetB(fix_model)
+
+
+@pytest.fixture
+def fix_registry_with_two(
+ fix_model: Model, fix_set_a: ExampleAgentSetA, fix_set_b: ExampleAgentSetB
+) -> AgentSetRegistry:
+ reg = AgentSetRegistry(fix_model)
+ reg.add([fix_set_a, fix_set_b])
+ return reg
+
+
+class TestAgentSetRegistry:
+ # Dunder: __init__
+ def test__init__(self):
+ model = Model()
+ reg = AgentSetRegistry(model)
+ assert reg.model is model
+ assert len(reg) == 0
+ assert reg.ids.len() == 0
+
+ # Public: add
+ def test_add(self, fix_model: Model) -> None:
+ reg = AgentSetRegistry(fix_model)
+ a1 = ExampleAgentSetA(fix_model)
+ a2 = ExampleAgentSetA(fix_model)
+ # Add single
+ reg.add(a1)
+ assert len(reg) == 1
+ assert a1 in reg
+ # Add list; second should be auto-renamed with suffix
+ reg.add([a2])
+ assert len(reg) == 2
+ names = [s.name for s in reg]
+ assert names[0] == "ExampleAgentSetA"
+ assert names[1] in ("ExampleAgentSetA_1", "ExampleAgentSetA_2")
+ # ids concatenated
+ assert reg.ids.len() == len(a1) + len(a2)
+ # Duplicate instance rejected
+ with pytest.raises(ValueError, match="already present in the AgentSetRegistry"):
+ reg.add([a1])
+ # Duplicate unique_id space rejected
+ a3 = ExampleAgentSetB(fix_model)
+ a3.df = a1.df
+ with pytest.raises(ValueError, match="agent IDs are not unique"):
+ reg.add(a3)
+
+ # Public: contains
+ def test_contains(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ a_name = next(iter(reg)).name
+ # Single instance
+ assert reg.contains(reg[0]) is True
+ # Single type
+ assert reg.contains(ExampleAgentSetA) is True
+ # Single name
+ assert reg.contains(a_name) is True
+ # Iterable: instances
+ assert reg.contains([reg[0], reg[1]]).to_list() == [True, True]
+ # Iterable: types
+ types_result = reg.contains([ExampleAgentSetA, ExampleAgentSetB])
+ assert types_result.dtype == pl.Boolean
+ assert types_result.to_list() == [True, True]
+ # Iterable: names
+ names = [s.name for s in reg]
+ assert reg.contains(names).to_list() == [True, True]
+ # Empty iterable is vacuously true
+ assert reg.contains([]) is True
+ # Unsupported element type (rejected by runtime type checking)
+ with pytest.raises(bear_roar.BeartypeCallHintParamViolation):
+ reg.contains([object()])
+
+ # Public: do
+ def test_do(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ # Inplace operation across both sets
+ reg.do("add_wealth", 5)
+ assert reg[0]["wealth"].to_list() == [6, 7, 8, 9]
+ assert reg[1]["wealth"].to_list() == [15, 25, 35, 45]
+ # return_results with different key domains
+ res_by_name = reg.do("count", return_results=True, key_by="name")
+ assert set(res_by_name.keys()) == {s.name for s in reg}
+ assert all(v == 4 for v in res_by_name.values())
+ res_by_index = reg.do("count", return_results=True, key_by="index")
+ assert set(res_by_index.keys()) == {0, 1}
+ res_by_type = reg.do("count", return_results=True, key_by="type")
+ assert set(res_by_type.keys()) == {ExampleAgentSetA, ExampleAgentSetB}
+
+ # Public: get
+ def test_get(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ # By index
+ assert isinstance(reg.get(0), AgentSet)
+ # By name
+ name = reg[0].name
+ assert reg.get(name) is reg[0]
+ # By type returns list
+ aset_list = reg.get(ExampleAgentSetA)
+ assert isinstance(aset_list, list) and all(
+ isinstance(s, ExampleAgentSetA) for s in aset_list
+ )
+ # Missing returns default None
+ assert reg.get(9999) is None
+ # Out-of-range index handled without raising
+ assert reg.get(10) is None
+
+ # Public: remove
+ def test_remove(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ total_ids = reg.ids.len()
+ # By instance
+ reg.remove(reg[0])
+ assert len(reg) == 1
+ # By type
+ reg.add(ExampleAgentSetA(reg.model))
+ assert len(reg.get(ExampleAgentSetA)) == 1
+ reg.remove(ExampleAgentSetA)
+ assert all(not isinstance(s, ExampleAgentSetA) for s in reg)
+ # By name (no error if not present)
+ reg.remove("nonexistent")
+ # ids recomputed and not equal to previous total
+ assert reg.ids.len() != total_ids
+
+ # Public: shuffle
+ def test_shuffle(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ reg.shuffle(inplace=True)
+ assert len(reg) == 2
+
+ # Public: sort
+ def test_sort(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ reg.sort(by="wealth", ascending=False)
+ assert reg[0]["wealth"].to_list() == sorted(
+ reg[0]["wealth"].to_list(), reverse=True
+ )
+
+ # Public: rename
+ def test_rename(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ # Single rename by instance, inplace
+ a0 = reg[0]
+ reg.rename(a0, "X")
+ assert a0.name == "X"
+ assert reg.get("X") is a0
+
+ # Rename second to same name should canonicalize
+ a1 = reg[1]
+ reg.rename(a1, "X")
+ assert a1.name != "X" and a1.name.startswith("X_")
+ assert reg.get(a1.name) is a1
+
+ # Non-inplace copy
+ reg2 = reg.rename(a0, "Y", inplace=False)
+ assert reg2 is not reg
+ assert reg.get("Y") is None
+ assert reg2.get("Y") is not None
+
+ # Atomic conflict raise: attempt to rename to existing name
+ with pytest.raises(ValueError):
+ reg.rename({a0: a1.name}, on_conflict="raise", mode="atomic")
+ # Names unchanged
+ assert reg.get(a1.name) is a1
+
+ # Best-effort: one ok, one conflicting → only ok applied
+ unique_name = "Z_unique"
+ reg.rename(
+ {a0: unique_name, a1: unique_name}, on_conflict="raise", mode="best_effort"
+ )
+ assert a0.name == unique_name
+ # a1 stays with its previous (non-unique_name) value
+ assert a1.name != unique_name
+
+ # Dunder: __getattr__
+ def test__getattr__(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ ages = reg.age
+ assert isinstance(ages, dict)
+ assert set(ages.keys()) == {s.name for s in reg}
+
+ # Dunder: __iter__
+ def test__iter__(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ it = list(iter(reg))
+ assert it[0] is reg[0]
+ assert all(isinstance(s, AgentSet) for s in it)
+
+ # Dunder: __len__
+ def test__len__(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ assert len(reg) == 2
+
+ # Dunder: __repr__
+ def test__repr__(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ repr(reg)
+
+ # Dunder: __str__
+ def test__str__(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ str(reg)
+
+ # Dunder: __reversed__
+ def test__reversed__(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ list(reversed(reg))
+
+ # Dunder: __setitem__
+ def test__setitem__(self, fix_model: Model) -> None:
+ reg = AgentSetRegistry(fix_model)
+ a1 = ExampleAgentSetA(fix_model)
+ a2 = ExampleAgentSetB(fix_model)
+ reg.add([a1, a2])
+ # Assign by index with duplicate name should raise
+ a_dup = ExampleAgentSetA(fix_model)
+ a_dup.name = reg[1].name # create name collision
+ with pytest.raises(ValueError, match="Duplicate agent set name disallowed"):
+ reg[0] = a_dup
+ # Assign by name: replace existing slot, authoritative name should be key
+ new_set = ExampleAgentSetA(fix_model)
+ reg[reg[1].name] = new_set
+ assert reg[1] is new_set
+ assert reg[1].name == reg[1].name
+ # Assign new name appends
+ extra = ExampleAgentSetA(fix_model)
+ reg["extra_set"] = extra
+ assert reg["extra_set"] is extra
+ # Model mismatch raises
+ other_model_set = ExampleAgentSetA(Model())
+ with pytest.raises(
+ TypeError, match="Assigned AgentSet must belong to the same model"
+ ):
+ reg[0] = other_model_set
+
+ # Public: keys
+ def test_keys(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ # keys by name
+ names = list(reg.keys())
+ assert names == [s.name for s in reg]
+ # keys by index
+ assert list(reg.keys(key_by="index")) == [0, 1]
+ # keys by type
+ assert set(reg.keys(key_by="type")) == {ExampleAgentSetA, ExampleAgentSetB}
+ # invalid key_by
+ with pytest.raises(bear_roar.BeartypeCallHintParamViolation):
+ list(reg.keys(key_by="bad"))
+
+ # Public: items
+ def test_items(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ items_name = list(reg.items())
+ assert [k for k, _ in items_name] == [s.name for s in reg]
+ items_idx = list(reg.items(key_by="index"))
+ assert [k for k, _ in items_idx] == [0, 1]
+ items_type = list(reg.items(key_by="type"))
+ assert {k for k, _ in items_type} == {ExampleAgentSetA, ExampleAgentSetB}
+
+ # Public: values
+ def test_values(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ assert list(reg.values())[0] is reg[0]
+
+ # Public: discard
+ def test_discard(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ original_len = len(reg)
+ # Missing selector ignored without error
+ reg.discard("missing_name")
+ assert len(reg) == original_len
+ # Remove by instance
+ reg.discard(reg[0])
+ assert len(reg) == original_len - 1
+ # Non-inplace returns new copy
+ reg2 = reg.discard("missing_name", inplace=False)
+ assert len(reg2) == len(reg)
+
+ # Public: ids (property)
+ def test_ids(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ assert isinstance(reg.ids, pl.Series)
+ before = reg.ids.len()
+ reg.remove(reg[0])
+ assert reg.ids.len() < before
+
+ # Dunder: __getitem__
+ def test__getitem__(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ # By index
+ assert reg[0] is next(iter(reg))
+ # By name
+ name0 = reg[0].name
+ assert reg[name0] is reg[0]
+ # By type
+ lst = reg[ExampleAgentSetA]
+ assert isinstance(lst, list) and all(
+ isinstance(s, ExampleAgentSetA) for s in lst
+ )
+ # Missing name raises KeyError
+ with pytest.raises(KeyError):
+ _ = reg["missing"]
+
+ # Dunder: __contains__ (membership)
+ def test__contains__(self, fix_registry_with_two: AgentSetRegistry) -> None:
+ reg = fix_registry_with_two
+ assert reg[0] in reg
+ new_set = ExampleAgentSetA(reg.model)
+ assert new_set not in reg
+
+ # Dunder: __add__
+ def test__add__(self, fix_model: Model) -> None:
+ reg = AgentSetRegistry(fix_model)
+ a1 = ExampleAgentSetA(fix_model)
+ a2 = ExampleAgentSetB(fix_model)
+ reg.add(a1)
+ reg_new = reg + a2
+ # original unchanged, new has two
+ assert len(reg) == 1
+ assert len(reg_new) == 2
+ # Presence by type/name (instances are deep-copied)
+ assert reg_new.contains(ExampleAgentSetA) is True
+ assert reg_new.contains(ExampleAgentSetB) is True
+
+ # Dunder: __iadd__
+ def test__iadd__(self, fix_model: Model) -> None:
+ reg = AgentSetRegistry(fix_model)
+ a1 = ExampleAgentSetA(fix_model)
+ a2 = ExampleAgentSetB(fix_model)
+ reg += a1
+ assert len(reg) == 1
+ reg += [a2]
+ assert len(reg) == 2
+ assert reg.contains([a1, a2]).all()
+
+ # Dunder: __sub__
+ def test__sub__(self, fix_model: Model) -> None:
+ reg = AgentSetRegistry(fix_model)
+ a1 = ExampleAgentSetA(fix_model)
+ a2 = ExampleAgentSetB(fix_model)
+ reg.add([a1, a2])
+ reg_new = reg - a1
+ # original unchanged
+ assert len(reg) == 2
+ # In current implementation, subtraction with instance returns a copy
+ # without mutation due to deep-copied identity; ensure new object
+ assert isinstance(reg_new, AgentSetRegistry) and reg_new is not reg
+ assert len(reg_new) == len(reg)
+ # subtract list of instances also yields unchanged copy
+ reg_new2 = reg - [a1, a2]
+ assert len(reg_new2) == len(reg)
+
+ # Dunder: __isub__
+ def test__isub__(self, fix_model: Model) -> None:
+ reg = AgentSetRegistry(fix_model)
+ a1 = ExampleAgentSetA(fix_model)
+ a2 = ExampleAgentSetB(fix_model)
+ reg.add([a1, a2])
+ reg -= a1
+ assert len(reg) == 1 and a1 not in reg
+ reg -= [a2]
+ assert len(reg) == 0
+
+ # Public: replace
+ def test_replace(self, fix_model: Model) -> None:
+ reg = AgentSetRegistry(fix_model)
+ a1 = ExampleAgentSetA(fix_model)
+ a2 = ExampleAgentSetB(fix_model)
+ a3 = ExampleAgentSetA(fix_model)
+ reg.add([a1, a2])
+ # Replace by index
+ reg.replace({0: a3})
+ assert reg[0] is a3
+ # Replace by name (authoritative)
+ reg.replace({reg[1].name: a2})
+ assert reg[1] is a2
+ # Atomic aliasing error: same object in two positions
+ with pytest.raises(ValueError, match="already exists at index"):
+ reg.replace({0: a2, 1: a2})
+ # Model mismatch
+ with pytest.raises(TypeError, match="must belong to the same model"):
+ reg.replace({0: ExampleAgentSetA(Model())})
+ # Non-atomic: only applies valid keys to copy
+ reg2 = reg.replace({0: a1}, inplace=False, atomic=False)
+ assert reg2[0] is a1
+ assert reg[0] is not a1
diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py
index b7407711..b2ac3279 100644
--- a/tests/test_datacollector.py
+++ b/tests/test_datacollector.py
@@ -164,7 +164,7 @@ def test_collect(self, fix1_model):
)
},
agent_reporters={
- "wealth": lambda model: model.sets._agentsets[0]["wealth"],
+ "wealth": lambda sets: sets[0]["wealth"],
"age": "age",
},
)
@@ -223,7 +223,7 @@ def test_collect_step(self, fix1_model):
)
},
agent_reporters={
- "wealth": lambda model: model.sets._agentsets[0]["wealth"],
+ "wealth": lambda sets: sets[0]["wealth"],
"age": "age",
},
)
@@ -279,7 +279,7 @@ def test_conditional_collect(self, fix1_model):
)
},
agent_reporters={
- "wealth": lambda model: model.sets._agentsets[0]["wealth"],
+ "wealth": lambda sets: sets[0]["wealth"],
"age": "age",
},
)
@@ -361,7 +361,7 @@ def test_flush_local_csv(self, fix1_model):
)
},
agent_reporters={
- "wealth": lambda model: model.sets._agentsets[0]["wealth"],
+ "wealth": lambda sets: sets[0]["wealth"],
"age": "age",
},
storage="csv",
@@ -437,7 +437,7 @@ def test_flush_local_parquet(self, fix1_model):
)
},
agent_reporters={
- "wealth": lambda model: model.sets._agentsets[0]["wealth"],
+ "wealth": lambda sets: sets[0]["wealth"],
},
storage="parquet",
storage_uri=tmpdir,
@@ -513,7 +513,7 @@ def test_postgress(self, fix1_model, postgres_uri):
)
},
agent_reporters={
- "wealth": lambda model: model.sets._agentsets[0]["wealth"],
+ "wealth": lambda sets: sets[0]["wealth"],
"age": "age",
},
storage="postgresql",
@@ -562,7 +562,7 @@ def test_batch_memory(self, fix2_model):
)
},
agent_reporters={
- "wealth": lambda model: model.sets._agentsets[0]["wealth"],
+ "wealth": lambda sets: sets[0]["wealth"],
"age": "age",
},
)
@@ -707,7 +707,7 @@ def test_batch_save(self, fix2_model):
)
},
agent_reporters={
- "wealth": lambda model: model.sets._agentsets[0]["wealth"],
+ "wealth": lambda sets: sets[0]["wealth"],
"age": "age",
},
storage="csv",
diff --git a/tests/test_grid.py b/tests/test_grid.py
index 6d75f3cc..904efdb0 100644
--- a/tests/test_grid.py
+++ b/tests/test_grid.py
@@ -12,10 +12,8 @@
def get_unique_ids(model: Model) -> pl.Series:
- # return model.get_sets_of_type(model.set_types[0])["unique_id"]
- series_list = [
- agent_set["unique_id"].cast(pl.UInt64) for agent_set in model.sets.df.values()
- ]
+ # Collect unique_id across all concrete AgentSets in the registry
+ series_list = [aset["unique_id"].cast(pl.UInt64) for aset in model.sets]
return pl.concat(series_list)