3838 `C`: The number of feature channels.
3939"""
4040
41+ from typing import Optional
42+
4143import tensorflow as tf
4244from tensorflow_graphics .util import export_api
4345
@@ -62,13 +64,15 @@ def __init__(self, channels, momentum):
6264 self .channels = channels
6365 self .momentum = momentum
6466
65- def build (self , input_shape ):
67+ def build (self , input_shape : tf . Tensor ):
6668 """Builds the layer with a specified input_shape."""
6769 self .conv = tf .keras .layers .Conv2D (
6870 self .channels , (1 , 1 ), input_shape = input_shape )
6971 self .bn = tf .keras .layers .BatchNormalization (momentum = self .momentum )
7072
71- def call (self , inputs , training = None ): # pylint: disable=arguments-differ
73+ def call (self ,
74+ inputs : tf .Tensor ,
75+ training : Optional [bool ] = None ) -> tf .Tensor : # pylint: disable=arguments-differ
7276 """Executes the convolution.
7377
7478 Args:
@@ -96,12 +100,14 @@ def __init__(self, channels, momentum):
96100 self .momentum = momentum
97101 self .channels = channels
98102
99- def build (self , input_shape ):
103+ def build (self , input_shape : tf . Tensor ):
100104 """Builds the layer with a specified input_shape."""
101105 self .dense = tf .keras .layers .Dense (self .channels , input_shape = input_shape )
102106 self .bn = tf .keras .layers .BatchNormalization (momentum = self .momentum )
103107
104- def call (self , inputs , training = None ): # pylint: disable=arguments-differ
108+ def call (self ,
109+ inputs : tf .Tensor ,
110+ training : Optional [bool ] = None ) -> tf .Tensor : # pylint: disable=arguments-differ
105111 """Executes the convolution.
106112
107113 Args:
@@ -125,7 +131,7 @@ class VanillaEncoder(tf.keras.layers.Layer):
125131 https://github.com/charlesq34/pointnet/blob/master/models/pointnet_cls_basic.py
126132 """
127133
128- def __init__ (self , momentum = .5 ):
134+ def __init__ (self , momentum : float = .5 ):
129135 """Constructs a VanillaEncoder keras layer.
130136
131137 Args:
@@ -138,7 +144,9 @@ def __init__(self, momentum=.5):
138144 self .conv4 = PointNetConv2Layer (128 , momentum )
139145 self .conv5 = PointNetConv2Layer (1024 , momentum )
140146
141- def call (self , inputs , training = None ): # pylint: disable=arguments-differ
147+ def call (self ,
148+ inputs : tf .Tensor ,
149+ training : Optional [bool ] = None ) -> tf .Tensor : # pylint: disable=arguments-differ
142150 """Computes the PointNet features.
143151
144152 Args:
@@ -166,7 +174,10 @@ class ClassificationHead(tf.keras.layers.Layer):
166174 logits of the num_classes classes.
167175 """
168176
169- def __init__ (self , num_classes = 40 , momentum = 0.5 , dropout_rate = 0.3 ):
177+ def __init__ (self ,
178+ num_classes : int = 40 ,
179+ momentum : float = 0.5 ,
180+ dropout_rate : float = 0.3 ):
170181 """Constructor.
171182
172183 Args:
@@ -180,7 +191,9 @@ def __init__(self, num_classes=40, momentum=0.5, dropout_rate=0.3):
180191 self .dropout = tf .keras .layers .Dropout (dropout_rate )
181192 self .dense3 = tf .keras .layers .Dense (num_classes , activation = "linear" )
182193
183- def call (self , inputs , training = None ): # pylint: disable=arguments-differ
194+ def call (self ,
195+ inputs : tf .Tensor ,
196+ training : Optional [bool ] = None ) -> tf .Tensor : # pylint: disable=arguments-differ
184197 """Computes the classifiation logits given features (note: without softmax).
185198
186199 Args:
@@ -199,7 +212,10 @@ def call(self, inputs, training=None): # pylint: disable=arguments-differ
199212class PointNetVanillaClassifier (tf .keras .layers .Layer ):
200213 """The PointNet 'Vanilla' classifier (i.e. without spatial transformer)."""
201214
202- def __init__ (self , num_classes = 40 , momentum = .5 , dropout_rate = .3 ):
215+ def __init__ (self ,
216+ num_classes : int = 40 ,
217+ momentum : float = .5 ,
218+ dropout_rate : float = .3 ):
203219 """Constructor.
204220
205221 Args:
@@ -212,7 +228,9 @@ def __init__(self, num_classes=40, momentum=.5, dropout_rate=.3):
212228 self .classifier = ClassificationHead (
213229 num_classes = num_classes , momentum = momentum , dropout_rate = dropout_rate )
214230
215- def call (self , points , training = None ): # pylint: disable=arguments-differ
231+ def call (self ,
232+ points : tf .Tensor ,
233+ training : Optional [bool ] = None ) -> tf .Tensor : # pylint: disable=arguments-differ
216234 """Computes the classifiation logits of a point set.
217235
218236 Args:
@@ -227,7 +245,8 @@ def call(self, points, training=None): # pylint: disable=arguments-differ
227245 return logits
228246
229247 @staticmethod
230- def loss (labels , logits ):
248+ def loss (labels : tf .Tensor ,
249+ logits : tf .Tensor ) -> tf .Tensor :
231250 """The classification model training loss.
232251
233252 Note:
@@ -236,6 +255,9 @@ def loss(labels, logits):
236255 Args:
237256 labels: a tensor with shape `[B,]`
238257 logits: a tensor with shape `[B,num_classes]`
258+
259+ Returns:
260+ A tensor with the same shape as labels and of the same type as logits.
239261 """
240262 cross_entropy = tf .nn .sparse_softmax_cross_entropy_with_logits
241263 residual = cross_entropy (labels , logits )
0 commit comments