|
24 | 24 | from __future__ import division |
25 | 25 | from __future__ import print_function |
26 | 26 |
|
| 27 | +from typing import Any, Callable, Dict, List, Tuple |
| 28 | + |
27 | 29 | import numpy as np |
28 | 30 | import tensorflow as tf |
29 | 31 |
|
30 | 32 | from tensorflow_graphics.geometry.convolution import utils as conv_utils |
31 | 33 | from tensorflow_graphics.geometry.representation.mesh import utils as mesh_utils |
32 | 34 | from tensorflow_graphics.util import shape |
| 35 | +from tensorflow_graphics.util import type_alias |
33 | 36 |
|
34 | 37 | DEFAULT_IO_PARAMS = { |
35 | 38 | 'batch_size': 8, |
|
42 | 45 | } |
43 | 46 |
|
44 | 47 |
|
45 | | -def adjacency_from_edges(edges, weights, num_edges, num_vertices): |
| 48 | +def adjacency_from_edges( |
| 49 | + edges: type_alias.TensorLike, |
| 50 | + weights: type_alias.TensorLike, |
| 51 | + num_edges: type_alias.TensorLike, |
| 52 | + num_vertices: type_alias.TensorLike) -> tf.SparseTensor: |
46 | 53 | """Returns a batched sparse 1-ring adj tensor from edge list tensor. |
47 | 54 |
|
48 | 55 | Args: |
@@ -103,7 +110,9 @@ def adjacency_from_edges(edges, weights, num_edges, num_vertices): |
103 | 110 | return adjacency |
104 | 111 |
|
105 | 112 |
|
106 | | -def get_weighted_edges(faces, self_edges=True): |
| 113 | +def get_weighted_edges( |
| 114 | + faces: np.ndarray, |
| 115 | + self_edges: bool = True) -> Tuple[np.ndarray, np.ndarray]: |
107 | 116 | r"""Gets unique edges and degree weights from a triangular mesh. |
108 | 117 |
|
109 | 118 | The shorthands used below are: |
@@ -136,12 +145,12 @@ def get_weighted_edges(faces, self_edges=True): |
136 | 145 | return edges, weights |
137 | 146 |
|
138 | 147 |
|
139 | | -def _tfrecords_to_dataset(tfrecords, |
140 | | - parallel_threads, |
141 | | - shuffle, |
142 | | - repeat, |
143 | | - sloppy, |
144 | | - max_readers=16): |
| 148 | +def _tfrecords_to_dataset(tfrecords: List[str], |
| 149 | + parallel_threads: int, |
| 150 | + shuffle: bool, |
| 151 | + repeat: bool, |
| 152 | + sloppy: bool, |
| 153 | + max_readers: int = 16) -> tf.data.TFRecordDataset: |
145 | 154 | """Creates a TFRecordsDataset that iterates over filenames in parallel. |
146 | 155 |
|
147 | 156 | Args: |
@@ -244,7 +253,9 @@ def _parse_mesh_data(mesh_data, mean_center=True): |
244 | 253 | return mesh_data |
245 | 254 |
|
246 | 255 |
|
247 | | -def create_dataset_from_tfrecords(tfrecords, params): |
| 256 | +def create_dataset_from_tfrecords( |
| 257 | + tfrecords: List[str], |
| 258 | + params: Dict[str, Any]) -> tf.data.Dataset: |
248 | 259 | """Creates a mesh dataset given a list of tf records filenames. |
249 | 260 |
|
250 | 261 | Args: |
@@ -309,7 +320,10 @@ def _set_default_if_none(param, param_dict, default_val): |
309 | 320 | drop_remainder=is_training) |
310 | 321 |
|
311 | 322 |
|
312 | | -def create_input_from_dataset(dataset_fn, files, io_params): |
| 323 | +def create_input_from_dataset( |
| 324 | + dataset_fn: Callable[..., Any], |
| 325 | + files: List[str], |
| 326 | + io_params: Dict[str, Any]) -> Tuple[Dict[str, Any], tf.Tensor]: |
313 | 327 | """Creates input function given dataset generator and input files. |
314 | 328 |
|
315 | 329 | Args: |
|
0 commit comments