Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 156c227

Browse files
authored
Move the rest of the math ops to logical (#54)
* migrate addScaledMat and conv2d to logical sampling and improve shader compiler * fix conv2d zero paddig and make the project build * migrate rest of conv shaders to logical * replace zero pad with if * migrate pool ops * removing math.reshape * Merge remote-tracking branch 'origin/master' into reshape * migrate copy op to logical * remove duplicate copy file * move the rest of math ops to logical * address comments
1 parent f3b71bc commit 156c227

22 files changed

+438
-1242
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
"editor.insertSpaces": true,
1515
"files.insertFinalNewline": true,
1616
"editor.detectIndentation": false,
17+
"editor.wrappingIndent": "none",
1718
"typescript.tsdk": "node_modules/typescript/lib"
1819
}

src/graph_runner.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ export class GraphRunner {
211211
}
212212

213213
});
214-
setTimeout(() => this.trainNetwork());
214+
requestAnimationFrame(() => this.trainNetwork());
215215
}
216216

217217
infer(
@@ -243,7 +243,7 @@ export class GraphRunner {
243243
this.currentInferenceLoopNumPasses = numPasses;
244244
if (!this.isInferring) {
245245
this.inferencePassesThisRun = 0;
246-
setTimeout(() => this.inferNetwork());
246+
requestAnimationFrame(() => this.inferNetwork());
247247
}
248248
this.isInferring = true;
249249
}

src/math/math.ts

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,10 @@ export abstract class NDArrayMath {
7171
*/
7272
enableDebugMode() {
7373
this.debugMode = true;
74-
console.warn('Debugging mode is ON. The output of every math call will ' +
75-
'be downloaded to CPU and checked for NaNs. ' +
76-
'This significantly impacts performance.');
74+
console.warn(
75+
'Debugging mode is ON. The output of every math call will ' +
76+
'be downloaded to CPU and checked for NaNs. ' +
77+
'This significantly impacts performance.');
7778
}
7879

7980
/**
@@ -97,7 +98,7 @@ export abstract class NDArrayMath {
9798
endScope(result: ScopeResult) {
9899
let arraysToKeep = this.activeScopeNDArraysToKeep;
99100
if (result != null) {
100-
arraysToKeep = arraysToKeep.concat(result as NDArray|NDArray[]);
101+
arraysToKeep = arraysToKeep.concat(result as NDArray | NDArray[]);
101102
}
102103
// Dispose the current scope.
103104
for (let i = 0; i < this.activeScope.length; i++) {
@@ -321,22 +322,15 @@ export abstract class NDArrayMath {
321322
protected abstract cloneInternal<T extends NDArray>(ndarray: T): T;
322323

323324
/**
324-
* Reshapes an NDArray to a new shape. The size of the input NDArray must
325-
* match the size of the requested shape.
326-
* @param ndarray The input NDArray.
327-
* @param newShape The new shape to reshape the NDArray to. Must be the same
328-
* size as the NDArray.
325+
* @deprecated Please call reshape() directly on the ndarray object.
329326
*/
330327
reshape<T1 extends NDArray, T2 extends NDArray>(
331328
ndarray: T1, newShape: number[]): T2 {
332-
util.assert(
333-
ndarray.size === util.sizeFromShape(newShape),
334-
`Error in reshape: old size ${ndarray.size} must match new size ` +
335-
`${util.sizeFromShape(newShape)}.`);
336-
return this.track(this.reshapeInternal<T1, T2>(ndarray, newShape));
329+
console.warn(
330+
'math.reshape() is deprecated. Please call reshape() ' +
331+
'directly on the ndarray object');
332+
return ndarray.reshape(newShape);
337333
}
338-
protected abstract reshapeInternal<T1 extends NDArray, T2 extends NDArray>(
339-
ndarray: T1, newShape: number[]): T2;
340334

341335
/**
342336
* Extracts a slice from a matrix. The operation extraces a slice from input
@@ -1148,7 +1142,8 @@ export abstract class NDArrayMath {
11481142
* @param h Array of previous cell outputs.
11491143
* @return Tuple [nextCellStates, cellOutputs]
11501144
*/
1151-
multiRNNCell(lstmCells: LSTMCell[], data: Array2D, c: Array2D[],
1145+
multiRNNCell(
1146+
lstmCells: LSTMCell[], data: Array2D, c: Array2D[],
11521147
h: Array2D[]): [Array2D[], Array2D[]] {
11531148
util.assert(
11541149
data.shape[0] === 1,
@@ -1187,8 +1182,9 @@ export abstract class NDArrayMath {
11871182
* @param h Previous cell output.
11881183
* @return Tuple [nextCellState, cellOutput]
11891184
*/
1190-
basicLSTMCell(forgetBias: Scalar, lstmKernel: Array2D, lstmBias: Array1D,
1191-
data: Array2D, c: Array2D, h: Array2D): [Array2D, Array2D] {
1185+
basicLSTMCell(
1186+
forgetBias: Scalar, lstmKernel: Array2D, lstmBias: Array1D, data: Array2D,
1187+
c: Array2D, h: Array2D): [Array2D, Array2D] {
11921188
const res = this.scope(() => {
11931189
util.assert(
11941190
data.shape[0] === 1,
@@ -1207,25 +1203,25 @@ export abstract class NDArrayMath {
12071203

12081204
// i = input_gate, j = new_input, f = forget_gate, o = output_gate
12091205
const i = this.slice2D(res, [0, 0], [res.shape[0], res.shape[1] / 4]);
1210-
const j = this.slice2D(res, [0, res.shape[1] / 4 * 1],
1211-
[res.shape[0], res.shape[1] / 4]);
1212-
const f = this.slice2D(res, [0, res.shape[1] / 4 * 2],
1213-
[res.shape[0], res.shape[1] / 4]);
1214-
const o = this.slice2D(res, [0, res.shape[1] / 4 * 3],
1215-
[res.shape[0], res.shape[1] / 4]);
1216-
1217-
const newC = this.add(
1218-
this.multiplyStrict(c,
1219-
this.sigmoid(this.scalarPlusArray(forgetBias, f))),
1220-
this.multiplyStrict(this.sigmoid(i), this.tanh(j))) as Array2D;
1221-
const newH = this.multiplyStrict(
1222-
this.tanh(newC), this.sigmoid(o)) as Array2D;
1206+
const j = this.slice2D(
1207+
res, [0, res.shape[1] / 4 * 1], [res.shape[0], res.shape[1] / 4]);
1208+
const f = this.slice2D(
1209+
res, [0, res.shape[1] / 4 * 2], [res.shape[0], res.shape[1] / 4]);
1210+
const o = this.slice2D(
1211+
res, [0, res.shape[1] / 4 * 3], [res.shape[0], res.shape[1] / 4]);
1212+
1213+
const newC =
1214+
this.add(
1215+
this.multiplyStrict(
1216+
c, this.sigmoid(this.scalarPlusArray(forgetBias, f))),
1217+
this.multiplyStrict(this.sigmoid(i), this.tanh(j))) as Array2D;
1218+
const newH =
1219+
this.multiplyStrict(this.tanh(newC), this.sigmoid(o)) as Array2D;
12231220

12241221
return [newC, newH];
12251222
});
12261223
return [res[0], res[1]];
12271224
}
1228-
12291225
}
12301226

12311227
export enum MatrixOrientation {

src/math/math_cpu.ts

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,6 @@ export class NDArrayMathCPU extends NDArrayMath {
3131
ndarray.shape, {values: new Float32Array(ndarray.getValues())});
3232
}
3333

34-
protected reshapeInternal<T1 extends NDArray, T2 extends NDArray>(
35-
ndarray: T1, newShape: number[]): T2 {
36-
return this.cloneInternal(ndarray).reshape<T2>(newShape);
37-
}
38-
3934
protected slice2DInternal(
4035
input: Array2D, beginRowCol: [number, number],
4136
sizeRowCol: [number, number]): Array2D {

0 commit comments

Comments
 (0)