Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ __pycache__
community_use_cases.md
.vscode/

.idea/
.idea/
*.parquet
157 changes: 155 additions & 2 deletions ratinabox/Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import matplotlib
from matplotlib import pyplot as plt
import warnings

from typing import Union
import pandas as pd

from ratinabox import utils

Expand Down Expand Up @@ -533,7 +534,159 @@ def initialise_position_and_velocity(self):
if self.Environment.dimensionality == "1D":
self.velocity = np.array([self.speed_mean]) + 1e-8 #to avoid nans
return


def export_history(self,
filename:Union[str,None] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still the best way to type hint? Shouldn't we use filename: str | None = None nowadays?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah! But it depends what python version do wanna support. I think anything below 3.10 might be an issue but I agree that 3.10 is a fair. In the config we have python>=3.7

keys_to_export: Union[str, list[str], None] = None,
save_to_file: bool = False,
**kwargs
)-> pd.DataFrame:
"""Exports the agent history to a csv file at the given filename.Only the parameters saved to history are exported, not the agent parameters.
Args:
filename (str, optional): The name of the file to save the history to. Defaults to "agent_history.csv".
params_to_export (list, optional): A list of parameters to export from the agent. If None, exports all parameters.
Copy link
Collaborator

@TomGeorge1234 TomGeorge1234 Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think this docstring needs updating (params --> keys, and save_to_file is missing)

"""

if save_to_file and filename is None:
if self.name is None:
filename = "agent_history.csv"
else:
filename = f"agent_{self.name}_history.csv"

# process the history dict for correct format
dict_to_export = {}

dict_to_export["agent_name"] = [self.name] * len(self.history["t"])

if keys_to_export is None:
keys_to_export = list(self.history.keys())
elif isinstance(keys_to_export, str):
keys_to_export = [keys_to_export]

if "t" in keys_to_export:
dict_to_export["t"] = self.history["t"]

if "pos" in keys_to_export:
pos = np.array(self.history["pos"]).astype(np.float32)
if self.Environment.dimensionality == "2D":
dict_to_export["pos_x"] = pos[:, 0]
dict_to_export["pos_y"] = pos[:, 1]
elif self.Environment.dimensionality == "1D":
dict_to_export["pos_x"] = pos[:, 0]
dict_to_export["pos_y"] = np.zeros_like(pos[:, 0])

if "vel" in keys_to_export:
vel = np.array(self.history["vel"]).astype(np.float32)
if self.Environment.dimensionality == "2D":
dict_to_export["vel_x"] = vel[:, 0]
dict_to_export["vel_y"] = vel[:, 1]
elif self.Environment.dimensionality == "1D":
dict_to_export["vel_x"] = vel[:, 0]
dict_to_export["vel_y"] = np.zeros_like(vel[:, 0])

if "rot_vel" in keys_to_export:
if self.Environment.dimensionality == "2D":
rot_vel_ = np.array(self.history["rot_vel"]).astype(np.float32)
dict_to_export["rot_vel"] = rot_vel_

if "head_direction" in keys_to_export:
head_direction = np.array(self.history["head_direction"]).astype(np.float32)
if self.Environment.dimensionality == "2D":
dict_to_export["head_dir_x"] = head_direction[:, 0]
dict_to_export["head_dir_y"] = head_direction[:, 1]
elif self.Environment.dimensionality == "1D":
dict_to_export["head_dir_x"] = head_direction[:, 0]
dict_to_export["head_dir_y"] = np.zeros_like(head_direction[:, 0])

if "distance_travelled" in keys_to_export:
dict_to_export["distance_travelled"] = np.array(self.history["distance_travelled"]).astype(np.float32)


return utils.export_history(
history_dict=dict_to_export,
filename=filename,
save_to_file=save_to_file,
**kwargs
)

def import_history(self,
filename: str):
"""Imports agent history from a CSV or Parquet file that was previously exported using export_history.

Args:
filename (str): path to the file to import (either .csv or .parquet format)
Returns:
None

Raises:
FileNotFoundError: if the specified file doesn't exist
ValueError: if the file format is not supported or data is invalid
"""

# Import the dataframe using utils function
df = utils.import_history(filename)


# Convert DataFrame back to agent history format
n_timesteps = len(df)

if n_timesteps == 0:
print("Warning: No data found in imported file")
return

# Time data
if 't' in df.columns:
self.history["t"].extend(df['t'].tolist())

# Position data - reconstruct from pos_x, pos_y columns as numpy arrays
if 'pos_x' in df.columns:
if self.Environment.dimensionality == "2D" and 'pos_y' in df.columns:
pos_data = [np.array([x, y]) for x, y in zip(df['pos_x'], df['pos_y'])]
elif self.Environment.dimensionality == "1D":
pos_data = [np.array([x]) for x in df['pos_x']]
else:
raise ValueError("Position data format doesn't match environment dimensionality")
self.history["pos"].extend(pos_data)

# Velocity data - reconstruct from vel_x, vel_y columns as numpy arrays
if 'vel_x' in df.columns:
if self.Environment.dimensionality == "2D" and 'vel_y' in df.columns:
vel_data = [np.array([x, y]) for x, y in zip(df['vel_x'], df['vel_y'])]
elif self.Environment.dimensionality == "1D":
vel_data = [np.array([x]) for x in df['vel_x']]
else:
raise ValueError("Velocity data format doesn't match environment dimensionality")
self.history["vel"].extend(vel_data)

# Rotational velocity data (2D only)
if 'rot_vel' in df.columns and self.Environment.dimensionality == "2D":
self.history["rot_vel"].extend(df['rot_vel'].tolist())

# Head direction data - reconstruct from head_dir_x, head_dir_y columns as numpy arrays
if 'head_dir_x' in df.columns:
if self.Environment.dimensionality == "2D" and 'head_dir_y' in df.columns:
head_dir_data = [np.array([x, y]) for x, y in zip(df['head_dir_x'], df['head_dir_y'])]
elif self.Environment.dimensionality == "1D":
head_dir_data = [np.array([x]) for x in df['head_dir_x']]
else:
raise ValueError("Head direction data format doesn't match environment dimensionality")
self.history["head_direction"].extend(head_dir_data)

# Update agent time to match the last imported time
if len(self.history["t"]) > 0:
self.t = self.history["t"][-1]

print(f"Successfully imported {n_timesteps} timesteps of agent history from {filename}")

# Verify agent name matches if available
if 'agent_name' in df.columns:
imported_agent_name = df['agent_name'].iloc[0]
if imported_agent_name != self.name:
print(f"Warning: Imported agent name '{imported_agent_name}' doesn't match current agent name '{self.name}'")




def reset_history(self):
"""Clears the history dataframe, primarily intended for saving memory when running long simulations."""
for key in self.history.keys():
Expand Down
39 changes: 38 additions & 1 deletion ratinabox/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib
from matplotlib import pyplot as plt
import shapely
import pandas as pd


import warnings
Expand Down Expand Up @@ -63,6 +64,7 @@ class Environment:
"""

default_params = {
"name": "Environment", # name of the environment
"dimensionality": "2D", # 1D or 2D environment
"boundary_conditions": "solid", # solid vs periodic
"scale": 1, # scale of environment (in metres)
Expand Down Expand Up @@ -90,6 +92,8 @@ def __init__(self, params={}):
self.Agents : List[Agent] = [] # each new Agent will append itself to this list
self.agents_dict = {} # this is a dictionary which allows you to lookup a agent by name

self.name = self.params["name"]

if self.dimensionality == "1D":
self.D = 1
self.extent = np.array([0, self.scale])
Expand Down Expand Up @@ -307,7 +311,6 @@ def add_agent(self, agent: Agent = None):
self.Agents.append(agent)
self.agents_dict[agent.name] = agent


def remove_agent(self, agent: Union[str, Agent] = None):
"""
A function to remove a agent from the Environment.Agents list and the Environment.agents_dict dictionary
Expand Down Expand Up @@ -890,3 +893,37 @@ def apply_boundary_conditions(self, pos):
else: # polygon shaped env, just resample random position
pos = self.sample_positions(n=1, method="random").reshape(-1)
return pos

def export_agents_history(self,
agent_names: Union[str, list[str], None] = None,
keys_to_export: Union[str, list[str],None] = None,
verbose: bool = False,
save_to_file: bool = True):
"""Exports the history of the agents in the environment.
Args:
agent_names (str, list[str]): the name of the agent you want to export the history for. If None, exports all agents.
save (bool): whether to save the history to a file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would propose we should remove this parameter and always save to file. Can you think of a good reason why a users would want to export the data to a df but then not save? If so, wouldn't they just use the agent.history attribute.

What about this: We write all the logic for export to a dataframe. And have a wrapper on this which then saves this. o users have two separate APIs:
convert_history_to_dataframe() and export_history(). I think it's clearer that way.

verbose (bool): whether to print the history to the console
Returns:
dict: a dictionary of all the default parameters of the class, including those inherited from its parents.
"""
agents = self.Agents
if agent_names is not None:
agents = self.agent_lookup(agent_names)

combined_df = pd.DataFrame()

for agent in agents:
if verbose:
print(f"Exporting history for agent {agent.name}")
df = agent.export_history(keys_to_export=keys_to_export,
save_to_file=save_to_file,
filename_prefix=f"{self.name}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here this should be agent.name? No?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no..as this is looping over the agent objects after looking up the names passed by the user

if df is not None:
if combined_df.empty:
combined_df = df
else:
combined_df = pd.concat([combined_df, df], axis=0, ignore_index=True)

return combined_df

114 changes: 114 additions & 0 deletions ratinabox/Neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from scipy import stats as stats
import warnings
from matplotlib.collections import EllipseCollection
import pandas as pd
from typing import Union

from ratinabox import utils

Expand Down Expand Up @@ -116,6 +118,8 @@ def __init__(self, Agent, params={}):

utils.update_class_params(self, self.params, get_all_defaults=True)
utils.check_params(self, params.keys())

self.name = self.params["name"]

self.firingrate = np.zeros(self.n)
self.noise = np.zeros(self.n)
Expand Down Expand Up @@ -690,7 +694,117 @@ def reset_history(self):
for key in self.history.keys():
self.history[key] = []
return

def export_history(self,
filename:Union[str,None] = None,
keys_to_export: Union[str, list[str],None] = None,
save_to_file:bool = False,
**kwargs
) -> pd.DataFrame:
"""Exports the Neuron history to a csv file at the given filename.Only the parameters saved to history are exported, not the agent parameters.
Args:
filename (str, optional): The name of the file to save the history to. Defaults to "agent_history.csv".
params_to_export (list, optional): A list of parameters to export from the agent. If None, exports all parameters.
"""

if filename is None:
if self.name is None:
filename = "neuron_history.csv"
else:
filename = f"neuron_{self.name}_history.csv"

dict_to_export = self.history.copy()

dict_to_export["agent_name"] = [self.Agent.name] * len(self.history["t"])
dict_to_export["neuron_group"] = [self.name] * len(self.history["t"])

# convert any numpy arrays to lists
for key in dict_to_export.keys():
if type(dict_to_export[key]) is list:
# convert to numpy array first and the to list to ensure all elements are lists
dict_to_export[key] = np.array(dict_to_export[key])

# limit the float to np.float32
if dict_to_export[key].dtype == np.float64:
dict_to_export[key] = dict_to_export[key].astype(np.float32)

elif dict_to_export[key].dtype == np.bool:
dict_to_export[key] = dict_to_export[key].astype(int)


# convert numpy arrays to lists
if type(dict_to_export[key]) is np.ndarray:
dict_to_export[key] = dict_to_export[key].tolist()




return utils.export_history(
history_dict=dict_to_export,
filename=filename,
keys_to_export=keys_to_export,
save_to_file=save_to_file,
**kwargs,
)

def import_history(self,
filename: str):
"""Imports neuron history from a CSV or Parquet file that was previously exported using export_history.

Args:
filename (str): path to the file to import (either .csv or .parquet format)
Returns:
None

Raises:
FileNotFoundError: if the specified file doesn't exist
ValueError: if the file format is not supported or data is invalid
"""

# Import the dataframe using utils function
df = utils.import_history(filename)

# Convert DataFrame back to neuron history format
n_timesteps = len(df)

if n_timesteps == 0:
print("Warning: No data found in imported file")
return

if 'firingrate' not in df.columns or 'spikes' not in df.columns:
raise ValueError("The imported file must contain 'firingrate' and 'spikes' columns.")

# Reset the history
self.reset_history()

# if the filename had csv in it, then convert the firingrate and spikes columns to lists of numpy arrays
if filename.endswith(".csv"):
# convert the firingrate and spikes columns to lists of numpy arrays
# from string
df['firingrate'] = df['firingrate'].apply(lambda x: np.fromstring(x[1:-1], sep=',') if isinstance(x, str) else x)
df['spikes'] = df['spikes'].apply(lambda x: np.fromstring(x[1:-1], sep=',') if isinstance(x, str) else x)

# ensure the columns : firingrate, spikes i nthe correct format
# convert the firingrate into arrays
df['firingrate'] = df['firingrate'].apply(lambda x: np.array(x) if isinstance(x, list) else x)
# convert the spikes into arrays
df['spikes'] = df['spikes'].apply(lambda x: np.array(x) if isinstance(x, list) else x)

df['spikes'] = df['spikes'].apply(lambda x: x.astype(bool) if isinstance(x, np.ndarray) else x)
# Time data
if 't' in df.columns:
self.history["t"].extend(df['t'].tolist())

firingrate_df = [np.array(row) for row in df.firingrate]
self.history["firingrate"] = firingrate_df

spikes_df = [np.array(row) for row in df.spikes]
self.history["spikes"] = spikes_df



print(f"Successfully imported {n_timesteps} timesteps of neuron history from {filename}")

def animate_rate_timeseries(
self,
t_start=None,
Expand Down
2 changes: 1 addition & 1 deletion ratinabox/contribs/ValueNeuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class ValueNeuron(FeedForwardLayer):
"""
r"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mistake?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no linting was complaining with all the backslashes in the docs but I guess it should be another PR to solve it library wide. removing it

Contributer: Tom George [email protected]
The ValueNeuron class defines neuron(s) which learns the "value" of a policy using TD learning. For n > 1 the reward function is assumed to be multidimensional and n value functions (one neuron for each reward function) will be learned under the current policy.
Expand Down
Loading
Loading