Skip to content

Commit

Permalink
Move minDepth and maxDepth to estimation config (tensorflow#1014)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedsabie authored May 10, 2022
1 parent 4cc7c20 commit cb34be9
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 61 deletions.
12 changes: 6 additions & 6 deletions depth-estimation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ For example:

```javascript
const model = depthEstimation.SupportedModels.ARPortraitDepth;
const estimatorConfig = {
minDepth: 0,
maxDepth: 1,
}
const estimator = await depthEstimation.createEstimator(model, estimatorConfig);
const estimator = await depthEstimation.createEstimator(model);
```

Next, you can use the estimator to estimate depth.

```javascript
const depthMap = await estimator.estimateDepth(image);
const estimationConfig = {
minDepth: 0,
maxDepth: 1,
}
const depthMap = await estimator.estimateDepth(image, estimationConfig);
```

The returned depth map contains depth values for each pixel in the image.
Expand Down
2 changes: 1 addition & 1 deletion depth-estimation/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@tensorflow-models/depth-estimation",
"version": "0.0.1",
"version": "0.0.2",
"description": "Pretrained depth model",
"main": "dist/index.js",
"jsnext:main": "dist/depth-estimation.esm.js",
Expand Down
12 changes: 6 additions & 6 deletions depth-estimation/src/ar_portrait_depth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,6 @@ Pass in `depthEstimation.SupportedModels.ARPortraitDepth` from the

`estimatorConfig` is an object that defines ARPortraitDepth specific configurations for `ARPortraitDepthModelConfig`:

* *minDepth*: The minimum depth value for the model to map to 0. Any smaller
depth values will also get mapped to 0.

* *maxDepth*: The maximum depth value for the model to map to 1. Any larger
depth values will also get mapped to 1.

* *segmentationModelUrl*: An optional string that specifies custom url of
the segmenter model. This is useful for area/countries that don't have access to the model hosted on tf.hub. It also accepts `io.IOHandler` which can be used with
[tfjs-react-native](https://github.com/tensorflow/tfjs/tree/master/tfjs-react-native)
Expand Down Expand Up @@ -97,6 +91,12 @@ options, you can pass in a second `estimationConfig` parameter.

`estimationConfig` is an object that defines ARPortraitDepth specific configurations for `ARPortraitDepthEstimationConfig`:

* *minDepth*: The minimum depth value for the model to map to 0. Any smaller
depth values will also get mapped to 0.

* *maxDepth*: The maximum depth value for the model to map to 1. Any larger
depth values will also get mapped to 1.

* *flipHorizontal*: Optional. Defaults to false. When image data comes from camera, the result has to flip horizontally.

The following code snippet demonstrates how to run the model inference:
Expand Down
26 changes: 15 additions & 11 deletions depth-estimation/src/ar_portrait_depth/ar_portrait_depth_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,13 @@ describeWithFlags('ARPortraitDepth', ALL_ENVS, () => {

// Note: this makes a network request for model assets.
const estimator = await depthEstimation.createEstimator(
depthEstimation.SupportedModels.ARPortraitDepth,
{minDepth: 0, maxDepth: 1});
depthEstimation.SupportedModels.ARPortraitDepth);
const input: tf.Tensor3D = tf.zeros([128, 128, 3]);

const beforeTensors = tf.memory().numTensors;

const depthMap = await estimator.estimateDepth(input);
const depthMap =
await estimator.estimateDepth(input, {minDepth: 0, maxDepth: 1});

(await depthMap.toTensor()).dispose();
expect(tf.memory().numTensors).toEqual(beforeTensors);
Expand All @@ -133,21 +133,25 @@ describeWithFlags('ARPortraitDepth', ALL_ENVS, () => {

it('throws error when minDepth is not set.', async (done) => {
try {
await depthEstimation.createEstimator(
const estimator = await depthEstimation.createEstimator(
depthEstimation.SupportedModels.ARPortraitDepth);
const input: tf.Tensor3D = tf.zeros([128, 128, 3]);
await estimator.estimateDepth(input);
done.fail('Loading without minDepth succeeded unexpectedly.');
} catch (e) {
expect(e.message).toEqual(
'A model config with minDepth and maxDepth set must be provided.');
'An estimation config with ' +
'minDepth and maxDepth set must be provided.');
done();
}
});

it('throws error when minDepth is greater than maxDepth.', async (done) => {
try {
await depthEstimation.createEstimator(
depthEstimation.SupportedModels.ARPortraitDepth,
{minDepth: 1, maxDepth: 0.99});
const estimator = await depthEstimation.createEstimator(
depthEstimation.SupportedModels.ARPortraitDepth);
const input: tf.Tensor3D = tf.zeros([128, 128, 3]);
await estimator.estimateDepth(input, {minDepth: 1, maxDepth: 0.99});
done.fail(
'Loading with minDepth greater than maxDepth ' +
'succeeded unexpectedly.');
Expand Down Expand Up @@ -187,12 +191,12 @@ describeWithFlags('ARPortraitDepth static image ', BROWSER_ENVS, () => {
// Get actual depth values.
// Note: this makes a network request for model assets.
estimator = await depthEstimation.createEstimator(
depthEstimation.SupportedModels.ARPortraitDepth,
{minDepth: 0.2, maxDepth: 0.9});
depthEstimation.SupportedModels.ARPortraitDepth);

const beforeTensors = tf.memory().numTensors;

const result = await estimator.estimateDepth(image);
const result =
await estimator.estimateDepth(image, {minDepth: 0.2, maxDepth: 0.9});
const actualDepthValues = await result.toTensor();
const coloredDepthValues =
actualDepthValues.arraySync().flat().map(value => turboPlus(value));
Expand Down
12 changes: 8 additions & 4 deletions depth-estimation/src/ar_portrait_depth/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
* limitations under the License.
* =============================================================================
*/
import {ARPortraitDepthEstimationConfig} from './types';
import {ARPortraitDepthModelConfig} from './types';

export const DEFAULT_AR_PORTRAIT_DEPTH_MODEL_URL =
'https://tfhub.dev/tensorflow/tfjs-model/ar_portrait_depth/1';

export const DEFAULT_AR_PORTRAIT_DEPTH_ESTIMATION_CONFIG:
ARPortraitDepthEstimationConfig = {
flipHorizontal: false,
export const DEFAULT_AR_PORTRAIT_DEPTH_MODEL_CONFIG:
ARPortraitDepthModelConfig = {
depthModelUrl: DEFAULT_AR_PORTRAIT_DEPTH_MODEL_URL,
};

export const DEFAULT_AR_PORTRAIT_DEPTH_ESTIMATION_CONFIG = {
flipHorizontal: false,
};
16 changes: 10 additions & 6 deletions depth-estimation/src/ar_portrait_depth/estimator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ const PORTRAIT_WIDTH = 192;
class ARPortraitDepthEstimator implements DepthEstimator {
constructor(
private readonly segmenter: bodySegmentation.BodySegmenter,
private readonly estimatorModel: tfconv.GraphModel,
private readonly minDepth: number, private readonly maxDepth: number) {}
private readonly estimatorModel: tfconv.GraphModel) {}

/**
* Estimates depth for an image or video frame.
Expand All @@ -69,8 +68,14 @@ class ARPortraitDepthEstimator implements DepthEstimator {
* image to feed through the network.
*
* @param config Optional.
* minDepth: The minimum depth value for the model to map to 0. Any
* smaller depth values will also get mapped to 0.
*
* maxDepth`: The maximum depth value for the model to map to 1. Any
* larger depth values will also get mapped to 1.
*
* flipHorizontal: Optional. Default to false. When image data comes
* from camera, the result has to flip horizontally.
* from camera, the result has to flip horizontally.
*
* @return `DepthMap`.
*/
Expand Down Expand Up @@ -128,7 +133,7 @@ class ARPortraitDepthEstimator implements DepthEstimator {

// Normalize to user requirements.
const depthTransform =
transformValueRange(this.minDepth, this.maxDepth, 0, 1);
transformValueRange(config.minDepth, config.maxDepth, 0, 1);

// depth4D is roughly in [0,2] range, so half the scale factor to put it
// in [0,1] range.
Expand Down Expand Up @@ -188,6 +193,5 @@ export async function load(modelConfig: ARPortraitDepthModelConfig):
bodySegmentation.SupportedModels.MediaPipeSelfieSegmentation,
{runtime: 'tfjs', modelUrl: config.segmentationModelUrl});

return new ARPortraitDepthEstimator(
segmenter, depthModel, config.minDepth, config.maxDepth);
return new ARPortraitDepthEstimator(segmenter, depthModel);
}
25 changes: 13 additions & 12 deletions depth-estimation/src/ar_portrait_depth/estimator_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,35 @@
* =============================================================================
*/

import {DEFAULT_AR_PORTRAIT_DEPTH_ESTIMATION_CONFIG, DEFAULT_AR_PORTRAIT_DEPTH_MODEL_URL} from './constants';
import {DEFAULT_AR_PORTRAIT_DEPTH_ESTIMATION_CONFIG, DEFAULT_AR_PORTRAIT_DEPTH_MODEL_CONFIG} from './constants';
import {ARPortraitDepthEstimationConfig, ARPortraitDepthModelConfig} from './types';

export function validateModelConfig(modelConfig: ARPortraitDepthModelConfig):
ARPortraitDepthModelConfig {
if (modelConfig == null || modelConfig.minDepth == null ||
modelConfig.maxDepth == null) {
throw new Error(
`A model config with minDepth and maxDepth set must be provided.`);
}

if (modelConfig.minDepth > modelConfig.maxDepth) {
throw new Error('minDepth must be <= maxDepth.');
if (modelConfig == null) {
return {...DEFAULT_AR_PORTRAIT_DEPTH_MODEL_CONFIG};
}

const config = {...modelConfig};

if (config.depthModelUrl == null) {
config.depthModelUrl = DEFAULT_AR_PORTRAIT_DEPTH_MODEL_URL;
config.depthModelUrl = DEFAULT_AR_PORTRAIT_DEPTH_MODEL_CONFIG.depthModelUrl;
}

return config;
}

export function validateEstimationConfig(
estimationConfig: ARPortraitDepthEstimationConfig) {
if (estimationConfig == null) {
return {...DEFAULT_AR_PORTRAIT_DEPTH_ESTIMATION_CONFIG};
if (estimationConfig == null || estimationConfig.minDepth == null ||
estimationConfig.maxDepth == null) {
throw new Error(
'An estimation config with ' +
'minDepth and maxDepth set must be provided.');
}

if (estimationConfig.minDepth > estimationConfig.maxDepth) {
throw new Error('minDepth must be <= maxDepth.');
}

const config = {...estimationConfig};
Expand Down
6 changes: 0 additions & 6 deletions depth-estimation/src/ar_portrait_depth/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ import {EstimationConfig, ModelConfig} from '../types';
/**
* Model parameters for ARPortraitDepth.
*
* `minDepth`: The minimum depth value for the model to map to 0. Any smaller
* depth values will also get mapped to 0.
*
* `maxDepth`: The maximum depth value for the model to map to 1. Any larger
* depth values will also get mapped to 1.
*
* `segmentationModelUrl`: Optional. An optional string that specifies custom
* url of the selfie segmentation model. This is useful for area/countries that
* don't have access to the model hosted on tf.hub.
Expand Down
15 changes: 7 additions & 8 deletions depth-estimation/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,25 @@ export enum SupportedModels {

/**
* Common config to create the depth estimator.
*/
export interface ModelConfig {}

/**
* Common config for the `estimateDepth` method.
*
* `minDepth`: The minimum depth value for the model to map to 0. Any smaller
* depth values will also get mapped to 0.
*
* `maxDepth`: The maximum depth value for the model to map to 1. Any larger
* depth values will also get mapped to 1.
*/
export interface ModelConfig {
minDepth: number;
maxDepth: number;
}

/**
* Common config for the `estimateDepth` method.
*
* `flipHorizontal`: Optional. Default to false. In some cases, the image is
* mirrored, e.g., video stream from camera, flipHorizontal will flip the
* keypoints horizontally.
*/
export interface EstimationConfig {
minDepth: number;
maxDepth: number;
flipHorizontal?: boolean;
}

Expand Down
2 changes: 1 addition & 1 deletion depth-estimation/src/version.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/** @license See the LICENSE file. */

// This code is auto-generated, do not modify this file!
const version = '0.0.1';
const version = '0.0.2';
export {version};

0 comments on commit cb34be9

Please sign in to comment.