Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch normalization for 5d #8268

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion tfjs-core/src/gradients/FusedBatchNorm_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ export const fusedBatchNormGradConfig: GradConfig = {
return reshape(
mul(mul(dy,
tile(
reshape(oneOverSqrtVariance, [1, 1, 1, mean.shape[0]]),
reshape(oneOverSqrtVariance,
[1, 1, 1, 1, mean.shape[0]]),
tileShape)),
scaleValue),
x.shape);
Expand Down
8 changes: 4 additions & 4 deletions tfjs-core/src/ops/batchnorm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import {ENGINE} from '../engine';
import {FusedBatchNorm, FusedBatchNormAttrs, FusedBatchNormInputs} from '../kernel_names';
import {NamedAttrMap} from '../kernel_registry';
import {Tensor, Tensor1D, Tensor4D} from '../tensor';
import {Tensor, Tensor1D, Tensor5D} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {Rank, TensorLike} from '../types';
import * as util from '../util';

import {xAs4D} from './batchnorm_util';
import {xAs5D} from './batchnorm_util';
import {op} from './operation';
import {reshape} from './reshape';

Expand Down Expand Up @@ -88,10 +88,10 @@ function batchNorm_<R extends Rank>(
() => 'Batch normalization gradient requires mean and scale to have ' +
'equal ranks.');

const x4D: Tensor4D = xAs4D($x);
const x5D: Tensor5D = xAs5D($x);

const inputs: FusedBatchNormInputs = {
x: x4D,
x: x5D,
scale: $scale,
offset: $offset,
mean: $mean,
Expand Down
79 changes: 79 additions & 0 deletions tfjs-core/src/ops/batchnorm5d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/**
* @license
* Copyright 2020 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import {Tensor1D, Tensor5D} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';
import * as util from '../util';

import {batchNorm} from './batchnorm';
import {op} from './operation';

/**
* Batch normalization, strictly for 5D. For the more relaxed version, see
* `tf.batchNorm`.
*
* @param x The input Tensor.
* @param mean A mean Tensor.
* @param variance A variance Tensor.
* @param offset An offset Tensor.
* @param scale A scale Tensor.
* @param varianceEpsilon A small float number to avoid dividing by 0.
*/
function batchNorm5d_(
x: Tensor5D|TensorLike, mean: Tensor5D|Tensor1D|TensorLike,
variance: Tensor5D|Tensor1D|TensorLike,
offset?: Tensor5D|Tensor1D|TensorLike, scale?: Tensor5D|Tensor1D|TensorLike,
varianceEpsilon?: number): Tensor5D {
const $x = convertToTensor(x, 'x', 'batchNorm');
const $mean = convertToTensor(mean, 'mean', 'batchNorm');
const $variance = convertToTensor(variance, 'variance', 'batchNorm');
let $scale: Tensor5D|Tensor1D;
if (scale != null) {
$scale = convertToTensor(scale, 'scale', 'batchNorm');
}
let $offset: Tensor5D|Tensor1D;
if (offset != null) {
$offset = convertToTensor(offset, 'offset', 'batchNorm');
}
util.assert(
$x.rank === 5,
() => `Error in batchNorm5D: x must be rank 5 but got rank ` +
`${$x.rank}.`);
util.assert(
$mean.rank === 5 || $mean.rank === 1,
() => `Error in batchNorm5D: mean must be rank 5 or rank 1 but ` +
`got rank ${$mean.rank}.`);
util.assert(
$variance.rank === 5 || $variance.rank === 1,
() => `Error in batchNorm5D: variance must be rank 5 or rank 1 ` +
`but got rank ${$variance.rank}.`);
if ($scale != null) {
util.assert(
$scale.rank === 5 || $scale.rank === 1,
() => `Error in batchNorm5D: scale must be rank 5 or rank 1 ` +
`but got rank ${$scale.rank}.`);
}
if ($offset != null) {
util.assert(
$offset.rank === 5 || $offset.rank === 1,
() => `Error in batchNorm5D: offset must be rank 5 or rank 1 ` +
`but got rank ${$offset.rank}.`);
}
return batchNorm($x, $mean, $variance, $offset, $scale, varianceEpsilon);
}

export const batchNorm5d = op({batchNorm5d_});
218 changes: 218 additions & 0 deletions tfjs-core/src/ops/batchnorm_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,224 @@ import * as tf from '../index';
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {expectArraysClose} from '../test_util';

describe('batchNorm5D', () => {
it('simple batchnorm5D, no offset or scale, 2x1x1x1x2', async () => {
const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]);
const meanT = tf.tensor1d([1, 2]);
const varianceT = tf.tensor1d([2, 3]);
const varianceEpsilon = .001;

const result = tf.batchNorm5d(
xT, meanT, varianceT, undefined, undefined, varianceEpsilon);

const x = await xT.array();
const mean = await meanT.array();
const variance = await varianceT.array();
expectArraysClose(await result.data(), [
(x[0][0][0][0][0] - mean[0]) * 1 /
Math.sqrt(variance[0] + varianceEpsilon),
(x[0][0][0][0][1] - mean[1]) * 1 /
Math.sqrt(variance[1] + varianceEpsilon),
(x[1][0][0][0][0] - mean[0]) * 1 /
Math.sqrt(variance[0] + varianceEpsilon),
(x[1][0][0][0][1] - mean[1]) * 1 /
Math.sqrt(variance[1] + varianceEpsilon)
]);
});

it('simple batchnorm5D, no offset, 2x1x1x1x2', async () => {
const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]);
const meanT = tf.tensor1d([1, 2]);
const varianceT = tf.tensor1d([2, 3]);
const scaleT = tf.tensor1d([4, 5]);
const varianceEpsilon = .001;

const result = tf.batchNorm5d(
xT, meanT, varianceT, undefined, scaleT, varianceEpsilon);
const x = await xT.buffer();
const mean = await meanT.buffer();
const variance = await varianceT.buffer();
const scale = await scaleT.buffer();

expectArraysClose(await result.data(), [
(x.get(0, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) /
Math.sqrt(variance.get(0) + varianceEpsilon),
(x.get(0, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) /
Math.sqrt(variance.get(1) + varianceEpsilon),
(x.get(1, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) /
Math.sqrt(variance.get(0) + varianceEpsilon),
(x.get(1, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) /
Math.sqrt(variance.get(1) + varianceEpsilon)
]);
});

it('simple batchnorm5D, no scale, 2x1x1x1x2', async () => {
const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]);
const meanT = tf.tensor1d([1, 2]);
const varianceT = tf.tensor1d([2, 3]);
const offsetT = tf.tensor1d([4, 5]);

const varianceEpsilon = .001;

const result = tf.batchNorm5d(
xT, meanT, varianceT, offsetT, undefined, varianceEpsilon);
const x = await xT.buffer();
const mean = await meanT.buffer();
const variance = await varianceT.buffer();
const offset = await offsetT.buffer();

expectArraysClose(await result.data(), [
offset.get(0) +
(x.get(0, 0, 0, 0, 0) - mean.get(0)) * 1 /
Math.sqrt(variance.get(0) + varianceEpsilon),
offset.get(1) +
(x.get(0, 0, 0, 0, 1) - mean.get(1)) * 1 /
Math.sqrt(variance.get(1) + varianceEpsilon),
offset.get(0) +
(x.get(1, 0, 0, 0, 0) - mean.get(0)) * 1 /
Math.sqrt(variance.get(0) + varianceEpsilon),
offset.get(1) +
(x.get(1, 0, 0, 0, 1) - mean.get(1)) * 1 /
Math.sqrt(variance.get(1) + varianceEpsilon)
]);
});

it('simple batchnorm5D, 2x1x1x1x2', async () => {
const xT = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]);
const meanT = tf.tensor1d([1, 2]);
const varianceT = tf.tensor1d([2, 3]);
const offsetT = tf.tensor1d([3, 4]);
const scaleT = tf.tensor1d([4, 5]);

const varianceEpsilon = .001;

const result =
tf.batchNorm5d(xT, meanT, varianceT, offsetT, scaleT, varianceEpsilon);
const x = await xT.buffer();
const mean = await meanT.buffer();
const variance = await varianceT.buffer();
const scale = await scaleT.buffer();
const offset = await offsetT.buffer();

expectArraysClose(await result.data(), [
offset.get(0) +
(x.get(0, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) /
Math.sqrt(variance.get(0) + varianceEpsilon),
offset.get(1) +
(x.get(0, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) /
Math.sqrt(variance.get(1) + varianceEpsilon),
offset.get(0) +
(x.get(1, 0, 0, 0, 0) - mean.get(0)) * scale.get(0) /
Math.sqrt(variance.get(0) + varianceEpsilon),
offset.get(1) +
(x.get(1, 0, 0, 0, 1) - mean.get(1)) * scale.get(1) /
Math.sqrt(variance.get(1) + varianceEpsilon)
]);
});

it('accepts a tensor-like object', async () => {
const x = [[[[[2, 4]]]], [[[[9, 23]]]]]; // 2x1x1x1x2
const mean = [1, 2];
const variance = [2, 3];
const offset = [3, 4];
const scale = [4, 5];

const varianceEpsilon = .001;

const result =
tf.batchNorm5d(x, mean, variance, offset, scale, varianceEpsilon);

expectArraysClose(await result.data(), [
offset[0] +
(x[0][0][0][0][0] - mean[0]) * scale[0] /
Math.sqrt(variance[0] + varianceEpsilon),
offset[1] +
(x[0][0][0][0][1] - mean[1]) * scale[1] /
Math.sqrt(variance[1] + varianceEpsilon),
offset[0] +
(x[1][0][0][0][0] - mean[0]) * scale[0] /
Math.sqrt(variance[0] + varianceEpsilon),
offset[1] +
(x[1][0][0][0][1] - mean[1]) * scale[1] /
Math.sqrt(variance[1] + varianceEpsilon)
]);
});

it('simple batchnorm5D gradients, 2x1x1x1x2', async () => {
const x = tf.tensor5d([2, 4, 9, 23], [2, 1, 1, 1, 2]);
const mean = tf.tensor1d([1, 2]);
const variance = tf.tensor1d([2, 3]);
const offset = tf.tensor1d([3, 4]);
const scale = tf.tensor1d([2, 5]);

const varianceEpsilon = .001;

const dy = tf.tensor5d([-1, -1, -1, -1], [2, 1, 1, 1, 2]);
const gradX = tf.grad(
(x: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(x, dy);
expectArraysClose(await gradX.data(), [-1.414, -2.887, -1.414, -2.887]);
expect(gradX.shape).toEqual([2, 1, 1, 1, 2]);
const gradMean = tf.grad(
(mean: tf.Tensor1D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(mean, dy);
expectArraysClose(await gradMean.data(), [2.828, 5.773]);
expect(gradMean.shape).toEqual([2]);
const gradVariance = tf.grad(
(variance: tf.Tensor1D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(variance, dy);
expectArraysClose(await gradVariance.data(), [3.180, 11.060]);
expect(gradVariance.shape).toEqual([2]);
const gradOffset = tf.grad(
(offset: tf.Tensor1D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(offset, dy);
expectArraysClose(await gradOffset.data(), await dy.sum([0, 1, 2]).data());
expect(gradOffset.shape).toEqual([2]);
const gradScale = tf.grad(
(scale: tf.Tensor1D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(scale, dy);
expectArraysClose(await gradScale.data(), [-6.362, -13.277]);
expect(gradScale.shape).toEqual([2]);
});

it('batchnorm5D gradients, same shapes in x, mean and variance', async () => {
const x = tf.tensor5d([10, 20, 30, 40], [2, 1, 1, 1, 2]);
const mean = tf.tensor5d([0, 5, 10, 15], [2, 1, 1, 1, 2]);
const variance = tf.tensor5d([2, 4, 6, 8], [2, 1, 1, 1, 2]);
const scale = tf.tensor5d([2, 5, 2, 5], [2, 1, 1, 1, 2]);
const offset = tf.tensor5d([0, 0, 0, 0], [2, 1, 1, 1, 2]);

const varianceEpsilon = .001;

const dy = tf.tensor5d([-1, -1, -1, -1], [2, 1, 1, 1, 2]);
const gradX = tf.grad(
(x: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(x, dy);
expectArraysClose(await gradX.data(), [-1.414, -2.500, -0.816, -1.768]);
expect(gradX.shape).toEqual([2, 1, 1, 1, 2]);
const gradMean = tf.grad(
(mean: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(mean, dy);
expectArraysClose(await gradMean.data(), [1.414, 2.500, 0.816, 1.768]);
expect(gradMean.shape).toEqual([2, 1, 1, 1, 2]);
const gradVariance = tf.grad(
(variance: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(variance, dy);
expectArraysClose(await gradVariance.data(), [3.533, 4.686, 1.360, 2.762]);
expect(gradVariance.shape).toEqual([2, 1, 1, 1, 2]);
const gradOffset = tf.grad(
(offset: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(offset, dy);
expectArraysClose(await gradOffset.data(), await dy.data());
expect(gradOffset.shape).toEqual([2, 1, 1, 1, 2]);
const gradScale = tf.grad(
(scale: tf.Tensor5D) => tf.batchNorm5d(
x, mean, variance, offset, scale, varianceEpsilon))(scale, dy);
expectArraysClose(await gradScale.data(), [-7.069, -7.499, -8.164, -8.838]);
expect(gradScale.shape).toEqual([2, 1, 1, 1, 2]);
});
});

describeWithFlags('batchNorm4D', ALL_ENVS, () => {
it('simple batchnorm4D, no offset or scale, 2x1x1x2', async () => {
const xT = tf.tensor4d([2, 4, 9, 23], [2, 1, 1, 2]);
Expand Down
18 changes: 10 additions & 8 deletions tfjs-core/src/ops/batchnorm_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,23 @@
* limitations under the License.
* =============================================================================
*/
import {Tensor, Tensor4D} from '../tensor';
import {Tensor, Tensor5D} from '../tensor';
import {Rank} from '../types';
import {reshape} from './reshape';

export function xAs4D<R extends Rank>(x: Tensor<R>) {
let x4D: Tensor4D;
export function xAs5D<R extends Rank>(x: Tensor<R>) {
let x5D: Tensor5D;
if (x.rank === 0 || x.rank === 1) {
x4D = reshape(x, [1, 1, 1, x.size]);
x5D = reshape(x, [1, 1, 1, 1, x.size]);
} else if (x.rank === 2) {
x4D = reshape(x, [1, 1, x.shape[0], x.shape[1]]);
x5D = reshape(x, [1, 1, 1, x.shape[0], x.shape[1]]);
} else if (x.rank === 3) {
x4D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2]]);
x5D = reshape(x, [1, 1, x.shape[0], x.shape[1], x.shape[2]]);
} else if (x.rank === 4) {
x5D = reshape(x, [1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]]);
} else {
x4D = x as Tensor4D;
x5D = x as Tensor5D;
}

return x4D;
return x5D;
}
1 change: 1 addition & 0 deletions tfjs-core/src/ops/ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export {batchNorm} from './batchnorm';
export {batchNorm2d} from './batchnorm2d';
export {batchNorm3d} from './batchnorm3d';
export {batchNorm4d} from './batchnorm4d';
export {batchNorm5d} from './batchnorm5d';
export {bincount} from './bincount';
export {bitwiseAnd} from './bitwise_and';
export {broadcastArgs} from './broadcast_args';
Expand Down