-
|
According to the documentation : "The flattening order (i.e. the order of elements in the output list) is deterministic, corresponding to a left-to-right depth-first tree traversal." For example, d = dict()
d['b'] = dict([('U_amp', 0.5), ('U_phase', 0.1), ('UA_amp', 0.3), ('UA_phase', 0.2)])
d['a'] = dict([('U_amp', 0.6), ('U_phase', 0.2), ('UA_amp', 0.4), ('UA_phase', 0.3)])
d['c'] = dict([('U_amp', 0.7), ('U_phase', 0.3), ('UA_amp', 0.5), ('UA_phase', 0.4)])
jax.tree_util.tree_flatten(d)and the result will be The result is sorted by keys. However, if I use OrderedDict d = OrderedDict()
d['b'] = OrderedDict([('U_amp', 0.5), ('U_phase', 0.1), ('UA_amp', 0.3), ('UA_phase', 0.2)])
d['a'] = OrderedDict([('U_amp', 0.6), ('U_phase', 0.2), ('UA_amp', 0.4), ('UA_phase', 0.3)])
d['c'] = OrderedDict([('U_amp', 0.7), ('U_phase', 0.3), ('UA_amp', 0.5), ('UA_phase', 0.4)])
jax.tree_util.tree_flatten(d) The order is preserved. I want to know is this a feature or is this just a happy accident and I shouldn't rely on this. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
This is working as expected – dicts need a deterministic order during flattening in order for equality semantics of the tree structure to match equality semantics of the original object. In the case of |
Beta Was this translation helpful? Give feedback.
This is working as expected – dicts need a deterministic order during flattening in order for equality semantics of the tree structure to match equality semantics of the original object. In the case of
OrderedDict, that deterministic order is inherent to the object; in case of a standard dict, that deterministic order comes from sorting the keys. Does that make sense?