diff --git a/docs/general/user-guide/1_classes.md b/docs/general/user-guide/1_classes.md index f85c062d..9ac446de 100644 --- a/docs/general/user-guide/1_classes.md +++ b/docs/general/user-guide/1_classes.md @@ -27,9 +27,9 @@ You can access the underlying DataFrame where agents are stored with `self.df`. ## Model 🏗️ -To add your AgentSet to your Model, you should also add it to the sets with `+=` or `add`. +To add your AgentSet to your Model, use the registry `self.sets` with `+=` or `add`. -NOTE: Model.sets are stored in a class which is entirely similar to AgentSet called AgentSetRegistry. The API of the two are the same. If you try accessing AgentSetRegistry.df, you will get a dictionary of `[AgentSet, DataFrame]`. +Note: All agent sets live inside `AgentSetRegistry` (available as `model.sets`). Access sets through the registry, and access DataFrames from the set itself. For example: `self.sets["Preys"].df`. Example: @@ -43,7 +43,8 @@ class EcosystemModel(Model): def step(self): self.sets.do("move") self.sets.do("hunt") - self.prey.do("reproduce") + # Access specific sets via the registry + self.sets["Preys"].do("reproduce") ``` ## Space: Grid 🌐 @@ -76,18 +77,23 @@ Example: class ExampleModel(Model): def __init__(self): super().__init__() - self.sets = MoneyAgent(self) + # Add the set to the registry + self.sets.add(MoneyAgents(100, self)) + # Configure reporters: use the registry to locate sets; get df from the set self.datacollector = DataCollector( model=self, - model_reporters={"total_wealth": lambda m: lambda m: list(m.sets.df.values())[0]["wealth"].sum()}, + model_reporters={ + "total_wealth": lambda m: m.sets["MoneyAgents"].df["wealth"].sum(), + }, agent_reporters={"wealth": "wealth"}, storage="csv", storage_uri="./data", - trigger=lambda m: m.schedule.steps % 2 == 0 + trigger=lambda m: m.steps % 2 == 0, ) def step(self): - self.sets.step() + # Step all sets via the registry + self.sets.do("step") self.datacollector.conditional_collect() self.datacollector.flush() ``` diff --git a/docs/general/user-guide/2_introductory-tutorial.ipynb b/docs/general/user-guide/2_introductory-tutorial.ipynb index ec1165da..11391f9d 100644 --- a/docs/general/user-guide/2_introductory-tutorial.ipynb +++ b/docs/general/user-guide/2_introductory-tutorial.ipynb @@ -74,7 +74,9 @@ " self.sets += agents_cls(N, self)\n", " self.datacollector = DataCollector(\n", " model=self,\n", - " model_reporters={\"total_wealth\": lambda m: m.agents[\"wealth\"].sum()},\n", + " model_reporters={\n", + " \"total_wealth\": lambda m: m.sets[\"MoneyAgents\"].df[\"wealth\"].sum()\n", + " },\n", " agent_reporters={\"wealth\": \"wealth\"},\n", " storage=\"csv\",\n", " storage_uri=\"./data\",\n", diff --git a/docs/general/user-guide/4_datacollector.ipynb b/docs/general/user-guide/4_datacollector.ipynb index 3fa16b49..0809caa2 100644 --- a/docs/general/user-guide/4_datacollector.ipynb +++ b/docs/general/user-guide/4_datacollector.ipynb @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 6, "id": "9a63283cbaf04dbcab1f6479b197f3a8", "metadata": { "editable": true @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "72eea5119410473aa328ad9291626812", "metadata": { "editable": true @@ -63,11 +63,11 @@ " │ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", " │ i64 ┆ str ┆ i64 ┆ f64 ┆ i64 │\n", " ╞══════╪═════════════════════════════════╪═══════╪══════════════╪══════════╡\n", - " │ 2 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", - " │ 4 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", - " │ 6 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", - " │ 8 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", - " │ 10 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", + " │ 2 ┆ 332212815818606584686857770936… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", + " │ 4 ┆ 332212815818606584686857770936… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", + " │ 6 ┆ 332212815818606584686857770936… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", + " │ 8 ┆ 332212815818606584686857770936… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", + " │ 10 ┆ 332212815818606584686857770936… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", " └──────┴─────────────────────────────────┴───────┴──────────────┴──────────┘,\n", " 'agent': shape: (5_000, 4)\n", " ┌────────────────────┬──────┬─────────────────────────────────┬───────┐\n", @@ -75,21 +75,21 @@ " │ --- ┆ --- ┆ --- ┆ --- │\n", " │ f64 ┆ i32 ┆ str ┆ i32 │\n", " ╞════════════════════╪══════╪═════════════════════════════════╪═══════╡\n", - " │ 0.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n", - " │ 3.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n", - " │ 1.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n", - " │ 3.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n", - " │ 6.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ 3.0 ┆ 2 ┆ 332212815818606584686857770936… ┆ 0 │\n", + " │ 0.0 ┆ 2 ┆ 332212815818606584686857770936… ┆ 0 │\n", + " │ 2.0 ┆ 2 ┆ 332212815818606584686857770936… ┆ 0 │\n", + " │ 1.0 ┆ 2 ┆ 332212815818606584686857770936… ┆ 0 │\n", + " │ 0.0 ┆ 2 ┆ 332212815818606584686857770936… ┆ 0 │\n", " │ … ┆ … ┆ … ┆ … │\n", - " │ 4.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n", - " │ 1.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n", - " │ 0.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n", - " │ 0.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n", - " │ 0.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ 0.0 ┆ 10 ┆ 332212815818606584686857770936… ┆ 0 │\n", + " │ 0.0 ┆ 10 ┆ 332212815818606584686857770936… ┆ 0 │\n", + " │ 0.0 ┆ 10 ┆ 332212815818606584686857770936… ┆ 0 │\n", + " │ 0.0 ┆ 10 ┆ 332212815818606584686857770936… ┆ 0 │\n", + " │ 0.0 ┆ 10 ┆ 332212815818606584686857770936… ┆ 0 │\n", " └────────────────────┴──────┴─────────────────────────────────┴───────┘}" ] }, - "execution_count": 19, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -120,8 +120,8 @@ " self.dc = DataCollector(\n", " model=self,\n", " model_reporters={\n", - " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n", - " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n", + " \"total_wealth\": lambda m: m.sets[\"MoneyAgents\"].df[\"wealth\"].sum(),\n", + " \"n_agents\": lambda m: len(m.sets[\"MoneyAgents\"]),\n", " },\n", " agent_reporters={\n", " \"wealth\": \"wealth\", # pull existing column\n", @@ -175,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "5f14f38c", "metadata": {}, "outputs": [ @@ -185,7 +185,7 @@ "[]" ] }, - "execution_count": 20, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -198,8 +198,8 @@ "model_csv.dc = DataCollector(\n", " model=model_csv,\n", " model_reporters={\n", - " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n", - " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n", + " \"total_wealth\": lambda m: m.sets[\"MoneyAgents\"].df[\"wealth\"].sum(),\n", + " \"n_agents\": lambda m: len(m.sets[\"MoneyAgents\"]),\n", " },\n", " agent_reporters={\n", " \"wealth\": \"wealth\",\n", @@ -226,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "8763a12b2bbd4a93a75aff182afb95dc", "metadata": { "editable": true @@ -238,7 +238,7 @@ "[]" ] }, - "execution_count": 21, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -249,8 +249,8 @@ "model_parq.dc = DataCollector(\n", " model=model_parq,\n", " model_reporters={\n", - " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n", - " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n", + " \"total_wealth\": lambda m: m.sets[\"MoneyAgents\"].df[\"wealth\"].sum(),\n", + " \"n_agents\": lambda m: len(m.sets[\"MoneyAgents\"]),\n", " },\n", " agent_reporters={\n", " \"wealth\": \"wealth\",\n", @@ -279,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "7cdc8c89c7104fffa095e18ddfef8986", "metadata": { "editable": true @@ -290,8 +290,8 @@ "model_s3.dc = DataCollector(\n", " model=model_s3,\n", " model_reporters={\n", - " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n", - " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n", + " \"total_wealth\": lambda m: m.sets[\"MoneyAgents\"].df[\"wealth\"].sum(),\n", + " \"n_agents\": lambda m: len(m.sets[\"MoneyAgents\"]),\n", " },\n", " agent_reporters={\n", " \"wealth\": \"wealth\",\n", @@ -319,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 11, "id": "938c804e27f84196a10c8828c723f798", "metadata": { "editable": true @@ -381,7 +381,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 12, "id": "59bbdb311c014d738909a11f9e486628", "metadata": { "editable": true @@ -410,7 +410,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 13, "id": "8a65eabff63a45729fe45fb5ade58bdc", "metadata": { "editable": true @@ -426,7 +426,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 5)
stepseedbatchtotal_wealthn_agents
i64stri64f64i64
2"732054881101029867447298951813…0100.0100
4"732054881101029867447298951813…0100.0100
6"732054881101029867447298951813…0100.0100
8"732054881101029867447298951813…0100.0100
10"732054881101029867447298951813…0100.0100
" + "shape: (5, 5)
stepseedbatchtotal_wealthn_agents
i64stri64f64i64
2"540832786058427425452319829502…0100.0100
4"540832786058427425452319829502…0100.0100
6"540832786058427425452319829502…0100.0100
8"540832786058427425452319829502…0100.0100
10"540832786058427425452319829502…0100.0100
" ], "text/plain": [ "shape: (5, 5)\n", @@ -435,15 +435,15 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ i64 ┆ str ┆ i64 ┆ f64 ┆ i64 │\n", "╞══════╪═════════════════════════════════╪═══════╪══════════════╪══════════╡\n", - "│ 2 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n", - "│ 4 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n", - "│ 6 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n", - "│ 8 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n", - "│ 10 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n", + "│ 2 ┆ 540832786058427425452319829502… ┆ 0 ┆ 100.0 ┆ 100 │\n", + "│ 4 ┆ 540832786058427425452319829502… ┆ 0 ┆ 100.0 ┆ 100 │\n", + "│ 6 ┆ 540832786058427425452319829502… ┆ 0 ┆ 100.0 ┆ 100 │\n", + "│ 8 ┆ 540832786058427425452319829502… ┆ 0 ┆ 100.0 ┆ 100 │\n", + "│ 10 ┆ 540832786058427425452319829502… ┆ 0 ┆ 100.0 ┆ 100 │\n", "└──────┴─────────────────────────────────┴───────┴──────────────┴──────────┘" ] }, - "execution_count": 25, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/sugarscape_ig/ss_polars/agents.py b/examples/sugarscape_ig/ss_polars/agents.py index b0ecbe90..32ca91f5 100644 --- a/examples/sugarscape_ig/ss_polars/agents.py +++ b/examples/sugarscape_ig/ss_polars/agents.py @@ -35,12 +35,15 @@ def __init__( self.add(agents) def eat(self): + # Only consider cells currently occupied by agents of this set cells = self.space.cells.filter(pl.col("agent_id").is_not_null()) - self[cells["agent_id"], "sugar"] = ( - self[cells["agent_id"], "sugar"] - + cells["sugar"] - - self[cells["agent_id"], "metabolism"] - ) + mask_in_set = cells["agent_id"].is_in(self.index) + if mask_in_set.any(): + cells = cells.filter(mask_in_set) + ids = cells["agent_id"] + self[ids, "sugar"] = ( + self[ids, "sugar"] + cells["sugar"] - self[ids, "metabolism"] + ) def step(self): self.shuffle().do("move").do("eat") diff --git a/examples/sugarscape_ig/ss_polars/model.py b/examples/sugarscape_ig/ss_polars/model.py index 56a3a83b..36b2718e 100644 --- a/examples/sugarscape_ig/ss_polars/model.py +++ b/examples/sugarscape_ig/ss_polars/model.py @@ -33,7 +33,10 @@ def __init__( sugar=sugar_grid.flatten(), max_sugar=sugar_grid.flatten() ) self.space.set_cells(sugar_grid) - self.sets += agent_type(self, n_agents, initial_sugar, metabolism, vision) + # Create and register the main agent set; keep its name for later lookups + main_set = agent_type(self, n_agents, initial_sugar, metabolism, vision) + self.sets += main_set + self._main_set_name = main_set.name if initial_positions is not None: self.space.place_agents(self.sets, initial_positions) else: @@ -41,7 +44,8 @@ def __init__( def run_model(self, steps: int) -> list[int]: for _ in range(steps): - if len(list(self.sets.df.values())[0]) == 0: + # Stop if the main agent set is empty + if len(self.sets[self._main_set_name]) == 0: # type: ignore[index] return empty_cells = self.space.empty_cells full_cells = self.space.full_cells diff --git a/mesa_frames/abstract/agentset.py b/mesa_frames/abstract/agentset.py index a7da9097..ae5db2db 100644 --- a/mesa_frames/abstract/agentset.py +++ b/mesa_frames/abstract/agentset.py @@ -20,10 +20,12 @@ from abc import abstractmethod from collections.abc import Collection, Iterable, Iterator +from contextlib import suppress from typing import Any, Literal, Self, overload -from mesa_frames.abstract.agentsetregistry import AbstractAgentSetRegistry -from mesa_frames.abstract.mixin import DataFrameMixin +from numpy.random import Generator + +from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin from mesa_frames.types_ import ( AgentMask, BoolSeries, @@ -35,7 +37,7 @@ ) -class AbstractAgentSet(AbstractAgentSetRegistry, DataFrameMixin): +class AbstractAgentSet(CopyMixin, DataFrameMixin): """The AbstractAgentSet class is a container for agents of the same type. Parameters @@ -44,6 +46,7 @@ class AbstractAgentSet(AbstractAgentSetRegistry, DataFrameMixin): The model that the agent set belongs to. """ + _copy_only_reference: list[str] = ["_model"] _df: DataFrame # The agents in the AbstractAgentSet _mask: AgentMask # The underlying mask used for the active agents in the AbstractAgentSet. _model: ( @@ -75,7 +78,31 @@ def add( Returns ------- Self - A new AbstractAgentSetRegistry with the added agents. + A new AbstractAgentSet with the added agents. + """ + ... + + @overload + @abstractmethod + def contains(self, agents: int) -> bool: ... + + @overload + @abstractmethod + def contains(self, agents: IdsLike) -> BoolSeries: ... + + @abstractmethod + def contains(self, agents: IdsLike) -> bool | BoolSeries: + """Check if agents with the specified IDs are in the AgentSet. + + Parameters + ---------- + agents : IdsLike + The ID(s) to check for. + + Returns + ------- + bool | BoolSeries + True if the agent is in the AgentSet, False otherwise. """ ... @@ -94,65 +121,85 @@ def discard(self, agents: IdsLike | AgentMask, inplace: bool = True) -> Self: Self The updated AbstractAgentSet. """ - return super().discard(agents, inplace) + with suppress(KeyError, ValueError): + return self.remove(agents, inplace=inplace) + return self._get_obj(inplace) + + @abstractmethod + def remove(self, agents: IdsLike | AgentMask, inplace: bool = True) -> Self: + """Remove agents from this AbstractAgentSet. + + Parameters + ---------- + agents : IdsLike | AgentMask + The agents or mask to remove. + inplace : bool, optional + Whether to remove in place, by default True. + + Returns + ------- + Self + The updated agent set. + """ + ... @overload + @abstractmethod def do( self, method_name: str, - *args, + *args: Any, mask: AgentMask | None = None, return_results: Literal[False] = False, inplace: bool = True, - **kwargs, + **kwargs: Any, ) -> Self: ... @overload + @abstractmethod def do( self, method_name: str, - *args, + *args: Any, mask: AgentMask | None = None, return_results: Literal[True], inplace: bool = True, - **kwargs, + **kwargs: Any, ) -> Any: ... + @abstractmethod def do( self, method_name: str, - *args, + *args: Any, mask: AgentMask | None = None, return_results: bool = False, inplace: bool = True, - **kwargs, + **kwargs: Any, ) -> Self | Any: - masked_df = self._get_masked_df(mask) - # If the mask is empty, we can use the object as is - if len(masked_df) == len(self._df): - obj = self._get_obj(inplace) - method = getattr(obj, method_name) - result = method(*args, **kwargs) - else: # If the mask is not empty, we need to create a new masked AbstractAgentSet and concatenate the AbstractAgentSets at the end - obj = self._get_obj(inplace=False) - obj._df = masked_df - original_masked_index = obj._get_obj_copy(obj.index) - method = getattr(obj, method_name) - result = method(*args, **kwargs) - obj._concatenate_agentsets( - [self], - duplicates_allowed=True, - keep_first_only=True, - original_masked_index=original_masked_index, - ) - if inplace: - for key, value in obj.__dict__.items(): - setattr(self, key, value) - obj = self - if return_results: - return result - else: - return obj + """Invoke a method on the AgentSet. + + Parameters + ---------- + method_name : str + The name of the method to invoke. + *args : Any + Positional arguments to pass to the method + mask : AgentMask | None, optional + The subset of agents on which to apply the method + return_results : bool, optional + Whether to return the result of the method, by default False + inplace : bool, optional + Whether the operation should be done inplace, by default False + **kwargs : Any + Keyword arguments to pass to the method + + Returns + ------- + Self | Any + The updated AgentSet or the result of the method. + """ + ... @abstractmethod @overload @@ -182,20 +229,6 @@ def step(self) -> None: """Run a single step of the AbstractAgentSet. This method should be overridden by subclasses.""" ... - def remove(self, agents: IdsLike | AgentMask, inplace: bool = True) -> Self: - if isinstance(agents, str) and agents == "active": - agents = self.active_agents - if agents is None or (isinstance(agents, Iterable) and len(agents) == 0): - return self._get_obj(inplace) - agents = self._df_index(self._get_masked_df(agents), "unique_id") - sets = self.model.sets.remove(agents, inplace=inplace) - # TODO: Refactor AgentSetRegistry to return dict[str, AbstractAgentSet] instead of dict[AbstractAgentSet, DataFrame] - # And assign a name to AbstractAgentSet? This has to be replaced by a nicer API of AgentSetRegistry - for agentset in sets.df.keys(): - if isinstance(agentset, self.__class__): - return agentset - return self - @abstractmethod def _concatenate_agentsets( self, @@ -285,9 +318,9 @@ def __add__(self, other: DataFrame | DataFrameInput) -> Self: Returns ------- Self - A new AbstractAgentSetRegistry with the added agents. + A new AbstractAgentSet with the added agents. """ - return super().__add__(other) + return self.add(other, inplace=False) def __iadd__(self, other: DataFrame | DataFrameInput) -> Self: """ @@ -305,9 +338,17 @@ def __iadd__(self, other: DataFrame | DataFrameInput) -> Self: Returns ------- Self - The updated AbstractAgentSetRegistry. + The updated AbstractAgentSet. """ - return super().__iadd__(other) + return self.add(other, inplace=True) + + def __isub__(self, other: IdsLike | AgentMask | DataFrame) -> Self: + """Remove agents via -= operator.""" + return self.discard(other, inplace=True) + + def __sub__(self, other: IdsLike | AgentMask | DataFrame) -> Self: + """Return a new set with agents removed via - operator.""" + return self.discard(other, inplace=False) @abstractmethod def __getattr__(self, name: str) -> Any: @@ -336,9 +377,20 @@ def __getitem__( | tuple[AgentMask, Collection[str]] ), ) -> Series | DataFrame: - attr = super().__getitem__(key) - assert isinstance(attr, (Series, DataFrame, Index)) - return attr + # Mirror registry/old container behavior: delegate to get() + if isinstance(key, tuple): + return self.get(mask=key[0], attr_names=key[1]) + else: + if isinstance(key, str) or ( + isinstance(key, Collection) and all(isinstance(k, str) for k in key) + ): + return self.get(attr_names=key) + else: + return self.get(mask=key) + + def __contains__(self, agents: int) -> bool: + """Membership test for an agent id in this set.""" + return bool(self.contains(agents)) def __len__(self) -> int: return len(self._df) @@ -376,6 +428,7 @@ def active_agents(self) -> DataFrame: ... def inactive_agents(self) -> DataFrame: ... @property + @abstractmethod def index(self) -> Index: ... @property @@ -391,3 +444,101 @@ def pos(self) -> DataFrame: pos, self.index, new_index_cols="unique_id", original_index_cols="agent_id" ) return pos + + @property + def name(self) -> str: + """The name of the agent set. + + Returns + ------- + str + The name of the agent set + """ + return self._name + + @property + def model(self) -> mesa_frames.concrete.model.Model: + return self._model + + @property + def random(self) -> Generator: + return self.model.random + + @property + def space(self) -> mesa_frames.abstract.space.Space | None: + return self.model.space + + @abstractmethod + def rename(self, new_name: str, inplace: bool = True) -> Self: + """Rename this AgentSet. + + Concrete subclasses must implement the mechanics for coordinating with + any containing registry and managing ``inplace`` semantics. The method + should update the set's name (or return a renamed copy when + ``inplace=False``) while preserving registry invariants. + + Parameters + ---------- + new_name : str + Desired new name for this AgentSet. + inplace : bool, optional + Whether to perform the rename in place. If False, a renamed copy is + returned, by default True. + + Returns + ------- + Self + The updated AgentSet (or a renamed copy when ``inplace=False``). + """ + ... + + @abstractmethod + def set( + self, + attr_names: str | Collection[str] | dict[str, Any] | None = None, + values: Any | None = None, + mask: AgentMask | None = None, + inplace: bool = True, + ) -> Self: + """Update agent attributes, optionally on a masked subset. + + Parameters + ---------- + attr_names : str | Collection[str] | dict[str, Any] | None, optional + Attribute(s) to assign. When ``None``, concrete implementations may + derive targets from ``values``. + values : Any | None, optional + Replacement value(s) aligned with ``attr_names``. + mask : AgentMask | None, optional + Subset selector limiting which agents are updated. + inplace : bool, optional + Whether to mutate in place or return an updated copy, by default True. + + Returns + ------- + Self + The updated AgentSet (or a modified copy when ``inplace=False``). + """ + ... + + def __setitem__( + self, + key: str + | Collection[str] + | AgentMask + | tuple[AgentMask, str | Collection[str]], + values: Any, + ) -> None: + """Set values using [] syntax, delegating to set().""" + if isinstance(key, tuple): + self.set(mask=key[0], attr_names=key[1], values=values) + else: + if isinstance(key, str) or ( + isinstance(key, Collection) and all(isinstance(k, str) for k in key) + ): + try: + self.set(attr_names=key, values=values) + except KeyError: # key may actually be a mask + self.set(attr_names=None, mask=key, values=values) + else: + self.set(attr_names=None, mask=key, values=values) diff --git a/mesa_frames/abstract/agentsetregistry.py b/mesa_frames/abstract/agentsetregistry.py index c8fa6c60..abb0ef69 100644 --- a/mesa_frames/abstract/agentsetregistry.py +++ b/mesa_frames/abstract/agentsetregistry.py @@ -43,7 +43,7 @@ def __init__(self, model): from __future__ import annotations # PEP 563: postponed evaluation of type annotations from abc import abstractmethod -from collections.abc import Callable, Collection, Iterator, Sequence +from collections.abc import Callable, Collection, Iterable, Iterator, Sequence from contextlib import suppress from typing import Any, Literal, Self, overload @@ -51,12 +51,12 @@ def __init__(self, model): from mesa_frames.abstract.mixin import CopyMixin from mesa_frames.types_ import ( - AgentMask, + AbstractAgentSetSelector as AgentSetSelector, +) +from mesa_frames.types_ import ( BoolSeries, - DataFrame, - DataFrameInput, - IdsLike, Index, + KeyBy, Series, ) @@ -74,20 +74,17 @@ def __init__(self) -> None: ... def discard( self, - agents: IdsLike - | AgentMask - | mesa_frames.abstract.agentset.AbstractAgentSet - | Collection[mesa_frames.abstract.agentset.AbstractAgentSet], + sets: AgentSetSelector, inplace: bool = True, ) -> Self: - """Remove agents from the AbstractAgentSetRegistry. Does not raise an error if the agent is not found. + """Remove AgentSets selected by ``sets``. Ignores missing. Parameters ---------- - agents : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] - The agents to remove + sets : AgentSetSelector + Which AgentSets to remove (instance, type, name, or collection thereof). inplace : bool - Whether to remove the agent in place. Defaults to True. + Whether to remove in place. Defaults to True. Returns ------- @@ -95,26 +92,70 @@ def discard( The updated AbstractAgentSetRegistry. """ with suppress(KeyError, ValueError): - return self.remove(agents, inplace=inplace) + return self.remove(sets, inplace=inplace) return self._get_obj(inplace) + @abstractmethod + def rename( + self, + target: ( + mesa_frames.abstract.agentset.AbstractAgentSet + | str + | dict[mesa_frames.abstract.agentset.AbstractAgentSet | str, str] + | list[tuple[mesa_frames.abstract.agentset.AbstractAgentSet | str, str]] + ), + new_name: str | None = None, + *, + on_conflict: Literal["canonicalize", "raise"] = "canonicalize", + mode: Literal["atomic", "best_effort"] = "atomic", + inplace: bool = True, + ) -> Self: + """Rename AgentSets in this registry, handling conflicts. + + Parameters + ---------- + target : mesa_frames.abstract.agentset.AbstractAgentSet | str | dict[mesa_frames.abstract.agentset.AbstractAgentSet | str, str] | list[tuple[mesa_frames.abstract.agentset.AbstractAgentSet | str, str]] + Single target (instance or existing name) with ``new_name`` provided, + or a mapping/sequence of (target, new_name) pairs for batch rename. + new_name : str | None + New name for single-target rename. + on_conflict : Literal["canonicalize", "raise"] + When a desired name collides, either canonicalize by appending a + numeric suffix (default) or raise ``ValueError``. + mode : Literal["atomic", "best_effort"] + In "atomic" mode, validate all renames before applying any. In + "best_effort" mode, apply what can be applied and skip failures. + + Returns + ------- + Self + Updated registry (or a renamed copy when ``inplace=False``). + + Parameters + ---------- + inplace : bool, optional + Whether to perform the rename in place. If False, a renamed copy is + returned, by default True. + """ + ... + @abstractmethod def add( self, - agents: DataFrame - | DataFrameInput - | mesa_frames.abstract.agentset.AbstractAgentSet - | Collection[mesa_frames.abstract.agentset.AbstractAgentSet], + sets: ( + mesa_frames.abstract.agentset.AbstractAgentSet + | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + ), inplace: bool = True, ) -> Self: - """Add agents to the AbstractAgentSetRegistry. + """Add AgentSets to the AbstractAgentSetRegistry. Parameters ---------- - agents : DataFrame | DataFrameInput | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] - The agents to add. + sets : mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The AgentSet(s) to add. inplace : bool - Whether to add the agents in place. Defaults to True. + Whether to add in place. Defaults to True. Returns ------- @@ -125,29 +166,40 @@ def add( @overload @abstractmethod - def contains(self, agents: int) -> bool: ... + def contains( + self, + sets: ( + mesa_frames.abstract.agentset.AbstractAgentSet + | type[mesa_frames.abstract.agentset.AbstractAgentSet] + | str + ), + ) -> bool: ... @overload @abstractmethod def contains( - self, agents: mesa_frames.abstract.agentset.AbstractAgentSet | IdsLike + self, + sets: Collection[ + mesa_frames.abstract.agentset.AbstractAgentSet + | type[mesa_frames.abstract.agentset.AbstractAgentSet] + | str + ], ) -> BoolSeries: ... @abstractmethod - def contains( - self, agents: mesa_frames.abstract.agentset.AbstractAgentSet | IdsLike - ) -> bool | BoolSeries: - """Check if agents with the specified IDs are in the AbstractAgentSetRegistry. + def contains(self, sets: AgentSetSelector) -> bool | BoolSeries: + """Check if selected AgentSets are present in the registry. Parameters ---------- - agents : mesa_frames.abstract.agentset.AbstractAgentSet | IdsLike - The ID(s) to check for. + sets : AgentSetSelector + An AgentSet instance, class/type, name string, or a collection of + those. For collections, returns a BoolSeries aligned with input order. Returns ------- bool | BoolSeries - True if the agent is in the AbstractAgentSetRegistry, False otherwise. + Boolean for single selector values; BoolSeries for collections. """ @overload @@ -156,9 +208,10 @@ def do( self, method_name: str, *args: Any, - mask: AgentMask | None = None, + sets: AgentSetSelector | None = None, return_results: Literal[False] = False, inplace: bool = True, + key_by: KeyBy = "name", **kwargs: Any, ) -> Self: ... @@ -168,22 +221,35 @@ def do( self, method_name: str, *args: Any, - mask: AgentMask | None = None, + sets: AgentSetSelector, return_results: Literal[True], inplace: bool = True, + key_by: KeyBy = "name", **kwargs: Any, - ) -> Any | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Any]: ... + ) -> ( + Any + | dict[str, Any] + | dict[int, Any] + | dict[type[mesa_frames.abstract.agentset.AbstractAgentSet], Any] + ): ... @abstractmethod def do( self, method_name: str, *args: Any, - mask: AgentMask | None = None, + sets: AgentSetSelector = None, return_results: bool = False, inplace: bool = True, + key_by: KeyBy = "name", **kwargs: Any, - ) -> Self | Any | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Any]: + ) -> ( + Self + | Any + | dict[str, Any] + | dict[int, Any] + | dict[type[mesa_frames.abstract.agentset.AbstractAgentSet], Any] + ): """Invoke a method on the AbstractAgentSetRegistry. Parameters @@ -192,71 +258,88 @@ def do( The name of the method to invoke. *args : Any Positional arguments to pass to the method - mask : AgentMask | None, optional - The subset of agents on which to apply the method + sets : AgentSetSelector, optional + Which AgentSets to target (instance, type, name, or collection thereof). Defaults to all. return_results : bool, optional - Whether to return the result of the method, by default False + Whether to return per-set results as a dictionary, by default False. inplace : bool, optional - Whether the operation should be done inplace, by default False + Whether the operation should be done inplace, by default True + key_by : KeyBy, optional + Key domain for the returned mapping when ``return_results`` is True. + - "name" (default) → keys are set names (str) + - "index" → keys are positional indices (int) + - "type" → keys are concrete set classes (type) **kwargs : Any Keyword arguments to pass to the method Returns ------- - Self | Any | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Any] - The updated AbstractAgentSetRegistry or the result of the method. + Self | Any | dict[str, Any] | dict[int, Any] | dict[type[mesa_frames.abstract.agentset.AbstractAgentSet], Any] + The updated registry, or the method result(s). When ``return_results`` + is True, returns a dictionary keyed per ``key_by``. """ ... - @abstractmethod @overload - def get(self, attr_names: str) -> Series | dict[str, Series]: ... - @abstractmethod + def get( + self, key: int, default: None = ... + ) -> mesa_frames.abstract.agentset.AbstractAgentSet | None: ... + @overload + @abstractmethod def get( - self, attr_names: Collection[str] | None = None - ) -> DataFrame | dict[str, DataFrame]: ... + self, key: str, default: None = ... + ) -> mesa_frames.abstract.agentset.AbstractAgentSet | None: ... + @overload @abstractmethod def get( self, - attr_names: str | Collection[str] | None = None, - mask: AgentMask | None = None, - ) -> Series | dict[str, Series] | DataFrame | dict[str, DataFrame]: - """Retrieve the value of a specified attribute for each agent in the AbstractAgentSetRegistry. + key: type[mesa_frames.abstract.agentset.AbstractAgentSet], + default: None = ..., + ) -> list[mesa_frames.abstract.agentset.AbstractAgentSet]: ... - Parameters - ---------- - attr_names : str | Collection[str] | None, optional - The attributes to retrieve. If None, all attributes are retrieved. Defaults to None. - mask : AgentMask | None, optional - The AgentMask of agents to retrieve the attribute for. If None, attributes of all agents are returned. Defaults to None. + @overload + @abstractmethod + def get( + self, + key: int | str | type[mesa_frames.abstract.agentset.AbstractAgentSet], + default: mesa_frames.abstract.agentset.AbstractAgentSet + | list[mesa_frames.abstract.agentset.AbstractAgentSet] + | None, + ) -> ( + mesa_frames.abstract.agentset.AbstractAgentSet + | list[mesa_frames.abstract.agentset.AbstractAgentSet] + | None + ): ... - Returns - ------- - Series | dict[str, Series] | DataFrame | dict[str, DataFrame] - The attribute values. - """ - ... + @abstractmethod + def get( + self, + key: int | str | type[mesa_frames.abstract.agentset.AbstractAgentSet], + default: mesa_frames.abstract.agentset.AbstractAgentSet + | list[mesa_frames.abstract.agentset.AbstractAgentSet] + | None = None, + ) -> ( + mesa_frames.abstract.agentset.AbstractAgentSet + | list[mesa_frames.abstract.agentset.AbstractAgentSet] + | None + ): + """Safe lookup for AgentSet(s) by index, name, or type.""" @abstractmethod def remove( self, - agents: ( - IdsLike - | AgentMask - | mesa_frames.abstract.agentset.AbstractAgentSet - | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] - ), + sets: AgentSetSelector, inplace: bool = True, ) -> Self: - """Remove the agents from the AbstractAgentSetRegistry. + """Remove AgentSets from the AbstractAgentSetRegistry. Parameters ---------- - agents : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] - The agents to remove. + sets : AgentSetSelector + Which AgentSets to remove (instance, type, name, or collection thereof). inplace : bool, optional Whether to remove the agent in place. @@ -267,96 +350,46 @@ def remove( """ ... - @abstractmethod - def select( - self, - mask: AgentMask | None = None, - filter_func: Callable[[Self], AgentMask] | None = None, - n: int | None = None, - negate: bool = False, - inplace: bool = True, - ) -> Self: - """Select agents in the AbstractAgentSetRegistry based on the given criteria. - - Parameters - ---------- - mask : AgentMask | None, optional - The AgentMask of agents to be selected, by default None - filter_func : Callable[[Self], AgentMask] | None, optional - A function which takes as input the AbstractAgentSetRegistry and returns a AgentMask, by default None - n : int | None, optional - The maximum number of agents to be selected, by default None - negate : bool, optional - If the selection should be negated, by default False - inplace : bool, optional - If the operation should be performed on the same object, by default True - - Returns - ------- - Self - A new or updated AbstractAgentSetRegistry. - """ - ... - - @abstractmethod - @overload - def set( - self, - attr_names: dict[str, Any], - values: None, - mask: AgentMask | None = None, - inplace: bool = True, - ) -> Self: ... + # select() intentionally removed from the abstract API. @abstractmethod - @overload - def set( + def replace( self, - attr_names: str | Collection[str], - values: Any, - mask: AgentMask | None = None, - inplace: bool = True, - ) -> Self: ... - - @abstractmethod - def set( - self, - attr_names: DataFrameInput | str | Collection[str], - values: Any | None = None, - mask: AgentMask | None = None, + mapping: ( + dict[int | str, mesa_frames.abstract.agentset.AbstractAgentSet] + | list[tuple[int | str, mesa_frames.abstract.agentset.AbstractAgentSet]] + ), + *, inplace: bool = True, + atomic: bool = True, ) -> Self: - """Set the value of a specified attribute or attributes for each agent in the mask in AbstractAgentSetRegistry. + """Batch assign/replace AgentSets by index or name. Parameters ---------- - attr_names : DataFrameInput | str | Collection[str] - The key can be: - - A string: sets the specified column of the agents in the AbstractAgentSetRegistry. - - A collection of strings: sets the specified columns of the agents in the AbstractAgentSetRegistry. - - A dictionary: keys should be attributes and values should be the values to set. Value should be None. - values : Any | None - The value to set the attribute to. If None, attr_names must be a dictionary. - mask : AgentMask | None - The AgentMask of agents to set the attribute for. - inplace : bool - Whether to set the attribute in place. + mapping : dict[int | str, mesa_frames.abstract.agentset.AbstractAgentSet] | list[tuple[int | str, mesa_frames.abstract.agentset.AbstractAgentSet]] + Keys are indices or names to assign; values are AgentSets bound to the same model. + inplace : bool, optional + Whether to apply on this registry or return a copy, by default True. + atomic : bool, optional + When True, validates all keys and name invariants before applying any + change; either all assignments succeed or none are applied. Returns ------- Self - The updated agent set. + Updated registry. """ ... @abstractmethod def shuffle(self, inplace: bool = False) -> Self: - """Shuffles the order of agents in the AbstractAgentSetRegistry. + """Shuffle the order of AgentSets in the registry. Parameters ---------- inplace : bool - Whether to shuffle the agents in place. + Whether to shuffle in place. Returns ------- @@ -373,7 +406,7 @@ def sort( **kwargs, ) -> Self: """ - Sorts the agents in the agent set based on the given criteria. + Sort the AgentSets in the registry based on the given criteria. Parameters ---------- @@ -394,145 +427,75 @@ def sort( def __add__( self, - other: DataFrame - | DataFrameInput - | mesa_frames.abstract.agentset.AbstractAgentSet + other: mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet], ) -> Self: - """Add agents to a new AbstractAgentSetRegistry through the + operator. - - Parameters - ---------- - other : DataFrame | DataFrameInput | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] - The agents to add. - - Returns - ------- - Self - A new AbstractAgentSetRegistry with the added agents. - """ - return self.add(agents=other, inplace=False) + """Add AgentSets to a new AbstractAgentSetRegistry through the + operator.""" + return self.add(sets=other, inplace=False) def __contains__( - self, agents: int | mesa_frames.abstract.agentset.AbstractAgentSet + self, sets: mesa_frames.abstract.agentset.AbstractAgentSet ) -> bool: - """Check if an agent is in the AbstractAgentSetRegistry. - - Parameters - ---------- - agents : int | mesa_frames.abstract.agentset.AbstractAgentSet - The ID(s) or AbstractAgentSet to check for. + """Check if an AgentSet is in the AbstractAgentSetRegistry.""" + return bool(self.contains(sets=sets)) - Returns - ------- - bool - True if the agent is in the AbstractAgentSetRegistry, False otherwise. - """ - return self.contains(agents=agents) + @overload + def __getitem__( + self, key: int + ) -> mesa_frames.abstract.agentset.AbstractAgentSet: ... @overload def __getitem__( - self, key: str | tuple[AgentMask, str] - ) -> Series | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Series]: ... + self, key: str + ) -> mesa_frames.abstract.agentset.AbstractAgentSet: ... @overload def __getitem__( - self, - key: AgentMask | Collection[str] | tuple[AgentMask, Collection[str]], - ) -> ( - DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame] - ): ... + self, key: type[mesa_frames.abstract.agentset.AbstractAgentSet] + ) -> list[mesa_frames.abstract.agentset.AbstractAgentSet]: ... def __getitem__( - self, - key: ( - str - | Collection[str] - | AgentMask - | tuple[AgentMask, str] - | tuple[AgentMask, Collection[str]] - | tuple[ - dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str - ] - | tuple[ - dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], - Collection[str], - ] - ), + self, key: int | str | type[mesa_frames.abstract.agentset.AbstractAgentSet] ) -> ( - Series - | DataFrame - | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Series] - | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame] + mesa_frames.abstract.agentset.AbstractAgentSet + | list[mesa_frames.abstract.agentset.AbstractAgentSet] ): - """Implement the [] operator for the AbstractAgentSetRegistry. - - The key can be: - - An attribute or collection of attributes (eg. AbstractAgentSetRegistry["str"], AbstractAgentSetRegistry[["str1", "str2"]]): returns the specified column(s) of the agents in the AbstractAgentSetRegistry. - - An AgentMask (eg. AbstractAgentSetRegistry[AgentMask]): returns the agents in the AbstractAgentSetRegistry that satisfy the AgentMask. - - A tuple (eg. AbstractAgentSetRegistry[AgentMask, "str"]): returns the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask. - - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, "str"]): returns the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary. - - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, Collection[str]]): returns the specified columns of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary. - - Parameters - ---------- - key : str | Collection[str] | AgentMask | tuple[AgentMask, str] | tuple[AgentMask, Collection[str]] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], Collection[str]] - The key to retrieve. - - Returns - ------- - Series | DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Series] | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame] - The attribute values. - """ - # TODO: fix types - if isinstance(key, tuple): - return self.get(mask=key[0], attr_names=key[1]) - else: - if isinstance(key, str) or ( - isinstance(key, Collection) and all(isinstance(k, str) for k in key) - ): - return self.get(attr_names=key) - else: - return self.get(mask=key) + """Retrieve AgentSet(s) by index, name, or type.""" def __iadd__( self, other: ( - DataFrame - | DataFrameInput - | mesa_frames.abstract.agentset.AbstractAgentSet + mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] ), ) -> Self: - """Add agents to the AbstractAgentSetRegistry through the += operator. + """Add AgentSets to the registry through the += operator. Parameters ---------- - other : DataFrame | DataFrameInput | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] - The agents to add. + other : mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The AgentSets to add. Returns ------- Self The updated AbstractAgentSetRegistry. """ - return self.add(agents=other, inplace=True) + return self.add(sets=other, inplace=True) def __isub__( self, other: ( - IdsLike - | AgentMask - | mesa_frames.abstract.agentset.AbstractAgentSet + mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] ), ) -> Self: - """Remove agents from the AbstractAgentSetRegistry through the -= operator. + """Remove AgentSets from the registry through the -= operator. Parameters ---------- - other : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] - The agents to remove. + other : mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The AgentSets to remove. Returns ------- @@ -544,144 +507,130 @@ def __isub__( def __sub__( self, other: ( - IdsLike - | AgentMask - | mesa_frames.abstract.agentset.AbstractAgentSet + mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] ), ) -> Self: - """Remove agents from a new AbstractAgentSetRegistry through the - operator. + """Remove AgentSets from a new registry through the - operator. Parameters ---------- - other : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] - The agents to remove. + other : mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The AgentSets to remove. Returns ------- Self - A new AbstractAgentSetRegistry with the removed agents. + A new AbstractAgentSetRegistry with the removed AgentSets. """ return self.discard(other, inplace=False) def __setitem__( self, - key: ( - str - | Collection[str] - | AgentMask - | tuple[AgentMask, str | Collection[str]] - | tuple[ - dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str - ] - | tuple[ - dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], - Collection[str], - ] - ), - values: Any, + key: int | str, + value: mesa_frames.abstract.agentset.AbstractAgentSet, ) -> None: - """Implement the [] operator for setting values in the AbstractAgentSetRegistry. - - The key can be: - - A string (eg. AbstractAgentSetRegistry["str"]): sets the specified column of the agents in the AbstractAgentSetRegistry. - - A list of strings(eg. AbstractAgentSetRegistry[["str1", "str2"]]): sets the specified columns of the agents in the AbstractAgentSetRegistry. - - A tuple (eg. AbstractAgentSetRegistry[AgentMask, "str"]): sets the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask. - - A AgentMask (eg. AbstractAgentSetRegistry[AgentMask]): sets the attributes of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask. - - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, "str"]): sets the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary. - - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, Collection[str]]): sets the specified columns of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary. - - Parameters - ---------- - key : str | Collection[str] | AgentMask | tuple[AgentMask, str | Collection[str]] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], Collection[str]] - The key to set. - values : Any - The values to set for the specified key. - """ - # TODO: fix types as in __getitem__ - if isinstance(key, tuple): - self.set(mask=key[0], attr_names=key[1], values=values) - else: - if isinstance(key, str) or ( - isinstance(key, Collection) and all(isinstance(k, str) for k in key) - ): - try: - self.set(attr_names=key, values=values) - except KeyError: # key=AgentMask - self.set(attr_names=None, mask=key, values=values) - else: - self.set(attr_names=None, mask=key, values=values) + """Assign/replace a single AgentSet at an index or name. + + Mirrors the invariants of ``replace`` for single-key assignment: + - Names remain unique across the registry + - ``value.model is self.model`` + - For name keys, the key is authoritative for the assigned set's name + - For index keys, collisions on a different entry's name must raise + """ + if value.model is not self.model: + raise TypeError("Assigned AgentSet must belong to the same model") + if isinstance(key, int): + # Delegate to replace() so subclasses centralize invariant handling. + self.replace({key: value}, inplace=True, atomic=True) + return + if isinstance(key, str): + for existing in self: + if existing.name == key: + self.replace({key: value}, inplace=True, atomic=True) + return + try: + value.rename(key, inplace=True) + except Exception: + if hasattr(value, "_name"): + value._name = key # type: ignore[attr-defined] + self.add(value, inplace=True) + return + raise TypeError("Key must be int index or str name") @abstractmethod def __getattr__(self, name: str) -> Any | dict[str, Any]: - """Fallback for retrieving attributes of the AbstractAgentSetRegistry. Retrieve an attribute of the underlying DataFrame(s). - - Parameters - ---------- - name : str - The name of the attribute to retrieve. - - Returns - ------- - Any | dict[str, Any] - The attribute value - """ + """Fallback for retrieving attributes of the AgentSetRegistry.""" @abstractmethod - def __iter__(self) -> Iterator[dict[str, Any]]: - """Iterate over the agents in the AbstractAgentSetRegistry. - - Returns - ------- - Iterator[dict[str, Any]] - An iterator over the agents. - """ + def __iter__(self) -> Iterator[mesa_frames.abstract.agentset.AbstractAgentSet]: + """Iterate over AgentSets in the registry.""" ... @abstractmethod def __len__(self) -> int: - """Get the number of agents in the AbstractAgentSetRegistry. - - Returns - ------- - int - The number of agents in the AbstractAgentSetRegistry. - """ + """Get the number of AgentSets in the registry.""" ... @abstractmethod def __repr__(self) -> str: - """Get a string representation of the DataFrame in the AbstractAgentSetRegistry. - - Returns - ------- - str - A string representation of the DataFrame in the AbstractAgentSetRegistry. - """ + """Get a string representation of the AgentSets in the registry.""" pass @abstractmethod def __reversed__(self) -> Iterator: - """Iterate over the agents in the AbstractAgentSetRegistry in reverse order. - - Returns - ------- - Iterator - An iterator over the agents in reverse order. - """ + """Iterate over AgentSets in reverse order.""" ... @abstractmethod def __str__(self) -> str: - """Get a string representation of the agents in the AbstractAgentSetRegistry. - - Returns - ------- - str - A string representation of the agents in the AbstractAgentSetRegistry. - """ + """Get a string representation of the AgentSets in the registry.""" ... + def keys( + self, *, key_by: KeyBy = "name" + ) -> Iterable[str | int | type[mesa_frames.abstract.agentset.AbstractAgentSet]]: + """Iterate keys for contained AgentSets (by name|index|type).""" + if key_by == "index": + yield from range(len(self)) + return + if key_by == "type": + for agentset in self: + yield type(agentset) + return + if key_by != "name": + raise ValueError("key_by must be 'name'|'index'|'type'") + for agentset in self: + if agentset.name is not None: + yield agentset.name + + def items( + self, *, key_by: KeyBy = "name" + ) -> Iterable[ + tuple[ + str | int | type[mesa_frames.abstract.agentset.AbstractAgentSet], + mesa_frames.abstract.agentset.AbstractAgentSet, + ] + ]: + """Iterate (key, AgentSet) pairs for contained sets.""" + if key_by == "index": + for idx, agentset in enumerate(self): + yield idx, agentset + return + if key_by == "type": + for agentset in self: + yield type(agentset), agentset + return + if key_by != "name": + raise ValueError("key_by must be 'name'|'index'|'type'") + for agentset in self: + if agentset.name is not None: + yield agentset.name, agentset + + def values(self) -> Iterable[mesa_frames.abstract.agentset.AbstractAgentSet]: + """Iterate contained AgentSets (values view).""" + yield from self + @property def model(self) -> mesa_frames.concrete.model.Model: """The model that the AbstractAgentSetRegistry belongs to. @@ -714,85 +663,12 @@ def space(self) -> mesa_frames.abstract.space.Space | None: @property @abstractmethod - def df(self) -> DataFrame | dict[str, DataFrame]: - """The agents in the AbstractAgentSetRegistry. - - Returns - ------- - DataFrame | dict[str, DataFrame] - """ - - @df.setter - @abstractmethod - def df( - self, agents: DataFrame | list[mesa_frames.abstract.agentset.AbstractAgentSet] - ) -> None: - """Set the agents in the AbstractAgentSetRegistry. - - Parameters - ---------- - agents : DataFrame | list[mesa_frames.abstract.agentset.AbstractAgentSet] - """ - - @property - @abstractmethod - def active_agents(self) -> DataFrame | dict[str, DataFrame]: - """The active agents in the AbstractAgentSetRegistry. - - Returns - ------- - DataFrame | dict[str, DataFrame] - """ - - @active_agents.setter - @abstractmethod - def active_agents( - self, - mask: AgentMask, - ) -> None: - """Set the active agents in the AbstractAgentSetRegistry. - - Parameters - ---------- - mask : AgentMask - The mask to apply. - """ - self.select(mask=mask, inplace=True) - - @property - @abstractmethod - def inactive_agents( - self, - ) -> DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]: - """The inactive agents in the AbstractAgentSetRegistry. + def ids(self) -> Series: + """Public view of all agent unique_id values across contained sets. Returns ------- - DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame] - """ - - @property - @abstractmethod - def index( - self, - ) -> Index | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Index]: - """The ids in the AbstractAgentSetRegistry. - - Returns - ------- - Index | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Index] - """ - ... - - @property - @abstractmethod - def pos( - self, - ) -> DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]: - """The position of the agents in the AbstractAgentSetRegistry. - - Returns - ------- - DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame] + Series + Concatenated unique_id Series for all AgentSets. """ ... diff --git a/mesa_frames/abstract/datacollector.py b/mesa_frames/abstract/datacollector.py index edbfb11f..6505408f 100644 --- a/mesa_frames/abstract/datacollector.py +++ b/mesa_frames/abstract/datacollector.py @@ -91,7 +91,11 @@ def __init__( model_reporters : dict[str, Callable] | None Functions to collect data at the model level. agent_reporters : dict[str, str | Callable] | None - Attributes or functions to collect data at the agent level. + Agent-level reporters. Values may be: + - str or list[str]: pull existing columns from each set; columns are suffixed per-set. + - Callable[[AbstractAgentSetRegistry], Series | DataFrame | dict[str, Series|DataFrame]]: registry-level, runs once per step. + - Callable[[mesa_frames.abstract.agentset.AbstractAgentSet], Series | DataFrame]: set-level, runs once per set. + Note: model-level callables are not supported for agent reporters. trigger : Callable[[Any], bool] | None A function(model) -> bool that determines whether to collect data. reset_memory : bool diff --git a/mesa_frames/abstract/space.py b/mesa_frames/abstract/space.py index 74df16e8..39abe6bd 100644 --- a/mesa_frames/abstract/space.py +++ b/mesa_frames/abstract/space.py @@ -7,14 +7,14 @@ performance and scalability. Classes: - SpaceDF(CopyMixin, DataFrameMixin): + Space(CopyMixin, DataFrameMixin): An abstract base class that defines the common interface for all space classes in mesa-frames. It combines fast copying functionality with DataFrame operations. - AbstractDiscreteSpace(SpaceDF): + AbstractDiscreteSpace(Space): An abstract base class for discrete space implementations, such as grids - and networks. It extends SpaceDF with methods specific to discrete spaces. + and networks. It extends Space with methods specific to discrete spaces. AbstractGrid(AbstractDiscreteSpace): An abstract base class for grid-based spaces. It inherits from @@ -52,7 +52,7 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type): from abc import abstractmethod from collections.abc import Callable, Collection, Sequence, Sized from itertools import product -from typing import Any, Literal, Self +from typing import Any, Literal, Self, cast from warnings import warn import numpy as np @@ -64,7 +64,6 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type): AbstractAgentSetRegistry, ) from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin -from mesa_frames.concrete.agentsetregistry import AgentSetRegistry from mesa_frames.types_ import ( ArrayLike, BoolSeries, @@ -98,7 +97,7 @@ class Space(CopyMixin, DataFrameMixin): ] # The column names of the positions in the _agents dataframe (eg. ['dim_0', 'dim_1', ...] in Grids, ['node_id', 'edge_id'] in Networks) def __init__(self, model: mesa_frames.concrete.model.Model) -> None: - """Create a new SpaceDF. + """Create a new Space. Parameters ---------- @@ -109,7 +108,9 @@ def __init__(self, model: mesa_frames.concrete.model.Model) -> None: def move_agents( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], pos: SpaceCoordinate | SpaceCoordinates, inplace: bool = True, @@ -120,7 +121,7 @@ def move_agents( Parameters ---------- - agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] + agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] The agents to move pos : SpaceCoordinate | SpaceCoordinates The coordinates for each agents. The length of the coordinates must match the number of agents. @@ -145,7 +146,9 @@ def move_agents( def place_agents( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], pos: SpaceCoordinate | SpaceCoordinates, inplace: bool = True, @@ -154,7 +157,7 @@ def place_agents( Parameters ---------- - agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] + agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] The agents to place in the space pos : SpaceCoordinate | SpaceCoordinates The coordinates for each agents. The length of the coordinates must match the number of agents. @@ -198,10 +201,14 @@ def random_agents( def swap_agents( self, agents0: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], agents1: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: @@ -211,9 +218,9 @@ def swap_agents( Parameters ---------- - agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] + agents0 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] The first set of agents to swap - agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] + agents1 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] The second set of agents to swap inplace : bool, optional Whether to perform the operation inplace, by default True @@ -222,26 +229,27 @@ def swap_agents( ------- Self """ - agents0 = self._get_ids_srs(agents0) - agents1 = self._get_ids_srs(agents1) + # Normalize inputs to Series of ids for validation and operations + ids0 = self._get_ids_srs(agents0) + ids1 = self._get_ids_srs(agents1) if __debug__: - if len(agents0) != len(agents1): + if len(ids0) != len(ids1): raise ValueError("The two sets of agents must have the same length") - if not self._df_contains(self._agents, "agent_id", agents0).all(): + if not self._df_contains(self._agents, "agent_id", ids0).all(): raise ValueError("Some agents in agents0 are not in the space") - if not self._df_contains(self._agents, "agent_id", agents1).all(): + if not self._df_contains(self._agents, "agent_id", ids1).all(): raise ValueError("Some agents in agents1 are not in the space") - if self._srs_contains(agents0, agents1).any(): + if self._srs_contains(ids0, ids1).any(): raise ValueError("Some agents are present in both agents0 and agents1") obj = self._get_obj(inplace) agents0_df = obj._df_get_masked_df( - obj._agents, index_cols="agent_id", mask=agents0 + obj._agents, index_cols="agent_id", mask=ids0 ) agents1_df = obj._df_get_masked_df( - obj._agents, index_cols="agent_id", mask=agents1 + obj._agents, index_cols="agent_id", mask=ids1 ) - agents0_df = obj._df_set_index(agents0_df, "agent_id", agents1) - agents1_df = obj._df_set_index(agents1_df, "agent_id", agents0) + agents0_df = obj._df_set_index(agents0_df, "agent_id", ids1) + agents1_df = obj._df_set_index(agents1_df, "agent_id", ids0) obj._agents = obj._df_combine_first( agents0_df, obj._agents, index_cols="agent_id" ) @@ -257,11 +265,15 @@ def get_directions( pos0: SpaceCoordinate | SpaceCoordinates | None = None, pos1: SpaceCoordinate | SpaceCoordinates | None = None, agents0: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None = None, agents1: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None = None, normalize: bool = False, @@ -278,9 +290,9 @@ def get_directions( The starting positions pos1 : SpaceCoordinate | SpaceCoordinates | None, optional The ending positions - agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional + agents0 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional The starting agents - agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional + agents1 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional The ending agents normalize : bool, optional Whether to normalize the vectors to unit norm. By default False @@ -298,11 +310,15 @@ def get_distances( pos0: SpaceCoordinate | SpaceCoordinates | None = None, pos1: SpaceCoordinate | SpaceCoordinates | None = None, agents0: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None = None, agents1: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None = None, ) -> DataFrame: @@ -318,9 +334,9 @@ def get_distances( The starting positions pos1 : SpaceCoordinate | SpaceCoordinates | None, optional The ending positions - agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional + agents0 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional The starting agents - agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional + agents1 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional The ending agents Returns @@ -336,7 +352,9 @@ def get_neighbors( radius: int | float | Sequence[int] | Sequence[float] | ArrayLike, pos: SpaceCoordinate | SpaceCoordinates | None = None, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None = None, include_center: bool = False, @@ -351,7 +369,7 @@ def get_neighbors( The radius(es) of the neighborhood pos : SpaceCoordinate | SpaceCoordinates | None, optional The coordinates of the cell to get the neighborhood from, by default None - agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional + agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional The id of the agents to get the neighborhood from, by default None include_center : bool, optional If the center cells or agents should be included in the result, by default False @@ -373,7 +391,9 @@ def get_neighbors( def move_to_empty( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: @@ -381,7 +401,7 @@ def move_to_empty( Parameters ---------- - agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] + agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] The agents to move to empty cells/positions inplace : bool, optional Whether to perform the operation inplace, by default True @@ -396,7 +416,9 @@ def move_to_empty( def place_to_empty( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: @@ -404,7 +426,7 @@ def place_to_empty( Parameters ---------- - agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] + agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] The agents to place in empty cells/positions inplace : bool, optional Whether to perform the operation inplace, by default True @@ -438,7 +460,9 @@ def random_pos( def remove_agents( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: @@ -448,7 +472,7 @@ def remove_agents( Parameters ---------- - agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] + agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] The agents to remove from the space inplace : bool, optional Whether to perform the operation inplace, by default True @@ -467,7 +491,9 @@ def remove_agents( def _get_ids_srs( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], ) -> Series: if isinstance(agents, Sized) and len(agents) == 0: @@ -478,10 +504,11 @@ def _get_ids_srs( name="agent_id", dtype="uint64", ) - elif isinstance(agents, AgentSetRegistry): - return self._srs_constructor(agents._ids, name="agent_id", dtype="uint64") + elif isinstance(agents, AbstractAgentSetRegistry): + return self._srs_constructor(agents.ids, name="agent_id", dtype="uint64") elif isinstance(agents, Collection) and ( - isinstance(agents[0], AbstractAgentSetRegistry) + isinstance(agents[0], AbstractAgentSet) + or isinstance(agents[0], AbstractAgentSetRegistry) ): ids = [] for a in agents: @@ -493,9 +520,9 @@ def _get_ids_srs( dtype="uint64", ) ) - elif isinstance(a, AgentSetRegistry): + elif isinstance(a, AbstractAgentSetRegistry): ids.append( - self._srs_constructor(a._ids, name="agent_id", dtype="uint64") + self._srs_constructor(a.ids, name="agent_id", dtype="uint64") ) return self._df_concat(ids, ignore_index=True) elif isinstance(agents, int): @@ -657,7 +684,9 @@ def move_to_empty( self, agents: IdsLike | AbstractAgentSetRegistry - | Collection[AbstractAgentSetRegistry], + | Collection[AbstractAgentSetRegistry] + | AbstractAgentSet + | Collection[AbstractAgentSet], inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) @@ -668,7 +697,9 @@ def move_to_empty( def move_to_available( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: @@ -676,7 +707,7 @@ def move_to_available( Parameters ---------- - agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] + agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] The agents to move to available cells/positions inplace : bool, optional Whether to perform the operation inplace, by default True @@ -686,6 +717,7 @@ def move_to_available( Self """ obj = self._get_obj(inplace) + return obj._place_or_move_agents_to_cells( agents, cell_type="available", is_move=True ) @@ -693,11 +725,14 @@ def move_to_available( def place_to_empty( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) + return obj._place_or_move_agents_to_cells( agents, cell_type="empty", is_move=False ) @@ -705,7 +740,9 @@ def place_to_empty( def place_to_available( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: @@ -823,7 +860,9 @@ def get_neighborhood( radius: int | float | Sequence[int] | Sequence[float] | ArrayLike, pos: DiscreteCoordinate | DiscreteCoordinates | None = None, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] = None, include_center: bool = False, ) -> DataFrame: @@ -837,7 +876,7 @@ def get_neighborhood( The radius(es) of the neighborhoods pos : DiscreteCoordinate | DiscreteCoordinates | None, optional The coordinates of the cell(s) to get the neighborhood from - agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry], optional + agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], optional The agent(s) to get the neighborhood from include_center : bool, optional If the cell in the center of the neighborhood should be included in the result, by default False @@ -933,7 +972,9 @@ def _check_cells( def _place_or_move_agents_to_cells( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], cell_type: Literal["any", "empty", "available"], is_move: bool, @@ -942,8 +983,8 @@ def _place_or_move_agents_to_cells( agents = self._get_ids_srs(agents) if __debug__: - # Check ids presence in model - b_contained = self.model.sets.contains(agents) + # Check ids presence in model using public API + b_contained = agents.is_in(self.model.sets.ids) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): @@ -994,7 +1035,7 @@ def _sample_cells( self, n: int | None, with_replacement: bool, - condition: Callable[[DiscreteSpaceCapacity], BoolSeries], + condition: Callable[[DiscreteSpaceCapacity], BoolSeries | np.ndarray], respect_capacity: bool = True, ) -> DataFrame: """Sample cells from the grid according to a condition on the capacity. @@ -1005,7 +1046,7 @@ def _sample_cells( The number of cells to sample. If None, samples the maximum available. with_replacement : bool If the sampling should be with replacement - condition : Callable[[DiscreteSpaceCapacity], BoolSeries] + condition : Callable[[DiscreteSpaceCapacity], BoolSeries | np.ndarray] The condition to apply on the capacity respect_capacity : bool, optional If the capacity should be respected in the sampling. @@ -1259,11 +1300,15 @@ def get_directions( pos0: GridCoordinate | GridCoordinates | None = None, pos1: GridCoordinate | GridCoordinates | None = None, agents0: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None = None, agents1: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None = None, normalize: bool = False, @@ -1278,11 +1323,15 @@ def get_distances( pos0: GridCoordinate | GridCoordinates | None = None, pos1: GridCoordinate | GridCoordinates | None = None, agents0: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None = None, agents1: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None = None, ) -> DataFrame: @@ -1311,7 +1360,7 @@ def get_neighbors( def get_neighborhood( self, radius: int | Sequence[int] | ArrayLike, - pos: GridCoordinate | GridCoordinates | None = None, + pos: DiscreteCoordinate | DiscreteCoordinates | None = None, agents: IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] @@ -1549,7 +1598,9 @@ def out_of_bounds(self, pos: GridCoordinate | GridCoordinates) -> DataFrame: def remove_agents( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: @@ -1558,8 +1609,8 @@ def remove_agents( agents = obj._get_ids_srs(agents) if __debug__: - # Check ids presence in model - b_contained = obj.model.sets.contains(agents) + # Check ids presence in model via public ids + b_contained = agents.is_in(obj.model.sets.ids) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): @@ -1594,11 +1645,15 @@ def _calculate_differences( pos0: GridCoordinate | GridCoordinates | None, pos1: GridCoordinate | GridCoordinates | None, agents0: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, agents1: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, ) -> DataFrame: @@ -1610,9 +1665,9 @@ def _calculate_differences( The starting positions pos1 : GridCoordinate | GridCoordinates | None The ending positions - agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None + agents0 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None The starting agents - agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None + agents1 : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None The ending agents Returns @@ -1694,7 +1749,9 @@ def _get_df_coords( self, pos: GridCoordinate | GridCoordinates | None = None, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None = None, check_bounds: bool = True, @@ -1705,7 +1762,7 @@ def _get_df_coords( ---------- pos : GridCoordinate | GridCoordinates | None, optional The positions to get the DataFrame from, by default None - agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional + agents : IdsLike | AbstractAgentSet | AbstractAgentSetRegistry | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry] | None, optional The agents to get the DataFrame from, by default None check_bounds: bool, optional If the positions should be checked for out-of-bounds in non-toroidal grids, by default True @@ -1735,7 +1792,7 @@ def _get_df_coords( if agents is not None: agents = self._get_ids_srs(agents) # Check ids presence in model - b_contained = self.model.sets.contains(agents) + b_contained = agents.is_in(self.model.sets.ids) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): @@ -1796,7 +1853,9 @@ def _get_df_coords( def _place_or_move_agents( self, agents: IdsLike + | AbstractAgentSet | AbstractAgentSetRegistry + | Collection[AbstractAgentSet] | Collection[AbstractAgentSetRegistry], pos: GridCoordinate | GridCoordinates, is_move: bool, @@ -1812,8 +1871,8 @@ def _place_or_move_agents( if self._df_contains(self._agents, "agent_id", agents).any(): warn("Some agents are already present in the grid", RuntimeWarning) - # Check if agents are present in the model - b_contained = self.model.sets.contains(agents) + # Check if agents are present in the model using the public ids + b_contained = agents.is_in(self.model.sets.ids) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): diff --git a/mesa_frames/concrete/agentset.py b/mesa_frames/concrete/agentset.py index 5c64aef6..2a9b1a55 100644 --- a/mesa_frames/concrete/agentset.py +++ b/mesa_frames/concrete/agentset.py @@ -67,7 +67,7 @@ def step(self): from mesa_frames.abstract.agentset import AbstractAgentSet from mesa_frames.concrete.mixin import PolarsMixin -from mesa_frames.types_ import AgentPolarsMask, IntoExpr, PolarsIdsLike +from mesa_frames.types_ import AgentMask, AgentPolarsMask, IntoExpr, PolarsIdsLike from mesa_frames.utils import copydoc @@ -82,19 +82,76 @@ class AgentSet(AbstractAgentSet, PolarsMixin): _copy_only_reference: list[str] = ["_model", "_mask"] _mask: pl.Expr | pl.Series - def __init__(self, model: mesa_frames.concrete.model.Model) -> None: + def __init__( + self, model: mesa_frames.concrete.model.Model, name: str | None = None + ) -> None: """Initialize a new AgentSet. Parameters ---------- model : "mesa_frames.concrete.model.Model" The model that the agent set belongs to. + name : str | None, optional + Name for this agent set. If None, class name is used. + Will be converted to snake_case if in camelCase. """ + # Model reference self._model = model + # Set proposed name (no uniqueness guarantees here) + self._name = name if name is not None else self.__class__.__name__ # No definition of schema with unique_id, as it becomes hard to add new agents self._df = pl.DataFrame() self._mask = pl.repeat(True, len(self._df), dtype=pl.Boolean, eager=True) + def rename(self, new_name: str, inplace: bool = True) -> Self: + """Rename this agent set. If attached to AgentSetRegistry, delegate for uniqueness enforcement. + + Parameters + ---------- + new_name : str + Desired new name. + + inplace : bool, optional + Whether to perform the rename in place. If False, a renamed copy is + returned, by default True. + + Returns + ------- + Self + The updated AgentSet (or a renamed copy when ``inplace=False``). + + Raises + ------ + ValueError + If name conflicts occur and delegate encounters errors. + """ + # Respect inplace semantics consistently with other mutators + obj = self._get_obj(inplace) + + # Always delegate to the container's accessor if available through the model's sets + # Check if we have a model and can find the AgentSetRegistry that contains this set + try: + if self in self.model.sets: + # Save index to locate the copy on non-inplace path + try: + idx = list(self.model.sets).index(self) # type: ignore[arg-type] + except Exception: + idx = None + reg = self.model.sets.rename(self, new_name, inplace=inplace) + if inplace: + return self + if idx is not None: + return reg[idx] + return reg.get(new_name) # type: ignore[return-value] + except Exception: + # Fall back to local rename if delegation fails + obj._name = new_name + return obj + + # Set name locally if no container found + obj._name = new_name + return obj + def add( self, agents: pl.DataFrame | Sequence[Any] | dict[str, Any], @@ -181,6 +238,64 @@ def contains( else: return agents in self._df["unique_id"] + @overload + def do( + self, + method_name: str, + *args, + mask: AgentMask | None = None, + return_results: Literal[False] = False, + inplace: bool = True, + **kwargs, + ) -> Self: ... + + @overload + def do( + self, + method_name: str, + *args, + mask: AgentMask | None = None, + return_results: Literal[True], + inplace: bool = True, + **kwargs, + ) -> Any: ... + + def do( + self, + method_name: str, + *args, + mask: AgentMask | None = None, + return_results: bool = False, + inplace: bool = True, + **kwargs, + ) -> Self | Any: + masked_df = self._get_masked_df(mask) + # If the mask is empty, we can use the object as is + if len(masked_df) == len(self._df): + obj = self._get_obj(inplace) + method = getattr(obj, method_name) + result = method(*args, **kwargs) + else: # If the mask is not empty, we need to create a new masked AbstractAgentSet and concatenate the AbstractAgentSets at the end + obj = self._get_obj(inplace=False) + obj._df = masked_df + original_masked_index = obj._get_obj_copy(obj.index) + method = getattr(obj, method_name) + result = method(*args, **kwargs) + obj._concatenate_agentsets( + [self], + duplicates_allowed=True, + keep_first_only=True, + original_masked_index=original_masked_index, + ) + if inplace: + for key, value in obj.__dict__.items(): + setattr(self, key, value) + obj = self + if return_results: + return result + else: + return obj + def get( self, attr_names: IntoExpr | Iterable[IntoExpr] | None, @@ -198,6 +313,20 @@ def get( return masked_df[masked_df.columns[0]] return masked_df + def remove(self, agents: PolarsIdsLike | AgentMask, inplace: bool = True) -> Self: + if isinstance(agents, str) and agents == "active": + agents = self.active_agents + if agents is None or (isinstance(agents, Iterable) and len(agents) == 0): + return self._get_obj(inplace) + obj = self._get_obj(inplace) + # Normalize to Series of unique_ids + ids = obj._df_index(obj._get_masked_df(agents), "unique_id") + # Validate presence + if not ids.is_in(obj._df["unique_id"]).all(): + raise KeyError("Some 'unique_id' of mask are not present in this AgentSet.") + # Remove by ids + return obj._discard(ids) + def set( self, attr_names: str | Collection[str] | dict[str, Any] | None = None, @@ -463,7 +592,9 @@ def _update_mask( else: self._mask = self._df["unique_id"].is_in(original_active_indices) - def __getattr__(self, key: str) -> pl.Series: + def __getattr__(self, key: str) -> Any: + if key == "name": + return self.name super().__getattr__(key) return self._df[key] @@ -546,3 +677,13 @@ def index(self) -> pl.Series: @property def pos(self) -> pl.DataFrame: return super().pos + + @property + def name(self) -> str: + """Return the name of the AgentSet.""" + return self._name + + @name.setter + def name(self, value: str) -> None: + """Set the name of the AgentSet.""" + self._name = value diff --git a/mesa_frames/concrete/agentsetregistry.py b/mesa_frames/concrete/agentsetregistry.py index b9ed1563..4b486ba2 100644 --- a/mesa_frames/concrete/agentsetregistry.py +++ b/mesa_frames/concrete/agentsetregistry.py @@ -46,26 +46,17 @@ def step(self): from __future__ import annotations # For forward references -from collections import defaultdict -from collections.abc import Callable, Collection, Iterable, Iterator, Sequence -from typing import Any, Literal, Self, cast, overload - -import numpy as np +from collections.abc import Collection, Iterable, Iterator, Sequence +from typing import Any, Literal, Self, overload, cast +from collections.abc import Sized +from itertools import chain import polars as pl from mesa_frames.abstract.agentsetregistry import ( AbstractAgentSetRegistry, ) from mesa_frames.concrete.agentset import AgentSet -from mesa_frames.types_ import ( - AgentMask, - AgnosticAgentMask, - BoolSeries, - DataFrame, - IdsLike, - Index, - Series, -) +from mesa_frames.types_ import BoolSeries, KeyBy, AgentSetSelector class AgentSetRegistry(AbstractAgentSetRegistry): @@ -88,34 +79,31 @@ def __init__(self, model: mesa_frames.concrete.model.Model) -> None: def add( self, - agents: AgentSet | Iterable[AgentSet], + sets: AgentSet | Iterable[AgentSet], inplace: bool = True, ) -> Self: - """Add an AgentSet to the AgentSetRegistry. - - Parameters - ---------- - agents : AgentSet | Iterable[AgentSet] - The AgentSets to add. - inplace : bool, optional - Whether to add the AgentSets in place. Defaults to True. - - Returns - ------- - Self - The updated AgentSetRegistry. - - Raises - ------ - ValueError - If any AgentSets are already present or if IDs are not unique. - """ obj = self._get_obj(inplace) - other_list = obj._return_agentsets_list(agents) + other_list = obj._return_agentsets_list(sets) if obj._check_agentsets_presence(other_list).any(): raise ValueError( "Some agentsets are already present in the AgentSetRegistry." ) + # Ensure unique names across existing and to-be-added sets + existing_names = {s.name for s in obj._agentsets} + for agentset in other_list: + base_name = agentset.name or agentset.__class__.__name__ + name = base_name + if name in existing_names: + counter = 1 + candidate = f"{base_name}_{counter}" + while candidate in existing_names: + counter += 1 + candidate = f"{base_name}_{counter}" + name = candidate + # Assign back if changed or was None + if name != (agentset.name or base_name): + agentset.name = name + existing_names.add(name) new_ids = pl.concat( [obj._ids] + [pl.Series(agentset["unique_id"]) for agentset in other_list] ) @@ -125,215 +113,427 @@ def add( obj._ids = new_ids return obj + def rename( + self, + target: ( + AgentSet + | str + | dict[AgentSet | str, str] + | list[tuple[AgentSet | str, str]] + ), + new_name: str | None = None, + *, + on_conflict: Literal["canonicalize", "raise"] = "canonicalize", + mode: Literal["atomic", "best_effort"] = "atomic", + inplace: bool = True, + ) -> Self: + """Rename AgentSets with conflict handling. + + Supports single-target ``(set | old_name, new_name)`` and batch rename via + dict or list of pairs. Names remain unique across the registry. + """ + + # Normalize to list of (index_in_self, desired_name) using the original registry + def _resolve_one(x: AgentSet | str) -> int: + if isinstance(x, AgentSet): + for i, s in enumerate(self._agentsets): + if s is x: + return i + raise KeyError("AgentSet not found in registry") + # name lookup on original registry + for i, s in enumerate(self._agentsets): + if s.name == x: + return i + raise KeyError(f"Agent set '{x}' not found") + + if isinstance(target, (AgentSet, str)): + if new_name is None: + raise TypeError("new_name must be provided for single rename") + pairs_idx: list[tuple[int, str]] = [(_resolve_one(target), new_name)] + single = True + elif isinstance(target, dict): + pairs_idx = [(_resolve_one(k), v) for k, v in target.items()] + single = False + else: + pairs_idx = [(_resolve_one(k), v) for k, v in target] + single = False + + # Choose object to mutate + obj = self._get_obj(inplace) + # Translate indices to object AgentSets in the selected registry object + target_sets = [obj._agentsets[i] for i, _ in pairs_idx] + + # Build the set of names that remain fixed (exclude targets' current names) + targets_set = set(target_sets) + fixed_names: set[str] = { + s.name + for s in obj._agentsets + if s.name is not None and s not in targets_set + } # type: ignore[comparison-overlap] + + # Plan final names + final: list[tuple[AgentSet, str]] = [] + used = set(fixed_names) + + def _canonicalize(base: str) -> str: + if base not in used: + used.add(base) + return base + counter = 1 + cand = f"{base}_{counter}" + while cand in used: + counter += 1 + cand = f"{base}_{counter}" + used.add(cand) + return cand + + errors: list[Exception] = [] + for aset, (_idx, desired) in zip(target_sets, pairs_idx): + if on_conflict == "canonicalize": + final_name = _canonicalize(desired) + final.append((aset, final_name)) + else: # on_conflict == 'raise' + if desired in used: + err = ValueError( + f"Duplicate agent set name disallowed: '{desired}'" + ) + if mode == "atomic": + errors.append(err) + else: + # best_effort: skip this rename + continue + else: + used.add(desired) + final.append((aset, desired)) + + if errors and mode == "atomic": + # Surface first meaningful error + raise errors[0] + + # Apply renames + for aset, newn in final: + # Set the private name directly to avoid external uniqueness hooks + if hasattr(aset, "_name"): + aset._name = newn # type: ignore[attr-defined] + + return obj + + def replace( + self, + mapping: (dict[int | str, AgentSet] | list[tuple[int | str, AgentSet]]), + *, + inplace: bool = True, + atomic: bool = True, + ) -> Self: + # Normalize to list of (key, value) + items: list[tuple[int | str, AgentSet]] + if isinstance(mapping, dict): + items = list(mapping.items()) + else: + items = list(mapping) + + obj = self._get_obj(inplace) + + # Helpers (build name->idx map only if needed) + has_str_keys = any(isinstance(k, str) for k, _ in items) + if has_str_keys: + name_to_idx = { + s.name: i for i, s in enumerate(obj._agentsets) if s.name is not None + } + + def _find_index_by_name(name: str) -> int: + try: + return name_to_idx[name] + except KeyError: + raise KeyError(f"Agent set '{name}' not found") + else: + + def _find_index_by_name(name: str) -> int: + for i, s in enumerate(obj._agentsets): + if s.name == name: + return i + + raise KeyError(f"Agent set '{name}' not found") + + if atomic: + n = len(obj._agentsets) + # Map existing object identity -> index (for aliasing checks) + id_to_idx = {id(s): i for i, s in enumerate(obj._agentsets)} + + for k, v in items: + if not isinstance(v, AgentSet): + raise TypeError("Values must be AgentSet instances") + if v.model is not obj.model: + raise TypeError( + "All AgentSets must belong to the same model as the registry" + ) + + v_idx_existing = id_to_idx.get(id(v)) + + if isinstance(k, int): + if not (0 <= k < n): + raise IndexError( + f"Index {k} out of range for AgentSetRegistry of size {n}" + ) + + # Prevent aliasing: the same object cannot appear in two positions + if v_idx_existing is not None and v_idx_existing != k: + raise ValueError( + f"This AgentSet instance already exists at index {v_idx_existing}; cannot also place it at {k}." + ) + + # Preserve name uniqueness when assigning by index + vname = v.name + if vname is not None: + try: + other_idx = _find_index_by_name(vname) + if other_idx != k: + raise ValueError( + f"Duplicate agent set name disallowed: '{vname}' already at index {other_idx}" + ) + except KeyError: + # name not present elsewhere -> OK + pass + + elif isinstance(k, str): + # Locate the slot by name; replacing that slot preserves uniqueness + idx = _find_index_by_name(k) + + # Prevent aliasing: if the same object already exists at a different slot, forbid + if v_idx_existing is not None and v_idx_existing != idx: + raise ValueError( + f"This AgentSet instance already exists at index {v_idx_existing}; cannot also place it at {idx}." + ) + + else: + raise TypeError("Keys must be int indices or str names") + + # Apply + target = obj if inplace else obj.copy(deep=False) + if not inplace: + target._agentsets = list(obj._agentsets) + + for k, v in items: + if isinstance(k, int): + target._agentsets[k] = v # keep v.name as-is (validated above) + else: + idx = _find_index_by_name(k) + # Force the authoritative name without triggering external uniqueness checks + if hasattr(v, "_name"): + v._name = k # type: ignore[attr-defined] + target._agentsets[idx] = v + + # Recompute ids cache + target._recompute_ids() + + return target + @overload - def contains(self, agents: int | AgentSet) -> bool: ... + def contains(self, sets: AgentSet | type[AgentSet] | str) -> bool: ... @overload - def contains(self, agents: IdsLike | Iterable[AgentSet]) -> pl.Series: ... + def contains( + self, + sets: Iterable[AgentSet] | Iterable[type[AgentSet]] | Iterable[str], + ) -> pl.Series: ... def contains( - self, agents: IdsLike | AgentSet | Iterable[AgentSet] + self, + sets: AgentSet + | type[AgentSet] + | str + | Iterable[AgentSet] + | Iterable[type[AgentSet]] + | Iterable[str], ) -> bool | pl.Series: - if isinstance(agents, int): - return agents in self._ids - elif isinstance(agents, AgentSet): - return self._check_agentsets_presence([agents]).any() - elif isinstance(agents, Iterable): - if len(agents) == 0: - return True - elif isinstance(next(iter(agents)), AgentSet): - agents = cast(Iterable[AgentSet], agents) - return self._check_agentsets_presence(list(agents)) - else: # IdsLike - agents = cast(IdsLike, agents) - - return pl.Series(agents, dtype=pl.UInt64).is_in(self._ids) + # Single value fast paths + if isinstance(sets, AgentSet): + return self._check_agentsets_presence([sets]).any() + if isinstance(sets, type) and issubclass(sets, AgentSet): + return any(isinstance(s, sets) for s in self._agentsets) + if isinstance(sets, str): + return any(s.name == sets for s in self._agentsets) + + # Iterable paths without materializing unnecessarily + + if isinstance(sets, Sized) and len(sets) == 0: # type: ignore[arg-type] + return True + it = iter(sets) # type: ignore[arg-type] + try: + first = next(it) + except StopIteration: + return True + + if isinstance(first, AgentSet): + lst = [first, *it] + return self._check_agentsets_presence(lst) + + if isinstance(first, type) and issubclass(first, AgentSet): + present_types = {type(s) for s in self._agentsets} + + def has_type(t: type[AgentSet]) -> bool: + return any(issubclass(pt, t) for pt in present_types) + + return pl.Series( + (has_type(t) for t in chain([first], it)), dtype=pl.Boolean + ) + + if isinstance(first, str): + names = {s.name for s in self._agentsets if s.name is not None} + return pl.Series((x in names for x in chain([first], it)), dtype=pl.Boolean) + + raise TypeError("Unsupported type for contains()") @overload def do( self, method_name: str, - *args, - mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, + *args: Any, + sets: AgentSetSelector | None = None, return_results: Literal[False] = False, inplace: bool = True, - **kwargs, + key_by: KeyBy = "name", + **kwargs: Any, ) -> Self: ... @overload def do( self, method_name: str, - *args, - mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, + *args: Any, + sets: AgentSetSelector, return_results: Literal[True], inplace: bool = True, - **kwargs, - ) -> dict[AgentSet, Any]: ... + key_by: KeyBy = "name", + **kwargs: Any, + ) -> dict[str, Any] | dict[int, Any] | dict[type[AgentSet], Any]: ... def do( self, method_name: str, - *args, - mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, + *args: Any, + sets: AgentSetSelector = None, return_results: bool = False, inplace: bool = True, - **kwargs, + key_by: KeyBy = "name", + **kwargs: Any, ) -> Self | Any: obj = self._get_obj(inplace) - agentsets_masks = obj._get_bool_masks(mask) + target_sets = obj._resolve_selector(sets) + + if not target_sets: + return {} if return_results else obj + + index_lookup = {id(s): idx for idx, s in enumerate(obj._agentsets)} + if return_results: - return { - agentset: agentset.do( + + def make_key(agentset: AgentSet) -> Any: + if key_by == "name": + return agentset.name + if key_by == "index": + try: + return index_lookup[id(agentset)] + except KeyError as exc: # pragma: no cover - defensive + raise ValueError( + "AgentSet not found in registry; cannot key by index." + ) from exc + if key_by == "type": + return type(agentset) + return agentset # backward-compatible: key by object + + results: dict[Any, Any] = {} + for agentset in target_sets: + key = make_key(agentset) + if key_by == "type" and key in results: + raise ValueError( + "Multiple agent sets of the same type were selected; " + "use key_by='name' or key_by='index' instead." + ) + results[key] = agentset.do( method_name, *args, - mask=mask, - return_results=return_results, - **kwargs, + return_results=True, inplace=inplace, - ) - for agentset, mask in agentsets_masks.items() - } - else: - obj._agentsets = [ - agentset.do( - method_name, - *args, - mask=mask, - return_results=return_results, **kwargs, - inplace=inplace, ) - for agentset, mask in agentsets_masks.items() - ] - return obj + return results + + updates: list[tuple[int, AgentSet]] = [] + for agentset in target_sets: + try: + registry_index = index_lookup[id(agentset)] + except KeyError as exc: # pragma: no cover - defensive + raise ValueError( + "AgentSet not found in registry; cannot apply operation." + ) from exc + updated = agentset.do( + method_name, + *args, + return_results=False, + inplace=inplace, + **kwargs, + ) + updates.append((registry_index, updated)) - def get( - self, - attr_names: str | Collection[str] | None = None, - mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, - ) -> dict[AgentSet, Series] | dict[AgentSet, DataFrame]: - agentsets_masks = self._get_bool_masks(mask) - result = {} - - # Convert attr_names to list for consistent checking - if attr_names is None: - # None means get all data - no column filtering needed - required_columns = [] - elif isinstance(attr_names, str): - required_columns = [attr_names] - else: - required_columns = list(attr_names) + for registry_index, updated in updates: + obj._agentsets[registry_index] = updated + obj._recompute_ids() + return obj - for agentset, mask in agentsets_masks.items(): - # Fast column existence check - no data processing, just property access - agentset_columns = agentset.df.columns + @overload + def get(self, key: int, default: None = ...) -> AgentSet | None: ... - # Check if all required columns exist in this agent set - if not required_columns or all( - col in agentset_columns for col in required_columns - ): - result[agentset] = agentset.get(attr_names, mask) + @overload + def get(self, key: str, default: None = ...) -> AgentSet | None: ... - return result + @overload + def get(self, key: type[AgentSet], default: None = ...) -> list[AgentSet]: ... - def remove( + @overload + def get( self, - agents: AgentSet | Iterable[AgentSet] | IdsLike, - inplace: bool = True, - ) -> Self: - obj = self._get_obj(inplace) - if agents is None or (isinstance(agents, Iterable) and len(agents) == 0): - return obj - if isinstance(agents, AgentSet): - agents = [agents] - if isinstance(agents, Iterable) and isinstance(next(iter(agents)), AgentSet): - # We have to get the index of the original AgentSet because the copy made AgentSets with different hash - ids = [self._agentsets.index(agentset) for agentset in iter(agents)] - ids.sort(reverse=True) - removed_ids = pl.Series(dtype=pl.UInt64) - for id in ids: - removed_ids = pl.concat( - [ - removed_ids, - pl.Series(obj._agentsets[id]["unique_id"], dtype=pl.UInt64), - ] - ) - obj._agentsets.pop(id) - - else: # IDsLike - if isinstance(agents, (int, np.uint64)): - agents = [agents] - elif isinstance(agents, DataFrame): - agents = agents["unique_id"] - removed_ids = pl.Series(agents, dtype=pl.UInt64) - deleted = 0 - - for agentset in obj._agentsets: - initial_len = len(agentset) - agentset._discard(removed_ids) - deleted += initial_len - len(agentset) - if deleted == len(removed_ids): - break - if deleted < len(removed_ids): # TODO: fix type hint - raise KeyError( - "There exist some IDs which are not present in any agentset" - ) - try: - obj.space.remove_agents(removed_ids, inplace=True) - except ValueError: - pass - obj._ids = obj._ids.filter(obj._ids.is_in(removed_ids).not_()) - return obj + key: int | str | type[AgentSet], + default: AgentSet | list[AgentSet] | None, + ) -> AgentSet | list[AgentSet] | None: ... - def select( + def get( self, - mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, - filter_func: Callable[[AgentSet], AgentMask] | None = None, - n: int | None = None, - inplace: bool = True, - negate: bool = False, - ) -> Self: - obj = self._get_obj(inplace) - agentsets_masks = obj._get_bool_masks(mask) - if n is not None: - n = n // len(agentsets_masks) - obj._agentsets = [ - agentset.select( - mask=mask, filter_func=filter_func, n=n, negate=negate, inplace=inplace - ) - for agentset, mask in agentsets_masks.items() - ] - return obj + key: int | str | type[AgentSet], + default: AgentSet | list[AgentSet] | None = None, + ) -> AgentSet | list[AgentSet] | None: + try: + if isinstance(key, int): + return self._agentsets[key] + if isinstance(key, str): + for s in self._agentsets: + if s.name == key: + return s + return default + if isinstance(key, type) and issubclass(key, AgentSet): + return [s for s in self._agentsets if isinstance(s, key)] + except (IndexError, KeyError, TypeError): + return default + return default - def set( + def remove( self, - attr_names: str | dict[AgentSet, Any] | Collection[str], - values: Any | None = None, - mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, + sets: AgentSetSelector, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) - agentsets_masks = obj._get_bool_masks(mask) - if isinstance(attr_names, dict): - for agentset, values in attr_names.items(): - if not inplace: - # We have to get the index of the original AgentSet because the copy made AgentSets with different hash - id = self._agentsets.index(agentset) - agentset = obj._agentsets[id] - agentset.set( - attr_names=values, mask=agentsets_masks[agentset], inplace=True - ) - else: - obj._agentsets = [ - agentset.set( - attr_names=attr_names, values=values, mask=mask, inplace=True - ) - for agentset, mask in agentsets_masks.items() - ] + # Normalize to a list of AgentSet instances using _resolve_selector + selected = obj._resolve_selector(sets) # type: ignore[arg-type] + # Remove in reverse positional order + indices = [i for i, s in enumerate(obj._agentsets) if s in selected] + indices.sort(reverse=True) + for idx in indices: + obj._agentsets.pop(idx) + # Recompute ids cache + obj._recompute_ids() return obj - def shuffle(self, inplace: bool = True) -> Self: + def shuffle(self, inplace: bool = False) -> Self: obj = self._get_obj(inplace) obj._agentsets = [agentset.shuffle(inplace=True) for agentset in obj._agentsets] return obj @@ -343,7 +543,7 @@ def sort( by: str | Sequence[str], ascending: bool | Sequence[bool] = True, inplace: bool = True, - **kwargs, + **kwargs: Any, ) -> Self: obj = self._get_obj(inplace) obj._agentsets = [ @@ -352,23 +552,6 @@ def sort( ] return obj - def step(self, inplace: bool = True) -> Self: - """Advance the state of the agents in the AgentSetRegistry by one step. - - Parameters - ---------- - inplace : bool, optional - Whether to update the AgentSetRegistry in place, by default True - - Returns - ------- - Self - """ - obj = self._get_obj(inplace) - for agentset in obj._agentsets: - agentset.step() - return obj - def _check_ids_presence(self, other: list[AgentSet]) -> pl.DataFrame: """Check if the IDs of the agents to be added are unique. @@ -425,17 +608,54 @@ def _check_agentsets_presence(self, other: list[AgentSet]) -> pl.Series: [agentset in other_set for agentset in self._agentsets], dtype=pl.Boolean ) - def _get_bool_masks( - self, - mask: (AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask]) = None, - ) -> dict[AgentSet, BoolSeries]: - return_dictionary = {} - if not isinstance(mask, dict): - # No need to convert numpy integers - let polars handle them directly - mask = {agentset: mask for agentset in self._agentsets} - for agentset, mask_value in mask.items(): - return_dictionary[agentset] = agentset._get_bool_mask(mask_value) - return return_dictionary + def _recompute_ids(self) -> None: + """Rebuild the registry-level `unique_id` cache from current AgentSets. + + Ensures `self._ids` stays a `pl.UInt64` Series and empty when no sets. + """ + if self._agentsets: + cols = [pl.Series(s["unique_id"]) for s in self._agentsets] + self._ids = ( + pl.concat(cols) + if cols + else pl.Series(name="unique_id", dtype=pl.UInt64) + ) + else: + self._ids = pl.Series(name="unique_id", dtype=pl.UInt64) + + def _resolve_selector(self, selector: AgentSetSelector = None) -> list[AgentSet]: + """Resolve a selector (instance/type/name or collection) to a list of AgentSets.""" + if selector is None: + return list(self._agentsets) + # Single instance + if isinstance(selector, AgentSet): + return [selector] if selector in self._agentsets else [] + # Single type + if isinstance(selector, type) and issubclass(selector, AgentSet): + return [s for s in self._agentsets if isinstance(s, selector)] + # Single name + if isinstance(selector, str): + return [s for s in self._agentsets if s.name == selector] + # Collection of mixed selectors + selected: list[AgentSet] = [] + for item in selector: # type: ignore[assignment] + if isinstance(item, AgentSet): + if item in self._agentsets: + selected.append(item) + elif isinstance(item, type) and issubclass(item, AgentSet): + selected.extend([s for s in self._agentsets if isinstance(s, item)]) + elif isinstance(item, str): + selected.extend([s for s in self._agentsets if s.name == item]) + else: + raise TypeError("Unsupported selector element type") + # Deduplicate while preserving order + seen = set() + result = [] + for s in selected: + if s not in seen: + seen.add(s) + result.append(s) + return result def _return_agentsets_list( self, agentsets: AgentSet | Iterable[AgentSet] @@ -452,192 +672,67 @@ def _return_agentsets_list( """ return [agentsets] if isinstance(agentsets, AgentSet) else list(agentsets) - def __add__(self, other: AgentSet | Iterable[AgentSet]) -> Self: - """Add AgentSets to a new AgentSetRegistry through the + operator. - - Parameters - ---------- - other : AgentSet | Iterable[AgentSet] - The AgentSets to add. - - Returns - ------- - Self - A new AgentSetRegistry with the added AgentSets. - """ - return super().__add__(other) - - def __getattr__(self, name: str) -> dict[AgentSet, Any]: + def _generate_name(self, base_name: str) -> str: + """Generate a unique name for an agent set.""" + existing_names = [ + agentset.name for agentset in self._agentsets if agentset.name is not None + ] + if base_name not in existing_names: + return base_name + counter = 1 + candidate = f"{base_name}_{counter}" + while candidate in existing_names: + counter += 1 + candidate = f"{base_name}_{counter}" + return candidate + + def __getattr__(self, name: str) -> Any | dict[str, Any]: # Avoids infinite recursion of private attributes - if __debug__: # Only execute in non-optimized mode - if name.startswith("_"): - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) - return {agentset: getattr(agentset, name) for agentset in self._agentsets} - - @overload - def __getitem__( - self, key: str | tuple[dict[AgentSet, AgentMask], str] - ) -> dict[AgentSet, Series | pl.Expr]: ... - - @overload - def __getitem__( - self, - key: ( - Collection[str] - | AgnosticAgentMask - | IdsLike - | tuple[dict[AgentSet, AgentMask], Collection[str]] - ), - ) -> dict[AgentSet, DataFrame]: ... - - def __getitem__( - self, - key: ( - str - | Collection[str] - | AgnosticAgentMask - | IdsLike - | tuple[dict[AgentSet, AgentMask], str] - | tuple[dict[AgentSet, AgentMask], Collection[str]] - ), - ) -> dict[AgentSet, Series | pl.Expr] | dict[AgentSet, DataFrame]: - return super().__getitem__(key) - - def __iadd__(self, agents: AgentSet | Iterable[AgentSet]) -> Self: - """Add AgentSets to the AgentSetRegistry through the += operator. - - Parameters - ---------- - agents : AgentSet | Iterable[AgentSet] - The AgentSets to add. - - Returns - ------- - Self - The updated AgentSetRegistry. - """ - return super().__iadd__(agents) - - def __iter__(self) -> Iterator[dict[str, Any]]: - return (agent for agentset in self._agentsets for agent in iter(agentset)) - - def __isub__(self, agents: AgentSet | Iterable[AgentSet] | IdsLike) -> Self: - """Remove AgentSets from the AgentSetRegistry through the -= operator. + if name.startswith("_"): + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + # Delegate attribute access to sets; map results by set name + return {cast(str, s.name): getattr(s, name) for s in self._agentsets} - Parameters - ---------- - agents : AgentSet | Iterable[AgentSet] | IdsLike - The AgentSets or agent IDs to remove. - - Returns - ------- - Self - The updated AgentSetRegistry. - """ - return super().__isub__(agents) + def __iter__(self) -> Iterator[AgentSet]: + return iter(self._agentsets) def __len__(self) -> int: - return sum(len(agentset._df) for agentset in self._agentsets) + return len(self._agentsets) def __repr__(self) -> str: return "\n".join([repr(agentset) for agentset in self._agentsets]) - def __reversed__(self) -> Iterator: - return ( - agent - for agentset in self._agentsets - for agent in reversed(agentset._backend) - ) - - def __setitem__( - self, - key: ( - str - | Collection[str] - | AgnosticAgentMask - | IdsLike - | tuple[dict[AgentSet, AgentMask], str] - | tuple[dict[AgentSet, AgentMask], Collection[str]] - ), - values: Any, - ) -> None: - super().__setitem__(key, values) + def __reversed__(self) -> Iterator[AgentSet]: + return reversed(self._agentsets) def __str__(self) -> str: return "\n".join([str(agentset) for agentset in self._agentsets]) - def __sub__(self, agents: AgentSet | Iterable[AgentSet] | IdsLike) -> Self: - """Remove AgentSets from a new AgentSetRegistry through the - operator. - - Parameters - ---------- - agents : AgentSet | Iterable[AgentSet] | IdsLike - The AgentSets or agent IDs to remove. Supports NumPy integer types. - - Returns - ------- - Self - A new AgentSetRegistry with the removed AgentSets. - """ - return super().__sub__(agents) - - @property - def df(self) -> dict[AgentSet, DataFrame]: - return {agentset: agentset.df for agentset in self._agentsets} - - @df.setter - def df(self, other: Iterable[AgentSet]) -> None: - """Set the agents in the AgentSetRegistry. - - Parameters - ---------- - other : Iterable[AgentSet] - The AgentSets to set. - """ - self._agentsets = list(other) - @property - def active_agents(self) -> dict[AgentSet, DataFrame]: - return {agentset: agentset.active_agents for agentset in self._agentsets} - - @active_agents.setter - def active_agents( - self, agents: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] - ) -> None: - self.select(agents, inplace=True) - - @property - def agentsets_by_type(self) -> dict[type[AgentSet], Self]: - """Get the agent sets in the AgentSetRegistry grouped by type. - - Returns - ------- - dict[type[AgentSet], Self] - A dictionary mapping agent set types to the corresponding AgentSetRegistry. - """ - - def copy_without_agentsets() -> Self: - return self.copy(deep=False, skip=["_agentsets"]) + def ids(self) -> pl.Series: + """Public view of all agent unique_id values across contained sets.""" + return self._ids - dictionary = defaultdict(copy_without_agentsets) - - for agentset in self._agentsets: - agents_df = dictionary[agentset.__class__] - agents_df._agentsets = [] - agents_df._agentsets = agents_df._agentsets + [agentset] - dictionary[agentset.__class__] = agents_df - return dictionary - - @property - def inactive_agents(self) -> dict[AgentSet, DataFrame]: - return {agentset: agentset.inactive_agents for agentset in self._agentsets} + @overload + def __getitem__(self, key: int) -> AgentSet: ... - @property - def index(self) -> dict[AgentSet, Index]: - return {agentset: agentset.index for agentset in self._agentsets} + @overload + def __getitem__(self, key: str) -> AgentSet: ... - @property - def pos(self) -> dict[AgentSet, DataFrame]: - return {agentset: agentset.pos for agentset in self._agentsets} + @overload + def __getitem__(self, key: type[AgentSet]) -> list[AgentSet]: ... + + def __getitem__(self, key: int | str | type[AgentSet]) -> AgentSet | list[AgentSet]: + """Retrieve AgentSet(s) by index, name, or type.""" + if isinstance(key, int): + return self._agentsets[key] + if isinstance(key, str): + for s in self._agentsets: + if s.name == key: + return s + raise KeyError(f"Agent set '{key}' not found") + if isinstance(key, type) and issubclass(key, AgentSet): + return [s for s in self._agentsets if isinstance(s, key)] + raise TypeError("Key must be int, str (name), or AgentSet type") diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index 2b50c76d..f7db338e 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -177,13 +177,94 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): Constructs a LazyFrame with one column per reporter and includes `step` and `seed` metadata. Appends it to internal storage. """ - agent_data_dict = {} + + def _is_str_collection(x: Any) -> bool: + try: + from collections.abc import Collection + + if isinstance(x, str): + return False + return isinstance(x, Collection) and all(isinstance(i, str) for i in x) + except Exception: + return False + + agent_data_dict: dict[str, pl.Series] = {} + for col_name, reporter in self._agent_reporters.items(): - if isinstance(reporter, str): - for k, v in self._model.sets[reporter].items(): - agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v - else: - agent_data_dict[col_name] = reporter(self._model) + # 1) String or collection[str]: shorthand to fetch columns + if isinstance(reporter, str) or _is_str_collection(reporter): + # If a single string, fetch that attribute from each set + if isinstance(reporter, str): + values_by_set = getattr(self._model.sets, reporter) + for set_name, series in values_by_set.items(): + agent_data_dict[f"{col_name}_{set_name}"] = series + else: + # Collection of strings: pull multiple columns from each set via set.get([...]) + for set_name, aset in self._model.sets.items(): # type: ignore[attr-defined] + df = aset.get(list(reporter)) # DataFrame of requested attrs + if isinstance(df, pl.Series): + # Defensive, though get(list) should yield DataFrame + agent_data_dict[f"{col_name}_{df.name}_{set_name}"] = df + else: + for subcol in df.columns: + agent_data_dict[f"{col_name}_{subcol}_{set_name}"] = df[ + subcol + ] + continue + + # 2) Callables: prefer registry-level; then set-level + if callable(reporter): + called = False + # Try registry-level callable: reporter(AgentSetRegistry) + try: + reg_result = reporter(self._model.sets) + # Accept Series | DataFrame | dict[str, Series|DataFrame] + if isinstance(reg_result, pl.Series): + agent_data_dict[col_name] = reg_result + called = True + elif isinstance(reg_result, pl.DataFrame): + for subcol in reg_result.columns: + agent_data_dict[f"{col_name}_{subcol}"] = reg_result[subcol] + called = True + elif isinstance(reg_result, dict): + for key, val in reg_result.items(): + if isinstance(val, pl.Series): + agent_data_dict[f"{col_name}_{key}"] = val + elif isinstance(val, pl.DataFrame): + for subcol in val.columns: + agent_data_dict[f"{col_name}_{key}_{subcol}"] = val[ + subcol + ] + else: + raise TypeError( + "Registry-level reporter dict values must be Series or DataFrame" + ) + called = True + except TypeError: + called = False + + if not called: + # Fallback: set-level callable, run once per set and suffix by set name + for set_name, aset in self._model.sets.items(): # type: ignore[attr-defined] + set_result = reporter(aset) + if isinstance(set_result, pl.Series): + agent_data_dict[f"{col_name}_{set_name}"] = set_result + elif isinstance(set_result, pl.DataFrame): + for subcol in set_result.columns: + agent_data_dict[f"{col_name}_{subcol}_{set_name}"] = ( + set_result[subcol] + ) + else: + raise TypeError( + "Set-level reporter must return polars Series or DataFrame" + ) + continue + + # Unknown type + raise TypeError( + "agent_reporters values must be str, collection[str], or callable" + ) + agent_lazy_frame = pl.LazyFrame(agent_data_dict) agent_lazy_frame = agent_lazy_frame.with_columns( [ @@ -441,7 +522,10 @@ def _validate_reporter_table(self, conn: connection, table_name: str): ) def _validate_reporter_table_columns( - self, conn: connection, table_name: str, reporter: dict[str, Callable | str] + self, + conn: connection, + table_name: str, + reporter: dict[str, Callable | str], ): """ Check if the expected columns are present in a given PostgreSQL table. @@ -460,15 +544,35 @@ def _validate_reporter_table_columns( ValueError If any expected columns are missing from the table. """ - expected_columns = set() - for col_name, required_column in reporter.items(): - if isinstance(required_column, str): - for k, v in self._model.sets[required_column].items(): - expected_columns.add( - (col_name + "_" + str(k.__class__.__name__)).lower() - ) - else: - expected_columns.add(col_name.lower()) + + def _is_str_collection(x: Any) -> bool: + try: + from collections.abc import Collection + + if isinstance(x, str): + return False + return isinstance(x, Collection) and all(isinstance(i, str) for i in x) + except Exception: + return False + + expected_columns: set[str] = set() + for col_name, req in reporter.items(): + # Strings → one column per set with suffix + if isinstance(req, str): + for set_name, _ in self._model.sets.items(): # type: ignore[attr-defined] + expected_columns.add(f"{col_name}_{set_name}".lower()) + continue + + # Collection[str] → one column per attribute per set + if _is_str_collection(req): + for set_name, _ in self._model.sets.items(): # type: ignore[attr-defined] + for subcol in req: # type: ignore[assignment] + expected_columns.add(f"{col_name}_{subcol}_{set_name}".lower()) + continue + + # Callable: conservative default → require 'col_name' to exist + # We cannot know the dynamic column explosion without running model code safely here. + expected_columns.add(col_name.lower()) query = f""" SELECT column_name diff --git a/mesa_frames/concrete/model.py b/mesa_frames/concrete/model.py index a10ce240..b91db207 100644 --- a/mesa_frames/concrete/model.py +++ b/mesa_frames/concrete/model.py @@ -64,7 +64,7 @@ class Model: running: bool _seed: int | Sequence[int] _sets: AgentSetRegistry # Where the agent sets are stored - _space: Space | None # This will be a MultiSpaceDF object + _space: Space | None # This will be a Space object def __init__(self, seed: int | Sequence[int] | None = None) -> None: """Create a new model. @@ -99,24 +99,6 @@ def steps(self) -> int: """Get the current step count.""" return self._steps - def get_sets_of_type(self, agent_type: type) -> AgentSet: - """Retrieve the AgentSet of a specified type. - - Parameters - ---------- - agent_type : type - The type of AgentSet to retrieve. - - Returns - ------- - AgentSet - The AgentSet of the specified type. - """ - for agentset in self._sets._agentsets: - if isinstance(agentset, agent_type): - return agentset - raise ValueError(f"No agent sets of type {agent_type} found in the model.") - def reset_randomizer(self, seed: int | Sequence[int] | None) -> None: """Reset the model random number generator. @@ -144,7 +126,8 @@ def step(self) -> None: The default method calls the step() method of all agents. Overload as needed. """ - self.sets.step() + # Invoke step on all contained AgentSets via the public registry API + self.sets.do("step") @property def steps(self) -> int: @@ -187,17 +170,6 @@ def sets(self, sets: AgentSetRegistry) -> None: self._sets = sets - @property - def set_types(self) -> list[type]: - """Get a list of different agent set types present in the model. - - Returns - ------- - list[type] - A list of the different agent set types present in the model. - """ - return [agent.__class__ for agent in self._sets._agentsets] - @property def space(self) -> Space: """Get the space object associated with the model. diff --git a/mesa_frames/types_.py b/mesa_frames/types_.py index 05ab1b3f..5873e034 100644 --- a/mesa_frames/types_.py +++ b/mesa_frames/types_.py @@ -1,15 +1,17 @@ """Type aliases for the mesa_frames package.""" from __future__ import annotations -from collections.abc import Collection, Sequence -from datetime import date, datetime, time, timedelta -from typing import Literal, Annotated, Union, Any -from collections.abc import Mapping -from beartype.vale import IsEqual + import math +from collections.abc import Collection, Mapping, Sequence +from datetime import date, datetime, time, timedelta +from typing import TYPE_CHECKING, Annotated, Any, Literal, Union + +import numpy as np import polars as pl +from beartype.vale import IsEqual from numpy import ndarray -import numpy as np + # import geopolars as gpl # TODO: Uncomment when geopolars is available ###----- Optional Types -----### @@ -83,6 +85,46 @@ ArrayLike = ndarray | Series | Sequence Infinity = Annotated[float, IsEqual[math.inf]] # Only accepts math.inf +# Common option types +KeyBy = Literal["name", "index", "type"] + +# Selectors for choosing AgentSets at the registry level +# Abstract (for abstract layer APIs) +if TYPE_CHECKING: + from mesa_frames.abstract.agentset import AbstractAgentSet as _AAS + + AbstractAgentSetSelector = ( + _AAS | type[_AAS] | str | Collection[_AAS | type[_AAS] | str] | None + ) +else: + AbstractAgentSetSelector = Any # runtime fallback to avoid import cycles + +# Concrete (for concrete layer APIs) +if TYPE_CHECKING: + from mesa_frames.concrete.agentset import AgentSet as _CAS + + AgentSetSelector = ( + _CAS | type[_CAS] | str | Collection[_CAS | type[_CAS] | str] | None + ) +else: + AgentSetSelector = Any # runtime fallback to avoid import cycles + +__all__ = [ + # common + "DataFrame", + "Series", + "Index", + "BoolSeries", + "Mask", + "AgentMask", + "IdsLike", + "ArrayLike", + "KeyBy", + # selectors + "AbstractAgentSetSelector", + "AgentSetSelector", +] + ###----- Time ------### TimeT = float | int diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..fd84a7ac --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,11 @@ +"""Conftest for tests. + +Ensure beartype runtime checking is enabled before importing the package. + +This module sets MESA_FRAMES_RUNTIME_TYPECHECKING=1 at import time so tests that +assert beartype failures at import or construct time behave deterministically. +""" + +import os + +os.environ.setdefault("MESA_FRAMES_RUNTIME_TYPECHECKING", "1") diff --git a/tests/test_agents.py b/tests/test_agents.py deleted file mode 100644 index f43d94f6..00000000 --- a/tests/test_agents.py +++ /dev/null @@ -1,1036 +0,0 @@ -from copy import copy, deepcopy - -import polars as pl -import pytest - -from mesa_frames import AgentSetRegistry, Model -from mesa_frames import AgentSet -from mesa_frames.types_ import AgentMask -from tests.test_agentset import ( - ExampleAgentSet, - ExampleAgentSetNoWealth, - fix1_AgentSet_no_wealth, - fix1_AgentSet, - fix2_AgentSet, - fix3_AgentSet, -) - - -@pytest.fixture -def fix_AgentSetRegistry( - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, -) -> AgentSetRegistry: - model = Model() - agents = AgentSetRegistry(model) - agents.add([fix1_AgentSet, fix2_AgentSet]) - return agents - - -class Test_AgentSetRegistry: - def test___init__(self): - model = Model() - agents = AgentSetRegistry(model) - assert agents.model == model - assert isinstance(agents._agentsets, list) - assert len(agents._agentsets) == 0 - assert isinstance(agents._ids, pl.Series) - assert agents._ids.is_empty() - assert agents._ids.name == "unique_id" - - def test_add( - self, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - ): - model = Model() - agents = AgentSetRegistry(model) - agentset_polars1 = fix1_AgentSet - agentset_polars2 = fix2_AgentSet - - # Test with a single AgentSet - result = agents.add(agentset_polars1, inplace=False) - assert result._agentsets[0] is agentset_polars1 - assert result._ids.to_list() == agentset_polars1._df["unique_id"].to_list() - - # Test with a list of AgentSets - result = agents.add([agentset_polars1, agentset_polars2], inplace=True) - assert result._agentsets[0] is agentset_polars1 - assert result._agentsets[1] is agentset_polars2 - assert ( - result._ids.to_list() - == agentset_polars1._df["unique_id"].to_list() - + agentset_polars2._df["unique_id"].to_list() - ) - - # Test if adding the same AgentSet raises ValueError - with pytest.raises(ValueError): - agents.add(agentset_polars1, inplace=False) - - def test_contains( - self, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - fix3_AgentSet: ExampleAgentSet, - fix_AgentSetRegistry: AgentSetRegistry, - ): - agents = fix_AgentSetRegistry - agentset_polars1 = agents._agentsets[0] - - # Test with an AgentSet - assert agents.contains(agentset_polars1) - assert agents.contains(fix1_AgentSet) - assert agents.contains(fix2_AgentSet) - - # Test with an AgentSet not present - assert not agents.contains(fix3_AgentSet) - - # Test with an iterable of AgentSets - assert agents.contains([agentset_polars1, fix3_AgentSet]).to_list() == [ - True, - False, - ] - - # Test with single id - assert agents.contains(agentset_polars1["unique_id"][0]) - - # Test with a list of ids - assert agents.contains([agentset_polars1["unique_id"][0], 0]).to_list() == [ - True, - False, - ] - - def test_copy(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - agents.test_list = [[1, 2, 3]] - - # Test with deep=False - agents2 = agents.copy(deep=False) - agents2.test_list[0].append(4) - assert agents.test_list[0][-1] == agents2.test_list[0][-1] - assert agents.model == agents2.model - assert agents._agentsets[0] == agents2._agentsets[0] - assert (agents._ids == agents2._ids).all() - - # Test with deep=True - agents2 = fix_AgentSetRegistry.copy(deep=True) - agents2.test_list[0].append(4) - assert agents.test_list[-1] != agents2.test_list[-1] - assert agents.model == agents2.model - assert agents._agentsets[0] != agents2._agentsets[0] - assert (agents._ids == agents2._ids).all() - - def test_discard( - self, fix_AgentSetRegistry: AgentSetRegistry, fix2_AgentSet: ExampleAgentSet - ): - agents = fix_AgentSetRegistry - # Test with a single AgentSet - agentset_polars2 = agents._agentsets[1] - result = agents.discard(agents._agentsets[0], inplace=False) - assert isinstance(result._agentsets[0], ExampleAgentSet) - assert len(result._agentsets) == 1 - - # Test with a list of AgentSets - result = agents.discard(agents._agentsets.copy(), inplace=False) - assert len(result._agentsets) == 0 - - # Test with IDs - ids = [ - agents._agentsets[0]._df["unique_id"][0], - agents._agentsets[1]._df["unique_id"][0], - ] - agentset_polars1 = agents._agentsets[0] - agentset_polars2 = agents._agentsets[1] - result = agents.discard(ids, inplace=False) - assert ( - result._agentsets[0]["unique_id"][0] - == agentset_polars1._df.select("unique_id").row(1)[0] - ) - assert ( - result._agentsets[1].df["unique_id"][0] - == agentset_polars2._df["unique_id"][1] - ) - - # Test if removing an AgentSet not present raises ValueError - result = agents.discard(fix2_AgentSet, inplace=False) - - # Test if removing an ID not present raises KeyError - assert 0 not in agents._ids - result = agents.discard(0, inplace=False) - - def test_do(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - - expected_result_0 = agents._agentsets[0].df["wealth"] - expected_result_0 += 1 - - expected_result_1 = agents._agentsets[1].df["wealth"] - expected_result_1 += 1 - - # Test with no return_results, no mask, inplace - agents.do("add_wealth", 1) - assert ( - agents._agentsets[0].df["wealth"].to_list() == expected_result_0.to_list() - ) - assert ( - agents._agentsets[1].df["wealth"].to_list() == expected_result_1.to_list() - ) - - # Test with return_results=True, no mask, inplace - expected_result_0 = agents._agentsets[0].df["wealth"] - expected_result_0 += 1 - - expected_result_1 = agents._agentsets[1].df["wealth"] - expected_result_1 += 1 - assert agents.do("add_wealth", 1, return_results=True) == { - agents._agentsets[0]: None, - agents._agentsets[1]: None, - } - assert ( - agents._agentsets[0].df["wealth"].to_list() == expected_result_0.to_list() - ) - assert ( - agents._agentsets[1].df["wealth"].to_list() == expected_result_1.to_list() - ) - - # Test with a mask, inplace - mask0 = agents._agentsets[0].df["wealth"] > 10 # No agent should be selected - mask1 = agents._agentsets[1].df["wealth"] > 10 # All agents should be selected - mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} - - expected_result_0 = agents._agentsets[0].df["wealth"] - expected_result_1 = agents._agentsets[1].df["wealth"] - expected_result_1 += 1 - - agents.do("add_wealth", 1, mask=mask_dictionary) - assert ( - agents._agentsets[0].df["wealth"].to_list() == expected_result_0.to_list() - ) - assert ( - agents._agentsets[1].df["wealth"].to_list() == expected_result_1.to_list() - ) - - def test_get( - self, - fix_AgentSetRegistry: AgentSetRegistry, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - fix1_AgentSet_no_wealth: ExampleAgentSetNoWealth, - ): - agents = fix_AgentSetRegistry - - # Test with a single attribute - assert ( - agents.get("wealth")[fix1_AgentSet].to_list() - == fix1_AgentSet._df["wealth"].to_list() - ) - assert ( - agents.get("wealth")[fix2_AgentSet].to_list() - == fix2_AgentSet._df["wealth"].to_list() - ) - - # Test with a list of attributes - result = agents.get(["wealth", "age"]) - assert result[fix1_AgentSet].columns == ["wealth", "age"] - assert ( - result[fix1_AgentSet]["wealth"].to_list() - == fix1_AgentSet._df["wealth"].to_list() - ) - assert ( - result[fix1_AgentSet]["age"].to_list() == fix1_AgentSet._df["age"].to_list() - ) - - assert result[fix2_AgentSet].columns == ["wealth", "age"] - assert ( - result[fix2_AgentSet]["wealth"].to_list() - == fix2_AgentSet._df["wealth"].to_list() - ) - assert ( - result[fix2_AgentSet]["age"].to_list() == fix2_AgentSet._df["age"].to_list() - ) - - # Test with a single attribute and a mask - mask0 = fix1_AgentSet._df["wealth"] > fix1_AgentSet._df["wealth"][0] - mask1 = fix2_AgentSet._df["wealth"] > fix2_AgentSet._df["wealth"][0] - mask_dictionary = {fix1_AgentSet: mask0, fix2_AgentSet: mask1} - result = agents.get("wealth", mask=mask_dictionary) - assert ( - result[fix1_AgentSet].to_list() == fix1_AgentSet._df["wealth"].to_list()[1:] - ) - assert ( - result[fix2_AgentSet].to_list() == fix2_AgentSet._df["wealth"].to_list()[1:] - ) - - # Test heterogeneous agent sets (different columns) - # This tests the fix for the bug where agents_df["column"] would raise - # ColumnNotFoundError when some agent sets didn't have that column. - - # Create a new AgentSetRegistry with heterogeneous agent sets - model = Model() - hetero_agents = AgentSetRegistry(model) - hetero_agents.add([fix1_AgentSet, fix1_AgentSet_no_wealth]) - - # Test 1: Access column that exists in only one agent set - result_wealth = hetero_agents.get("wealth") - assert len(result_wealth) == 1, ( - "Should only return agent sets that have 'wealth'" - ) - assert fix1_AgentSet in result_wealth, ( - "Should include the agent set with wealth" - ) - assert fix1_AgentSet_no_wealth not in result_wealth, ( - "Should not include agent set without wealth" - ) - assert result_wealth[fix1_AgentSet].to_list() == [1, 2, 3, 4] - - # Test 2: Access column that exists in all agent sets - result_age = hetero_agents.get("age") - assert len(result_age) == 2, "Should return both agent sets that have 'age'" - assert fix1_AgentSet in result_age - assert fix1_AgentSet_no_wealth in result_age - assert result_age[fix1_AgentSet].to_list() == [10, 20, 30, 40] - assert result_age[fix1_AgentSet_no_wealth].to_list() == [1, 2, 3, 4] - - # Test 3: Access column that exists in no agent sets - result_nonexistent = hetero_agents.get("nonexistent_column") - assert len(result_nonexistent) == 0, ( - "Should return empty dict for non-existent column" - ) - - # Test 4: Access multiple columns (mixed availability) - result_multi = hetero_agents.get(["wealth", "age"]) - assert len(result_multi) == 1, ( - "Should only include agent sets that have ALL requested columns" - ) - assert fix1_AgentSet in result_multi - assert fix1_AgentSet_no_wealth not in result_multi - assert result_multi[fix1_AgentSet].columns == ["wealth", "age"] - - # Test 5: Access multiple columns where some exist in different sets - result_mixed = hetero_agents.get(["age", "income"]) - assert len(result_mixed) == 1, ( - "Should only include agent set that has both 'age' and 'income'" - ) - assert fix1_AgentSet_no_wealth in result_mixed - assert fix1_AgentSet not in result_mixed - - # Test 6: Test via __getitem__ syntax (the original bug report case) - wealth_via_getitem = hetero_agents["wealth"] - assert len(wealth_via_getitem) == 1 - assert fix1_AgentSet in wealth_via_getitem - assert wealth_via_getitem[fix1_AgentSet].to_list() == [1, 2, 3, 4] - - # Test 7: Test get(None) - should return all columns for all agent sets - result_none = hetero_agents.get(None) - assert len(result_none) == 2, ( - "Should return both agent sets when attr_names=None" - ) - assert fix1_AgentSet in result_none - assert fix1_AgentSet_no_wealth in result_none - - # Verify each agent set returns all its columns (excluding unique_id) - wealth_set_result = result_none[fix1_AgentSet] - assert isinstance(wealth_set_result, pl.DataFrame), ( - "Should return DataFrame when attr_names=None" - ) - expected_wealth_cols = {"wealth", "age"} # unique_id should be excluded - assert set(wealth_set_result.columns) == expected_wealth_cols - - no_wealth_set_result = result_none[fix1_AgentSet_no_wealth] - assert isinstance(no_wealth_set_result, pl.DataFrame), ( - "Should return DataFrame when attr_names=None" - ) - expected_no_wealth_cols = {"income", "age"} # unique_id should be excluded - assert set(no_wealth_set_result.columns) == expected_no_wealth_cols - - def test_remove( - self, - fix_AgentSetRegistry: AgentSetRegistry, - fix3_AgentSet: ExampleAgentSet, - ): - agents = fix_AgentSetRegistry - - # Test with a single AgentSet - agentset_polars = agents._agentsets[1] - result = agents.remove(agents._agentsets[0], inplace=False) - assert isinstance(result._agentsets[0], ExampleAgentSet) - assert len(result._agentsets) == 1 - - # Test with a list of AgentSets - result = agents.remove(agents._agentsets.copy(), inplace=False) - assert len(result._agentsets) == 0 - - # Test with IDs - ids = [ - agents._agentsets[0]._df["unique_id"][0], - agents._agentsets[1]._df["unique_id"][0], - ] - agentset_polars1 = agents._agentsets[0] - agentset_polars2 = agents._agentsets[1] - result = agents.remove(ids, inplace=False) - assert ( - result._agentsets[0]["unique_id"][0] - == agentset_polars1._df.select("unique_id").row(1)[0] - ) - assert ( - result._agentsets[1].df["unique_id"][0] - == agentset_polars2._df["unique_id"][1] - ) - - # Test if removing an AgentSet not present raises ValueError - with pytest.raises(ValueError): - result = agents.remove(fix3_AgentSet, inplace=False) - - # Test if removing an ID not present raises KeyError - assert 0 not in agents._ids - with pytest.raises(KeyError): - result = agents.remove(0, inplace=False) - - def test_select(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - - # Test with default arguments. Should select all agents - selected = agents.select(inplace=False) - active_agents_dict = selected.active_agents - agents_dict = selected.df - assert active_agents_dict.keys() == agents_dict.keys() - # Using assert to compare all DataFrames in the dictionaries - - assert ( - list(active_agents_dict.values())[0].rows() - == list(agents_dict.values())[0].rows() - ) - - assert all( - series.all() - for series in ( - list(active_agents_dict.values())[1] == list(agents_dict.values())[1] - ) - ) - - # Test with a mask - mask0 = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean) - mask1 = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean) - mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} - selected = agents.select(mask_dictionary, inplace=False) - assert ( - selected.active_agents[selected._agentsets[0]]["wealth"].to_list()[0] - == agents._agentsets[0]["wealth"].to_list()[0] - ) - assert ( - selected.active_agents[selected._agentsets[0]]["wealth"].to_list()[-1] - == agents._agentsets[0]["wealth"].to_list()[-1] - ) - - assert ( - selected.active_agents[selected._agentsets[1]]["wealth"].to_list()[0] - == agents._agentsets[1]["wealth"].to_list()[0] - ) - assert ( - selected.active_agents[selected._agentsets[1]]["wealth"].to_list()[-1] - == agents._agentsets[1]["wealth"].to_list()[-1] - ) - - # Test with filter_func - - def filter_func(agentset: AgentSet) -> pl.Series: - return agentset.df["wealth"] > agentset.df["wealth"].to_list()[0] - - selected = agents.select(filter_func=filter_func, inplace=False) - assert ( - selected.active_agents[selected._agentsets[0]]["wealth"].to_list() - == agents._agentsets[0]["wealth"].to_list()[1:] - ) - assert ( - selected.active_agents[selected._agentsets[1]]["wealth"].to_list() - == agents._agentsets[1]["wealth"].to_list()[1:] - ) - - # Test with n - selected = agents.select(n=3, inplace=False) - assert sum(len(df) for df in selected.active_agents.values()) in [2, 3] - - # Test with n, filter_func and mask - selected = agents.select( - mask_dictionary, filter_func=filter_func, n=2, inplace=False - ) - assert any( - el in selected.active_agents[selected._agentsets[0]]["wealth"].to_list() - for el in agents.active_agents[agents._agentsets[0]]["wealth"].to_list()[ - 2:4 - ] - ) - - assert any( - el in selected.active_agents[selected._agentsets[1]]["wealth"].to_list() - for el in agents.active_agents[agents._agentsets[1]]["wealth"].to_list()[ - 2:4 - ] - ) - - def test_set(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - - # Test with a single attribute - result = agents.set("wealth", 0, inplace=False) - assert result._agentsets[0].df["wealth"].to_list() == [0] * len( - agents._agentsets[0] - ) - assert result._agentsets[1].df["wealth"].to_list() == [0] * len( - agents._agentsets[1] - ) - - # Test with a list of attributes - agents.set(["wealth", "age"], 1, inplace=True) - assert agents._agentsets[0].df["wealth"].to_list() == [1] * len( - agents._agentsets[0] - ) - assert agents._agentsets[0].df["age"].to_list() == [1] * len( - agents._agentsets[0] - ) - - # Test with a single attribute and a mask - mask0 = pl.Series( - "mask", [True] + [False] * (len(agents._agentsets[0]) - 1), dtype=pl.Boolean - ) - mask1 = pl.Series( - "mask", [True] + [False] * (len(agents._agentsets[1]) - 1), dtype=pl.Boolean - ) - mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} - result = agents.set("wealth", 0, mask=mask_dictionary, inplace=False) - assert result._agentsets[0].df["wealth"].to_list() == [0] + [1] * ( - len(agents._agentsets[0]) - 1 - ) - assert result._agentsets[1].df["wealth"].to_list() == [0] + [1] * ( - len(agents._agentsets[1]) - 1 - ) - - # Test with a dictionary - agents.set( - {agents._agentsets[0]: {"wealth": 0}, agents._agentsets[1]: {"wealth": 1}}, - inplace=True, - ) - assert agents._agentsets[0].df["wealth"].to_list() == [0] * len( - agents._agentsets[0] - ) - assert agents._agentsets[1].df["wealth"].to_list() == [1] * len( - agents._agentsets[1] - ) - - def test_shuffle(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - for _ in range(100): - original_order_0 = agents._agentsets[0].df["unique_id"].to_list() - original_order_1 = agents._agentsets[1].df["unique_id"].to_list() - agents.shuffle(inplace=True) - if ( - original_order_0 != agents._agentsets[0].df["unique_id"].to_list() - and original_order_1 != agents._agentsets[1].df["unique_id"].to_list() - ): - return - assert False - - def test_sort(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - agents.sort("wealth", ascending=False, inplace=True) - assert pl.Series(agents._agentsets[0].df["wealth"]).is_sorted(descending=True) - assert pl.Series(agents._agentsets[1].df["wealth"]).is_sorted(descending=True) - - def test_step( - self, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - fix_AgentSetRegistry: AgentSetRegistry, - ): - previous_wealth_0 = fix1_AgentSet._df["wealth"].clone() - previous_wealth_1 = fix2_AgentSet._df["wealth"].clone() - - agents = fix_AgentSetRegistry - agents.step() - - assert ( - agents._agentsets[0].df["wealth"].to_list() - == (previous_wealth_0 + 1).to_list() - ) - assert ( - agents._agentsets[1].df["wealth"].to_list() - == (previous_wealth_1 + 1).to_list() - ) - - def test__check_ids_presence( - self, - fix_AgentSetRegistry: AgentSetRegistry, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - ): - agents = fix_AgentSetRegistry.remove(fix2_AgentSet, inplace=False) - agents_different_index = deepcopy(fix2_AgentSet) - result = agents._check_ids_presence([fix1_AgentSet]) - assert result.filter(pl.col("unique_id").is_in(fix1_AgentSet._df["unique_id"]))[ - "present" - ].all() - - assert not result.filter( - pl.col("unique_id").is_in(agents_different_index._df["unique_id"]) - )["present"].any() - - def test__check_agentsets_presence( - self, - fix_AgentSetRegistry: AgentSetRegistry, - fix1_AgentSet: ExampleAgentSet, - fix3_AgentSet: ExampleAgentSet, - ): - agents = fix_AgentSetRegistry - result = agents._check_agentsets_presence([fix1_AgentSet, fix3_AgentSet]) - assert result[0] - assert not result[1] - - def test__get_bool_masks(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - # Test with mask = None - result = agents._get_bool_masks(mask=None) - truth_value = True - for i, mask in enumerate(result.values()): - if isinstance(mask, pl.Expr): - mask = agents._agentsets[i]._df.select(mask).to_series() - truth_value &= mask.all() - assert truth_value - - # Test with mask = "all" - result = agents._get_bool_masks(mask="all") - truth_value = True - for i, mask in enumerate(result.values()): - if isinstance(mask, pl.Expr): - mask = agents._agentsets[i]._df.select(mask).to_series() - truth_value &= mask.all() - assert truth_value - - # Test with mask = "active" - mask0 = ( - agents._agentsets[0].df["wealth"] - > agents._agentsets[0].df["wealth"].to_list()[0] - ) - mask1 = agents._agentsets[1].df["wealth"] > agents._agentsets[1].df["wealth"][0] - mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} - agents.select(mask=mask_dictionary) - result = agents._get_bool_masks(mask="active") - assert result[agents._agentsets[0]].to_list() == mask0.to_list() - assert result[agents._agentsets[1]].to_list() == mask1.to_list() - - # Test with mask = IdsLike - result = agents._get_bool_masks( - mask=[ - agents._agentsets[0]["unique_id"][0], - agents._agentsets[1].df["unique_id"][0], - ] - ) - assert result[agents._agentsets[0]].to_list() == [True] + [False] * ( - len(agents._agentsets[0]) - 1 - ) - assert result[agents._agentsets[1]].to_list() == [True] + [False] * ( - len(agents._agentsets[1]) - 1 - ) - - # Test with mask = dict[AgentSet, AgentMask] - result = agents._get_bool_masks(mask=mask_dictionary) - assert result[agents._agentsets[0]].to_list() == mask0.to_list() - assert result[agents._agentsets[1]].to_list() == mask1.to_list() - - def test__get_obj(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - assert agents._get_obj(inplace=True) is agents - assert agents._get_obj(inplace=False) is not agents - - def test__return_agentsets_list( - self, - fix_AgentSetRegistry: AgentSetRegistry, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - ): - agents = fix_AgentSetRegistry - result = agents._return_agentsets_list(fix1_AgentSet) - assert result == [fix1_AgentSet] - result = agents._return_agentsets_list([fix1_AgentSet, fix2_AgentSet]) - assert result == [fix1_AgentSet, fix2_AgentSet] - - def test___add__( - self, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - ): - model = Model() - agents = AgentSetRegistry(model) - agentset_polars1 = fix1_AgentSet - agentset_polars2 = fix2_AgentSet - - # Test with a single AgentSet - result = agents + agentset_polars1 - assert result._agentsets[0] is agentset_polars1 - assert result._ids.to_list() == agentset_polars1._df["unique_id"].to_list() - - # Test with a single AgentSet same as above - result = agents + agentset_polars2 - assert result._agentsets[0] is agentset_polars2 - assert result._ids.to_list() == agentset_polars2._df["unique_id"].to_list() - - # Test with a list of AgentSets - result = agents + [agentset_polars1, agentset_polars2] - assert result._agentsets[0] is agentset_polars1 - assert result._agentsets[1] is agentset_polars2 - assert ( - result._ids.to_list() - == agentset_polars1._df["unique_id"].to_list() - + agentset_polars2._df["unique_id"].to_list() - ) - - # Test if adding the same AgentSet raises ValueError - with pytest.raises(ValueError): - result + agentset_polars1 - - def test___contains__( - self, fix_AgentSetRegistry: AgentSetRegistry, fix3_AgentSet: ExampleAgentSet - ): - # Test with a single value - agents = fix_AgentSetRegistry - agentset_polars1 = agents._agentsets[0] - - # Test with an AgentSet - assert agentset_polars1 in agents - # Test with an AgentSet not present - assert fix3_AgentSet not in agents - - # Test with single id present - assert agentset_polars1["unique_id"][0] in agents - - # Test with single id not present - assert 0 not in agents - - def test___copy__(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - agents.test_list = [[1, 2, 3]] - - # Test with deep=False - agents2 = copy(agents) - agents2.test_list[0].append(4) - assert agents.test_list[0][-1] == agents2.test_list[0][-1] - assert agents.model == agents2.model - assert agents._agentsets[0] == agents2._agentsets[0] - assert (agents._ids == agents2._ids).all() - - def test___deepcopy__(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - agents.test_list = [[1, 2, 3]] - - agents2 = deepcopy(agents) - agents2.test_list[0].append(4) - assert agents.test_list[-1] != agents2.test_list[-1] - assert agents.model == agents2.model - assert agents._agentsets[0] != agents2._agentsets[0] - assert (agents._ids == agents2._ids).all() - - def test___getattr__(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - assert isinstance(agents.model, Model) - result = agents.wealth - assert ( - result[agents._agentsets[0]].to_list() - == agents._agentsets[0].df["wealth"].to_list() - ) - assert ( - result[agents._agentsets[1]].to_list() - == agents._agentsets[1].df["wealth"].to_list() - ) - - def test___getitem__( - self, - fix_AgentSetRegistry: AgentSetRegistry, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - ): - agents = fix_AgentSetRegistry - - # Test with a single attribute - assert ( - agents["wealth"][fix1_AgentSet].to_list() - == fix1_AgentSet._df["wealth"].to_list() - ) - assert ( - agents["wealth"][fix2_AgentSet].to_list() - == fix2_AgentSet._df["wealth"].to_list() - ) - - # Test with a list of attributes - result = agents[["wealth", "age"]] - assert result[fix1_AgentSet].columns == ["wealth", "age"] - assert ( - result[fix1_AgentSet]["wealth"].to_list() - == fix1_AgentSet._df["wealth"].to_list() - ) - assert ( - result[fix1_AgentSet]["age"].to_list() == fix1_AgentSet._df["age"].to_list() - ) - assert result[fix2_AgentSet].columns == ["wealth", "age"] - assert ( - result[fix2_AgentSet]["wealth"].to_list() - == fix2_AgentSet._df["wealth"].to_list() - ) - assert ( - result[fix2_AgentSet]["age"].to_list() == fix2_AgentSet._df["age"].to_list() - ) - - # Test with a single attribute and a mask - mask0 = fix1_AgentSet._df["wealth"] > fix1_AgentSet._df["wealth"][0] - mask1 = fix2_AgentSet._df["wealth"] > fix2_AgentSet._df["wealth"][0] - mask_dictionary: dict[AgentSet, AgentMask] = { - fix1_AgentSet: mask0, - fix2_AgentSet: mask1, - } - result = agents[mask_dictionary, "wealth"] - assert ( - result[fix1_AgentSet].to_list() == fix1_AgentSet.df["wealth"].to_list()[1:] - ) - assert ( - result[fix2_AgentSet].to_list() == fix2_AgentSet.df["wealth"].to_list()[1:] - ) - - def test___iadd__( - self, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - ): - model = Model() - agents = AgentSetRegistry(model) - agentset_polars1 = fix1_AgentSet - agentset_polars = fix2_AgentSet - - # Test with a single AgentSet - agents_copy = deepcopy(agents) - agents_copy += agentset_polars - assert agents_copy._agentsets[0] is agentset_polars - assert agents_copy._ids.to_list() == agentset_polars._df["unique_id"].to_list() - - # Test with a list of AgentSets - agents_copy = deepcopy(agents) - agents_copy += [agentset_polars1, agentset_polars] - assert agents_copy._agentsets[0] is agentset_polars1 - assert agents_copy._agentsets[1] is agentset_polars - assert ( - agents_copy._ids.to_list() - == agentset_polars1._df["unique_id"].to_list() - + agentset_polars._df["unique_id"].to_list() - ) - - # Test if adding the same AgentSet raises ValueError - with pytest.raises(ValueError): - agents_copy += agentset_polars1 - - def test___iter__(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - len_agentset0 = len(agents._agentsets[0]) - len_agentset1 = len(agents._agentsets[1]) - for i, agent in enumerate(agents): - assert isinstance(agent, dict) - if i < len_agentset0: - assert agent["unique_id"] == agents._agentsets[0].df["unique_id"][i] - else: - assert ( - agent["unique_id"] - == agents._agentsets[1].df["unique_id"][i - len_agentset0] - ) - assert i == len_agentset0 + len_agentset1 - 1 - - def test___isub__( - self, - fix_AgentSetRegistry: AgentSetRegistry, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - ): - # Test with an AgentSet and a DataFrame - agents = fix_AgentSetRegistry - agents -= fix1_AgentSet - assert agents._agentsets[0] == fix2_AgentSet - assert len(agents._agentsets) == 1 - - def test___len__( - self, - fix_AgentSetRegistry: AgentSetRegistry, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - ): - assert len(fix_AgentSetRegistry) == len(fix1_AgentSet) + len(fix2_AgentSet) - - def test___repr__(self, fix_AgentSetRegistry: AgentSetRegistry): - repr(fix_AgentSetRegistry) - - def test___reversed__(self, fix2_AgentSet: AgentSetRegistry): - agents = fix2_AgentSet - reversed_wealth = [] - for agent in reversed(list(agents)): - reversed_wealth.append(agent["wealth"]) - assert reversed_wealth == list(reversed(agents["wealth"])) - - def test___setitem__(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - - # Test with a single attribute - agents["wealth"] = 0 - assert agents._agentsets[0].df["wealth"].to_list() == [0] * len( - agents._agentsets[0] - ) - assert agents._agentsets[1].df["wealth"].to_list() == [0] * len( - agents._agentsets[1] - ) - - # Test with a list of attributes - agents[["wealth", "age"]] = 1 - assert agents._agentsets[0].df["wealth"].to_list() == [1] * len( - agents._agentsets[0] - ) - assert agents._agentsets[0].df["age"].to_list() == [1] * len( - agents._agentsets[0] - ) - - # Test with a single attribute and a mask - mask0 = pl.Series( - "mask", [True] + [False] * (len(agents._agentsets[0]) - 1), dtype=pl.Boolean - ) - mask1 = pl.Series( - "mask", [True] + [False] * (len(agents._agentsets[1]) - 1), dtype=pl.Boolean - ) - mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} - agents[mask_dictionary, "wealth"] = 0 - assert agents._agentsets[0].df["wealth"].to_list() == [0] + [1] * ( - len(agents._agentsets[0]) - 1 - ) - assert agents._agentsets[1].df["wealth"].to_list() == [0] + [1] * ( - len(agents._agentsets[1]) - 1 - ) - - def test___str__(self, fix_AgentSetRegistry: AgentSetRegistry): - str(fix_AgentSetRegistry) - - def test___sub__( - self, - fix_AgentSetRegistry: AgentSetRegistry, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - ): - # Test with an AgentSet and a DataFrame - result = fix_AgentSetRegistry - fix1_AgentSet - assert isinstance(result._agentsets[0], ExampleAgentSet) - assert len(result._agentsets) == 1 - - def test_agents( - self, - fix_AgentSetRegistry: AgentSetRegistry, - fix1_AgentSet: ExampleAgentSet, - fix2_AgentSet: ExampleAgentSet, - ): - assert isinstance(fix_AgentSetRegistry.df, dict) - assert len(fix_AgentSetRegistry.df) == 2 - assert fix_AgentSetRegistry.df[fix1_AgentSet] is fix1_AgentSet._df - assert fix_AgentSetRegistry.df[fix2_AgentSet] is fix2_AgentSet._df - - # Test agents.setter - fix_AgentSetRegistry.df = [fix1_AgentSet, fix2_AgentSet] - assert fix_AgentSetRegistry._agentsets[0] == fix1_AgentSet - assert fix_AgentSetRegistry._agentsets[1] == fix2_AgentSet - - def test_active_agents(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - - # Test with select - mask0 = ( - agents._agentsets[0].df["wealth"] - > agents._agentsets[0].df["wealth"].to_list()[0] - ) - mask1 = ( - agents._agentsets[1].df["wealth"] - > agents._agentsets[1].df["wealth"].to_list()[0] - ) - mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} - - agents1 = agents.select(mask=mask_dictionary, inplace=False) - - result = agents1.active_agents - assert isinstance(result, dict) - assert isinstance(result[agents1._agentsets[0]], pl.DataFrame) - assert isinstance(result[agents1._agentsets[1]], pl.DataFrame) - - assert all( - series.all() - for series in ( - result[agents1._agentsets[0]] == agents1._agentsets[0]._df.filter(mask0) - ) - ) - - assert all( - series.all() - for series in ( - result[agents1._agentsets[1]] == agents1._agentsets[1]._df.filter(mask1) - ) - ) - - # Test with active_agents.setter - agents1.active_agents = mask_dictionary - result = agents1.active_agents - assert isinstance(result, dict) - assert isinstance(result[agents1._agentsets[0]], pl.DataFrame) - assert isinstance(result[agents1._agentsets[1]], pl.DataFrame) - assert all( - series.all() - for series in ( - result[agents1._agentsets[0]] == agents1._agentsets[0]._df.filter(mask0) - ) - ) - assert all( - series.all() - for series in ( - result[agents1._agentsets[1]] == agents1._agentsets[1]._df.filter(mask1) - ) - ) - - def test_agentsets_by_type(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - - result = agents.agentsets_by_type - assert isinstance(result, dict) - assert isinstance(result[ExampleAgentSet], AgentSetRegistry) - - assert ( - result[ExampleAgentSet]._agentsets[0].df.rows() - == agents._agentsets[1].df.rows() - ) - - def test_inactive_agents(self, fix_AgentSetRegistry: AgentSetRegistry): - agents = fix_AgentSetRegistry - - # Test with select - mask0 = ( - agents._agentsets[0].df["wealth"] - > agents._agentsets[0].df["wealth"].to_list()[0] - ) - mask1 = ( - agents._agentsets[1].df["wealth"] - > agents._agentsets[1].df["wealth"].to_list()[0] - ) - mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} - agents1 = agents.select(mask=mask_dictionary, inplace=False) - result = agents1.inactive_agents - assert isinstance(result, dict) - assert isinstance(result[agents1._agentsets[0]], pl.DataFrame) - assert isinstance(result[agents1._agentsets[1]], pl.DataFrame) - assert all( - series.all() - for series in ( - result[agents1._agentsets[0]] - == agents1._agentsets[0].select(mask0, negate=True).active_agents - ) - ) - assert all( - series.all() - for series in ( - result[agents1._agentsets[1]] - == agents1._agentsets[1].select(mask1, negate=True).active_agents - ) - ) diff --git a/tests/test_agentset.py b/tests/test_agentset.py index d475a4fc..c8459a80 100644 --- a/tests/test_agentset.py +++ b/tests/test_agentset.py @@ -260,6 +260,36 @@ def test_select(self, fix1_AgentSet: ExampleAgentSet): selected.active_agents["wealth"].to_list() == agents.df["wealth"].to_list() ) + def test_rename(self, fix1_AgentSet: ExampleAgentSet) -> None: + agents = fix1_AgentSet + reg = agents.model.sets + # Inplace rename returns self and updates registry + old_name = agents.name + result = agents.rename("alpha", inplace=True) + assert result is agents + assert agents.name == "alpha" + assert reg.get("alpha") is agents + assert reg.get(old_name) is None + + # Add a second set and claim the same name via registry first + other = ExampleAgentSet(agents.model) + other["wealth"] = other.starting_wealth + other["age"] = [1, 2, 3, 4] + reg.add(other) + reg.rename(other, "omega") + # Now rename the first to an existing name; should canonicalize to omega_1 + agents.rename("omega", inplace=True) + assert agents.name != "omega" + assert agents.name.startswith("omega_") + assert reg.get(agents.name) is agents + + # Non-inplace: returns a renamed copy of the set + copy_set = agents.rename("beta", inplace=False) + assert copy_set is not agents + assert copy_set.name in ("beta", "beta_1") + # Original remains unchanged + assert agents.name not in ("beta", "beta_1") + # Test with a pl.Series[bool] mask = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean) selected = agents.select(mask, inplace=False) diff --git a/tests/test_agentsetregistry.py b/tests/test_agentsetregistry.py new file mode 100644 index 00000000..07422fb6 --- /dev/null +++ b/tests/test_agentsetregistry.py @@ -0,0 +1,436 @@ +import polars as pl +import pytest +import beartype.roar as bear_roar + +from mesa_frames import AgentSet, AgentSetRegistry, Model + + +class ExampleAgentSetA(AgentSet): + def __init__(self, model: Model): + super().__init__(model) + self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) + self["age"] = pl.Series("age", [10, 20, 30, 40]) + + def add_wealth(self, amount: int) -> None: + self["wealth"] += amount + + def step(self) -> None: + self.add_wealth(1) + + def count(self) -> int: + return len(self) + + +class ExampleAgentSetB(AgentSet): + def __init__(self, model: Model): + super().__init__(model) + self["wealth"] = pl.Series("wealth", [10, 20, 30, 40]) + self["age"] = pl.Series("age", [11, 22, 33, 44]) + + def add_wealth(self, amount: int) -> None: + self["wealth"] += amount + + def step(self) -> None: + self.add_wealth(2) + + def count(self) -> int: + return len(self) + + +@pytest.fixture +def fix_model() -> Model: + return Model() + + +@pytest.fixture +def fix_set_a(fix_model: Model) -> ExampleAgentSetA: + return ExampleAgentSetA(fix_model) + + +@pytest.fixture +def fix_set_b(fix_model: Model) -> ExampleAgentSetB: + return ExampleAgentSetB(fix_model) + + +@pytest.fixture +def fix_registry_with_two( + fix_model: Model, fix_set_a: ExampleAgentSetA, fix_set_b: ExampleAgentSetB +) -> AgentSetRegistry: + reg = AgentSetRegistry(fix_model) + reg.add([fix_set_a, fix_set_b]) + return reg + + +class TestAgentSetRegistry: + # Dunder: __init__ + def test__init__(self): + model = Model() + reg = AgentSetRegistry(model) + assert reg.model is model + assert len(reg) == 0 + assert reg.ids.len() == 0 + + # Public: add + def test_add(self, fix_model: Model) -> None: + reg = AgentSetRegistry(fix_model) + a1 = ExampleAgentSetA(fix_model) + a2 = ExampleAgentSetA(fix_model) + # Add single + reg.add(a1) + assert len(reg) == 1 + assert a1 in reg + # Add list; second should be auto-renamed with suffix + reg.add([a2]) + assert len(reg) == 2 + names = [s.name for s in reg] + assert names[0] == "ExampleAgentSetA" + assert names[1] in ("ExampleAgentSetA_1", "ExampleAgentSetA_2") + # ids concatenated + assert reg.ids.len() == len(a1) + len(a2) + # Duplicate instance rejected + with pytest.raises(ValueError, match="already present in the AgentSetRegistry"): + reg.add([a1]) + # Duplicate unique_id space rejected + a3 = ExampleAgentSetB(fix_model) + a3.df = a1.df + with pytest.raises(ValueError, match="agent IDs are not unique"): + reg.add(a3) + + # Public: contains + def test_contains(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + a_name = next(iter(reg)).name + # Single instance + assert reg.contains(reg[0]) is True + # Single type + assert reg.contains(ExampleAgentSetA) is True + # Single name + assert reg.contains(a_name) is True + # Iterable: instances + assert reg.contains([reg[0], reg[1]]).to_list() == [True, True] + # Iterable: types + types_result = reg.contains([ExampleAgentSetA, ExampleAgentSetB]) + assert types_result.dtype == pl.Boolean + assert types_result.to_list() == [True, True] + # Iterable: names + names = [s.name for s in reg] + assert reg.contains(names).to_list() == [True, True] + # Empty iterable is vacuously true + assert reg.contains([]) is True + # Unsupported element type (rejected by runtime type checking) + with pytest.raises(bear_roar.BeartypeCallHintParamViolation): + reg.contains([object()]) + + # Public: do + def test_do(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + # Inplace operation across both sets + reg.do("add_wealth", 5) + assert reg[0]["wealth"].to_list() == [6, 7, 8, 9] + assert reg[1]["wealth"].to_list() == [15, 25, 35, 45] + # return_results with different key domains + res_by_name = reg.do("count", return_results=True, key_by="name") + assert set(res_by_name.keys()) == {s.name for s in reg} + assert all(v == 4 for v in res_by_name.values()) + res_by_index = reg.do("count", return_results=True, key_by="index") + assert set(res_by_index.keys()) == {0, 1} + res_by_type = reg.do("count", return_results=True, key_by="type") + assert set(res_by_type.keys()) == {ExampleAgentSetA, ExampleAgentSetB} + + # Public: get + def test_get(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + # By index + assert isinstance(reg.get(0), AgentSet) + # By name + name = reg[0].name + assert reg.get(name) is reg[0] + # By type returns list + aset_list = reg.get(ExampleAgentSetA) + assert isinstance(aset_list, list) and all( + isinstance(s, ExampleAgentSetA) for s in aset_list + ) + # Missing returns default None + assert reg.get(9999) is None + # Out-of-range index handled without raising + assert reg.get(10) is None + + # Public: remove + def test_remove(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + total_ids = reg.ids.len() + # By instance + reg.remove(reg[0]) + assert len(reg) == 1 + # By type + reg.add(ExampleAgentSetA(reg.model)) + assert len(reg.get(ExampleAgentSetA)) == 1 + reg.remove(ExampleAgentSetA) + assert all(not isinstance(s, ExampleAgentSetA) for s in reg) + # By name (no error if not present) + reg.remove("nonexistent") + # ids recomputed and not equal to previous total + assert reg.ids.len() != total_ids + + # Public: shuffle + def test_shuffle(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + reg.shuffle(inplace=True) + assert len(reg) == 2 + + # Public: sort + def test_sort(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + reg.sort(by="wealth", ascending=False) + assert reg[0]["wealth"].to_list() == sorted( + reg[0]["wealth"].to_list(), reverse=True + ) + + # Public: rename + def test_rename(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + # Single rename by instance, inplace + a0 = reg[0] + reg.rename(a0, "X") + assert a0.name == "X" + assert reg.get("X") is a0 + + # Rename second to same name should canonicalize + a1 = reg[1] + reg.rename(a1, "X") + assert a1.name != "X" and a1.name.startswith("X_") + assert reg.get(a1.name) is a1 + + # Non-inplace copy + reg2 = reg.rename(a0, "Y", inplace=False) + assert reg2 is not reg + assert reg.get("Y") is None + assert reg2.get("Y") is not None + + # Atomic conflict raise: attempt to rename to existing name + with pytest.raises(ValueError): + reg.rename({a0: a1.name}, on_conflict="raise", mode="atomic") + # Names unchanged + assert reg.get(a1.name) is a1 + + # Best-effort: one ok, one conflicting → only ok applied + unique_name = "Z_unique" + reg.rename( + {a0: unique_name, a1: unique_name}, on_conflict="raise", mode="best_effort" + ) + assert a0.name == unique_name + # a1 stays with its previous (non-unique_name) value + assert a1.name != unique_name + + # Dunder: __getattr__ + def test__getattr__(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + ages = reg.age + assert isinstance(ages, dict) + assert set(ages.keys()) == {s.name for s in reg} + + # Dunder: __iter__ + def test__iter__(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + it = list(iter(reg)) + assert it[0] is reg[0] + assert all(isinstance(s, AgentSet) for s in it) + + # Dunder: __len__ + def test__len__(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + assert len(reg) == 2 + + # Dunder: __repr__ + def test__repr__(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + repr(reg) + + # Dunder: __str__ + def test__str__(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + str(reg) + + # Dunder: __reversed__ + def test__reversed__(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + list(reversed(reg)) + + # Dunder: __setitem__ + def test__setitem__(self, fix_model: Model) -> None: + reg = AgentSetRegistry(fix_model) + a1 = ExampleAgentSetA(fix_model) + a2 = ExampleAgentSetB(fix_model) + reg.add([a1, a2]) + # Assign by index with duplicate name should raise + a_dup = ExampleAgentSetA(fix_model) + a_dup.name = reg[1].name # create name collision + with pytest.raises(ValueError, match="Duplicate agent set name disallowed"): + reg[0] = a_dup + # Assign by name: replace existing slot, authoritative name should be key + new_set = ExampleAgentSetA(fix_model) + reg[reg[1].name] = new_set + assert reg[1] is new_set + assert reg[1].name == reg[1].name + # Assign new name appends + extra = ExampleAgentSetA(fix_model) + reg["extra_set"] = extra + assert reg["extra_set"] is extra + # Model mismatch raises + other_model_set = ExampleAgentSetA(Model()) + with pytest.raises( + TypeError, match="Assigned AgentSet must belong to the same model" + ): + reg[0] = other_model_set + + # Public: keys + def test_keys(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + # keys by name + names = list(reg.keys()) + assert names == [s.name for s in reg] + # keys by index + assert list(reg.keys(key_by="index")) == [0, 1] + # keys by type + assert set(reg.keys(key_by="type")) == {ExampleAgentSetA, ExampleAgentSetB} + # invalid key_by + with pytest.raises(bear_roar.BeartypeCallHintParamViolation): + list(reg.keys(key_by="bad")) + + # Public: items + def test_items(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + items_name = list(reg.items()) + assert [k for k, _ in items_name] == [s.name for s in reg] + items_idx = list(reg.items(key_by="index")) + assert [k for k, _ in items_idx] == [0, 1] + items_type = list(reg.items(key_by="type")) + assert {k for k, _ in items_type} == {ExampleAgentSetA, ExampleAgentSetB} + + # Public: values + def test_values(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + assert list(reg.values())[0] is reg[0] + + # Public: discard + def test_discard(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + original_len = len(reg) + # Missing selector ignored without error + reg.discard("missing_name") + assert len(reg) == original_len + # Remove by instance + reg.discard(reg[0]) + assert len(reg) == original_len - 1 + # Non-inplace returns new copy + reg2 = reg.discard("missing_name", inplace=False) + assert len(reg2) == len(reg) + + # Public: ids (property) + def test_ids(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + assert isinstance(reg.ids, pl.Series) + before = reg.ids.len() + reg.remove(reg[0]) + assert reg.ids.len() < before + + # Dunder: __getitem__ + def test__getitem__(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + # By index + assert reg[0] is next(iter(reg)) + # By name + name0 = reg[0].name + assert reg[name0] is reg[0] + # By type + lst = reg[ExampleAgentSetA] + assert isinstance(lst, list) and all( + isinstance(s, ExampleAgentSetA) for s in lst + ) + # Missing name raises KeyError + with pytest.raises(KeyError): + _ = reg["missing"] + + # Dunder: __contains__ (membership) + def test__contains__(self, fix_registry_with_two: AgentSetRegistry) -> None: + reg = fix_registry_with_two + assert reg[0] in reg + new_set = ExampleAgentSetA(reg.model) + assert new_set not in reg + + # Dunder: __add__ + def test__add__(self, fix_model: Model) -> None: + reg = AgentSetRegistry(fix_model) + a1 = ExampleAgentSetA(fix_model) + a2 = ExampleAgentSetB(fix_model) + reg.add(a1) + reg_new = reg + a2 + # original unchanged, new has two + assert len(reg) == 1 + assert len(reg_new) == 2 + # Presence by type/name (instances are deep-copied) + assert reg_new.contains(ExampleAgentSetA) is True + assert reg_new.contains(ExampleAgentSetB) is True + + # Dunder: __iadd__ + def test__iadd__(self, fix_model: Model) -> None: + reg = AgentSetRegistry(fix_model) + a1 = ExampleAgentSetA(fix_model) + a2 = ExampleAgentSetB(fix_model) + reg += a1 + assert len(reg) == 1 + reg += [a2] + assert len(reg) == 2 + assert reg.contains([a1, a2]).all() + + # Dunder: __sub__ + def test__sub__(self, fix_model: Model) -> None: + reg = AgentSetRegistry(fix_model) + a1 = ExampleAgentSetA(fix_model) + a2 = ExampleAgentSetB(fix_model) + reg.add([a1, a2]) + reg_new = reg - a1 + # original unchanged + assert len(reg) == 2 + # In current implementation, subtraction with instance returns a copy + # without mutation due to deep-copied identity; ensure new object + assert isinstance(reg_new, AgentSetRegistry) and reg_new is not reg + assert len(reg_new) == len(reg) + # subtract list of instances also yields unchanged copy + reg_new2 = reg - [a1, a2] + assert len(reg_new2) == len(reg) + + # Dunder: __isub__ + def test__isub__(self, fix_model: Model) -> None: + reg = AgentSetRegistry(fix_model) + a1 = ExampleAgentSetA(fix_model) + a2 = ExampleAgentSetB(fix_model) + reg.add([a1, a2]) + reg -= a1 + assert len(reg) == 1 and a1 not in reg + reg -= [a2] + assert len(reg) == 0 + + # Public: replace + def test_replace(self, fix_model: Model) -> None: + reg = AgentSetRegistry(fix_model) + a1 = ExampleAgentSetA(fix_model) + a2 = ExampleAgentSetB(fix_model) + a3 = ExampleAgentSetA(fix_model) + reg.add([a1, a2]) + # Replace by index + reg.replace({0: a3}) + assert reg[0] is a3 + # Replace by name (authoritative) + reg.replace({reg[1].name: a2}) + assert reg[1] is a2 + # Atomic aliasing error: same object in two positions + with pytest.raises(ValueError, match="already exists at index"): + reg.replace({0: a2, 1: a2}) + # Model mismatch + with pytest.raises(TypeError, match="must belong to the same model"): + reg.replace({0: ExampleAgentSetA(Model())}) + # Non-atomic: only applies valid keys to copy + reg2 = reg.replace({0: a1}, inplace=False, atomic=False) + assert reg2[0] is a1 + assert reg[0] is not a1 diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index b7407711..b2ac3279 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -164,7 +164,7 @@ def test_collect(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": lambda sets: sets[0]["wealth"], "age": "age", }, ) @@ -223,7 +223,7 @@ def test_collect_step(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": lambda sets: sets[0]["wealth"], "age": "age", }, ) @@ -279,7 +279,7 @@ def test_conditional_collect(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": lambda sets: sets[0]["wealth"], "age": "age", }, ) @@ -361,7 +361,7 @@ def test_flush_local_csv(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": lambda sets: sets[0]["wealth"], "age": "age", }, storage="csv", @@ -437,7 +437,7 @@ def test_flush_local_parquet(self, fix1_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": lambda sets: sets[0]["wealth"], }, storage="parquet", storage_uri=tmpdir, @@ -513,7 +513,7 @@ def test_postgress(self, fix1_model, postgres_uri): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": lambda sets: sets[0]["wealth"], "age": "age", }, storage="postgresql", @@ -562,7 +562,7 @@ def test_batch_memory(self, fix2_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": lambda sets: sets[0]["wealth"], "age": "age", }, ) @@ -707,7 +707,7 @@ def test_batch_save(self, fix2_model): ) }, agent_reporters={ - "wealth": lambda model: model.sets._agentsets[0]["wealth"], + "wealth": lambda sets: sets[0]["wealth"], "age": "age", }, storage="csv", diff --git a/tests/test_grid.py b/tests/test_grid.py index 6d75f3cc..904efdb0 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -12,10 +12,8 @@ def get_unique_ids(model: Model) -> pl.Series: - # return model.get_sets_of_type(model.set_types[0])["unique_id"] - series_list = [ - agent_set["unique_id"].cast(pl.UInt64) for agent_set in model.sets.df.values() - ] + # Collect unique_id across all concrete AgentSets in the registry + series_list = [aset["unique_id"].cast(pl.UInt64) for aset in model.sets] return pl.concat(series_list)