1+ /* *
2+ * \file codac2_OctaSym_operator.h
3+ * ----------------------------------------------------------------------------
4+ * \date 2025
5+ * \author Simon Rohou
6+ * \copyright Copyright 2025 Codac Team
7+ * \license GNU Lesser General Public License (LGPL)
8+ */
9+
10+ #pragma once
11+
12+ #include " codac2_OctaSym.h"
13+
14+ namespace codac2
15+ {
16+ struct OctaSymOp
17+ {
18+ template <typename X1>
19+ static inline std::string str (const X1& x1)
20+ {
21+ return " sym(" + x1->str () + " )" ;
22+ }
23+
24+ template <typename X1>
25+ static inline std::pair<Index,Index> output_shape ([[maybe_unused]] const X1& s1)
26+ {
27+ return s1->output_shape ();
28+ }
29+
30+ static inline IntervalVector fwd (const OctaSym& s, const IntervalVector& x1)
31+ {
32+ assert ((Index)s.size () == x1.size ());
33+ return s (x1);
34+ }
35+
36+ static inline VectorType fwd_natural (const OctaSym& s, const VectorType& x1)
37+ {
38+ assert ((Index)s.size () == x1.m .size ());
39+ return {
40+ fwd (s, x1.a ),
41+ x1.def_domain
42+ };
43+ }
44+
45+ static inline VectorType fwd_centered (const OctaSym& s, const VectorType& x1)
46+ {
47+ assert ((Index)s.size () == x1.m .size ());
48+
49+ auto da = x1.da ;
50+ for (size_t i = 0 ; i < s.size () ; i++)
51+ da.row (i) = sign (s[i])*x1.da .row (std::abs (s[i])-1 );
52+
53+ return {
54+ fwd (s, x1.m ),
55+ fwd (s, x1.a ),
56+ da,
57+ x1.def_domain
58+ };
59+ }
60+
61+ static inline void bwd (const OctaSym& s, const IntervalVector& y, IntervalVector& x1)
62+ {
63+ assert ((Index)s.size () == y.size () && (Index)s.size () == x1.size ());
64+ x1 &= s.invert ()(y);
65+ }
66+ };
67+
68+
69+ template <>
70+ class AnalyticOperationExpr <OctaSymOp,VectorType,VectorType>
71+ : public AnalyticExpr<VectorType>, public OperationExprBase<AnalyticExpr<VectorType>>
72+ {
73+ public:
74+
75+ AnalyticOperationExpr (const OctaSym& s, const VectorExpr& x1)
76+ : OperationExprBase<AnalyticExpr<VectorType>>(x1), _s(s)
77+ { }
78+
79+ std::shared_ptr<ExprBase> copy () const
80+ {
81+ return std::make_shared<AnalyticOperationExpr<OctaSymOp,VectorType,VectorType>>(*this );
82+ }
83+
84+ void replace_arg (const ExprID& old_arg_id, const std::shared_ptr<ExprBase>& new_expr)
85+ {
86+ return OperationExprBase<AnalyticExpr<VectorType>>::replace_arg (old_arg_id, new_expr);
87+ }
88+
89+ VectorType fwd_eval (ValuesMap& v, Index total_input_size, bool natural_eval) const
90+ {
91+ if (natural_eval)
92+ return AnalyticExpr<VectorType>::init_value (
93+ v, OctaSymOp::fwd_natural (_s, std::get<0 >(this ->_x )->fwd_eval (v, total_input_size, natural_eval)));
94+ else
95+ return AnalyticExpr<VectorType>::init_value (
96+ v, OctaSymOp::fwd_centered (_s, std::get<0 >(this ->_x )->fwd_eval (v, total_input_size, natural_eval)));
97+ }
98+
99+ void bwd_eval (ValuesMap& v) const
100+ {
101+ OctaSymOp::bwd (_s, AnalyticExpr<VectorType>::value (v).a , std::get<0 >(this ->_x )->value (v).a );
102+ std::get<0 >(this ->_x )->bwd_eval (v);
103+ }
104+
105+ std::pair<Index,Index> output_shape () const {
106+ return { _s.size (), 1 };
107+ }
108+
109+ virtual bool belongs_to_args_list (const FunctionArgsList& args) const
110+ {
111+ return std::get<0 >(this ->_x )->belongs_to_args_list (args);
112+ }
113+
114+ std::string str (bool in_parentheses = false ) const
115+ {
116+ std::string s = " S" ; // user cannot (yet) specify a name for the symmetry
117+ return in_parentheses ? " (" + s + " )" : s;
118+ }
119+
120+ virtual bool is_str_leaf () const
121+ {
122+ return true ;
123+ }
124+
125+ protected:
126+
127+ const OctaSym _s;
128+ };
129+ }
0 commit comments