Skip to content

Commit 12c6823

Browse files
WenheLIdsmilkov
authored andcommitted
Add webworker example (#322)
This PR adds an example that trains an rnn network under webworker. In this way, the whole training process will not block UI.
1 parent c402eff commit 12c6823

File tree

8 files changed

+7503
-1
lines changed

8 files changed

+7503
-1
lines changed

addition-rnn-webworker/.babelrc

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
"plugins": [
16+
"transform-runtime"
17+
]
18+
}

addition-rnn-webworker/README.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# TensorFlow.js Example: Addition RNN in Webworker
2+
3+
This example uses an RNN to compute (in a worker thread) the addition of two integers by doing
4+
string => string translation. Obviously it's not the best way to add two
5+
numbers, but it makes a fun example. In this way, we can do long-running computation without blocking the UI thread.
6+
7+
Note: This example is based on the addition-rnn [example](https://github.com/tensorflow/tfjs-examples/tree/master/addition-rnn) in this repo, which is based on the original Keras python code [here](https://github.com/keras-team/keras/blob/master/examples/addition_rnn.py)
8+
9+
[See this example live!](https://storage.googleapis.com/tfjs-examples/addition-rnn/dist/index.html)

addition-rnn-webworker/index.html

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
<!--
2+
Copyright 2019 Google LLC. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
-->
17+
18+
<!doctype html>
19+
20+
<head>
21+
<meta charset="UTF-8">
22+
<meta name="viewport" content="width=device-width, initial-scale=1">
23+
<link rel="stylesheet" href="../shared/tfjs-examples.css" />
24+
</head>
25+
26+
<style>
27+
.setting {
28+
padding: 6px;
29+
}
30+
31+
#trainModel {
32+
margin-top: 12px;
33+
}
34+
35+
.setting-label {
36+
display: inline-block;
37+
width: 12em;
38+
}
39+
40+
.answer-correct {
41+
color: green;
42+
}
43+
44+
.answer-wrong {
45+
color: red;
46+
}
47+
</style>
48+
49+
<body>
50+
<div class='tfjs-example-container centered-container'>
51+
<section class='title-area'>
52+
<h1>TensorFlow.js: Addition RNN in web worker</h1>
53+
<p class='subtitle'>Train a model in web worker</p>
54+
</section>
55+
<section>
56+
<p class='section-head'>Description</p>
57+
<p>
58+
This example trains a <a href="https://en.wikipedia.org/wiki/Recurrent_neural_network">Recurrent Neural Network</a>
59+
in <a href="https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Using_web_workers">web worker</a>
60+
to do addition without explicitly defining the addition operator. Instead
61+
we feed it examples of sums and let it learn from that.
62+
</p>
63+
<p>
64+
In this way, we can do long-running computation without actually blocking the UI rendering thread.
65+
</p>
66+
<p>
67+
Given a <span class='in-type'>string</span> like
68+
<span class='in-example'>3 + 4</span>, it will learn to output
69+
a <span class='out-type'>number</span>
70+
like <span class=out-example>7</span>.
71+
</p>
72+
<p>
73+
This example generates its own training data programatically.
74+
</p>
75+
</section>
76+
77+
<div>
78+
<section>
79+
<p class='section-head'>Instructions</p>
80+
<p>
81+
Click the "Train Model" to start the model training button. You can edit the
82+
parameters used to train the model as well.
83+
</p>
84+
</section>
85+
86+
<div class="controls with-rows">
87+
<div class="settings">
88+
<div class="setting">
89+
<span class="setting-label">Digits:</span>
90+
<input id="digits" value="2"></input>
91+
</div>
92+
<div class="setting">
93+
<span class="setting-label">Training Size:</span>
94+
<input id="trainingSize" value="5000"></input>
95+
</div>
96+
<div class="setting">
97+
<span class="setting-label">RNN Type:</span>
98+
<select id="rnnType">
99+
<option value="SimpleRNN">SimpleRNN</option>
100+
<option value="GRU">GRU</option>
101+
<option value="LSTM">LSTM</option>
102+
</select>
103+
</div>
104+
<div class="setting">
105+
<span class="setting-label">RNN Layers:</span>
106+
<input id="rnnLayers" value="1"></input>
107+
</div>
108+
<div class="setting">
109+
<span class="setting-label">RNN Hidden Layer Size:</span>
110+
<input id="rnnLayerSize" value="128"></input>
111+
</div>
112+
<div class="setting">
113+
<span class="setting-label">Batch Size:</span>
114+
<input id="batchSize" value="128"></input>
115+
</div>
116+
<div class="setting">
117+
<span class="setting-label">Train Iterations:</span>
118+
<input id="trainIterations" value="100"></input>
119+
</div>
120+
<div class="setting">
121+
<span class="setting-label"># of test examples:</span>
122+
<input id="numTestExamples" value="20"></input>
123+
</div>
124+
</div>
125+
126+
<div>
127+
<span>
128+
<button class="btn-primary" id="trainModel">Train Model</button>
129+
</span>
130+
</div>
131+
</div>
132+
133+
134+
<section>
135+
<p class='section-head'>Training Progress</p>
136+
<p id="trainStatus"></p>
137+
<div class='with-cols'>
138+
<div id="lossChart"></div>
139+
<div id="accuracyChart"></div>
140+
</div>
141+
</section>
142+
143+
<section>
144+
<p class='section-head'>Test Examples</p>
145+
<p id="testExamples"></p>
146+
</section>
147+
148+
149+
150+
151+
</div>
152+
</div>
153+
154+
</body>
155+
<script src="./index.js">
156+
157+
</script>

addition-rnn-webworker/index.js

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Addition RNN example.
20+
*
21+
* Based on tfjs example:
22+
* https://github.com/tensorflow/tfjs-examples/tree/master/addition-rnn
23+
*/
24+
25+
import * as tfvis from '@tensorflow/tfjs-vis';
26+
const worker = new Worker('./worker.js');
27+
28+
async function runAdditionRNNDemo() {
29+
document.getElementById('trainModel').addEventListener('click', async () => {
30+
const digits = +(document.getElementById('digits')).value;
31+
const trainingSize = +(document.getElementById('trainingSize')).value;
32+
const rnnTypeSelect = document.getElementById('rnnType');
33+
const rnnType =
34+
rnnTypeSelect.options[rnnTypeSelect.selectedIndex].getAttribute(
35+
'value');
36+
const layers = +(document.getElementById('rnnLayers')).value;
37+
const hiddenSize = +(document.getElementById('rnnLayerSize')).value;
38+
const batchSize = +(document.getElementById('batchSize')).value;
39+
const trainIterations = +(document.getElementById('trainIterations')).value;
40+
const numTestExamples = +(document.getElementById('numTestExamples')).value;
41+
42+
// Do some checks on the user-specified parameters.
43+
const status = document.getElementById('trainStatus');
44+
if (digits < 1 || digits > 5) {
45+
status.textContent = 'digits must be >= 1 and <= 5';
46+
return;
47+
}
48+
const trainingSizeLimit = Math.pow(Math.pow(10, digits), 2);
49+
if (trainingSize > trainingSizeLimit) {
50+
status.textContent =
51+
`With digits = ${digits}, you cannot have more than ` +
52+
`${trainingSizeLimit} examples`;
53+
return;
54+
}
55+
worker.postMessage({ digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples });
56+
worker.addEventListener('message', (e) => {
57+
if (e.data.isPredict) {
58+
const { i, iterations, modelFitTime, lossValues, accuracyValues } = e.data;
59+
document.getElementById('trainStatus').textContent =
60+
`Iteration ${i + 1} of ${iterations}: ` +
61+
`Time per iteration: ${modelFitTime.toFixed(3)} (seconds)`;
62+
const lossContainer = document.getElementById('lossChart');
63+
tfvis.render.linechart(
64+
lossContainer, { values: lossValues, series: ['train', 'validation'] },
65+
{
66+
width: 420,
67+
height: 300,
68+
xLabel: 'epoch',
69+
yLabel: 'loss',
70+
});
71+
72+
const accuracyContainer = document.getElementById('accuracyChart');
73+
tfvis.render.linechart(
74+
accuracyContainer,
75+
{ values: accuracyValues, series: ['train', 'validation'] }, {
76+
width: 420,
77+
height: 300,
78+
xLabel: 'epoch',
79+
yLabel: 'accuracy',
80+
});
81+
} else {
82+
const { isCorrect, examples } = e.data;
83+
const examplesDiv = document.getElementById('testExamples');
84+
const examplesContent = examples.map(
85+
(example, i) =>
86+
`<div class="${
87+
isCorrect[i] ? 'answer-correct' : 'answer-wrong'}">` +
88+
`${example}` +
89+
`</div>`);
90+
91+
examplesDiv.innerHTML = examplesContent.join('\n');
92+
}
93+
});
94+
});
95+
}
96+
97+
runAdditionRNNDemo();

addition-rnn-webworker/package.json

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"name": "tfjs-examples-addition-rnn",
3+
"version": "0.1.0",
4+
"description": "",
5+
"main": "index.js",
6+
"license": "Apache-2.0",
7+
"private": true,
8+
"engines": {
9+
"node": ">=8.9.0"
10+
},
11+
"dependencies": {
12+
"@tensorflow/tfjs": "^1.2.6",
13+
"@tensorflow/tfjs-vis": "^1.1.0"
14+
},
15+
"scripts": {
16+
"watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open",
17+
"build": "cross-env NODE_ENV=production parcel build index.html --no-minify --public-url ./",
18+
"link-local": "yalc link"
19+
},
20+
"devDependencies": {
21+
"babel-core": "^6.26.3",
22+
"babel-plugin-transform-runtime": "~6.23.0",
23+
"babel-polyfill": "~6.26.0",
24+
"babel-preset-env": "~1.6.1",
25+
"clang-format": "~1.2.2",
26+
"cross-env": "^5.1.6",
27+
"parcel-bundler": "~1.10.3",
28+
"yalc": "~1.0.0-pre.22"
29+
}
30+
}

0 commit comments

Comments
 (0)