@@ -83,28 +83,70 @@ def cat(*sparse_tensors):
8383 >>> sout2 = ME.cat(sin, sin2, sout) # Can concatenate multiple sparse tensors
8484
8585 """
86- for s in sparse_tensors :
87- assert isinstance (s , SparseTensor ), "Inputs must be sparse tensors."
88- coordinate_manager = sparse_tensors [0 ].coordinate_manager
89- coordinate_map_key = sparse_tensors [0 ].coordinate_map_key
90- for s in sparse_tensors :
91- assert (
92- coordinate_manager == s .coordinate_manager
93- ), COORDINATE_MANAGER_DIFFERENT_ERROR
94- assert coordinate_map_key == s .coordinate_map_key , (
95- COORDINATE_KEY_DIFFERENT_ERROR
96- + str (coordinate_map_key )
97- + " != "
98- + str (s .coordinate_map_key )
86+ assert (
87+ len (sparse_tensors ) > 1
88+ ), f"Invalid number of inputs. The input must be at least two len(sparse_tensors) > 1"
89+
90+ if isinstance (sparse_tensors [0 ], SparseTensor ):
91+ device = sparse_tensors [0 ].device
92+ coordinate_manager = sparse_tensors [0 ].coordinate_manager
93+ coordinate_map_key = sparse_tensors [0 ].coordinate_map_key
94+ for s in sparse_tensors :
95+ assert isinstance (
96+ s , SparseTensor
97+ ), "Inputs must be either SparseTensors or TensorFields."
98+ assert (
99+ device == s .device
100+ ), f"Device must be the same. { device } != { s .device } "
101+ assert (
102+ coordinate_manager == s .coordinate_manager
103+ ), COORDINATE_MANAGER_DIFFERENT_ERROR
104+ assert coordinate_map_key == s .coordinate_map_key , (
105+ COORDINATE_KEY_DIFFERENT_ERROR
106+ + str (coordinate_map_key )
107+ + " != "
108+ + str (s .coordinate_map_key )
109+ )
110+ tens = []
111+ for s in sparse_tensors :
112+ tens .append (s .F )
113+ return SparseTensor (
114+ torch .cat (tens , dim = 1 ),
115+ coordinate_map_key = coordinate_map_key ,
116+ coordinate_manager = coordinate_manager ,
117+ )
118+ elif isinstance (sparse_tensors [0 ], TensorField ):
119+ device = sparse_tensors [0 ].device
120+ coordinate_manager = sparse_tensors [0 ].coordinate_manager
121+ coordinate_field_map_key = sparse_tensors [0 ].coordinate_field_map_key
122+ for s in sparse_tensors :
123+ assert isinstance (
124+ s , TensorField
125+ ), "Inputs must be either SparseTensors or TensorFields."
126+ assert (
127+ device == s .device
128+ ), f"Device must be the same. { device } != { s .device } "
129+ assert (
130+ coordinate_manager == s .coordinate_manager
131+ ), COORDINATE_MANAGER_DIFFERENT_ERROR
132+ assert coordinate_field_map_key == s .coordinate_field_map_key , (
133+ COORDINATE_KEY_DIFFERENT_ERROR
134+ + str (coordinate_field_map_key )
135+ + " != "
136+ + str (s .coordinate_field_map_key )
137+ )
138+ tens = []
139+ for s in sparse_tensors :
140+ tens .append (s .F )
141+ return TensorField (
142+ torch .cat (tens , dim = 1 ),
143+ coordinate_field_map_key = coordinate_field_map_key ,
144+ coordinate_manager = coordinate_manager ,
145+ )
146+ else :
147+ raise ValueError (
148+ "Invalid data type. The input must be either a list of sparse tensors or a list of tensor fields."
99149 )
100- tens = []
101- for s in sparse_tensors :
102- tens .append (s .F )
103- return SparseTensor (
104- torch .cat (tens , dim = 1 ),
105- coordinate_map_key = coordinate_map_key ,
106- coordinate_manager = coordinate_manager ,
107- )
108150
109151
110152def dense_coordinates (shape : Union [list , torch .Size ]):
@@ -131,7 +173,7 @@ def dense_coordinates(shape: Union[list, torch.Size]):
131173 for s in np .meshgrid (
132174 np .linspace (0 , B - 1 , B ),
133175 * (np .linspace (0 , s - 1 , s ) for s in size [2 :]),
134- indexing = "ij"
176+ indexing = "ij" ,
135177 )
136178 ],
137179 1 ,
0 commit comments