diff --git a/.gitignore b/.gitignore index 77f5ff8..9993b25 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ __pycache__ community_use_cases.md .vscode/ -.idea/ \ No newline at end of file +.idea/ +*.parquet \ No newline at end of file diff --git a/ratinabox/Agent.py b/ratinabox/Agent.py index 427da82..428d1ac 100644 --- a/ratinabox/Agent.py +++ b/ratinabox/Agent.py @@ -7,7 +7,8 @@ import matplotlib from matplotlib import pyplot as plt import warnings - +from typing import Union +import pandas as pd from ratinabox import utils @@ -533,7 +534,159 @@ def initialise_position_and_velocity(self): if self.Environment.dimensionality == "1D": self.velocity = np.array([self.speed_mean]) + 1e-8 #to avoid nans return - + + def export_history(self, + filename:Union[str,None] = None, + keys_to_export: Union[str, list[str], None] = None, + save_to_file: bool = False, + **kwargs + )-> pd.DataFrame: + """Exports the agent history to a csv file at the given filename.Only the parameters saved to history are exported, not the agent parameters. + Args: + filename (str, optional): The name of the file to save the history to. Defaults to "agent_history.csv". + params_to_export (list, optional): A list of parameters to export from the agent. If None, exports all parameters. + """ + + if save_to_file and filename is None: + if self.name is None: + filename = "agent_history.csv" + else: + filename = f"agent_{self.name}_history.csv" + + # process the history dict for correct format + dict_to_export = {} + + dict_to_export["agent_name"] = [self.name] * len(self.history["t"]) + + if keys_to_export is None: + keys_to_export = list(self.history.keys()) + elif isinstance(keys_to_export, str): + keys_to_export = [keys_to_export] + + if "t" in keys_to_export: + dict_to_export["t"] = self.history["t"] + + if "pos" in keys_to_export: + pos = np.array(self.history["pos"]).astype(np.float32) + if self.Environment.dimensionality == "2D": + dict_to_export["pos_x"] = pos[:, 0] + dict_to_export["pos_y"] = pos[:, 1] + elif self.Environment.dimensionality == "1D": + dict_to_export["pos_x"] = pos[:, 0] + dict_to_export["pos_y"] = np.zeros_like(pos[:, 0]) + + if "vel" in keys_to_export: + vel = np.array(self.history["vel"]).astype(np.float32) + if self.Environment.dimensionality == "2D": + dict_to_export["vel_x"] = vel[:, 0] + dict_to_export["vel_y"] = vel[:, 1] + elif self.Environment.dimensionality == "1D": + dict_to_export["vel_x"] = vel[:, 0] + dict_to_export["vel_y"] = np.zeros_like(vel[:, 0]) + + if "rot_vel" in keys_to_export: + if self.Environment.dimensionality == "2D": + rot_vel_ = np.array(self.history["rot_vel"]).astype(np.float32) + dict_to_export["rot_vel"] = rot_vel_ + + if "head_direction" in keys_to_export: + head_direction = np.array(self.history["head_direction"]).astype(np.float32) + if self.Environment.dimensionality == "2D": + dict_to_export["head_dir_x"] = head_direction[:, 0] + dict_to_export["head_dir_y"] = head_direction[:, 1] + elif self.Environment.dimensionality == "1D": + dict_to_export["head_dir_x"] = head_direction[:, 0] + dict_to_export["head_dir_y"] = np.zeros_like(head_direction[:, 0]) + + if "distance_travelled" in keys_to_export: + dict_to_export["distance_travelled"] = np.array(self.history["distance_travelled"]).astype(np.float32) + + + return utils.export_history( + history_dict=dict_to_export, + filename=filename, + save_to_file=save_to_file, + **kwargs + ) + + def import_history(self, + filename: str): + """Imports agent history from a CSV or Parquet file that was previously exported using export_history. + + Args: + filename (str): path to the file to import (either .csv or .parquet format) + Returns: + None + + Raises: + FileNotFoundError: if the specified file doesn't exist + ValueError: if the file format is not supported or data is invalid + """ + + # Import the dataframe using utils function + df = utils.import_history(filename) + + + # Convert DataFrame back to agent history format + n_timesteps = len(df) + + if n_timesteps == 0: + print("Warning: No data found in imported file") + return + + # Time data + if 't' in df.columns: + self.history["t"].extend(df['t'].tolist()) + + # Position data - reconstruct from pos_x, pos_y columns as numpy arrays + if 'pos_x' in df.columns: + if self.Environment.dimensionality == "2D" and 'pos_y' in df.columns: + pos_data = [np.array([x, y]) for x, y in zip(df['pos_x'], df['pos_y'])] + elif self.Environment.dimensionality == "1D": + pos_data = [np.array([x]) for x in df['pos_x']] + else: + raise ValueError("Position data format doesn't match environment dimensionality") + self.history["pos"].extend(pos_data) + + # Velocity data - reconstruct from vel_x, vel_y columns as numpy arrays + if 'vel_x' in df.columns: + if self.Environment.dimensionality == "2D" and 'vel_y' in df.columns: + vel_data = [np.array([x, y]) for x, y in zip(df['vel_x'], df['vel_y'])] + elif self.Environment.dimensionality == "1D": + vel_data = [np.array([x]) for x in df['vel_x']] + else: + raise ValueError("Velocity data format doesn't match environment dimensionality") + self.history["vel"].extend(vel_data) + + # Rotational velocity data (2D only) + if 'rot_vel' in df.columns and self.Environment.dimensionality == "2D": + self.history["rot_vel"].extend(df['rot_vel'].tolist()) + + # Head direction data - reconstruct from head_dir_x, head_dir_y columns as numpy arrays + if 'head_dir_x' in df.columns: + if self.Environment.dimensionality == "2D" and 'head_dir_y' in df.columns: + head_dir_data = [np.array([x, y]) for x, y in zip(df['head_dir_x'], df['head_dir_y'])] + elif self.Environment.dimensionality == "1D": + head_dir_data = [np.array([x]) for x in df['head_dir_x']] + else: + raise ValueError("Head direction data format doesn't match environment dimensionality") + self.history["head_direction"].extend(head_dir_data) + + # Update agent time to match the last imported time + if len(self.history["t"]) > 0: + self.t = self.history["t"][-1] + + print(f"Successfully imported {n_timesteps} timesteps of agent history from {filename}") + + # Verify agent name matches if available + if 'agent_name' in df.columns: + imported_agent_name = df['agent_name'].iloc[0] + if imported_agent_name != self.name: + print(f"Warning: Imported agent name '{imported_agent_name}' doesn't match current agent name '{self.name}'") + + + + def reset_history(self): """Clears the history dataframe, primarily intended for saving memory when running long simulations.""" for key in self.history.keys(): diff --git a/ratinabox/Environment.py b/ratinabox/Environment.py index 0671aef..fe2851d 100644 --- a/ratinabox/Environment.py +++ b/ratinabox/Environment.py @@ -6,6 +6,7 @@ import matplotlib from matplotlib import pyplot as plt import shapely +import pandas as pd import warnings @@ -63,6 +64,7 @@ class Environment: """ default_params = { + "name": "Environment", # name of the environment "dimensionality": "2D", # 1D or 2D environment "boundary_conditions": "solid", # solid vs periodic "scale": 1, # scale of environment (in metres) @@ -90,6 +92,8 @@ def __init__(self, params={}): self.Agents : List[Agent] = [] # each new Agent will append itself to this list self.agents_dict = {} # this is a dictionary which allows you to lookup a agent by name + self.name = self.params["name"] + if self.dimensionality == "1D": self.D = 1 self.extent = np.array([0, self.scale]) @@ -307,7 +311,6 @@ def add_agent(self, agent: Agent = None): self.Agents.append(agent) self.agents_dict[agent.name] = agent - def remove_agent(self, agent: Union[str, Agent] = None): """ A function to remove a agent from the Environment.Agents list and the Environment.agents_dict dictionary @@ -890,3 +893,37 @@ def apply_boundary_conditions(self, pos): else: # polygon shaped env, just resample random position pos = self.sample_positions(n=1, method="random").reshape(-1) return pos + + def export_agents_history(self, + agent_names: Union[str, list[str], None] = None, + keys_to_export: Union[str, list[str],None] = None, + verbose: bool = False, + save_to_file: bool = True): + """Exports the history of the agents in the environment. + Args: + agent_names (str, list[str]): the name of the agent you want to export the history for. If None, exports all agents. + save (bool): whether to save the history to a file + verbose (bool): whether to print the history to the console + Returns: + dict: a dictionary of all the default parameters of the class, including those inherited from its parents. + """ + agents = self.Agents + if agent_names is not None: + agents = self.agent_lookup(agent_names) + + combined_df = pd.DataFrame() + + for agent in agents: + if verbose: + print(f"Exporting history for agent {agent.name}") + df = agent.export_history(keys_to_export=keys_to_export, + save_to_file=save_to_file, + filename_prefix=f"{self.name}") + if df is not None: + if combined_df.empty: + combined_df = df + else: + combined_df = pd.concat([combined_df, df], axis=0, ignore_index=True) + + return combined_df + \ No newline at end of file diff --git a/ratinabox/Neurons.py b/ratinabox/Neurons.py index 20c45cc..2eca385 100644 --- a/ratinabox/Neurons.py +++ b/ratinabox/Neurons.py @@ -10,6 +10,8 @@ from scipy import stats as stats import warnings from matplotlib.collections import EllipseCollection +import pandas as pd +from typing import Union from ratinabox import utils @@ -116,6 +118,8 @@ def __init__(self, Agent, params={}): utils.update_class_params(self, self.params, get_all_defaults=True) utils.check_params(self, params.keys()) + + self.name = self.params["name"] self.firingrate = np.zeros(self.n) self.noise = np.zeros(self.n) @@ -690,7 +694,117 @@ def reset_history(self): for key in self.history.keys(): self.history[key] = [] return + + def export_history(self, + filename:Union[str,None] = None, + keys_to_export: Union[str, list[str],None] = None, + save_to_file:bool = False, + **kwargs + ) -> pd.DataFrame: + """Exports the Neuron history to a csv file at the given filename.Only the parameters saved to history are exported, not the agent parameters. + Args: + filename (str, optional): The name of the file to save the history to. Defaults to "agent_history.csv". + params_to_export (list, optional): A list of parameters to export from the agent. If None, exports all parameters. + """ + + if filename is None: + if self.name is None: + filename = "neuron_history.csv" + else: + filename = f"neuron_{self.name}_history.csv" + + dict_to_export = self.history.copy() + + dict_to_export["agent_name"] = [self.Agent.name] * len(self.history["t"]) + dict_to_export["neuron_group"] = [self.name] * len(self.history["t"]) + + # convert any numpy arrays to lists + for key in dict_to_export.keys(): + if type(dict_to_export[key]) is list: + # convert to numpy array first and the to list to ensure all elements are lists + dict_to_export[key] = np.array(dict_to_export[key]) + + # limit the float to np.float32 + if dict_to_export[key].dtype == np.float64: + dict_to_export[key] = dict_to_export[key].astype(np.float32) + + elif dict_to_export[key].dtype == np.bool: + dict_to_export[key] = dict_to_export[key].astype(int) + + + # convert numpy arrays to lists + if type(dict_to_export[key]) is np.ndarray: + dict_to_export[key] = dict_to_export[key].tolist() + + + + + return utils.export_history( + history_dict=dict_to_export, + filename=filename, + keys_to_export=keys_to_export, + save_to_file=save_to_file, + **kwargs, + ) + def import_history(self, + filename: str): + """Imports neuron history from a CSV or Parquet file that was previously exported using export_history. + + Args: + filename (str): path to the file to import (either .csv or .parquet format) + Returns: + None + + Raises: + FileNotFoundError: if the specified file doesn't exist + ValueError: if the file format is not supported or data is invalid + """ + + # Import the dataframe using utils function + df = utils.import_history(filename) + + # Convert DataFrame back to neuron history format + n_timesteps = len(df) + + if n_timesteps == 0: + print("Warning: No data found in imported file") + return + + if 'firingrate' not in df.columns or 'spikes' not in df.columns: + raise ValueError("The imported file must contain 'firingrate' and 'spikes' columns.") + + # Reset the history + self.reset_history() + + # if the filename had csv in it, then convert the firingrate and spikes columns to lists of numpy arrays + if filename.endswith(".csv"): + # convert the firingrate and spikes columns to lists of numpy arrays + # from string + df['firingrate'] = df['firingrate'].apply(lambda x: np.fromstring(x[1:-1], sep=',') if isinstance(x, str) else x) + df['spikes'] = df['spikes'].apply(lambda x: np.fromstring(x[1:-1], sep=',') if isinstance(x, str) else x) + + # ensure the columns : firingrate, spikes i nthe correct format + # convert the firingrate into arrays + df['firingrate'] = df['firingrate'].apply(lambda x: np.array(x) if isinstance(x, list) else x) + # convert the spikes into arrays + df['spikes'] = df['spikes'].apply(lambda x: np.array(x) if isinstance(x, list) else x) + + df['spikes'] = df['spikes'].apply(lambda x: x.astype(bool) if isinstance(x, np.ndarray) else x) + # Time data + if 't' in df.columns: + self.history["t"].extend(df['t'].tolist()) + + firingrate_df = [np.array(row) for row in df.firingrate] + self.history["firingrate"] = firingrate_df + + spikes_df = [np.array(row) for row in df.spikes] + self.history["spikes"] = spikes_df + + + + print(f"Successfully imported {n_timesteps} timesteps of neuron history from {filename}") + def animate_rate_timeseries( self, t_start=None, diff --git a/ratinabox/contribs/ValueNeuron.py b/ratinabox/contribs/ValueNeuron.py index 5859c0a..39d443c 100644 --- a/ratinabox/contribs/ValueNeuron.py +++ b/ratinabox/contribs/ValueNeuron.py @@ -8,7 +8,7 @@ class ValueNeuron(FeedForwardLayer): - """ + r""" Contributer: Tom George tomgeorge1@btinternet.com The ValueNeuron class defines neuron(s) which learns the "value" of a policy using TD learning. For n > 1 the reward function is assumed to be multidimensional and n value functions (one neuron for each reward function) will be learned under the current policy. diff --git a/ratinabox/utils.py b/ratinabox/utils.py index 650bcf7..3dacce1 100644 --- a/ratinabox/utils.py +++ b/ratinabox/utils.py @@ -1,6 +1,7 @@ import numpy as np import matplotlib from matplotlib import pyplot as plt +import pandas as pd import scipy import inspect import os @@ -8,6 +9,9 @@ from datetime import datetime from scipy import stats as stats from typing import Union + +from typing import Union, Tuple, List + import ratinabox """OTHER USEFUL FUNCTIONS""" @@ -1024,8 +1028,6 @@ def activate(x, activation="sigmoid", deriv=False, other_args={}): return other_args["gain"] / (1 + np.exp(-(x - other_args["threshold"]))) - - # ** Manifold Functions ** def create_uniform_radial_assembly(distance_range: list = [0.0, 0.2], @@ -1216,3 +1218,114 @@ def create_random_assembly( sigma_angle *= np.pi / 180 # convert to radians return tuning_distance, tuning_angle, sigma_distance, sigma_angle + + + +def export_history(history_dict: dict, + filename:Union[str,None] = None, + keys_to_export: Union[str, List[str], None] = None, + filename_prefix:Union[str,None] = None, + save_to_file: bool = True, + format: str = "parquet", + **kwargs + ) -> pd.DataFrame: + """Exports the history dictionary to a file in the specified format. + Args: + filename (str, optional): name of the file to save the history to. Defaults to None, in which case it is saved to the history directory with the name "history.{format}". + history_dict (dict, optional): history dictionary to export. Defaults to None, in which case it uses ratinabox.history. + keys_to_export (list[str], optional): list of parameters to export from the history dictionary. If None, all parameters are exported. + format (str, optional): file format to save as, either "parquet" or "csv". Defaults to "parquet". + """ + + if history_dict is None: + raise ValueError("history_dict must be provided") + + # Validate format + if format.lower() not in ["parquet", "csv"]: + raise ValueError("format must be either 'parquet' or 'csv'") + + format = format.lower() + + if filename is None: + filename = f"history.{format}" + else: + # Remove any existing extension and add the correct one for the format + filename = filename.split('.')[0] + f".{format}" + + # add the prefix to the filename if it is provided + if filename_prefix is not None: + if not filename_prefix.endswith("_"): + filename_prefix += "_" + + # get the directory of the filename + directory = os.path.dirname(filename) if filename else "" + # get the base name of the filename + base_name = os.path.basename(filename) if filename else f"history.{format}" + # create the new filename with the prefix + if directory: + filename = os.path.join(directory, f"{filename_prefix}{base_name}") + else: + filename = f"{filename_prefix}{base_name}" + + + + + # get the all the keys from the history dictionary + history_keys = list(history_dict.keys()) + + # if keys_to_export is None, export all keys_to_export + if keys_to_export is None: + keys_to_export = history_keys + else: + if isinstance(keys_to_export, str): + keys_to_export = [keys_to_export] + # check that all keys_to_export are in the history_keys + for param in keys_to_export: + if param not in history_keys: + raise ValueError(f"Parameter {param} not found in history keys_to_export.") + + history_to_export = {key: history_dict[key] for key in keys_to_export} + + df = pd.DataFrame(history_to_export) + if save_to_file: + # save the dataframe in the specified format + if format == "parquet": + df.to_parquet(filename, index=False) + elif format == "csv": + df.to_csv(filename, index=False) + print(f"History exported to {filename}") + return df + +def import_history(filename: str) -> pd.DataFrame: + """Imports history data from a CSV or Parquet file. + + Args: + filename (str): path to the file to import. Can be either .csv or .parquet format. + + Returns: + pd.DataFrame: DataFrame containing the imported history data + + Raises: + FileNotFoundError: if the specified file doesn't exist + ValueError: if the file format is not supported (not .csv or .parquet) + """ + + if not os.path.exists(filename): + raise FileNotFoundError(f"File not found: {filename}") + + # Get file extension to determine format + _, ext = os.path.splitext(filename) + ext = ext.lower() + + if ext == '.csv': + df = pd.read_csv(filename) + print(f"History imported from CSV: {filename}") + elif ext == '.parquet': + df = pd.read_parquet(filename) + print(f"History imported from Parquet: {filename}") + else: + raise ValueError(f"Unsupported file format: {ext}. Only .csv and .parquet are supported.") + + return df + + diff --git a/setup.cfg b/setup.cfg index feeff28..6c0e123 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,8 @@ install_requires = matplotlib scipy shapely + pandas + fastparquet python_requires = >=3.7 include_package_data = False