@@ -71,9 +71,10 @@ export abstract class NDArrayMath {
71
71
*/
72
72
enableDebugMode ( ) {
73
73
this . debugMode = true ;
74
- console . warn ( 'Debugging mode is ON. The output of every math call will ' +
75
- 'be downloaded to CPU and checked for NaNs. ' +
76
- 'This significantly impacts performance.' ) ;
74
+ console . warn (
75
+ 'Debugging mode is ON. The output of every math call will ' +
76
+ 'be downloaded to CPU and checked for NaNs. ' +
77
+ 'This significantly impacts performance.' ) ;
77
78
}
78
79
79
80
/**
@@ -97,7 +98,7 @@ export abstract class NDArrayMath {
97
98
endScope ( result : ScopeResult ) {
98
99
let arraysToKeep = this . activeScopeNDArraysToKeep ;
99
100
if ( result != null ) {
100
- arraysToKeep = arraysToKeep . concat ( result as NDArray | NDArray [ ] ) ;
101
+ arraysToKeep = arraysToKeep . concat ( result as NDArray | NDArray [ ] ) ;
101
102
}
102
103
// Dispose the current scope.
103
104
for ( let i = 0 ; i < this . activeScope . length ; i ++ ) {
@@ -321,22 +322,15 @@ export abstract class NDArrayMath {
321
322
protected abstract cloneInternal < T extends NDArray > ( ndarray : T ) : T ;
322
323
323
324
/**
324
- * Reshapes an NDArray to a new shape. The size of the input NDArray must
325
- * match the size of the requested shape.
326
- * @param ndarray The input NDArray.
327
- * @param newShape The new shape to reshape the NDArray to. Must be the same
328
- * size as the NDArray.
325
+ * @deprecated Please call reshape() directly on the ndarray object.
329
326
*/
330
327
reshape < T1 extends NDArray , T2 extends NDArray > (
331
328
ndarray : T1 , newShape : number [ ] ) : T2 {
332
- util . assert (
333
- ndarray . size === util . sizeFromShape ( newShape ) ,
334
- `Error in reshape: old size ${ ndarray . size } must match new size ` +
335
- `${ util . sizeFromShape ( newShape ) } .` ) ;
336
- return this . track ( this . reshapeInternal < T1 , T2 > ( ndarray , newShape ) ) ;
329
+ console . warn (
330
+ 'math.reshape() is deprecated. Please call reshape() ' +
331
+ 'directly on the ndarray object' ) ;
332
+ return ndarray . reshape ( newShape ) ;
337
333
}
338
- protected abstract reshapeInternal < T1 extends NDArray , T2 extends NDArray > (
339
- ndarray : T1 , newShape : number [ ] ) : T2 ;
340
334
341
335
/**
342
336
* Extracts a slice from a matrix. The operation extraces a slice from input
@@ -1148,7 +1142,8 @@ export abstract class NDArrayMath {
1148
1142
* @param h Array of previous cell outputs.
1149
1143
* @return Tuple [nextCellStates, cellOutputs]
1150
1144
*/
1151
- multiRNNCell ( lstmCells : LSTMCell [ ] , data : Array2D , c : Array2D [ ] ,
1145
+ multiRNNCell (
1146
+ lstmCells : LSTMCell [ ] , data : Array2D , c : Array2D [ ] ,
1152
1147
h : Array2D [ ] ) : [ Array2D [ ] , Array2D [ ] ] {
1153
1148
util . assert (
1154
1149
data . shape [ 0 ] === 1 ,
@@ -1187,8 +1182,9 @@ export abstract class NDArrayMath {
1187
1182
* @param h Previous cell output.
1188
1183
* @return Tuple [nextCellState, cellOutput]
1189
1184
*/
1190
- basicLSTMCell ( forgetBias : Scalar , lstmKernel : Array2D , lstmBias : Array1D ,
1191
- data : Array2D , c : Array2D , h : Array2D ) : [ Array2D , Array2D ] {
1185
+ basicLSTMCell (
1186
+ forgetBias : Scalar , lstmKernel : Array2D , lstmBias : Array1D , data : Array2D ,
1187
+ c : Array2D , h : Array2D ) : [ Array2D , Array2D ] {
1192
1188
const res = this . scope ( ( ) => {
1193
1189
util . assert (
1194
1190
data . shape [ 0 ] === 1 ,
@@ -1207,25 +1203,25 @@ export abstract class NDArrayMath {
1207
1203
1208
1204
// i = input_gate, j = new_input, f = forget_gate, o = output_gate
1209
1205
const i = this . slice2D ( res , [ 0 , 0 ] , [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1210
- const j = this . slice2D ( res , [ 0 , res . shape [ 1 ] / 4 * 1 ] ,
1211
- [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1212
- const f = this . slice2D ( res , [ 0 , res . shape [ 1 ] / 4 * 2 ] ,
1213
- [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1214
- const o = this . slice2D ( res , [ 0 , res . shape [ 1 ] / 4 * 3 ] ,
1215
- [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1216
-
1217
- const newC = this . add (
1218
- this . multiplyStrict ( c ,
1219
- this . sigmoid ( this . scalarPlusArray ( forgetBias , f ) ) ) ,
1220
- this . multiplyStrict ( this . sigmoid ( i ) , this . tanh ( j ) ) ) as Array2D ;
1221
- const newH = this . multiplyStrict (
1222
- this . tanh ( newC ) , this . sigmoid ( o ) ) as Array2D ;
1206
+ const j = this . slice2D (
1207
+ res , [ 0 , res . shape [ 1 ] / 4 * 1 ] , [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1208
+ const f = this . slice2D (
1209
+ res , [ 0 , res . shape [ 1 ] / 4 * 2 ] , [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1210
+ const o = this . slice2D (
1211
+ res , [ 0 , res . shape [ 1 ] / 4 * 3 ] , [ res . shape [ 0 ] , res . shape [ 1 ] / 4 ] ) ;
1212
+
1213
+ const newC =
1214
+ this . add (
1215
+ this . multiplyStrict (
1216
+ c , this . sigmoid ( this . scalarPlusArray ( forgetBias , f ) ) ) ,
1217
+ this . multiplyStrict ( this . sigmoid ( i ) , this . tanh ( j ) ) ) as Array2D ;
1218
+ const newH =
1219
+ this . multiplyStrict ( this . tanh ( newC ) , this . sigmoid ( o ) ) as Array2D ;
1223
1220
1224
1221
return [ newC , newH ] ;
1225
1222
} ) ;
1226
1223
return [ res [ 0 ] , res [ 1 ] ] ;
1227
1224
}
1228
-
1229
1225
}
1230
1226
1231
1227
export enum MatrixOrientation {
0 commit comments