@@ -170,8 +170,8 @@ def quaternion_weights(
170170 quaternions in its last dimension.
171171 quaternion2: A tensor of shape `[A1, ... , An, 4]` storing normalized
172172 quaternions in its last dimension.
173- percent: A `float` or a tensor with a shape broadcastable to the shape `[A1,
174- ... , An]` .
173+ percent: A `float` or tensor with shape broadcastable to the shape of input
174+ vectors .
175175 eps: A `float` used to make operations safe. When left as None, the function
176176 automatically picks the best epsilon based on the dtype and the operation.
177177 name: A name for this op. Defaults to "quaternion_weights".
@@ -198,7 +198,7 @@ def quaternion_weights(
198198 tensor = quaternion2 , tensor_name = "quaternion2" , has_dim_equals = (- 1 , 4 ))
199199 shape .compare_batch_dimensions (
200200 tensors = (quaternion1 , quaternion2 , percent ),
201- last_axes = ( - 2 , - 2 , - 1 ) ,
201+ last_axes = - 1 ,
202202 broadcast_compatible = True ,
203203 tensor_names = ("quaternion1" , "quaternion2" , "percent" ))
204204 quaternion1 = asserts .assert_normalized (quaternion1 )
@@ -266,7 +266,7 @@ def vector_weights(vector1: type_alias.TensorLike,
266266 tensor_names = ("vector1" , "vector2" ))
267267 shape .compare_batch_dimensions (
268268 tensors = (vector1 , vector2 , percent ),
269- last_axes = ( - 2 , - 2 , - 1 ) ,
269+ last_axes = - 1 ,
270270 broadcast_compatible = True ,
271271 tensor_names = ("vector1" , "vector2" , "percent" ))
272272 normalized1 = tf .nn .l2_normalize (vector1 , axis = - 1 )
0 commit comments