diff --git a/.gitignore b/.gitignore index da2008b..f9d55f1 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,6 @@ serow_layout.json:Zone.Identifier /serow_ros2/build/ /serow_ros2/install/ /serow_ros2/log/ +/python/datasets/* +/python/models/* +/python/*_test.npz \ No newline at end of file diff --git a/config/a1.json b/config/a1.json index d93d609..5a84ce6 100644 --- a/config/a1.json +++ b/config/a1.json @@ -40,6 +40,8 @@ 0.0, -1.0 ], + "imu_outlier_detection": true, + "use_imu_orientation": true, "calibrate_initial_imu_bias": true, "max_imu_calibration_cycles": 100, "bias_gyro": [ diff --git a/config/anymal_b.json b/config/anymal_b.json index 222e245..70d31be 100644 --- a/config/anymal_b.json +++ b/config/anymal_b.json @@ -38,6 +38,8 @@ 0.0, -1.0 ], + "imu_outlier_detection": true, + "use_imu_orientation": true, "calibrate_initial_imu_bias": false, "max_imu_calibration_cycles": 500, "bias_gyro": [ diff --git a/config/estimation.json b/config/estimation.json index aaf1a5d..95e53d2 100644 --- a/config/estimation.json +++ b/config/estimation.json @@ -68,6 +68,10 @@ 0.0, 1.0 ], + // boolean: whether or not to use the IMU outlier detection during the filter step + "imu_outlier_detection": true, + // boolean: whether or not to use the IMU orientation during the ContactEKF update step + "use_imu_orientation": true, // boolean: whether or not to estimate initial values for the IMU gyro/accelerometer biases "calibrate_initial_imu_bias": false, // integer: number of IMU measurements to use for estimating the IMU gyro/accelerometer biases diff --git a/config/go1.json b/config/go1.json index a6eb93e..1035180 100644 --- a/config/go1.json +++ b/config/go1.json @@ -40,6 +40,8 @@ 0.0, 1.0 ], + "imu_outlier_detection": true, + "use_imu_orientation": true, "calibrate_initial_imu_bias": true, "max_imu_calibration_cycles": 100, "bias_gyro": [ diff --git a/config/go2.json b/config/go2.json index 184383c..ebccdaa 100644 --- a/config/go2.json +++ b/config/go2.json @@ -40,6 +40,8 @@ 0.0, -1.0 ], + "imu_outlier_detection": true, + "use_imu_orientation": true, "calibrate_initial_imu_bias": true, "max_imu_calibration_cycles": 1000, "bias_gyro": [ diff --git a/config/go2_pytest.json b/config/go2_pytest.json new file mode 100644 index 0000000..277192b --- /dev/null +++ b/config/go2_pytest.json @@ -0,0 +1,250 @@ +{ + "robot_name": "go2", + "base_frame": "base", + "point_feet": true, + "foot_frames": { + "0": "FL_foot", + "1": "FR_foot", + "2": "RL_foot", + "3": "RR_foot" + }, + "model_path": "go2.urdf", + "g": 9.81, + "joint_rate": 500.0, + "estimate_joint_velocity": true, + "joint_cutoff_frequency": 15.0, + "joint_position_variance": 0.0005, + "angular_momentum_cutoff_frequency": 5.0, + "tau_0": 0.1, + "tau_1": 0.0, + "imu_rate": 500.0, + "R_base_to_gyro": [ + 1.0, + 0.0, + 0.0, + 0.0, + -1.0, + 0.0, + 0.0, + 0.0, + -1.0 + ], + "R_base_to_acc": [ + 1.0, + 0.0, + 0.0, + 0.0, + -1.0, + 0.0, + 0.0, + 0.0, + -1.0 + ], + "imu_outlier_detection": false, + "use_imu_orientation": false, + "calibrate_initial_imu_bias": false, + "max_imu_calibration_cycles": 1000, + "bias_gyro": [ + 0.0, + 0.0, + 0.0 + ], + "bias_acc": [ + 0.0, + 0.0, + 0.0 + ], + "gyro_cutoff_frequency": 5.0, + "force_torque_rate": 500.0, + "R_foot_to_force": { + "0": [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0 + ], + "1": [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0 + ], + "2": [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0 + ], + "3": [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0 + ] + }, + "R_foot_to_torque": null, + "attitude_estimator_proportional_gain": 0.1, + "attitude_estimator_integral_gain": 0.0, + "estimate_contact_status": true, + "high_threshold": 4.0, + "low_threshold": 2.0, + "median_window": 13, + "outlier_detection": false, + "convergence_cycles": 0, + "imu_angular_velocity_covariance": [ + 5e-5, + 5e-5, + 5e-5 + ], + "imu_angular_velocity_bias_covariance": [ + 1e-8, + 1e-8, + 1e-8 + ], + "imu_linear_acceleration_covariance": [ + 5e-3, + 5e-3, + 5e-3 + ], + "imu_linear_acceleration_bias_covariance": [ + 1e-8, + 1e-8, + 1e-8 + ], + "base_orientation_covariance": [ + 1e-3, + 1e-3, + 1e-3 + ], + "contact_position_covariance": [ + 5e-6, + 5e-6, + 5e-5 + ], + "contact_orientation_covariance": null, + "contact_position_slip_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "contact_orientation_slip_covariance": null, + "com_position_process_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "com_linear_velocity_process_covariance": [ + 1e-4, + 1e-4, + 1e-4 + ], + "external_forces_process_covariance": [ + 1e-1, + 1e-1, + 1e-1 + ], + "com_position_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "com_linear_acceleration_covariance": [ + 1e-2, + 1e-2, + 1e-2 + ], + "initial_base_position_covariance": [ + 1e-4, + 1e-4, + 1e-4 + ], + "initial_base_orientation_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "initial_base_linear_velocity_covariance": [ + 1e-3, + 1e-3, + 1e-3 + ], + "initial_contact_position_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "initial_contact_orientation_covariance": null, + "initial_imu_linear_acceleration_bias_covariance": [ + 1e-4, + 1e-4, + 1e-4 + ], + "initial_imu_angular_velocity_bias_covariance": [ + 1e-4, + 1e-4, + 1e-4 + ], + "initial_com_position_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "initial_com_linear_velocity_covariance": [ + 1e-3, + 1e-3, + 1e-3 + ], + "initial_external_forces_covariance": [ + 1.0, + 1.0, + 1.0 + ], + "T_base_to_odom": null, + "enable_terrain_estimation": false, + "terrain_estimator": "fast", + "minimum_terrain_height_variance": 1e-3, + "maximum_contact_points": 4, + "maximum_recenter_distance": 0.35, + "minimum_contact_probability": 0.15, + "T_base_to_ground_truth": [ + 1.0, + 0.0, + 0.0, + -0.03270961, + 0.0, + 1.0, + 0.0, + -0.00222456, + 0.0, + 0.0, + 1.0, + -0.01544948, + 0.0, + 0.0, + 0.0, + 1.0 + ], + "log_data": false, + "log_measurements": false, + "log_dir": "/tmp/serow" +} diff --git a/config/go2_rl.json b/config/go2_rl.json new file mode 100644 index 0000000..3d00f11 --- /dev/null +++ b/config/go2_rl.json @@ -0,0 +1,250 @@ +{ + "robot_name": "go2", + "base_frame": "base", + "point_feet": true, + "foot_frames": { + "0": "FL_foot", + "1": "FR_foot", + "2": "RL_foot", + "3": "RR_foot" + }, + "model_path": "go2.urdf", + "g": 9.81, + "joint_rate": 500.0, + "estimate_joint_velocity": true, + "joint_cutoff_frequency": 15.0, + "joint_position_variance": 0.0005, + "angular_momentum_cutoff_frequency": 5.0, + "tau_0": 0.1, + "tau_1": 0.0, + "imu_rate": 500.0, + "R_base_to_gyro": [ + 1.0, + 0.0, + 0.0, + 0.0, + -1.0, + 0.0, + 0.0, + 0.0, + -1.0 + ], + "R_base_to_acc": [ + 1.0, + 0.0, + 0.0, + 0.0, + -1.0, + 0.0, + 0.0, + 0.0, + -1.0 + ], + "imu_outlier_detection": false, + "use_imu_orientation": false, + "calibrate_initial_imu_bias": false, + "max_imu_calibration_cycles": 1000, + "bias_gyro": [ + 0.0, + 0.0, + 0.0 + ], + "bias_acc": [ + 0.0, + 0.0, + 0.0 + ], + "gyro_cutoff_frequency": 5.0, + "force_torque_rate": 500.0, + "R_foot_to_force": { + "0": [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0 + ], + "1": [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0 + ], + "2": [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0 + ], + "3": [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0 + ] + }, + "R_foot_to_torque": null, + "attitude_estimator_proportional_gain": 0.1, + "attitude_estimator_integral_gain": 0.0, + "estimate_contact_status": true, + "high_threshold": 4.0, + "low_threshold": 2.0, + "median_window": 13, + "outlier_detection": false, + "convergence_cycles": 0, + "imu_angular_velocity_covariance": [ + 5e-5, + 5e-5, + 5e-5 + ], + "imu_angular_velocity_bias_covariance": [ + 1e-8, + 1e-8, + 1e-8 + ], + "imu_linear_acceleration_covariance": [ + 5e-3, + 5e-3, + 5e-3 + ], + "imu_linear_acceleration_bias_covariance": [ + 1e-8, + 1e-8, + 1e-8 + ], + "base_orientation_covariance": [ + 1e-3, + 1e-3, + 1e-3 + ], + "contact_position_covariance": [ + 5e-6, + 5e-6, + 5e-5 + ], + "contact_orientation_covariance": null, + "contact_position_slip_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "contact_orientation_slip_covariance": null, + "com_position_process_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "com_linear_velocity_process_covariance": [ + 1e-4, + 1e-4, + 1e-4 + ], + "external_forces_process_covariance": [ + 1e-1, + 1e-1, + 1e-1 + ], + "com_position_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "com_linear_acceleration_covariance": [ + 1e-2, + 1e-2, + 1e-2 + ], + "initial_base_position_covariance": [ + 1e-4, + 1e-4, + 1e-4 + ], + "initial_base_orientation_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "initial_base_linear_velocity_covariance": [ + 1e-3, + 1e-3, + 1e-3 + ], + "initial_contact_position_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "initial_contact_orientation_covariance": null, + "initial_imu_linear_acceleration_bias_covariance": [ + 1e-4, + 1e-4, + 1e-4 + ], + "initial_imu_angular_velocity_bias_covariance": [ + 1e-4, + 1e-4, + 1e-4 + ], + "initial_com_position_covariance": [ + 1e-6, + 1e-6, + 1e-6 + ], + "initial_com_linear_velocity_covariance": [ + 1e-3, + 1e-3, + 1e-3 + ], + "initial_external_forces_covariance": [ + 1.0, + 1.0, + 1.0 + ], + "T_base_to_odom": null, + "enable_terrain_estimation": false, + "terrain_estimator": "fast", + "minimum_terrain_height_variance": 1e-3, + "maximum_contact_points": 4, + "maximum_recenter_distance": 0.35, + "minimum_contact_probability": 0.15, + "T_base_to_ground_truth": [ + 1.0, + 0.0, + 0.0, + -0.03270961, + 0.0, + 1.0, + 0.0, + -0.00222456, + 0.0, + 0.0, + 1.0, + -0.01544948, + 0.0, + 0.0, + 0.0, + 1.0 + ], + "log_data": true, + "log_measurements": true, + "log_dir": "/tmp/" +} diff --git a/config/h1.json b/config/h1.json index e2beba0..fb37797 100644 --- a/config/h1.json +++ b/config/h1.json @@ -38,6 +38,8 @@ 0.0, 1.0 ], + "imu_outlier_detection": true, + "use_imu_orientation": true, "calibrate_initial_imu_bias": false, "max_imu_calibration_cycles": 300, "bias_acc": [ diff --git a/config/nao.json b/config/nao.json index a2745e4..c3f5b94 100644 --- a/config/nao.json +++ b/config/nao.json @@ -38,6 +38,8 @@ 0.0, 1.0 ], + "imu_outlier_detection": true, + "use_imu_orientation": true, "calibrate_initial_imu_bias": false, "max_imu_calibration_cycles": 1000, "bias_acc": [ diff --git a/config/valkyrie.json b/config/valkyrie.json index 8e9dc8f..a759893 100644 --- a/config/valkyrie.json +++ b/config/valkyrie.json @@ -38,6 +38,8 @@ 0.0, -1.0 ], + "imu_outlier_detection": true, + "use_imu_orientation": true, "calibrate_initial_imu_bias": false, "max_imu_calibration_cycles": 300, "bias_acc": [ diff --git a/core/src/ContactEKF.cpp b/core/src/ContactEKF.cpp index 158fdae..cf5516f 100644 --- a/core/src/ContactEKF.cpp +++ b/core/src/ContactEKF.cpp @@ -17,7 +17,7 @@ namespace serow { void ContactEKF::init(const BaseState& state, std::set contacts_frame, bool point_feet, - double g, double imu_rate, bool outlier_detection) { + double g, double imu_rate, bool outlier_detection, bool use_imu_orientation, bool verbose) { num_leg_end_effectors_ = contacts_frame.size(); contacts_frame_ = std::move(contacts_frame); g_ = Eigen::Vector3d(0.0, 0.0, -g); @@ -115,6 +115,15 @@ void ContactEKF::init(const BaseState& state, std::set contacts_fra } last_imu_timestamp_.reset(); + + // Clear the action covariance gain matrix + clearAction(); + + use_imu_orientation_ = use_imu_orientation; + verbose_ = verbose; + if (verbose) { + std::cout << "[SEROW/ContactEKF]: Initialized successfully" << std::endl; + } } void ContactEKF::setState(const BaseState& state) { @@ -147,6 +156,9 @@ void ContactEKF::setState(const BaseState& state) { } } last_imu_timestamp_ = state.timestamp; + + // Clear the action covariance gain matrix + clearAction(); } std::tuple ContactEKF::computePredictionJacobians( @@ -190,9 +202,11 @@ void ContactEKF::predict(BaseState& state, const ImuMeasurement& imu, return; } if (dt < nominal_dt_ / 2) { - std::cout << "[SEROW/ContactEKF]: Predict step sample time is abnormal " << dt - << " while the nominal sample time is " << nominal_dt_ << " setting to nominal" - << std::endl; + if (verbose_) { + std::cout << "[SEROW/ContactEKF]: Predict step sample time is abnormal " << dt + << " while the nominal sample time is " << nominal_dt_ + << " setting to nominal" << std::endl; + } dt = nominal_dt_; } @@ -227,6 +241,13 @@ void ContactEKF::predict(BaseState& state, const ImuMeasurement& imu, computeDiscreteDynamics(state, dt, imu.angular_velocity, imu.linear_acceleration, kin.contacts_status, kin.contacts_position, kin.contacts_orientation); last_imu_timestamp_ = imu.timestamp; + + // Clear the action covariance gain matrix + clearAction(); + + // Clear the innovation and update buffers + contact_position_innovation_.clear(); + contact_orientation_innovation_.clear(); } void ContactEKF::computeDiscreteDynamics( @@ -293,6 +314,11 @@ void ContactEKF::updateWithContactPosition(BaseState& state, const std::string& const Eigen::Matrix3d R_base_transpose = R_base.transpose(); cp_noise += position_cov; + const double action = contact_position_action_cov_gain_.at(cf); + if (action > 0.0) { + cp_noise *= action; + } + // If the terrain estimator is in the loop reduce the effect that kinematics has in the // contact height update if (terrain_estimator) { @@ -366,6 +392,7 @@ void ContactEKF::updateWithContactPosition(BaseState& state, const std::string& P_ = std::move(P_i); state = std::move(updated_state_i); } + contact_position_innovation_[cf] = std::make_pair(z, s + 1e-8 * Eigen::Matrix3d::Identity()); } void ContactEKF::updateWithContactOrientation(BaseState& state, const std::string& cf, @@ -376,6 +403,11 @@ void ContactEKF::updateWithContactOrientation(BaseState& state, const std::strin return; } co_noise += orientation_cov; + // Check if the action covariance gain matrix is not the zero matrix + const double action = contact_orientation_action_cov_gain_.at(cf); + if (action > 0.0) { + co_noise *= action; + } const int num_iter = 5; Eigen::MatrixXd H(3, num_states_); @@ -406,6 +438,7 @@ void ContactEKF::updateWithContactOrientation(BaseState& state, const std::strin } P_ = (I_ - K * H) * P_; } + contact_orientation_innovation_[cf] = std::make_pair(z, s + 1e-8 * Eigen::Matrix3d::Identity()); } void ContactEKF::updateWithContacts( @@ -615,7 +648,9 @@ void ContactEKF::update(BaseState& state, const ImuMeasurement& imu, } // Update the state with the absolute IMU orientation - updateWithIMUOrientation(state, imu.orientation, imu.orientation_cov); + if (use_imu_orientation_) { + updateWithIMUOrientation(state, imu.orientation, imu.orientation_cov); + } // Update the state with the relative to base contacts updateWithContacts(state, kin.contacts_position, kin.contacts_position_noise, @@ -667,4 +702,53 @@ void ContactEKF::updateWithIMUOrientation(BaseState& state, updateState(state, dx, P_); } +void ContactEKF::setAction(const std::string& cf, const Eigen::VectorXd& action) { + if (action.size() == 1) { + contact_position_action_cov_gain_.at(cf) = action(0); + } else if (action.size() == 2) { + contact_position_action_cov_gain_.at(cf) = action(0); + if (!point_feet_) { + contact_orientation_action_cov_gain_.at(cf) = action(1); + } + } else { + throw std::invalid_argument("Action vector must have 1 or 2 elements"); + } +} + +bool ContactEKF::getContactPositionInnovation(const std::string& contact_frame, + Eigen::Vector3d& innovation, + Eigen::Matrix3d& covariance) const { + if (contact_position_innovation_.find(contact_frame) != contact_position_innovation_.end()) { + innovation = contact_position_innovation_.at(contact_frame).first; + covariance = contact_position_innovation_.at(contact_frame).second; + return true; + } + return false; +} + +bool ContactEKF::getContactOrientationInnovation(const std::string& contact_frame, + Eigen::Vector3d& innovation, + Eigen::Matrix3d& covariance) const { + if (point_feet_) { + return false; + } + + if (contact_orientation_innovation_.find(contact_frame) != + contact_orientation_innovation_.end()) { + innovation = contact_orientation_innovation_.at(contact_frame).first; + covariance = contact_orientation_innovation_.at(contact_frame).second; + return true; + } + return false; +} + +void ContactEKF::clearAction() { + for (const auto& cf : contacts_frame_) { + contact_position_action_cov_gain_[cf] = 0.0; + if (!point_feet_) { + contact_orientation_action_cov_gain_[cf] = 0.0; + } + } +} + } // namespace serow diff --git a/core/src/ContactEKF.hpp b/core/src/ContactEKF.hpp index a8be98a..58ffc09 100644 --- a/core/src/ContactEKF.hpp +++ b/core/src/ContactEKF.hpp @@ -58,9 +58,12 @@ class ContactEKF { * @param g Acceleration due to gravity. * @param imu_rate IMU update rate. * @param outlier_detection Flag indicating if outlier detection mechanisms should be enabled. + * @param use_imu_orientation Flag indicating if IMU orientation is used during the update step. + * @param verbose Flag indicating if verbose output should be enabled. */ void init(const BaseState& state, std::set contacts_frame, bool point_feet, - double g, double imu_rate, bool outlier_detection = false); + double g, double imu_rate, bool outlier_detection = false, bool use_imu_orientation = false, + bool verbose = false); /** * @brief Predicts the robot's state forward based on IMU and kinematic measurements. * @param state Current state of the robot. @@ -111,6 +114,37 @@ class ContactEKF { void updateWithIMUOrientation(BaseState& state, const Eigen::Quaterniond& imu_orientation, const Eigen::Matrix3d& imu_orientation_cov); + /** + * @brief Sets the action for the contact estimator + * @param cf Contact frame name + * @param action Action + */ + void setAction(const std::string& cf, const Eigen::VectorXd& action); + + /** + * @brief Clears the action covariance gain matrix + */ + void clearAction(); + + /** + * @brief Gets the contact position innovation + * @param contact_frame Contact frame name + * @param innovation Contact position innovation + * @param covariance Contact position covariance + */ + bool getContactPositionInnovation(const std::string& contact_frame, Eigen::Vector3d& innovation, + Eigen::Matrix3d& covariance) const; + + /** + * @brief Gets the contact orientation innovation + * @param contact_frame Contact frame name + * @param innovation Contact orientation innovation + * @param covariance Contact orientation covariance + */ + bool getContactOrientationInnovation(const std::string& contact_frame, + Eigen::Vector3d& innovation, + Eigen::Matrix3d& covariance) const; + private: int num_states_{}; ///< Number of state variables. int num_inputs_{}; ///< Number of input variables. @@ -149,6 +183,16 @@ class ContactEKF { OutlierDetector contact_outlier_detector; ///< Outlier detector instance. + std::map contact_position_action_cov_gain_; + std::map contact_orientation_action_cov_gain_; + + std::map> contact_position_innovation_; + std::map> + contact_orientation_innovation_; + + bool verbose_{}; ///< Flag indicating if verbose output is enabled. + bool use_imu_orientation_{}; ///< Flag indicating if IMU orientation is used during the update step. + /** * @brief Computes discrete dynamics for the prediction step of the EKF. * @param state Current state of the robot. diff --git a/core/src/Serow.cpp b/core/src/Serow.cpp index d68d961..4fb6054 100644 --- a/core/src/Serow.cpp +++ b/core/src/Serow.cpp @@ -159,6 +159,12 @@ bool Serow::initialize(const std::string& config_file) { if (!checkConfigParam("point_feet", params_.point_feet)) return false; + if (!checkConfigParam("use_imu_orientation", params_.use_imu_orientation)) + return false; + + if (!checkConfigParam("imu_outlier_detection", params_.imu_outlier_detection)) + return false; + if (!checkConfigParam("imu_rate", params_.imu_rate)) return false; @@ -1219,12 +1225,14 @@ bool Serow::filter(ImuMeasurement imu, std::map j timers_["joint-estimation"].stop(); // Check if the IMU measurements are valid with the Median Absolute Deviation (MAD) - timers_["imu-outlier-detection"].start(); - bool is_imu_outlier = isImuMeasurementOutlier(imu); - timers_["imu-outlier-detection"].stop(); - if (is_imu_outlier) { - timers_["total-time"].stop(); - return false; + if (params_.imu_outlier_detection) { + timers_["imu-outlier-detection"].start(); + bool is_imu_outlier = isImuMeasurementOutlier(imu); + timers_["imu-outlier-detection"].stop(); + if (is_imu_outlier) { + timers_["total-time"].stop(); + return false; + } } // Estimate the base frame attitude and initial IMU biases @@ -1474,7 +1482,7 @@ void Serow::reset() { // Initialize the base and CoM estimators base_estimator_.init(state_.base_state_, state_.getContactsFrame(), state_.isPointFeet(), - params_.g, params_.imu_rate, params_.outlier_detection); + params_.g, params_.imu_rate, params_.outlier_detection, params_.use_imu_orientation); com_estimator_.init(state_.centroidal_state_, state_.getMass(), params_.g, params_.force_torque_rate); @@ -1591,4 +1599,172 @@ void Serow::logTimings() { }); } +// RL-specific functions +std::tuple> +Serow::processMeasurements( + ImuMeasurement imu, std::map joints, + std::optional> force_torque, + std::optional> contacts_probability) { + // Check if foot frames exist on the F/T measurement + std::map ft; + if (force_torque.has_value()) { + for (const auto& frame : state_.contacts_frame_) { + if (force_torque.value().count(frame) == 0) { + throw std::runtime_error("Foot frame <" + frame + + "> does not exist in the force measurements"); + } + } + // Force-torque measurements are valid and ready to be consumed + ft = std::move(force_torque.value()); + } + + // Update the joint state estimate + runJointsEstimator(state_, joints); + + // Estimate the base frame attitude and initial IMU biases + runImuEstimator(state_, imu); + + // Update the kinematic structure + KinematicMeasurement kin = runForwardKinematics(state_); + // filter() uses timestamp_ instead which is not set here + kin.timestamp = joints.begin()->second.timestamp; + + // Estimate the contact state + if (!ft.empty()) { + runContactEstimator(state_, ft, kin, contacts_probability); + } + + // Compute the leg odometry and update the kinematic measurement accordingly + computeLegOdometry(state_, imu, kin); + + // Return the measurements + return std::make_tuple(imu, kin, ft); +} + +void Serow::baseEstimatorPredictStep(const ImuMeasurement& imu, const KinematicMeasurement& kin) { + // Initialize terrain estimator if needed + if (params_.enable_terrain_estimation && !terrain_estimator_) { + float terrain_height = 0.0; + int i = 0; + + for (const auto& [cf, cp] : state_.contact_state_.contacts_status) { + if (cp) { + i++; + terrain_height += state_.base_state_.contacts_position.at(cf).z(); + } + } + + if (i > 0) { + terrain_height /= i; + } + + // Initialize terrain elevation mapper + if (params_.terrain_estimator_type == "naive") { + terrain_estimator_ = std::make_shared(); + } else if (params_.terrain_estimator_type == "fast") { + terrain_estimator_ = std::make_shared(); + } else { + throw std::runtime_error("Invalid terrain estimator type: " + + params_.terrain_estimator_type); + } + terrain_estimator_->initializeLocalMap(terrain_height, 1e4, + params_.minimum_terrain_height_variance); + terrain_estimator_->recenter({static_cast(state_.base_state_.base_position.x()), + static_cast(state_.base_state_.base_position.y())}); + } + + // Call the base estimator predict step + state_.base_state_.timestamp = imu.timestamp; + base_estimator_.predict(state_.base_state_, imu, kin); +} + +void Serow::baseEstimatorUpdateWithContactPosition(const std::string& cf, + const KinematicMeasurement& kin) { + state_.base_state_.timestamp = kin.timestamp; + const bool cs = kin.contacts_status.at(cf); + const Eigen::Vector3d& cp = kin.contacts_position.at(cf); + const Eigen::Matrix3d& cp_noise = kin.contacts_position_noise.at(cf); + const Eigen::Matrix3d& position_cov = kin.position_cov; + base_estimator_.updateWithContactPosition(state_.base_state_, cf, cs, cp, cp_noise, + position_cov, terrain_estimator_); +} + +void Serow::baseEstimatorUpdateWithImuOrientation(const ImuMeasurement& imu) { + state_.base_state_.timestamp = imu.timestamp; + base_estimator_.updateWithIMUOrientation(state_.base_state_, imu.orientation, + imu.orientation_cov); +} + +void Serow::baseEstimatorFinishUpdate(const ImuMeasurement& imu, const KinematicMeasurement& kin) { + const Eigen::Isometry3d base_pose = state_.getBasePose(); + // Estimate base angular velocity and linear acceleration + const Eigen::Vector3d base_angular_velocity = + imu.angular_velocity - state_.getImuAngularVelocityBias(); + const Eigen::Vector3d base_linear_acceleration = + base_pose.linear() * (imu.linear_acceleration - state_.getImuLinearAccelerationBias()) - + Eigen::Vector3d(0.0, 0.0, params_.g); + if (!gyro_derivative_estimator) { + gyro_derivative_estimator = std::make_unique( + "Gyro Derivative", params_.imu_rate, params_.gyro_cutoff_frequency, 3); + if (state_.isInitialized()) { + const Eigen::Matrix3d R_base_to_world = base_pose.linear().transpose(); + gyro_derivative_estimator->setState( + R_base_to_world * state_.base_state_.base_angular_velocity, + R_base_to_world * state_.base_state_.base_angular_acceleration); + } else { + gyro_derivative_estimator->setState(base_angular_velocity, Eigen::Vector3d::Zero()); + } + } + const Eigen::Vector3d base_angular_acceleration = + gyro_derivative_estimator->filter(base_angular_velocity, imu.timestamp); + state_.base_state_.base_angular_velocity = base_pose.linear() * base_angular_velocity; + state_.base_state_.base_angular_acceleration = base_pose.linear() * base_angular_acceleration; + state_.base_state_.base_linear_acceleration = base_linear_acceleration; + + // Update feet pose/velocity in world frame + for (const auto& frame : state_.getContactsFrame()) { + // Cache calculations + const Eigen::Vector3d& base_foot_pos = kin.base_to_foot_positions.at(frame); + const Eigen::Vector3d transformed_pos = base_pose.linear() * base_foot_pos; + + state_.base_state_.feet_position[frame].noalias() = base_pose * base_foot_pos; + state_.base_state_.feet_orientation[frame] = Eigen::Quaterniond( + base_pose.linear() * kin.base_to_foot_orientations.at(frame).toRotationMatrix()); + + state_.base_state_.feet_linear_velocity[frame].noalias() = + state_.base_state_.base_linear_velocity + + state_.base_state_.base_angular_velocity.cross(transformed_pos) + + base_pose.linear() * kin.base_to_foot_linear_velocities.at(frame); + + state_.base_state_.feet_angular_velocity[frame].noalias() = + state_.base_state_.base_angular_velocity + + base_pose.linear() * kin.base_to_foot_angular_velocities.at(frame); + } + + // Update all frame transformations + updateFrameTree(state_); + + // Check if state has converged + if (!state_.is_valid_ && cycle_++ > params_.convergence_cycles) { + state_.is_valid_ = true; + } +} + +bool Serow::setAction(const std::string& cf, const Eigen::VectorXd& action) { + base_estimator_.setAction(cf, action); + return true; +} + +bool Serow::getContactPositionInnovation(const std::string& contact_frame, + Eigen::Vector3d& innovation, + Eigen::Matrix3d& covariance) const { + return base_estimator_.getContactPositionInnovation(contact_frame, innovation, covariance); +} + +bool Serow::getContactOrientationInnovation(const std::string& contact_frame, + Eigen::Vector3d& innovation, + Eigen::Matrix3d& covariance) const { + return base_estimator_.getContactOrientationInnovation(contact_frame, innovation, covariance); +} + } // namespace serow diff --git a/core/src/Serow.hpp b/core/src/Serow.hpp index cca4342..dd94f1c 100644 --- a/core/src/Serow.hpp +++ b/core/src/Serow.hpp @@ -102,6 +102,59 @@ class Serow { /// @param state the state to set void setState(const State& state); + /// @brief Processes the measurements and returns the IMU, kinematic, and force/torque + /// measurements + /// @param imu IMU measurement + /// @param joints joint measurements + /// @param force_torque force/torque measurements + /// @param contacts_probability contact probabilities + /// @return tuple containing the IMU, kinematic, and force/torque measurements + std::tuple> + processMeasurements( + ImuMeasurement imu, std::map joints, + std::optional> force_torque, + std::optional> contacts_probability); + + /// @brief Runs the base estimator predict step + /// @param imu IMU measurement + /// @param kin kinematic measurements + void baseEstimatorPredictStep(const ImuMeasurement& imu, const KinematicMeasurement& kin); + + /// @brief Runs the base estimator update step with IMU orientation + /// @param imu IMU measurement + void baseEstimatorUpdateWithImuOrientation(const ImuMeasurement& imu); + + /// @brief Runs the base estimator update step with contact position + /// @param cf contact frame name + /// @param kin kinematic measurements + void baseEstimatorUpdateWithContactPosition(const std::string& cf, + const KinematicMeasurement& kin); + + /// @brief Runs the base estimator finish update step + /// @param imu IMU measurement + /// @param kin kinematic measurements + void baseEstimatorFinishUpdate(const ImuMeasurement& imu, const KinematicMeasurement& kin); + + /// @brief Sets the action for the base estimator + /// @param cf contact frame name + /// @param action action + bool setAction(const std::string& cf, const Eigen::VectorXd& action); + + /// @brief Gets the contact position innovation + /// @param contact_frame contact frame name + /// @param innovation contact position innovation + /// @param covariance contact position covariance + bool getContactPositionInnovation(const std::string& contact_frame, Eigen::Vector3d& innovation, + Eigen::Matrix3d& covariance) const; + + /// @brief Gets the contact orientation innovation + /// @param contact_frame contact frame name + /// @param innovation contact orientation innovation + /// @param covariance contact orientation covariance + bool getContactOrientationInnovation(const std::string& contact_frame, + Eigen::Vector3d& innovation, + Eigen::Matrix3d& covariance) const; + private: struct Params { /// @brief name of the robot @@ -263,6 +316,10 @@ class Serow { std::set contacts_frame{}; /// @brief whether or not the robot has point feet bool point_feet{false}; + /// @brief whether or not to use the IMU orientation during the ContactEKF update step + bool use_imu_orientation{false}; + /// @brief whether or not to use the IMU outlier detection during the filter step + bool imu_outlier_detection{false}; }; /// @brief SEROW's configuration @@ -320,13 +377,13 @@ class Serow { /// @brief Timestamp of the estimated state double timestamp_{}; /// @brief Timestamp of the last IMU measurement - double last_imu_timestamp_{}; + double last_imu_timestamp_{-1.0}; /// @brief Timestamp of the last joint measurement - double last_joint_timestamp_{}; + double last_joint_timestamp_{-1.0}; /// @brief Timestamp of the last force/torque measurement - double last_ft_timestamp_{}; + double last_ft_timestamp_{-1.0}; /// @brief Timestamp of the last odometry measurement - double last_odom_timestamp_{}; + double last_odom_timestamp_{-1.0}; /// @brief Logs the measurements /// @param imu IMU measurement diff --git a/evaluation/mujoco_test/serow_viz.py b/evaluation/mujoco_test/serow_viz.py index 5ad6789..00330de 100644 --- a/evaluation/mujoco_test/serow_viz.py +++ b/evaluation/mujoco_test/serow_viz.py @@ -123,20 +123,24 @@ def compute_ATE_pos(gt_pos, est_x, est_y, est_z): def compute_ATE_rot(gt_rot, est_rot_w, est_rot_x, est_rot_y, est_rot_z): est_rot = np.column_stack((est_rot_w, est_rot_x, est_rot_y, est_rot_z)) rotation_errors = np.zeros((gt_rot.shape[0])) + for i in range(gt_rot.shape[0]): q_gt = gt_rot[i] q_est = est_rot[i] - + + # Ensure quaternions are on the same hemisphere + if np.dot(q_gt, q_est) < 0: + q_est = -q_est + q_gt_conj = np.array( [q_gt[0], -q_gt[1], -q_gt[2], -q_gt[3]] - ) # Conjugate of q_gt + ) q_rel = quaternion_multiply(q_gt_conj, q_est) rotation_errors[i] = 2 * np.arccos(np.clip(q_rel[0], -1.0, 1.0)) - + ate_rot = np.sqrt(np.mean(rotation_errors**2)) return ate_rot - def quaternion_multiply(q1, q2): """ Multiply two quaternions q1 and q2. diff --git a/python/go2_test.zip b/python/go2_test.zip new file mode 100644 index 0000000..9005000 Binary files /dev/null and b/python/go2_test.zip differ diff --git a/python/requirements.txt b/python/requirements.txt index 11d79a3..5845bf4 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -3,3 +3,5 @@ matplotlib>=3.7.0,<3.8.0 pybind11>=2.11.1 # Required for serow mcap>=1.2.2 flatbuffers>=25.2.10 +gymnasium>=0.29.0 +stable-baselines3>=2.0.0 diff --git a/python/serow/env.py b/python/serow/env.py new file mode 100644 index 0000000..e5bacdf --- /dev/null +++ b/python/serow/env.py @@ -0,0 +1,425 @@ +import serow +import numpy as np +import gymnasium as gym +import copy +from utils import ( + quaternion_to_rotation_matrix, + logMap, + sync_and_align_data, + plot_trajectories, +) + + +class SerowEnv(gym.Env): + def __init__( + self, + contact_frame, + robot, + joint_state, + base_state, + contact_state, + action_dim, + state_dim, + imu_data, + joint_data, + ft_data, + gt_data, + history_size=20, + ): + super(SerowEnv, self).__init__() + + # Environment parameters + self.robot = robot + self.contact_frame = contact_frame + self.serow_framework = serow.Serow() + self.serow_framework.initialize(f"{robot}_rl.json") + self.initial_state = self.serow_framework.get_state(allow_invalid=True) + self.initial_state.set_joint_state(joint_state) + self.initial_state.set_base_state(base_state) + self.initial_state.set_contact_state(contact_state) + self.serow_framework.set_state(self.initial_state) + self.contact_frames = [cf for cf in contact_state.contacts_status.keys()] + self.action_dim = action_dim + self.state_dim = state_dim + + # Action space - discrete choices for measurement noise scaling + self.discrete_actions = np.array( + [ + 1e-5, + 1e-4, + 1e-3, + 1e-2, + 1e-1, + 1.0, + 5.0, + 10.0, + 1000.0, + ], + dtype=np.float32, + ) + + self.action_space = gym.spaces.Discrete(len(self.discrete_actions)) + + # Observation space + self.observation_space = gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32 + ) + + # Training data + max_steps = len(imu_data) + self.raw_imu_data = imu_data[:max_steps] + self.raw_joint_data = joint_data[:max_steps] + self.raw_ft_data = ft_data[:max_steps] + self.gt_data = gt_data[:max_steps] + self.valid_prediction = False + self.max_steps = max_steps + self.history_size = history_size + measurement_history = [ + np.zeros(3, dtype=np.float32) for _ in range(self.history_size) + ] + self.measurement_history = { + cf: measurement_history for cf in self.contact_frames + } + action_history = [ + np.zeros((self.action_dim,), dtype=np.float32) + for _ in range(self.history_size) + ] + self.action_history = {cf: action_history for cf in self.contact_frames} + self.previous_action = {cf: None for cf in self.contact_frames} + + # Compute the baseline rewards, imu data, and kinematics + ( + _, + _, + _, + _, + _, + self.baseline_rewards, + self.imu_data, + self.kinematics, + ) = self.evaluate( + model=None, + stats=None, + plot=False, + sync=False, + ) + self.reset() + + def _compute_reward(self, cf, state, gt, action): + reward = 0.0 + done = False + success = False + + # Position error + position_error = np.linalg.norm(state.get_base_position() - gt.position, 1) + + # Orientation error + orientation_error = np.linalg.norm( + logMap( + quaternion_to_rotation_matrix(gt.orientation).transpose() + @ quaternion_to_rotation_matrix(state.get_base_orientation()), + ), + 1, + ) + + # Compute innovation and S + success, innovation, covariance = ( + self.serow_framework.get_contact_position_innovation(cf) + ) + + max_position_error = 3.0 + max_nis = 10.0 + if success: + nis = innovation @ np.linalg.inv(covariance) @ innovation.T + nis = np.clip(nis, 0, max_nis) + position_reward = np.exp(-5.0 * position_error) + innovation_reward = np.exp(-500.0 * nis) + orientation_reward = np.exp(-10.0 * orientation_error) + action_penalty = 0.0 + if self.previous_action[cf] is not None: + action_penalty = abs(action.item() - self.previous_action[cf].item()) + + reward = ( + 0.1 * innovation_reward.item() + + 0.5 * position_reward.item() + + 0.4 * orientation_reward.item() + - 0.01 * action_penalty + ) + if hasattr(self, "baseline_rewards"): + reward = reward - self.baseline_rewards[self.step_count][cf] + + done = position_error > max_position_error + return reward, done, success + + def _get_observation(self, cf, state, kin): + if not kin.contacts_status[cf] or state.get_contact_position(cf) is None: + return np.zeros((self.state_dim,)) + + R_base = quaternion_to_rotation_matrix(state.get_base_orientation()).transpose() + local_pos = R_base @ ( + state.get_contact_position(cf) - state.get_base_position() + ) + local_kin_pos = kin.contacts_position[cf] + innovation = local_kin_pos - local_pos + R = (kin.contacts_position_noise[cf] + kin.position_cov).flatten() + + # Ensure histories are properly sized before computing observation + while len(self.measurement_history[cf]) > self.history_size: + self.measurement_history[cf].pop(0) + while len(self.action_history[cf]) > self.history_size: + self.action_history[cf].pop(0) + + measurement_history = np.array(self.measurement_history[cf]).flatten() + action_history = np.array(self.action_history[cf]).flatten() + + obs = np.concatenate( + [ + innovation, + R, + state.get_base_linear_velocity(), + state.get_base_orientation(), + measurement_history, + action_history, + ], + axis=0, + ).astype(np.float32) + + return obs + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + + self.serow_framework.reset() + self.serow_framework.set_state(self.initial_state) + self.step_count = 0 + self.valid_prediction = False + measurement_history = [ + np.zeros(3, dtype=np.float32) for _ in range(self.history_size) + ] + self.measurement_history = { + cf: measurement_history for cf in self.contact_frames + } + action_history = [ + np.zeros((self.action_dim,), dtype=np.float32) + for _ in range(self.history_size) + ] + self.action_history = {cf: action_history for cf in self.contact_frames} + self.previous_action = {cf: None for cf in self.contact_frames} + obs = np.zeros((self.state_dim,)) + return obs, {} + + def get_observation_for_action(self): + """Get the observation that should be used for action computation.""" + self.valid_prediction = False + obs = np.zeros((self.state_dim,)) + # Run prediction step with current control input + imu = self.imu_data[self.step_count] + kin = self.kinematics[self.step_count] + next_kin = self.kinematics[self.step_count + 1] + prior_state = self.predict_step(imu, kin) + + if ( + kin.contacts_status[self.contact_frame] + and next_kin.contacts_status[self.contact_frame] + and prior_state.get_contact_position(self.contact_frame) is not None + ): + self.valid_prediction = True + # Get the observation that the policy should use + obs = self._get_observation(self.contact_frame, prior_state, kin) + else: + self.valid_prediction = False + + return obs + + def step(self, action): + reward = 0.0 + done = False + truncated = False + valid = self.valid_prediction + obs = np.zeros((self.state_dim,)) + imu = self.imu_data[self.step_count] + kin = self.kinematics[self.step_count] + next_kin = self.kinematics[self.step_count + 1] + # Map the action to the discrete action + action = np.array([self.discrete_actions[action]], dtype=np.float32) + + if valid: + post_state = self.update_step( + self.contact_frame, + kin, + action, + ) + + # Compute the reward + reward, done, success = self._compute_reward( + self.contact_frame, + post_state, + self.gt_data[self.step_count], + action, + ) + + # Save the action and measurement + self.previous_action[self.contact_frame] = action + self.action_history[self.contact_frame].append(action) + self.measurement_history[self.contact_frame].append( + np.abs(kin.contacts_position[self.contact_frame]) + ) + + # Get the observation + obs = self._get_observation(self.contact_frame, post_state, next_kin) + if not success or np.all(obs == np.zeros((self.state_dim,))): + valid = False + + info = {"step_count": self.step_count, "reward": reward, "valid": valid} + + for cf in self.contact_frames: + if cf == self.contact_frame and valid: + continue + self.update_step(cf, kin, np.array([0.0], dtype=np.float32)) + + self.serow_framework.base_estimator_finish_update(imu, kin) + self.step_count += 1 + + truncated = self.step_count == self.max_steps - 1 + if truncated: + done = True + + return obs, float(reward), bool(done), bool(truncated), info + + def predict_step(self, imu, kin): + # Predict the base state + self.serow_framework.base_estimator_predict_step(imu, kin) + + # Get the state + state = self.serow_framework.get_state(allow_invalid=True) + return state + + def update_step(self, cf, kin, action): + # Set the action + self.serow_framework.set_action(cf, action) + + # Run the update step with the contact position + self.serow_framework.base_estimator_update_with_contact_position(cf, kin) + + # Get the state + state = self.serow_framework.get_state(allow_invalid=True) + return state + + def render(self, mode="human"): + if mode == "human": + print(f"Step: {self.step_count}") + + def evaluate(self, model=None, stats=None, plot=True, sync=True): + # After training, evaluate the policy + self.reset() + + # Run SEROW + timestamps = [] + base_positions = [] + base_orientations = [] + gt_positions = [] + gt_orientations = [] + gt_timestamps = [] + rewards = [] + kinematics = [] + imu_data = [] + for _ in range(self.max_steps): + # Run prediction step with current control input + imu = copy.copy(self.raw_imu_data[self.step_count]) + joint = copy.copy(self.raw_joint_data[self.step_count]) + ft = copy.copy(self.raw_ft_data[self.step_count]) + imu, kin, ft = self.serow_framework.process_measurements( + imu, joint, ft, None + ) + imu_data.append(imu) + kinematics.append(kin) + + # Run the predict step + post_state = self.predict_step(imu, kin) + # Run the update step with the contact positions + reward = {cf: 0.0 for cf in self.contact_frames} + for cf in self.contact_frames: + action = np.array([0.0], dtype=np.float32) + obs = np.zeros((self.state_dim,), dtype=np.float32) + if model is not None: + obs = self._get_observation(cf, post_state, kin) + if not np.all(obs == np.zeros((self.state_dim,))): + if stats is not None: + obs = np.array( + (obs - np.array(stats["obs_mean"])) + / np.sqrt(np.array(stats["obs_var"])), + dtype=np.float32, + ) + action, _ = model.predict(obs, deterministic=True) + action = np.array( + [self.discrete_actions[action.item()]], dtype=np.float32 + ) + post_state = self.update_step(cf, kin, action) + reward[cf] = self._compute_reward( + cf, post_state, self.gt_data[self.step_count], action + )[0] + if model and not np.all(obs == np.zeros((self.state_dim,))): + self.previous_action[cf] = action + self.action_history[cf].append(action) + self.measurement_history[cf].append( + np.abs(kin.contacts_position[cf]) + ) + + self.serow_framework.base_estimator_finish_update(imu, kin) + + # Save the data + timestamps.append(self.raw_imu_data[self.step_count].timestamp) + gt_positions.append(self.gt_data[self.step_count].position) + gt_orientations.append(self.gt_data[self.step_count].orientation) + gt_timestamps.append(self.gt_data[self.step_count].timestamp) + base_positions.append(post_state.get_base_position()) + base_orientations.append(post_state.get_base_orientation()) + rewards.append(reward) + + # Progress to the next sample + self.step_count += 1 + + # Convert to numpy arrays + timestamps = np.array(timestamps) + base_positions = np.array(base_positions) + base_orientations = np.array(base_orientations) + gt_positions = np.array(gt_positions) + gt_orientations = np.array(gt_orientations) + + # Sync and align the data + if sync: + ( + timestamps, + base_positions, + base_orientations, + gt_positions, + gt_orientations, + ) = sync_and_align_data( + timestamps, + base_positions, + base_orientations, + gt_timestamps, + gt_positions, + gt_orientations, + align=True, + ) + + # Plot the trajectories + if plot: + plot_trajectories( + timestamps, + base_positions, + base_orientations, + gt_positions, + gt_orientations, + ) + return ( + timestamps, + base_positions, + base_orientations, + gt_positions, + gt_orientations, + rewards, + imu_data, + kinematics, + ) diff --git a/python/serow/env_validation.py b/python/serow/env_validation.py index 860663a..a27c506 100644 --- a/python/serow/env_validation.py +++ b/python/serow/env_validation.py @@ -2,15 +2,28 @@ import serow import matplotlib.pyplot as plt import unittest +import zipfile +import os -def run_serow(dataset, robot, start_idx=0): +def run_serow(dataset, robot, start_idx = 0): serow_framework = serow.Serow() - serow_framework.initialize(f"{robot}.json") + serow_framework.initialize(f"{robot}_pytest.json") + if (start_idx > 0): + initial_state = serow_framework.get_state(allow_invalid=True) + initial_state.set_base_state(dataset["base_states"][start_idx]) + initial_state.set_joint_state(dataset["joint_states"][start_idx]) + initial_state.set_contact_state(dataset["contact_states"][start_idx]) + serow_framework.set_state(initial_state) + + idx = 0 base_positions = [] base_orientations = [] for imu, joint, ft in zip(dataset["imu"], dataset["joints"], dataset["ft"]): + if idx < start_idx: + idx += 1 + continue status = serow_framework.filter(imu, joint, ft, None) if status: state = serow_framework.get_state(allow_invalid=True) @@ -19,14 +32,43 @@ def run_serow(dataset, robot, start_idx=0): return np.array(base_positions), np.array(base_orientations) +def run_serow_per_step(dataset, robot, start_idx = 0): + serow_framework = serow.Serow() + serow_framework.initialize(f"{robot}_pytest.json") + + if (start_idx > 0): + initial_state = serow_framework.get_state(allow_invalid=True) + initial_state.set_base_state(dataset["base_states"][start_idx]) + initial_state.set_joint_state(dataset["joint_states"][start_idx]) + initial_state.set_contact_state(dataset["contact_states"][start_idx]) + serow_framework.set_state(initial_state) + + idx = 0 + base_positions = [] + base_orientations = [] + for imu, joint, ft in zip(dataset["imu"], dataset["joints"], dataset["ft"]): + if idx < start_idx: + idx += 1 + continue + imu, kin, ft = serow_framework.process_measurements(imu, joint, ft, None) + serow_framework.base_estimator_predict_step(imu, kin) + for cf in kin.contacts_status.keys(): + serow_framework.set_action(cf, np.array([1.0], dtype=np.float32)) + serow_framework.base_estimator_update_with_contact_position(cf, kin) + serow_framework.base_estimator_finish_update(imu, kin) + state = serow_framework.get_state(allow_invalid=True) + base_positions.append(state.get_base_position()) + base_orientations.append(state.get_base_orientation()) + return np.array(base_positions), np.array(base_orientations) + -def run_serow_playback(dataset, start_idx=0): +def run_serow_playback(dataset, start_idx = 0): + idx = 0 base_positions = [] base_orientations = [] - i = 0 for bs in dataset["base_states"]: - if i <= start_idx: - i += 1 + if idx < start_idx: + idx += 1 continue base_positions.append(bs.base_position) base_orientations.append(bs.base_orientation) @@ -37,55 +79,82 @@ def run_serow_playback(dataset, start_idx=0): class TestSerow(unittest.TestCase): def setUp(self): self.robot = "go2" - self.dataset = np.load(f"{self.robot}_log.npz", allow_pickle=True) + # Unzip the dataset if the unzipped file does not exist + if not os.path.exists(f"{self.robot}_test.npz"): + with zipfile.ZipFile(f"{self.robot}_test.zip", 'r') as zip_ref: + zip_ref.extractall() + self.dataset = np.load(f"{self.robot}_test.npz", allow_pickle=True) print(f"Length of dataset base states: {len(self.dataset['base_states'])}") print(f"Length of dataset joint states: {len(self.dataset['joints'])}") print(f"Length of dataset imu: {len(self.dataset['imu'])}") print(f"Length of dataset ft: {len(self.dataset['ft'])}") + print(f"Length of dataset base pose ground truth: {len(self.dataset['base_pose_ground_truth'])}") def test_serow(self): - base_positions, base_orientations = run_serow(self.dataset, self.robot) + start_idx = 25 + base_positions, base_orientations = run_serow(self.dataset, self.robot, start_idx) actual_base_positions, actual_base_orientations = run_serow_playback( - self.dataset + self.dataset, start_idx + ) + base_positions_per_step, base_orientations_per_step = run_serow_per_step( + self.dataset, self.robot, start_idx ) print(f"Base positions: {len(base_positions)}") print(f"Actual base positions: {len(actual_base_positions)}") print(f"Base orientations: {len(base_orientations)}") print(f"Actual base orientations: {len(actual_base_orientations)}") + print(f"Base positions per step: {len(base_positions_per_step)}") + print(f"Base orientations per step: {len(base_orientations_per_step)}") assert len(base_positions) == len(actual_base_positions) assert len(base_orientations) == len(actual_base_orientations) + assert len(base_positions_per_step) == len(actual_base_positions) + assert len(base_orientations_per_step) == len(actual_base_orientations) position_error = base_positions - actual_base_positions orientation_error = base_orientations - actual_base_orientations + position_error_per_step = base_positions_per_step - actual_base_positions + orientation_error_per_step = base_orientations_per_step - actual_base_orientations + per_step_vs_filter_position_error = base_positions_per_step - base_positions + per_step_vs_filter_orientation_error = base_orientations_per_step - base_orientations + print(f"Actual Position error: {position_error.sum()}, {position_error.max()}") - print( - f"Actual Orientation error: {orientation_error.sum()}, {orientation_error.max()}" - ) + print(f"Actual Orientation error: {orientation_error.sum()}, {orientation_error.max()}") + print(f"Actual Position error per step: {position_error_per_step.sum()}, {position_error_per_step.max()}") + print(f"Actual Orientation error per step: {orientation_error_per_step.sum()}, {orientation_error_per_step.max()}") + print(f"per_step() vs filter() position error: {per_step_vs_filter_position_error.sum()}, {per_step_vs_filter_position_error.max()}") + print(f"per_step() vs filter() orientation error: {per_step_vs_filter_orientation_error.sum()}, {per_step_vs_filter_orientation_error.max()}") # Plot the base position and orientation fig, axs = plt.subplots(2, 1) axs[0].plot(actual_base_positions[:, 0], label="x Actual") axs[0].plot(base_positions[:, 0], label="x SEROW") + axs[0].plot(base_positions_per_step[:, 0], label="x SEROW Per Step") axs[0].plot(actual_base_positions[:, 1], label="y Actual") axs[0].plot(base_positions[:, 1], label="y SEROW") + axs[0].plot(base_positions_per_step[:, 1], label="y SEROW Per Step") axs[0].plot(actual_base_positions[:, 2], label="z Actual") axs[0].plot(base_positions[:, 2], label="z SEROW") + axs[0].plot(base_positions_per_step[:, 2], label="z SEROW Per Step") axs[1].plot(actual_base_orientations[:, 0], label="qw Actual") axs[1].plot(base_orientations[:, 0], label="qw SEROW") + axs[1].plot(base_orientations_per_step[:, 0], label="qw SEROW Per Step") axs[1].plot(actual_base_orientations[:, 1], label="qx Actual") axs[1].plot(base_orientations[:, 1], label="qx SEROW") + axs[1].plot(base_orientations_per_step[:, 1], label="qx SEROW Per Step") axs[1].plot(actual_base_orientations[:, 2], label="qy Actual") axs[1].plot(base_orientations[:, 2], label="qy SEROW") + axs[1].plot(base_orientations_per_step[:, 2], label="qy SEROW Per Step") axs[1].plot(actual_base_orientations[:, 3], label="qz Actual") axs[1].plot(base_orientations[:, 3], label="qz SEROW") + axs[1].plot(base_orientations_per_step[:, 3], label="qz SEROW Per Step") axs[0].set_title("Base Position") axs[1].set_title("Base Orientation") diff --git a/python/serow/example.py b/python/serow/example.py index 54be6ab..8d5b7ae 100644 --- a/python/serow/example.py +++ b/python/serow/example.py @@ -56,8 +56,8 @@ def main(): g = 9.81 # Gravity constant imu_rate = 1000.0 # IMU update rate in Hz outlier_detection = True # Enable outlier detection - - ekf.init(state, contacts_frame, point_feet, g, imu_rate, outlier_detection) + use_imu_orientation = True # Use the IMU orientation during the ContactEKF update step + ekf.init(state, contacts_frame, point_feet, g, imu_rate, outlier_detection, use_imu_orientation) # Create IMU measurement imu = ImuMeasurement() diff --git a/python/serow/generate_log.py b/python/serow/generate_log.py index bcaa857..5834651 100644 --- a/python/serow/generate_log.py +++ b/python/serow/generate_log.py @@ -1,4 +1,6 @@ import numpy as np +from scipy import signal +import matplotlib.pyplot as plt from read_mcap import ( read_base_states, @@ -9,6 +11,11 @@ read_base_pose_ground_truth, read_joint_states, ) +from utils import ( + BaseVelocityGroundTruth, + logMap, + quaternion_to_rotation_matrix, +) def generate_log(robot, mcap_path): @@ -25,6 +32,72 @@ def generate_log(robot, mcap_path): contact_states = read_contact_states(mcap_path + "/serow_proprioception.mcap") joint_states = read_joint_states(mcap_path + "/serow_proprioception.mcap") + # Numerically compute the base velocity + base_velocity_ground_truth = [] + gt_linear_velocity = [] + gt_angular_velocity = [] + gt_timestamps = [] + gt_prev = None + for i, gt in enumerate(base_pose_ground_truth): + if i == 0: + w = np.zeros(3) + v = np.zeros(3) + gt_prev = gt + else: + dt = gt.timestamp - gt_prev.timestamp + R = quaternion_to_rotation_matrix(gt.orientation) + R_prev = quaternion_to_rotation_matrix(gt_prev.orientation) + w = R @ logMap(R_prev.transpose() @ R) / dt + v = (gt.position - gt_prev.position) / dt + gt_prev = gt + + gt_timestamps.append(gt.timestamp) + gt_linear_velocity.append(v) + gt_angular_velocity.append(w) + + gt_linear_velocity = np.array(gt_linear_velocity) + gt_angular_velocity = np.array(gt_angular_velocity) + smooth_gt_linear_velocity = np.zeros_like(gt_linear_velocity) + smooth_gt_angular_velocity = np.zeros_like(gt_angular_velocity) + window_size = 31 + polyorder = 3 + for j in range(3): + smooth_gt_linear_velocity[:, j] = signal.savgol_filter( + gt_linear_velocity[:, j], window_size, polyorder, mode="nearest" + ) + smooth_gt_angular_velocity[:, j] = signal.savgol_filter( + gt_angular_velocity[:, j], window_size, polyorder, mode="nearest" + ) + smooth_gt_linear_velocity = list(smooth_gt_linear_velocity) + smooth_gt_angular_velocity = list(smooth_gt_angular_velocity) + + # Plot + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8)) + + # Plot smooth linear velocity + ax1.plot(gt_timestamps, smooth_gt_linear_velocity) + ax1.set_title("Smooth Linear Velocity") + ax1.set_ylabel("Velocity (m/s)") + ax1.legend(["X", "Y", "Z"]) + ax1.grid(True) + + # Plot smooth angular velocity + ax2.plot(gt_timestamps, smooth_gt_angular_velocity) + ax2.set_title("Smooth Angular Velocity") + ax2.set_xlabel("Time (s)") + ax2.set_ylabel("Angular Velocity (rad/s)") + ax2.legend(["X", "Y", "Z"]) + ax2.grid(True) + plt.tight_layout() + plt.show() + + # Create BaseVelocityGroundTruth objects with smoothed data + for timestamp, v, w in zip( + gt_timestamps, smooth_gt_linear_velocity, smooth_gt_angular_velocity + ): + gt_vel = BaseVelocityGroundTruth(timestamp, v, w) + base_velocity_ground_truth.append(gt_vel) + dataset = { "imu": imu_measurements, "joints": joint_measurements, @@ -33,6 +106,7 @@ def generate_log(robot, mcap_path): "contact_states": contact_states, "joint_states": joint_states, "base_pose_ground_truth": base_pose_ground_truth, + "base_velocity_ground_truth": base_velocity_ground_truth, } # Save dataset to a numpy file @@ -46,6 +120,7 @@ def generate_log(robot, mcap_path): contact_states=dataset["contact_states"], joint_states=dataset["joint_states"], base_pose_ground_truth=dataset["base_pose_ground_truth"], + base_velocity_ground_truth=dataset["base_velocity_ground_truth"], ) diff --git a/python/serow/inference.py b/python/serow/inference.py new file mode 100644 index 0000000..26bf6af --- /dev/null +++ b/python/serow/inference.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 + +import numpy as np +import json +import onnx +import onnx.numpy_helper +import os +import onnxruntime as ort + +from env import SerowEnv +from train import PreStepDQN + + +class ONNXInference: + def __init__(self, robot, path): + # Initialize ONNX Runtime sessions + model_path = f"{path}/{robot}_dqn.onnx" + + print(f"Loading ONNX model from: {model_path}") + print(f"File exists: {os.path.exists(model_path)}") + + self.session = ort.InferenceSession( + model_path, + providers=["CPUExecutionProvider"], + ) + + print("Session created successfully") + print(f"Session providers: {self.session.get_providers()}") + print(f"Available providers: {ort.get_available_providers()}") + + # Get input names + inputs = self.session.get_inputs() + outputs = self.session.get_outputs() + + print(f"Number of inputs: {len(inputs)}") + print(f"Number of outputs: {len(outputs)}") + + if len(inputs) > 0: + self.input_name = inputs[0].name + print(f"Input names: {self.input_name}") + self.state_dim = inputs[0].shape[1] + else: + raise ValueError("No inputs found in ONNX model") + + if len(outputs) > 0: + print(f"Output 0 shape: {outputs[0].shape}") + print(f"Output 0 name: {outputs[0].name}") + + # Handle dynamic shapes - if shape is a list with only one element, + # it means the output dimension is dynamic and we need to infer it + if len(outputs[0].shape) == 1 and outputs[0].shape[0] == "batch_size": + # This is a dynamic shape, we'll need to infer the actual + # dimension. For now, let's use a default value or try to get + # it from the model + print( + "Warning: Dynamic output shape detected, " + "using default action dimension" + ) + self.action_dim = 1 # Default action dimension + else: + self.action_dim = outputs[0].shape[1] + else: + raise ValueError("No outputs found in ONNX model") + + # Action space - discrete choices for measurement noise scaling + self.discrete_actions = np.array( + [ + 1e-5, + 1e-4, + 1e-3, + 1e-2, + 1e-1, + 1.0, + 5.0, + 10.0, + 1000.0, + ], + dtype=np.float32, + ) + + print(f"Initialized ONNX inference for {robot} with dqn model") + print(f"State dimension: {self.state_dim}") + print(f"Action dimension: {self.action_dim}") + + def forward(self, observation, deterministic=True): + # Prepare input + observation = np.array(observation, dtype=np.float32).reshape(1, -1) + output = self.session.run(None, {self.input_name: observation}) + + # DQN outputs Q-values, we need to select the action with highest Q-value + q_values = output[0] + action = np.argmax(q_values, axis=1) + return action, q_values + + def predict(self, observation, deterministic=True): + """ + Predict action given observation. + Matches the interface expected by SerowEnv.evaluate(). + Returns action and value + """ + return self.forward(observation, deterministic=deterministic) + + +def get_onnx_weights_biases(onnx_model_path): + """ + Loads an ONNX model from its file path and extracts its initializers + (weights and biases). + + Args: + onnx_model_path (str): Path to the ONNX model file. + + Returns: + dict: A dictionary where keys are the names of the initializers + (weights/biases) and values are their NumPy array + representations. + """ + # Load the ONNX model from the file path + onnx_model = onnx.load(onnx_model_path) + weights_biases = {} + for initializer in onnx_model.graph.initializer: + # Initializers are TensorProto objects, use numpy_helper to convert + # to numpy array + np_array = onnx.numpy_helper.to_array(initializer) + weights_biases[initializer.name] = np_array + return weights_biases + + +def compare_onnx_dqn_predictions(agent_onnx, agent_dqn, state_dim): + """ + Compares the predictions of an ONNX model with a DQN model. + + Args: + agent_onnx (ONNXInference): ONNX model to compare with. + agent_dqn (PreStepDQN): DQN model to compare with. + state_dim (int): Dimension of the state space. + """ + + # Generate a few random observations + for i in range(100): + # Generate a random observation + obs = np.random.randn(1, state_dim).astype(np.float32) + + # Get the DQN model prediction + dqn_action, _ = agent_dqn.predict(obs, deterministic=True) + dqn_action = np.array( + [agent_onnx.discrete_actions[dqn_action.item()]], dtype=np.float32 + ) + + # Get the ONNX model prediction + onnx_action, _ = agent_onnx.predict(obs, deterministic=True) + onnx_action = np.array( + [agent_onnx.discrete_actions[onnx_action.item()]], dtype=np.float32 + ) + + # Compare the actions + assert np.allclose(dqn_action, onnx_action, atol=1e-4) + print("ONNX and DQN action predictions match") + + +if __name__ == "__main__": + # Initialize ONNX inference + robot = "go2" + device = "cpu" + model_dir = "models" + + # Read the data + test_dataset = np.load(f"{robot}_log.npz", allow_pickle=True) + try: + stats = json.load(open(f"{model_dir}/{robot}_stats.json")) + except FileNotFoundError: + stats = None + + # Get contacts frame from the first measurement + contact_states = test_dataset["contact_states"] + contacts_frame = list(contact_states[0].contacts_status.keys()) + history_size = 100 + state_dim = 3 + 9 + 3 + 4 + 3 * history_size + 1 * history_size + action_dim = 1 # Based on the action vector used in ContactEKF.setAction() + + # Load the saved DQN model + try: + agent_dqn = PreStepDQN.load(f"models/{robot}_dqn") + print("Loaded DQN model successfully") + except Exception as e: + print(f"Could not load DQN model: {e}") + agent_dqn = None + + # Load the ONNX model + agent_onnx = ONNXInference(robot, path="models") + + # Compare the ONNX model predictions with the DQN model predictions + if agent_dqn is not None: + compare_onnx_dqn_predictions(agent_onnx, agent_dqn, state_dim) + + test_env = SerowEnv( + contacts_frame[0], + robot, + test_dataset["joint_states"][0], + test_dataset["base_states"][0], + test_dataset["contact_states"][0], + action_dim, + state_dim, + test_dataset["imu"], + test_dataset["joints"], + test_dataset["ft"], + test_dataset["base_pose_ground_truth"], + history_size, + ) + + # Use the loaded DQN model for evaluation + if agent_dqn is not None: + ( + dqn_timestamps, + dqn_base_positions, + dqn_base_orientations, + dqn_gt_positions, + dqn_gt_orientations, + dqn_rewards, + _, + _, + ) = test_env.evaluate(agent_dqn, stats, plot=True) + + # Use the ONNX model for evaluation + ( + onnx_timestamps, + onnx_base_positions, + onnx_base_orientations, + onnx_gt_positions, + onnx_gt_orientations, + onnx_rewards, + _, + _, + ) = test_env.evaluate(agent_onnx, stats, plot=True) + + # These must be equal if DQN model was loaded + if agent_dqn is not None: + assert np.allclose(dqn_timestamps, onnx_timestamps, atol=1e-3) + assert np.allclose(dqn_base_positions, onnx_base_positions, atol=1e-3) + assert np.allclose(dqn_base_orientations, onnx_base_orientations, atol=1e-3) + print("All tests passed") + else: + print("DQN model not loaded, skipping comparison tests") diff --git a/python/serow/rl.py b/python/serow/rl.py new file mode 100644 index 0000000..017d806 --- /dev/null +++ b/python/serow/rl.py @@ -0,0 +1,723 @@ +import numpy as np +import gymnasium as gym +import pandas as pd +import matplotlib.pyplot as plt +import torch.nn as nn + +from gymnasium import spaces +from stable_baselines3 import PPO +from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.callbacks import BaseCallback, CallbackList +from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize + + +def linear_schedule(initial_value, final_value): + """Linear learning rate schedule.""" + + def schedule(progress_remaining): + return final_value + progress_remaining * (initial_value - final_value) + + return schedule + + +def compute_rolling_average(data, window_size): + """Helper to compute rolling average, padding the start.""" + if len(data) == 0: + return [] + series = pd.Series(data) + # Use .rolling().mean() with min_periods to start from the first data point + rolling_avg = series.rolling(window=window_size, min_periods=1).mean() + return rolling_avg.tolist() + + +class AutoPreStepWrapper(gym.Wrapper): + """A wrapper that automatically handles pre-step logic.""" + + def __init__(self, env): + super().__init__(env) + self.env = env + + def get_observation_for_action(self): + """Delegate to the wrapped environment's get_observation_for_action method""" + return self.env.get_observation_for_action() + + def _get_observation(self): + """Override to automatically call get_observation_for_action""" + return self.env.get_observation_for_action() + + def step(self, action): + return self.env.step(action) + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +class ValidSampleCallback(BaseCallback): + """Callback to track and handle valid/invalid samples during training.""" + + def __init__(self, verbose=0): + super(ValidSampleCallback, self).__init__(verbose) + self.valid_samples_count = 0 + self.total_samples_count = 0 + self.invalid_samples_count = 0 + + def _on_step(self) -> bool: + # Count valid vs invalid samples + infos = self.locals.get("infos", []) + for info in infos: + self.total_samples_count += 1 + if info.get("valid", True): + self.valid_samples_count += 1 + else: + self.invalid_samples_count += 1 + + # Log statistics periodically + if self.total_samples_count % 100 == 0: + valid_ratio = ( + self.valid_samples_count / self.total_samples_count + if self.total_samples_count > 0 + else 0.0 + ) + if self.verbose > 0: + print( + f"Sample validity: {valid_ratio:.2%} valid " + f"({self.valid_samples_count}/{self.total_samples_count})" + ) + + return True + + +class PreStepPPO(PPO): + """Custom PPO model that handles pre-step logic during training and evaluation.""" + + def collect_rollouts( + self, + env, + callback, + rollout_buffer, + n_rollout_steps: int, + ): + """Override collect_rollouts to use get_observation_for_action during training""" + # Store original step method to restore later + original_step_methods = [] + + # Create wrapper environments that use get_observation_for_action + if hasattr(env, "envs"): + # Vectorized environment + for i, single_env in enumerate(env.envs): + # Store original step method + original_step_methods.append(single_env.step) + + # Create a wrapper that intercepts step calls + def make_step_wrapper(env, original_step): + def step_wrapper(action): + # Call get_observation_for_action before the step + env.get_observation_for_action() + return original_step(action) + + return step_wrapper + + single_env.step = make_step_wrapper( + single_env, original_step_methods[-1] + ) + else: + # Single environment + original_step_methods.append(env.step) + + def step_wrapper(action): + # Call get_observation_for_action before the step + env.get_observation_for_action() + return original_step_methods[0](action) + + env.step = step_wrapper + + try: + # Call the parent collect_rollouts method + result = super().collect_rollouts( + env, callback, rollout_buffer, n_rollout_steps + ) + + finally: + # Restore original step methods + if hasattr(env, "envs"): + for i, single_env in enumerate(env.envs): + if i < len(original_step_methods): + single_env.step = original_step_methods[i] + else: + if len(original_step_methods) > 0: + env.step = original_step_methods[0] + + return result + + def forward(self, obs, deterministic=False): + return self.policy.forward(obs, deterministic) + + +class KalmanFilterEnv(gym.Env): + """Custom Gym environment integrating a Kalman filter.""" + + def __init__( + self, + measurement, + u, + gt, + min_action, + max_action, + process_noise=0.1, + measurement_noise=0.1, + ): + super(KalmanFilterEnv, self).__init__() + + # Environment parameters + self.measurement = measurement + self.gt = gt + self.u = u + self.process_noise = process_noise + self.measurement_noise = measurement_noise + self.max_steps = len(measurement) - 1 + + # Kalman filter parameters + self.dt = 0.001 # time step + + # State: [position, velocity] + self.state_dim = 2 + self.measurement_dim = 1 # we only measure position + self.history_length = 100 + self.measurement_history = [] + self.measurement_noise_history = [] + self.prev_action = 0.0 + + # Initialize Kalman filter matrices + self.F = np.array([[1, self.dt], [0, 1]]) # State transition matrix + self.H = np.array([[1, 0]]) # Measurement matrix + self.B = np.array([0.0, self.dt]) + + self.Q = np.array([[self.process_noise**2]]) + self.R = np.array([[self.measurement_noise**2]]) # Measurement noise + + # Action space + self.action_space = spaces.Box( + low=min_action, high=max_action, shape=(1,), dtype=np.float32 + ) + + # Observation space: [position, velocity, position covariance, velocity covariance, measurement noise, innovation] + self.observation_space = spaces.Box( + low=-np.inf, + high=np.inf, + shape=(5 + self.history_length * 2,), + dtype=np.float32, + ) + + self.reset() + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + + # Reset Kalman filter state + self.x = np.array([0.0, 0.0]) # Initial state [position, velocity] + self.P = np.eye(self.state_dim) * np.random.uniform( + 0.1, 5.0 + ) # Initial covariance + self.step_count = 0 + self.reward = 0 + + self.measurement_history = [0.0] * self.history_length + self.measurement_noise_history = [float(self.R[0, 0])] * self.history_length + self.prev_action = 0.0 + + # Get initial observation + obs = self._get_observation() + return obs, {} + + def get_observation_for_action(self): + """Get the observation that should be used for action computation.""" + # Run prediction step with current control input + if self.step_count < len(self.u): + self.predict(self.u[self.step_count]) + + # Get the observation that the policy should use + obs = self._get_observation() + return obs + + def step(self, action): + next_state, y, S = self.update(self.measurement[self.step_count], action) + + position_error = abs(next_state[0] - self.gt[self.step_count]) + position_reward = -position_error / 10.0 + + # Innovation consistency reward (clipped) + nis = float(y @ np.linalg.inv(S) @ y.T) + nis = np.clip(nis, 0, 10.0) + innovation_reward = -nis / 10.0 + + # Small action penalty to encourage smoothness + action_penalty = abs(action[0] - self.prev_action) + + self.reward = ( + 4.0 * position_reward + 0.5 * innovation_reward - 0.005 * action_penalty + ) + + # Check termination conditions + terminated = position_error > 10.0 + truncated = self.step_count == self.max_steps - 1 + + info = { + "nis": nis, + "position_error": position_error, + "step_count": self.step_count, + "reward": self.reward, + "valid": ( + True if np.random.rand() > 0.1 else False + ), # Make 10% of samples invalid + } + + self.step_count += 1 + self.prev_action = action[0] + + # Get the final observation for the next step + obs = self._get_observation() + return obs, self.reward, bool(terminated), bool(truncated), info + + def predict(self, control_input): + """Kalman filter prediction step""" + # Control input matrix (acceleration affects position and velocity) + u = np.array([control_input]) + + # Predict state + self.x = self.F @ self.x + self.B * u + + # Predict covariance + self.P = self.F @ self.P @ self.F.T + self.Q + + def update(self, measurement, action): + """Kalman filter update step""" + # Update step + R = float(action[0]) + + # Innovation + y = measurement - self.H @ self.x + + # Innovation covariance + S = self.H @ self.P @ self.H.T + R + + # Kalman gain + K = self.P @ self.H.T @ np.linalg.inv(S) + + # Updated state estimate + self.x += K @ y + + # Updated covariance + self.P = self.P - K @ self.H @ self.P + + self.measurement_history.append(measurement) + self.measurement_noise_history.append(self.R[0, 0]) + while len(self.measurement_history) > self.history_length: + self.measurement_history.pop(0) + self.measurement_noise_history.pop(0) + + return self.x, y, S + + def _get_observation(self): + """Get current observation for the agent""" + + measurement_history = np.array(self.measurement_history).flatten() + measurement_noise_history = np.array(self.measurement_noise_history).flatten() + obs = np.concatenate( + [ + [self.x[1]], # current velocity + [self.P[0, 0]], # position covariance + [self.P[1, 1]], # velocity covariance + [self.R[0, 0]], # measurement noise + measurement_history, + measurement_noise_history, + [self.measurement[self.step_count] - self.x[0]], # innovation + ], + axis=0, + dtype=np.float32, + ) + + return obs + + def render(self, mode="human"): + if mode == "human": + print( + f"Step: {self.step_count}, Position: {self.x[0]}, " + f"Target: {self.gt[self.step_count]}, Reward: {self.reward}" + ) + + +class TrainingCallback(BaseCallback): + """Custom callback for monitoring training progress""" + + def __init__(self, verbose=0): + super(TrainingCallback, self).__init__(verbose) + self.episode_rewards = [] + self.episode_lengths = [] + self.step_rewards = [] # Track rewards per step as fallback + self.episode_reward_sum = 0 + self.episode_length = 0 + self.last_dones = None + self.total_steps = 0 + + def _on_step(self) -> bool: + self.total_steps += 1 + + # Get dones and truncated to detect episode completions + dones = self.locals.get("dones", [False] * len(self.locals.get("infos", []))) + truncated = self.locals.get( + "truncated", [False] * len(self.locals.get("infos", [])) + ) + + # Only accumulate episode rewards for valid steps + valid_steps = 0 + total_reward = 0.0 + episode_completed = False + if len(self.locals["infos"]) > 0: + for i, info in enumerate(self.locals["infos"]): + if info["valid"]: + reward = info["reward"] + valid_steps += 1 + total_reward += reward + self.step_rewards.append(reward) + # Only add to episode if there were valid steps + if valid_steps > 0: + avg_valid_reward = total_reward / valid_steps + self.episode_reward_sum += avg_valid_reward + self.episode_length += 1 + + # Check if any episode completed (either done or truncated) + for _, (done, trunc) in enumerate(zip(dones, truncated)): + if done or trunc: + episode_completed = True + break + + if episode_completed and self.episode_length > 0: + self.episode_rewards.append(self.episode_reward_sum) + self.episode_lengths.append(self.episode_length) + if self.verbose > 0: + print( + f"Episode completed: reward={self.episode_reward_sum:.3f}, " + f"length={self.episode_length}" + ) + + # Reset for next episode + self.episode_reward_sum = 0 + self.episode_length = 0 + + return True + + +def generate_dataset( + n_points=1000, + t_max=1.0, + measurement_noise_std=0.1, + control_noise_std=0.05, + seed=None, +): + """Generate a random dataset for Kalman filter training.""" + if seed is not None: + np.random.seed(seed) + + # Time vector + t = np.linspace(0, t_max, n_points) + + # Generate random parameters for different trajectories + # This ensures each dataset has a different ground truth signal + freq1 = np.random.uniform(1.0, 3.0) # Random frequency for primary + freq2 = np.random.uniform(2.0, 6.0) # Random frequency for secondary + freq3 = np.random.uniform(3.0, 8.0) # Random frequency for tertiary + + amp1 = np.random.uniform(1.0, 3.0) # Random amplitude for primary + amp2 = np.random.uniform(0.2, 1.0) # Random amplitude for secondary + amp3 = np.random.uniform(0.1, 0.5) # Random amplitude for tertiary + + quad_coeff = np.random.uniform(0.1, 0.8) # Random quadratic coefficient + phase1 = np.random.uniform(0, 2 * np.pi) # Random phase shifts + phase2 = np.random.uniform(0, 2 * np.pi) + phase3 = np.random.uniform(0, 2 * np.pi) + + # Generate a smooth ground truth signal using random parameters + ground_truth = ( + amp1 * np.sin(2 * np.pi * freq1 * t + phase1) # Primary oscillation + + amp2 * np.sin(2 * np.pi * freq2 * t + phase2) # Secondary + + quad_coeff * t**2 # Quadratic trend + + amp3 * np.cos(2 * np.pi * freq3 * t + phase3) # Additional complexity + ) + + # Compute the second derivative (acceleration) analytically + true_acceleration = ( + -4 + * np.pi**2 + * freq1**2 + * amp1 + * np.sin(2 * np.pi * freq1 * t + phase1) # Second derivative of primary + - -4 + * np.pi**2 + * freq2**2 + * amp2 + * np.sin(2 * np.pi * freq2 * t + phase2) # Second derivative of secondary + + 2 * quad_coeff # Second derivative of quadratic term + + -4 + * np.pi**2 + * freq3**2 + * amp3 + * np.cos(2 * np.pi * freq3 * t + phase3) # Second derivative of cosine + ) + + # Add varying measurement noise (zero-mean Gaussian with varying std) + time_varying_std = measurement_noise_std * (1 + 0.5 * t / t_max) + measurement_noise = np.random.normal(0, time_varying_std, n_points) + measurement = ground_truth + measurement_noise + + # Add varying control noise (zero-mean Gaussian with varying std) + control_varying_std = control_noise_std * (1 + 0.2 * np.abs(true_acceleration) / np.max(np.abs(true_acceleration))) + + control_noise = np.random.normal(0, control_varying_std, n_points) + control = true_acceleration + control_noise + + return measurement, control, ground_truth + + +def visualize_dataset(measurement, control, ground_truth, save_plot=False): + """Visualize the generated dataset.""" + fig, axes = plt.subplots(2, 1, figsize=(12, 8)) + + # Plot position signals + t = np.linspace(0, 1, len(ground_truth)) + axes[0].plot(t, ground_truth, "b-", linewidth=2, label="Ground Truth") + axes[0].plot(t, measurement, "r.", markersize=1, alpha=0.6, label="Measurement") + axes[0].set_xlabel("Time") + axes[0].set_ylabel("Position") + axes[0].set_title("Position Signal") + axes[0].legend() + axes[0].grid(True, alpha=0.3) + + # Plot control signals + axes[1].plot(t, control, "g-", linewidth=1, label="Control (Acceleration)") + axes[1].set_xlabel("Time") + axes[1].set_ylabel("Acceleration") + axes[1].set_title("Control Signal") + axes[1].legend() + axes[1].grid(True, alpha=0.3) + + plt.tight_layout() + + if save_plot: + plt.savefig("generated_dataset.png", dpi=300, bbox_inches="tight") + + plt.show() + + +def main(): + # Number of parallel environments + n_envs = 8 # You can adjust this based on your CPU cores + min_action = 1e-3 + max_action = 1e3 + measurement_noise_std = 0.1 + control_noise_std = 0.25 + scale = max_action + gen_scale = min_action + + # Generate random datasets + datasets = [] + for i in range(n_envs + 1): + measurement_signal, acceleration_signal, position_signal = generate_dataset( + n_points=1000, + t_max=1.0, + measurement_noise_std=gen_scale * measurement_noise_std, + control_noise_std=control_noise_std, + seed=42 + i, # Different seed for each dataset + ) + dataset = { + "control": acceleration_signal, + "measurement": measurement_signal, + "ground_truth": position_signal, + } + datasets.append(dataset) + + # Print some statistics to verify datasets are different + print( + f"Dataset {i}: GT range [{position_signal.min():.3f}, " + f"{position_signal.max():.3f}], GT std: {position_signal.std():.3f}" + ) + + # Only visualize the first dataset to avoid too many plots + # visualize_dataset(measurement_signal, acceleration_signal, position_signal) + + # Create vectorized environment + def make_env(dataset_idx): + """Helper function to create a single environment with specific dataset""" + dataset = datasets[dataset_idx] + base_env = KalmanFilterEnv( + measurement=dataset["measurement"], + u=dataset["control"], + gt=dataset["ground_truth"], + min_action=min_action, + max_action=max_action, + measurement_noise=scale * measurement_noise_std, + process_noise=control_noise_std, + ) + # Wrap with AutoPreStepWrapper to automatically use pre-step logic + return AutoPreStepWrapper(base_env) + + # Create vectorized environment with different datasets for each environment + env = DummyVecEnv([lambda i=i: make_env(i) for i in range(n_envs)]) + env = VecNormalize(env, norm_obs=True, norm_reward=False) + + # For testing, create a single environment using the last dataset + test_env = AutoPreStepWrapper( + KalmanFilterEnv( + measurement=datasets[-1]["measurement"], + u=datasets[-1]["control"], + gt=datasets[-1]["ground_truth"], + min_action=min_action, + max_action=max_action, + measurement_noise=scale * measurement_noise_std, + process_noise=control_noise_std, + ) + ) + + baseline_env = KalmanFilterEnv( + measurement=datasets[-1]["measurement"], + u=datasets[-1]["control"], + gt=datasets[-1]["ground_truth"], + min_action=min_action, + max_action=max_action, + measurement_noise=scale * measurement_noise_std, + process_noise=control_noise_std, + ) + + # Check environment + check_env(test_env) + print("Environment check passed!") + print(f"Training with {n_envs} parallel environments") + + model = PreStepPPO( + "MlpPolicy", + env, + device="cpu", + verbose=1, + learning_rate=linear_schedule(3e-4, 1e-5), + n_steps=512, + batch_size=128, + n_epochs=5, + gamma=0.99, + gae_lambda=0.95, + target_kl=0.035, + clip_range=0.2, + policy_kwargs=dict( + net_arch=dict(pi=[512, 512, 256, 128], vf=[512, 512, 256, 128]), + activation_fn=nn.Tanh, + ortho_init=True, + log_std_init=-1.0, + ), + ) + + # Create callbacks + training_callback = TrainingCallback(verbose=1) + valid_sample_callback = ValidSampleCallback(verbose=1) + + # Combine callbacks + callback = CallbackList([training_callback, valid_sample_callback]) + + # Train the model + print("Starting training...") + model.learn(total_timesteps=50000, callback=callback) + + stats = None + try: + # Extract the observation normalization statistics + stats = { + "obs_mean": env.obs_rms.mean, + "obs_var": env.obs_rms.var, + "obs_count": env.obs_rms.count, + } + print("Observation normalization stats:") + print(f" Mean: {stats['obs_mean']}") + print(f" Variance: {stats['obs_var']}") + print(f" Count: {stats['obs_count']}") + except Exception as e: + print(f"Error extracting observation normalization stats: {e}") + + # Save the model + model.save("kalman_ppo_model") + print("Model saved!") + + # Test the trained model + print("\nTesting trained model...") + obs, _ = test_env.reset() + baseline_env.reset() + + episode_rewards = [] + positions = [] + positions_baseline = [] + + for step in range(len(test_env.env.gt)): + # The model will automatically use get_observation_for_action + obs = test_env.get_observation_for_action() + if stats is not None: + obs = (obs - stats["obs_mean"]) / np.sqrt(stats["obs_var"]) + action, _ = model.predict(obs, deterministic=True) + print(f"step {step} action: {action}") + obs, reward, terminated, truncated, _ = test_env.step(action) + + # Run the baseline + baseline_env.get_observation_for_action() + baseline_env.step(baseline_env.R.flatten()) + if len(episode_rewards) == 0: + episode_rewards.append(reward) + else: + episode_rewards.append(episode_rewards[-1] + reward) + positions.append(test_env.env.x[0]) + positions_baseline.append(baseline_env.x[0]) + if step % 20 == 0: + test_env.render() + + if terminated or truncated: + break + + print(f"\nTest episode reward: {episode_rewards[-1]:.2f}") + + # Plot results + plt.figure(figsize=(12, 4)) + plt.subplot(1, 2, 1) + plt.plot(positions, color="b", label="agent") + plt.plot(test_env.env.gt, color="r", linestyle="--", label="gt") + plt.plot(positions_baseline, color="g", linestyle="--", label="baseline") + plt.xlabel("Time Step") + plt.ylabel("Position") + plt.title("Agent Position Over Time") + plt.legend() + plt.grid(True) + plt.subplot(1, 2, 2) + + # Debug information + print("Training callback stats:") + print(f"Episode rewards collected: {len(training_callback.episode_rewards)}") + print(f"Step rewards collected: {len(training_callback.step_rewards)}") + valid_count = valid_sample_callback.valid_samples_count + total_count = valid_sample_callback.total_samples_count + valid_ratio = valid_count / total_count + print(f"Valid sample ratio: {valid_count}/{total_count} ({valid_ratio:.2%})") + + # Plot the step rewards + step_rewards = training_callback.step_rewards + step_rewards_avg = compute_rolling_average(step_rewards, 100) + plt.plot(step_rewards_avg, label="Average Rewards", alpha=1.0, color="blue") + plt.plot( + step_rewards, + label="Rewards", + alpha=0.35, + color="lightblue", + ) + plt.xlabel("Samples") + plt.ylabel("Normalized Rewards") + plt.title("Training Progress") + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + main() diff --git a/python/serow/rl_discrete.py b/python/serow/rl_discrete.py new file mode 100644 index 0000000..2d51740 --- /dev/null +++ b/python/serow/rl_discrete.py @@ -0,0 +1,747 @@ +import numpy as np +import gymnasium as gym +import pandas as pd +import matplotlib.pyplot as plt +import torch.nn as nn +import torch + +from gymnasium import spaces +from stable_baselines3 import DQN +from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.type_aliases import RolloutReturn + + +def linear_schedule(initial_value, final_value): + """Linear learning rate schedule.""" + + def schedule(progress_remaining): + return final_value + progress_remaining * (initial_value - final_value) + + return schedule + + +def compute_rolling_average(data, window_size): + """Helper to compute rolling average, padding the start.""" + if len(data) == 0: + return [] + series = pd.Series(data) + # Use .rolling().mean() with min_periods to start from the first data point + rolling_avg = series.rolling(window=window_size, min_periods=1).mean() + return rolling_avg.tolist() + + +class PreStepDQN(DQN): + def collect_rollouts( + self, + env, + callback, + train_freq, + replay_buffer: ReplayBuffer, + action_noise=None, + learning_starts: int = 0, + log_interval=None, + ) -> RolloutReturn: + """ + Custom rollout collection: + - Always get observation from env.get_observation_for_action() + - Only store transitions where info['valid'] == True + """ + # Switch to eval mode to avoid dropout/batchnorm training + self.policy.set_training_mode(False) + n_steps = 0 + total_rewards = [] + completed_episodes = 0 + + # Reset buffer for new rollout + callback.on_rollout_start() + + while n_steps < train_freq[0]: + # 1. Get obs for action selection + obs_for_action = [] + if hasattr(env, "envs"): + # Vectorized env + for e in env.envs: + obs_for_action.append(e.get_observation_for_action()) + obs_for_action = np.array(obs_for_action) + else: + obs_for_action = np.array([env.get_observation_for_action()]) + + self._last_obs = obs_for_action + # 2. Predict action + actions, buffer_actions = self._sample_action( + learning_starts, action_noise, env.num_envs + ) + + # 3. Step environment + new_obs, rewards, dones, infos = env.step(actions) + + # 4. Nullify invalid samples + for idx, info in enumerate(infos): + if not info.get("valid", True): + rewards[idx] = np.nan + new_obs[idx] = np.zeros_like(new_obs[idx]) + self._last_obs[idx] = np.zeros_like(self._last_obs[idx]) + buffer_actions[idx] = np.zeros_like(buffer_actions[idx]) + + replay_buffer.add( + self._last_obs, + new_obs, + buffer_actions, + rewards, + dones, + infos, + ) + self._update_info_buffer(infos, dones) + + # 5. Update counters + n_steps += 1 + self.num_timesteps += env.num_envs + total_rewards.extend(rewards) + + # Count completed episodes + completed_episodes += sum(dones) + + # 6. Handle episode ends + callback.update_locals(locals()) + if not callback.on_step(): + return RolloutReturn( + episode_timesteps=n_steps, + n_episodes=completed_episodes, + continue_training=False, + ) + + callback.on_rollout_end() + + return RolloutReturn( + episode_timesteps=n_steps, + n_episodes=completed_episodes, + continue_training=True, + ) + + def train(self, gradient_steps: int, batch_size: int = 100) -> None: + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update learning rate according to schedule + self._update_learning_rate(self.policy.optimizer) + losses = [] + for _ in range(gradient_steps): + # Sample replay buffer + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] + # Filter out invalid samples + valid_mask = ~torch.isnan(replay_data.rewards.flatten()) + num_valid = valid_mask.sum() + + # Skip if too few valid samples (less than 25% of batch) + min_valid_samples = max(1, batch_size // 4) + if num_valid < min_valid_samples: + self.logger.record("train/skipped_batches", 1, exclude="tensorboard") + continue + + # Create filtered data instead of modifying the original object + filtered_observations = replay_data.observations[valid_mask] + filtered_next_observations = replay_data.next_observations[valid_mask] + filtered_actions = replay_data.actions[valid_mask] + filtered_rewards = replay_data.rewards[valid_mask] + filtered_dones = replay_data.dones[valid_mask] + filtered_discounts = ( + replay_data.discounts[valid_mask] + if replay_data.discounts is not None + else None + ) + + # For n-step replay, discount factor is gamma**n_steps (when no early termination) + discounts = ( + filtered_discounts if filtered_discounts is not None else self.gamma + ) + + with torch.no_grad(): + # Compute the next Q-values using the target network + next_q_values = self.q_net_target(filtered_next_observations) + # Follow greedy policy: use the one with the highest value + next_q_values, _ = next_q_values.max(dim=1) + # Avoid potential broadcast issue + next_q_values = next_q_values.reshape(-1, 1) + # 1-step TD target + target_q_values = ( + filtered_rewards + (1 - filtered_dones) * discounts * next_q_values + ) + + # Get current Q-values estimates + current_q_values = self.q_net(filtered_observations) + + # Retrieve the q-values for the actions from the replay buffer + current_q_values = torch.gather( + current_q_values, dim=1, index=filtered_actions.long() + ) + + # Compute Huber loss (less sensitive to outliers) + loss = torch.nn.functional.smooth_l1_loss(current_q_values, target_q_values) + losses.append(loss.item()) + + # Optimize the policy + self.policy.optimizer.zero_grad() + loss.backward() + # Clip gradient norm + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + # Increase update counter + self._n_updates += gradient_steps + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/loss", np.mean(losses)) + self.logger.dump(step=self.num_timesteps) + + def forward(self, obs, deterministic=False): + return self.policy.forward(obs, deterministic) + + +class KalmanFilterEnv(gym.Env): + """Custom Gym environment integrating a Kalman filter.""" + + def __init__( + self, + measurement, + u, + gt, + process_noise=0.1, + measurement_noise=0.1, + ): + super(KalmanFilterEnv, self).__init__() + + # Environment parameters + self.measurement = measurement + self.gt = gt + self.u = u + self.process_noise = process_noise + self.measurement_noise = measurement_noise + self.max_steps = len(measurement) - 1 + + # Kalman filter parameters + self.dt = 0.001 # time step + + # State: [position, velocity] + self.state_dim = 2 + self.measurement_dim = 1 # we only measure position + self.history_length = 100 + self.measurement_history = [] + self.measurement_noise_history = [] + self.prev_action = 0.0 + + # Initialize Kalman filter matrices + self.F = np.array([[1, self.dt], [0, 1]]) # State transition matrix + self.H = np.array([[1, 0]]) # Measurement matrix + self.B = np.array([0.0, self.dt]) + + self.Q = np.array([[self.process_noise**2]]) + self.R = np.array([[self.measurement_noise**2]]) # Measurement noise + + # Action space - discrete choices for measurement noise scaling + self.discrete_actions = np.array([1e-4, 1e-3, 1e-2, 1e-1, 1, 10, 100]) + self.action_space = spaces.Discrete(len(self.discrete_actions)) + + # Observation space: [position, velocity, position covariance, + # velocity covariance, measurement noise, innovation] + self.observation_space = spaces.Box( + low=-np.inf, + high=np.inf, + shape=(5 + self.history_length * 2,), + dtype=np.float32, + ) + + self.reset() + + def get_available_actions(self): + """Return the available discrete action values.""" + return self.discrete_actions.copy() + + def reset(self, seed=None, options=None): + super().reset(seed=seed) + + # Reset Kalman filter state + self.x = np.array([0.0, 0.0]) # Initial state [position, velocity] + self.P = np.eye(self.state_dim) * np.random.uniform( + 0.1, 5.0 + ) # Initial covariance + self.step_count = 0 + self.reward = 0 + + self.measurement_history = [0.0] * self.history_length + self.measurement_noise_history = [float(self.R[0, 0])] * self.history_length + self.prev_action = 0.0 + + # Get initial observation + obs = self._get_observation() + return obs, {} + + def get_observation_for_action(self): + """Get the observation that should be used for action computation.""" + # Run prediction step with current control input + if self.step_count < len(self.u): + self.predict(self.u[self.step_count]) + + # Get the observation that the policy should use + obs = self._get_observation() + return obs + + def step(self, action): + next_state, y, S = self.update(self.measurement[self.step_count], action) + + position_error = abs(next_state[0] - self.gt[self.step_count]) + position_reward = -position_error / 10.0 + + # Innovation consistency reward (clipped) + nis = float(y @ np.linalg.inv(S) @ y.T) + nis = np.clip(nis, 0, 10.0) + innovation_reward = -nis / 10.0 + + # Small action penalty to encourage smoothness + current_action_value = self.discrete_actions[action] + action_penalty = abs(current_action_value - self.prev_action) + + self.reward = ( + 4.0 * position_reward + 0.5 * innovation_reward - 0.005 * action_penalty + ) + + # Check termination conditions + terminated = position_error > 10.0 or self.step_count == self.max_steps - 1 + truncated = False + + info = { + "nis": nis, + "position_error": position_error, + "step_count": self.step_count, + "reward": self.reward, + "valid": ( + True if np.random.rand() > 0.5 else False + ), # Make 50% of samples invalid + } + + self.step_count += 1 + self.prev_action = current_action_value + + # Get the final observation for the next step + obs = self._get_observation() + return obs, self.reward, bool(terminated), bool(truncated), info + + def predict(self, control_input): + """Kalman filter prediction step""" + # Control input matrix (acceleration affects position and velocity) + u = np.array([control_input]) + + # Predict state + self.x = self.F @ self.x + self.B * u + + # Predict covariance + self.P = self.F @ self.P @ self.F.T + self.Q + + def update(self, measurement, action): + """Kalman filter update step""" + # Update step - convert discrete action index to actual scaling value + action_scaling = self.discrete_actions[action] + R = self.R * action_scaling + + # Innovation + y = measurement - self.H @ self.x + + # Innovation covariance + S = self.H @ self.P @ self.H.T + R + + # Kalman gain + K = self.P @ self.H.T @ np.linalg.inv(S) + + # Updated state estimate + self.x += K @ y + + # Updated covariance + self.P = self.P - K @ self.H @ self.P + + self.measurement_history.append(measurement) + self.measurement_noise_history.append(self.R[0, 0]) + while len(self.measurement_history) > self.history_length: + self.measurement_history.pop(0) + self.measurement_noise_history.pop(0) + + return self.x, y, S + + def _get_observation(self): + """Get current observation for the agent""" + + measurement_history = np.array(self.measurement_history).flatten() + measurement_noise_history = np.array(self.measurement_noise_history).flatten() + obs = np.concatenate( + [ + [self.x[1]], # current velocity + [self.P[0, 0]], # position covariance + [self.P[1, 1]], # velocity covariance + [self.R[0, 0]], # measurement noise + measurement_history, + measurement_noise_history, + [self.measurement[self.step_count] - self.x[0]], # innovation + ], + axis=0, + dtype=np.float32, + ) + + return obs + + def render(self, mode="human"): + if mode == "human": + print( + f"Step: {self.step_count}, Position: {self.x[0]}, " + f"Target: {self.gt[self.step_count]}, Reward: {self.reward}" + ) + + +class TrainingCallback(BaseCallback): + """Custom callback for monitoring training progress""" + + def __init__(self, verbose=0): + super(TrainingCallback, self).__init__(verbose) + self.episode_rewards = [] + self.episode_lengths = [] + self.step_rewards = [] # Track rewards per step as fallback + self.episode_reward_sum = 0 + self.episode_length = 0 + self.last_dones = None + self.total_steps = 0 + + def _on_step(self) -> bool: + self.total_steps += 1 + + # Get dones and truncated to detect episode completions + dones = self.locals.get("dones", [False] * len(self.locals.get("infos", []))) + truncated = self.locals.get( + "truncated", [False] * len(self.locals.get("infos", [])) + ) + + # Since invalid samples are now filtered out, all samples are valid + if len(self.locals["infos"]) > 0: + for i, info in enumerate(self.locals["infos"]): + reward = info["reward"] + self.step_rewards.append(reward) + self.episode_reward_sum += reward + self.episode_length += 1 + + # Check if any episode completed (either done or truncated) + episode_completed = False + for _, (done, trunc) in enumerate(zip(dones, truncated)): + if done or trunc: + episode_completed = True + break + + if episode_completed and self.episode_length > 0: + self.episode_rewards.append(self.episode_reward_sum) + self.episode_lengths.append(self.episode_length) + if self.verbose > 0: + print( + f"Episode completed: reward={self.episode_reward_sum:.3f}, " + f"length={self.episode_length}" + ) + + # Reset for next episode + self.episode_reward_sum = 0 + self.episode_length = 0 + + return True + + +def generate_dataset( + n_points=1000, + t_max=1.0, + measurement_noise_std=0.1, + control_noise_std=0.05, + seed=None, +): + """Generate a random dataset for Kalman filter training.""" + if seed is not None: + np.random.seed(seed) + + # Time vector + t = np.linspace(0, t_max, n_points) + + # Generate random parameters for different trajectories + # This ensures each dataset has a different ground truth signal + freq1 = np.random.uniform(1.0, 3.0) # Random frequency for primary + freq2 = np.random.uniform(2.0, 6.0) # Random frequency for secondary + freq3 = np.random.uniform(3.0, 8.0) # Random frequency for tertiary + + amp1 = np.random.uniform(1.0, 3.0) # Random amplitude for primary + amp2 = np.random.uniform(0.2, 1.0) # Random amplitude for secondary + amp3 = np.random.uniform(0.1, 0.5) # Random amplitude for tertiary + + quad_coeff = np.random.uniform(0.1, 0.8) # Random quadratic coefficient + phase1 = np.random.uniform(0, 2 * np.pi) # Random phase shifts + phase2 = np.random.uniform(0, 2 * np.pi) + phase3 = np.random.uniform(0, 2 * np.pi) + + # Generate a smooth ground truth signal using random parameters + ground_truth = ( + amp1 * np.sin(2 * np.pi * freq1 * t + phase1) # Primary oscillation + + amp2 * np.sin(2 * np.pi * freq2 * t + phase2) # Secondary + + quad_coeff * t**2 # Quadratic trend + + amp3 * np.cos(2 * np.pi * freq3 * t + phase3) # Additional complexity + ) + + # Compute the second derivative (acceleration) analytically + true_acceleration = ( + -4 + * np.pi**2 + * freq1**2 + * amp1 + * np.sin(2 * np.pi * freq1 * t + phase1) # Second derivative of primary + - -4 + * np.pi**2 + * freq2**2 + * amp2 + * np.sin(2 * np.pi * freq2 * t + phase2) # Second derivative of secondary + + 2 * quad_coeff # Second derivative of quadratic term + + -4 + * np.pi**2 + * freq3**2 + * amp3 + * np.cos(2 * np.pi * freq3 * t + phase3) # Second derivative of cosine + ) + + # Add varying measurement noise (zero-mean Gaussian with varying std) + time_varying_std = measurement_noise_std * (1 + 0.5 * t / t_max) + measurement_noise = np.random.normal(0, time_varying_std, n_points) + measurement = ground_truth + measurement_noise + + # Add varying control noise (zero-mean Gaussian with varying std) + control_varying_std = control_noise_std * ( + 1 + 0.2 * np.abs(true_acceleration) / np.max(np.abs(true_acceleration)) + ) + + control_noise = np.random.normal(0, control_varying_std, n_points) + control = true_acceleration + control_noise + + return measurement, control, ground_truth + + +def visualize_dataset(measurement, control, ground_truth, save_plot=False): + """Visualize the generated dataset.""" + fig, axes = plt.subplots(2, 1, figsize=(12, 8)) + + # Plot position signals + t = np.linspace(0, 1, len(ground_truth)) + axes[0].plot(t, ground_truth, "b-", linewidth=2, label="Ground Truth") + axes[0].plot(t, measurement, "r.", markersize=1, alpha=0.6, label="Measurement") + axes[0].set_xlabel("Time") + axes[0].set_ylabel("Position") + axes[0].set_title("Position Signal") + axes[0].legend() + axes[0].grid(True, alpha=0.3) + + # Plot control signals + axes[1].plot(t, control, "g-", linewidth=1, label="Control (Acceleration)") + axes[1].set_xlabel("Time") + axes[1].set_ylabel("Acceleration") + axes[1].set_title("Control Signal") + axes[1].legend() + axes[1].grid(True, alpha=0.3) + + plt.tight_layout() + + if save_plot: + plt.savefig("generated_dataset.png", dpi=300, bbox_inches="tight") + + plt.show() + + +def main(): + # Number of parallel environments + n_envs = 8 + measurement_noise_std = 0.1 + control_noise_std = 0.25 + scale = 1e3 + gen_scale = 1e-3 + + # Generate random datasets + datasets = [] + for i in range(n_envs + 1): + measurement_signal, acceleration_signal, position_signal = generate_dataset( + n_points=1000, + t_max=1.0, + measurement_noise_std=gen_scale * measurement_noise_std, + control_noise_std=control_noise_std, + seed=42 + i, # Different seed for each dataset + ) + dataset = { + "control": acceleration_signal, + "measurement": measurement_signal, + "ground_truth": position_signal, + } + datasets.append(dataset) + + # Print some statistics to verify datasets are different + print( + f"Dataset {i}: GT range [{position_signal.min():.3f}, " + f"{position_signal.max():.3f}], GT std: {position_signal.std():.3f}" + ) + + # Only visualize the first dataset to avoid too many plots + # visualize_dataset(measurement_signal, acceleration_signal, position_signal) + + # Create vectorized environment + def make_env(dataset_idx): + """Helper function to create a single environment with specific dataset""" + dataset = datasets[dataset_idx] + base_env = KalmanFilterEnv( + measurement=dataset["measurement"], + u=dataset["control"], + gt=dataset["ground_truth"], + measurement_noise=scale * measurement_noise_std, + process_noise=control_noise_std, + ) + return base_env + + # Create vectorized environment with different datasets for each environment + env = DummyVecEnv([lambda i=i: make_env(i) for i in range(n_envs)]) + env = VecNormalize(env, norm_obs=True, norm_reward=False) + + # For testing, create a single environment using the last dataset + test_env = KalmanFilterEnv( + measurement=datasets[-1]["measurement"], + u=datasets[-1]["control"], + gt=datasets[-1]["ground_truth"], + measurement_noise=scale * measurement_noise_std, + process_noise=control_noise_std, + ) + + baseline_env = KalmanFilterEnv( + measurement=datasets[-1]["measurement"], + u=datasets[-1]["control"], + gt=datasets[-1]["ground_truth"], + measurement_noise=scale * measurement_noise_std, + process_noise=control_noise_std, + ) + + # Check environment + check_env(test_env) + print("Environment check passed!") + print(f"Training with {n_envs} parallel environments") + + model = PreStepDQN( + "MlpPolicy", + env, + device="cpu", + verbose=1, + learning_rate=linear_schedule(3e-4, 1e-5), + buffer_size=1000000, + learning_starts=1000, + batch_size=64, + gamma=0.99, + train_freq=(4, "step"), + gradient_steps=1, + target_update_interval=1000, + exploration_fraction=0.1, + exploration_initial_eps=1.0, + exploration_final_eps=0.05, + policy_kwargs=dict( + net_arch=[512, 512, 256, 128], + activation_fn=nn.ReLU, + ), + ) + + # Train the model + print("Starting DQN training...") + training_callback = TrainingCallback() + model.learn(total_timesteps=100000, callback=training_callback) + + stats = None + try: + # Extract the observation normalization statistics + stats = { + "obs_mean": env.obs_rms.mean, + "obs_var": env.obs_rms.var, + "obs_count": env.obs_rms.count, + } + print("Observation normalization stats:") + print(f" Mean: {stats['obs_mean']}") + print(f" Variance: {stats['obs_var']}") + print(f" Count: {stats['obs_count']}") + except Exception as e: + print(f"Error extracting observation normalization stats: {e}") + + # Save the model + # model.save("kalman_dqn_model") + # print("Model saved!") + + # Test the trained model + print("\nTesting trained model...") + obs, _ = test_env.reset() + baseline_env.reset() + + episode_rewards = [] + positions = [] + positions_baseline = [] + + for step in range(len(test_env.gt)): + # Call get_observation_for_action manually for testing + obs = test_env.get_observation_for_action() + if stats is not None: + obs = (obs - stats["obs_mean"]) / np.sqrt(stats["obs_var"]) + action, _ = model.predict(obs, deterministic=True) + print(f"step {step} action: {test_env.discrete_actions[action]}") + obs, reward, terminated, truncated, _ = test_env.step(action) + + # Run the baseline + baseline_env.get_observation_for_action() + baseline_env.step(np.where(test_env.discrete_actions == 1.0)[0][0]) + if len(episode_rewards) == 0: + episode_rewards.append(reward) + else: + episode_rewards.append(episode_rewards[-1] + reward) + positions.append(test_env.x[0]) + positions_baseline.append(baseline_env.x[0]) + if step % 20 == 0: + test_env.render() + + if terminated or truncated: + break + + print(f"\nTest episode reward: {episode_rewards[-1]:.2f}") + + # Plot results + plt.figure(figsize=(12, 4)) + plt.subplot(1, 2, 1) + plt.plot(positions, color="b", label="agent") + plt.plot(test_env.gt, color="r", linestyle="--", label="gt") + plt.plot(positions_baseline, color="g", linestyle="--", label="baseline") + plt.xlabel("Time Step") + plt.ylabel("Position") + plt.title("Agent Position Over Time") + plt.legend() + plt.grid(True) + plt.subplot(1, 2, 2) + + # Debug information + print("Training callback stats:") + print(f"Episode rewards collected: {len(training_callback.episode_rewards)}") + print(f"Step rewards collected: {len(training_callback.step_rewards)}") + + # Plot the step rewards + step_rewards = training_callback.step_rewards + step_rewards_avg = compute_rolling_average(step_rewards, 100) + plt.plot(step_rewards_avg, label="Average Rewards", alpha=1.0, color="blue") + plt.plot( + step_rewards, + label="Rewards", + alpha=0.35, + color="lightblue", + ) + plt.xlabel("Samples") + plt.ylabel("Normalized Rewards") + plt.title("Training Progress") + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + main() diff --git a/python/serow/serow_bindings.cpp b/python/serow/serow_bindings.cpp index ac39c9e..a6d9783 100644 --- a/python/serow/serow_bindings.cpp +++ b/python/serow/serow_bindings.cpp @@ -643,7 +643,7 @@ PYBIND11_MODULE(serow, m) { .def(py::init<>(), "Default constructor") .def("init", &serow::ContactEKF::init, py::arg("state"), py::arg("contacts_frame"), py::arg("point_feet"), py::arg("g"), py::arg("imu_rate"), - py::arg("outlier_detection") = false, + py::arg("outlier_detection") = false, py::arg("use_imu_orientation") = false, py::arg("verbose") = false, "Initializes the EKF with the initial robot state and parameters") .def("predict", &serow::ContactEKF::predict, py::arg("state"), py::arg("imu"), py::arg("kin"), @@ -720,7 +720,67 @@ PYBIND11_MODULE(serow, m) { .def("get_state", &serow::Serow::getState, py::arg("allow_invalid") = false, "Gets the complete state of the robot") .def("is_initialized", &serow::Serow::isInitialized, "Returns true if SEROW is initialized") - .def("set_state", &serow::Serow::setState, py::arg("state"), "Sets the state of the robot"); + .def("set_state", &serow::Serow::setState, py::arg("state"), "Sets the state of the robot") + .def( + "get_contact_position_innovation", + [](serow::Serow& self, const std::string& contact_frame) { + Eigen::Vector3d innovation = Eigen::Vector3d::Zero(); + Eigen::Matrix3d covariance = Eigen::Matrix3d::Zero(); + bool success = + self.getContactPositionInnovation(contact_frame, innovation, covariance); + return std::make_tuple(success, innovation, covariance); + }, + py::arg("contact_frame"), + "Returns the contact position innovation and covariance for a given contact frame") + .def( + "get_contact_orientation_innovation", + [](serow::Serow& self, const std::string& contact_frame) { + Eigen::Vector3d innovation = Eigen::Vector3d::Zero(); + Eigen::Matrix3d covariance = Eigen::Matrix3d::Zero(); + bool success = + self.getContactOrientationInnovation(contact_frame, innovation, covariance); + return std::make_tuple(success, innovation, covariance); + }, + py::arg("contact_frame"), + "Returns the contact orientation innovation and covariance for a given contact frame") + .def( + "process_measurements", + [](serow::Serow& self, const serow::ImuMeasurement& imu, + const std::map& joints, + py::object force_torque, py::object contacts_probability) { + std::optional> ft_opt; + if (!force_torque.is_none()) { + ft_opt = + force_torque.cast>(); + } + + std::optional> contact_prob_opt; + if (!contacts_probability.is_none()) { + contact_prob_opt = + contacts_probability + .cast>(); + } + + return self.processMeasurements(imu, joints, ft_opt, contact_prob_opt); + }, + py::arg("imu"), py::arg("joints"), py::arg("force_torque") = py::none(), + py::arg("contacts_probability") = py::none(), + "Processes the measurements and returns a tuple of IMU, kinematic, and force-torque " + "measurements") + .def("base_estimator_predict_step", &serow::Serow::baseEstimatorPredictStep, py::arg("imu"), + py::arg("kin"), "Runs the base estimator's predict step") + .def("base_estimator_update_with_contact_position", + &serow::Serow::baseEstimatorUpdateWithContactPosition, py::arg("contact_frame"), + py::arg("kin"), "Runs the base estimator's update step with contact position") + .def("base_estimator_finish_update", &serow::Serow::baseEstimatorFinishUpdate, + py::arg("imu"), py::arg("kin"), + "Concludes the base estimator's update step with the IMU measurement") + .def("base_estimator_update_with_imu_orientation", + &serow::Serow::baseEstimatorUpdateWithImuOrientation, py::arg("imu"), + "Runs the base estimator's update step with the IMU orientation") + .def("reset", &serow::Serow::reset, "Resets the state of SEROW") + .def("set_action", &serow::Serow::setAction, py::arg("cf"), py::arg("action"), + "Sets the action of the robot"); // Binding for CentroidalState py::class_(m, "CentroidalState", diff --git a/python/serow/train.py b/python/serow/train.py new file mode 100644 index 0000000..3e9574d --- /dev/null +++ b/python/serow/train.py @@ -0,0 +1,532 @@ +import numpy as np +import gymnasium as gym +import matplotlib + +matplotlib.use("TkAgg") # Use Agg backend for non-GUI environments +import matplotlib.pyplot as plt +import json +import os +import torch +import torch.nn as nn +import pandas as pd + +from env import SerowEnv +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3 import DQN +from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize +from stable_baselines3.common.callbacks import CallbackList +from utils import export_dqn_to_onnx +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.type_aliases import RolloutReturn + + +def compute_rolling_average(data, window_size): + """Helper to compute rolling average, padding the start.""" + if data.size == 0: + return [] + series = pd.Series(data) + # Use .rolling().mean() with min_periods to start from the first data point + rolling_avg = series.rolling(window=window_size, min_periods=1).mean() + return rolling_avg.tolist() + + +def linear_schedule(initial_value, final_value): + """Linear learning rate schedule.""" + + def schedule(progress_remaining): + return final_value + progress_remaining * (initial_value - final_value) + + return schedule + + +class ValidSampleCallback(BaseCallback): + """Callback to track and handle valid/invalid samples during training.""" + + def __init__(self, verbose=0): + super(ValidSampleCallback, self).__init__(verbose) + self.valid_samples_count = 0 + self.total_samples_count = 0 + self.invalid_samples_count = 0 + + def _on_step(self) -> bool: + # Count valid vs invalid samples + infos = self.locals.get("infos", []) + for info in infos: + self.total_samples_count += 1 + if info.get("valid", True): + self.valid_samples_count += 1 + else: + self.invalid_samples_count += 1 + + # Log statistics periodically + if self.total_samples_count % 100 == 0: + valid_ratio = ( + self.valid_samples_count / self.total_samples_count + if self.total_samples_count > 0 + else 0.0 + ) + if self.verbose > 0: + print( + f"Sample validity: {valid_ratio:.2%} valid " + f"({self.valid_samples_count}/{self.total_samples_count})" + ) + + return True + + +class TrainingCallback(BaseCallback): + """Custom callback for monitoring training progress""" + + def __init__(self, verbose=0): + super(TrainingCallback, self).__init__(verbose) + self.episode_rewards = [] + self.episode_lengths = [] + self.step_rewards = [] # Track rewards per step as fallback + self.episode_reward_sum = 0 + self.episode_length = 0 + self.last_dones = None + self.total_steps = 0 + + def _on_step(self) -> bool: + self.total_steps += 1 + # Get dones and truncated to detect episode completions + infos = self.locals.get("infos", []) + dones = self.locals.get("dones", [False] * len(infos)) + truncated = self.locals.get("truncated", [False] * len(infos)) + + # Only accumulate episode rewards for valid steps + valid_steps = 0 + total_reward = 0.0 + episode_completed = False + if len(self.locals["infos"]) > 0: + for i, info in enumerate(self.locals["infos"]): + if info["valid"]: + reward = info["reward"] + valid_steps += 1 + total_reward += reward + self.step_rewards.append(reward) + + # Only add to episode if there were valid steps + if valid_steps > 0: + avg_valid_reward = total_reward / valid_steps + self.episode_reward_sum += avg_valid_reward + self.episode_length += 1 + + # Check if any episode completed (either done or truncated) + for _, (done, trunc) in enumerate(zip(dones, truncated)): + if done or trunc: + episode_completed = True + break + + if episode_completed and self.episode_length > 0: + self.episode_rewards.append(self.episode_reward_sum) + self.episode_lengths.append(self.episode_length) + if self.verbose > 0: + reward_str = f"reward={self.episode_reward_sum:.3f}" + length_str = f"length={self.episode_length}" + print(f"Episode completed: {reward_str}, {length_str}") + + # Reset for next episode + self.episode_reward_sum = 0 + self.episode_length = 0 + + return True + + +class PreStepDQN(DQN): + def collect_rollouts( + self, + env, + callback, + train_freq, + replay_buffer: ReplayBuffer, + action_noise=None, + learning_starts: int = 0, + log_interval=None, + ) -> RolloutReturn: + """ + Custom rollout collection: + - Always get observation from env.get_observation_for_action() + - Only use valid transitions where info['valid'] == True + """ + # Switch to eval mode to avoid dropout/batchnorm training + self.policy.set_training_mode(False) + n_steps = 0 + total_rewards = [] + completed_episodes = 0 + + # Reset buffer for new rollout + callback.on_rollout_start() + + while n_steps < train_freq[0]: + # 1. Get obs for action selection + obs_for_action = [] + if hasattr(env, "envs"): + # Vectorized env + for e in env.envs: + obs_for_action.append(e.get_observation_for_action()) + obs_for_action = np.array(obs_for_action) + else: + obs_for_action = np.array([env.get_observation_for_action()]) + + self._last_obs = obs_for_action + + # 2. Predict action + actions, buffer_actions = self._sample_action( + learning_starts, action_noise, env.num_envs + ) + + # 3. Step environment + new_obs, rewards, dones, infos = env.step(actions) + + # 4. Nullify invalid samples - Set rewards to nan so we can filter them out in train() + for idx, info in enumerate(infos): + if not info.get("valid", True): + rewards[idx] = np.nan + new_obs[idx] = np.zeros_like(new_obs[idx]) + self._last_obs[idx] = np.zeros_like(self._last_obs[idx]) + buffer_actions[idx] = np.zeros_like(buffer_actions[idx]) + + replay_buffer.add( + self._last_obs, + new_obs, + buffer_actions, + rewards, + dones, + infos, + ) + self._update_info_buffer(infos, dones) + + # 5. Update counters + n_steps += 1 + self.num_timesteps += env.num_envs + total_rewards.extend(rewards) + + # Count completed episodes + completed_episodes += sum(dones) + + # 6. Handle episode ends + callback.update_locals(locals()) + if not callback.on_step(): + return RolloutReturn( + episode_timesteps=n_steps, + n_episodes=completed_episodes, + continue_training=False, + ) + + callback.on_rollout_end() + + return RolloutReturn( + episode_timesteps=n_steps, + n_episodes=completed_episodes, + continue_training=True, + ) + + def train(self, gradient_steps: int, batch_size: int = 100) -> None: + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + + # Update learning rate according to schedule + self._update_learning_rate(self.policy.optimizer) + + losses = [] + for _ in range(gradient_steps): + # Sample replay buffer + replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr] + # Filter out invalid samples + valid_mask = ~torch.isnan(replay_data.rewards.flatten()) + num_valid = valid_mask.sum() + + # Skip if too few valid samples (less than 25% of batch) + min_valid_samples = max(1, batch_size // 4) + if num_valid < min_valid_samples: + self.logger.record("train/skipped_batches", 1, exclude="tensorboard") + continue + + # Create filtered data instead of modifying the original object + filtered_observations = replay_data.observations[valid_mask] + filtered_next_observations = replay_data.next_observations[valid_mask] + filtered_actions = replay_data.actions[valid_mask] + filtered_rewards = replay_data.rewards[valid_mask] + filtered_dones = replay_data.dones[valid_mask] + filtered_discounts = ( + replay_data.discounts[valid_mask] + if replay_data.discounts is not None + else None + ) + + # For n-step replay, discount factor is gamma**n_steps (when no early termination) + discounts = ( + filtered_discounts if filtered_discounts is not None else self.gamma + ) + + with torch.no_grad(): + # Compute the next Q-values using the target network + next_q_values = self.q_net_target(filtered_next_observations) + # Follow greedy policy: use the one with the highest value + next_q_values, _ = next_q_values.max(dim=1) + # Avoid potential broadcast issue + next_q_values = next_q_values.reshape(-1, 1) + # 1-step TD target + target_q_values = ( + filtered_rewards + (1 - filtered_dones) * discounts * next_q_values + ) + + # Get current Q-values estimates + current_q_values = self.q_net(filtered_observations) + + # Retrieve the q-values for the actions from the replay buffer + current_q_values = torch.gather( + current_q_values, dim=1, index=filtered_actions.long() + ) + + # Compute Huber loss (less sensitive to outliers) + loss = torch.nn.functional.smooth_l1_loss(current_q_values, target_q_values) + losses.append(loss.item()) + + # Optimize the policy + self.policy.optimizer.zero_grad() + loss.backward() + + # Clip gradient norm + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() + + # Increase update counter + self._n_updates += gradient_steps + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + self.logger.record("train/loss", np.mean(losses)) + self.logger.dump(step=self.num_timesteps) + + def forward(self, obs, deterministic=False): + return self.policy.forward(obs, deterministic) + + +if __name__ == "__main__": + # Load and preprocess the data + robot = "go2" + n_envs = 4 + n_contacts = 3 + total_samples = 100000 + device = "cpu" + history_size = 100 + datasets = [] + for i in range(n_envs): + dataset = np.load(f"datasets/{robot}_log_{i}.npz", allow_pickle=True) + datasets.append(dataset) + + test_dataset = np.load(f"{robot}_log.npz", allow_pickle=True) + contact_states = test_dataset["contact_states"] + contact_frame = list(contact_states[0].contacts_status.keys()) + print(f"Contact frames: {contact_frame}") + + state_dim = 3 + 9 + 3 + 4 + 3 * history_size + 1 * history_size + print(f"State dimension: {state_dim}") + action_dim = 1 # Based on the action vector used in ContactEKF.setAction() + + # Create vectorized environment + def make_env(i, j): + """Helper function to create a single environment with specific dataset""" + ds = datasets[i] + base_env = SerowEnv( + contact_frame[ + np.random.randint(0, len(contact_frame)) + ], # random choice of contact frame + robot, + ds["joint_states"][0], + ds["base_states"][0], + ds["contact_states"][0], + action_dim, + state_dim, + ds["imu"], + ds["joints"], + ds["ft"], + ds["base_pose_ground_truth"], + history_size, + ) + return base_env + + test_env = SerowEnv( + contact_frame[0], + robot, + test_dataset["joint_states"][0], + test_dataset["base_states"][0], + test_dataset["contact_states"][0], + action_dim, + state_dim, + test_dataset["imu"], + test_dataset["joints"], + test_dataset["ft"], + test_dataset["base_pose_ground_truth"], + history_size, + ) + + # Create vectorized environment with different datasets for each environment + env = DummyVecEnv( + [ + lambda i=i, j=j: make_env(i, j) + for i in range(n_envs) + for j in range(n_contacts) + ] + ) + + # Add normalization for observations and rewards + env = VecNormalize(env, norm_obs=True, norm_reward=False) + + lr_schedule = linear_schedule(5e-4, 1e-5) + model = PreStepDQN( + "MlpPolicy", + env, + device="cpu", + verbose=1, + learning_rate=lr_schedule, + buffer_size=1000000, + learning_starts=5000, + batch_size=128, + gamma=0.99, + train_freq=(8, "step"), + gradient_steps=4, + target_update_interval=2000, + exploration_fraction=0.2, + exploration_initial_eps=0.9, + exploration_final_eps=0.02, + policy_kwargs=dict( + net_arch=[1024, 1024, 512, 256, 128], + activation_fn=nn.ReLU, + ), + max_grad_norm=10.0, + tau=0.005, + ) + + # Create callbacks + training_callback = TrainingCallback(verbose=1) + valid_sample_callback = ValidSampleCallback(verbose=1) + callback = CallbackList( + [ + training_callback, + valid_sample_callback, + ] + ) + + # Train the model + print(f"Training with {n_envs * n_contacts} parallel environments") + print("Starting training...") + model.learn(total_timesteps=total_samples, callback=callback) + print("Training completed") + + stats = None + try: + # Extract the observation normalization statistics + obs_mean = env.obs_rms.mean + obs_var = env.obs_rms.var + obs_count = env.obs_rms.count + stats = { + "obs_mean": obs_mean, + "obs_var": obs_var, + "obs_count": obs_count, + } + print("Observation normalization stats:") + print(f" Mean: {stats['obs_mean']}") + print(f" Variance: {stats['obs_var']}") + print(f" Count: {stats['obs_count']}") + + # Convert numpy arrays to lists for JSON serialization + json_stats = { + "obs_mean": stats["obs_mean"].tolist(), + "obs_var": stats["obs_var"].tolist(), + "obs_count": int(stats["obs_count"]), + } + stats_file = f"models/{robot}_stats.json" + with open(stats_file, "w") as f: + json.dump(json_stats, f, indent=2) + except Exception as e: + print(f"Error saving stats: {e}") + + # Check if the models directory exists, if not create it + if not os.path.exists("models"): + os.makedirs("models") + model.save(f"models/{robot}_dqn") + + try: + # Create a wrapper class to match the expected interface for export_dqn_to_onnx + class DQNModelWrapper: + def __init__(self, dqn_model, device): + self.device = device + # For DQN, the policy is accessed via model.policy + self.policy = dqn_model.policy + self.name = "DQN" + + # Validate that the policy has the expected structure + if not hasattr(self.policy, "q_net"): + raise AttributeError("DQN policy must have a 'q_net' attribute") + + # Check if we have the expected policy type + policy_type = type(self.policy).__name__ + if "DQNPolicy" not in policy_type: + print(f"Warning: Expected DQNPolicy, got {policy_type}") + + # Validate the q_net structure + q_net_type = type(self.policy.q_net).__name__ + print(f"Q-network type: {q_net_type}") + + # Check if q_net has the expected forward method + if not hasattr(self.policy.q_net, "forward"): + raise AttributeError("Q-network must have a 'forward' method") + + # Create the wrapper - pass the entire model, not just model.policy + model_wrapper = DQNModelWrapper(model, device) + + # Define parameters for ONNX export + export_params = { + "state_dim": state_dim, + "action_dim": action_dim, + } + + # Export to ONNX + print("Exporting model to ONNX...") + print(f"Policy type: {type(model_wrapper.policy)}") + print(f"Q-net type: {type(model_wrapper.policy.q_net)}") + export_dqn_to_onnx(model_wrapper, robot, export_params, "models") + print("ONNX export completed successfully!") + + except Exception as e: + print(f"Error exporting model to ONNX: {e}") + import traceback + + traceback.print_exc() + + # Debug information + print("Training callback stats:") + episode_count = len(training_callback.episode_rewards) + print(f"Episode rewards collected: {episode_count}") + print(f"Step rewards collected: {len(training_callback.step_rewards)}") + valid_count = valid_sample_callback.valid_samples_count + total_count = valid_sample_callback.total_samples_count + valid_ratio = valid_count / total_count + ratio_str = f"Valid sample ratio: {valid_count}/{total_count} ({valid_ratio:.2%})" + print(ratio_str) + + # Plot the step rewards + step_rewards = np.array(training_callback.step_rewards) + # Normalize step rewards to 0-1 + step_rewards_norm = (step_rewards - np.min(step_rewards)) / ( + np.max(step_rewards) - np.min(step_rewards) + ).tolist() + step_rewards_avg = compute_rolling_average(step_rewards_norm, 100) + plt.plot(step_rewards_avg, label="Average Rewards", alpha=1.0, color="blue") + plt.plot( + step_rewards_norm, + label="Rewards", + alpha=0.35, + color="lightblue", + ) + plt.xlabel("Samples") + plt.ylabel("Normalized Rewards") + plt.title("Training Progress") + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.show() + + test_env.evaluate(model, stats) diff --git a/python/serow/utils.py b/python/serow/utils.py index 5651d89..3923458 100644 --- a/python/serow/utils.py +++ b/python/serow/utils.py @@ -2,8 +2,9 @@ import numpy as np import matplotlib.pyplot as plt -from mpl_toolkits.mplot3d import Axes3D import os +import torch +from mpl_toolkits.mplot3d import Axes3D def rotation_matrix_to_quaternion(R): @@ -512,3 +513,72 @@ def plot_contact_forces_and_torques(contact_states): plt.tight_layout() plt.show() + + +def export_model_to_onnx(agent, robot, params, path): + """Export the trained models to ONNX format""" + os.makedirs(path, exist_ok=True) + + # Set models to evaluation mode to disable dropout + agent.policy.eval() + + # Create dummy input and ensure no gradients + dummy_observation = torch.randn(1, params["state_dim"]).to(agent.device) + + with torch.no_grad(): + torch.onnx.export( + agent.policy, + dummy_observation, + f"{path}/{robot}_ppo.onnx", + export_params=True, + opset_version=11, + do_constant_folding=True, + input_names=["observation"], + output_names=["action", "value"], + dynamic_axes={ + "observation": {0: "batch_size"}, + "action": {0: "batch_size"}, + "value": {0: "batch_size"}, + }, + verbose=False, + ) + + +def export_dqn_to_onnx(agent, robot, params, path): + """Export the trained DQN model to ONNX format""" + os.makedirs(path, exist_ok=True) + + # For DQN, we need to export the q_net (the actual neural network) + # Set the q_net to evaluation mode to disable dropout + agent.policy.q_net.eval() + + # Create dummy input and ensure no gradients + # Use the same device as the model + device = next(agent.policy.q_net.parameters()).device + dummy_observation = torch.randn(1, params["state_dim"]).to(device) + + with torch.no_grad(): + torch.onnx.export( + agent.policy.q_net, # Export the q_net, not the policy wrapper + dummy_observation, + f"{path}/{robot}_dqn.onnx", + export_params=True, + opset_version=11, + do_constant_folding=True, + input_names=["observation"], + output_names=["q_values"], # DQN outputs Q-values, not actions + dynamic_axes={ + "observation": {0: "batch_size"}, + "q_values": {0: "batch_size"}, + }, + verbose=False, + ) + + print(f"ONNX model exported to: {path}/{robot}_dqn.onnx") + + +class BaseVelocityGroundTruth: + def __init__(self, timestamp, linear_velocity, angular_velocity): + self.timestamp = timestamp + self.linear_velocity = linear_velocity + self.angular_velocity = angular_velocity