Skip to content

Commit 96cdec0

Browse files
authored
[jena-weather] Update example to tfjs 1.0 (#255)
- Update the usage of `tf.data.generator()` as per the breaking API change of tfjs 1.0. - Update the usage of `tfvis.render.linechart()` and `tfvis.render.scatterplot()` as per the breaking API changes of tfjs-vis 1.0. - Add tensorboard support for Node.js-based RNN traininga, along with the documentation of that in README.md. - Simplify the callback logic for `Model.fitDataset`. Fixes tensorflow/tfjs#1234
1 parent 4a1f0cf commit 96cdec0

File tree

7 files changed

+211
-318
lines changed

7 files changed

+211
-318
lines changed

jena-weather/README.md

+34-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ TensorFlow.js
2828
training-set and validation-set losses at the end of batches and epochs of
2929
model training.
3030

31-
## Training RNNs
31+
## Training RNNs in Node.js
3232

3333
This example shows how to predict temperature using a few different types of
3434
models, including linear regressors, multilayer perceptrons, and recurrent
@@ -63,4 +63,36 @@ yarn
6363
yarn train-rnn --modelType baseline
6464
```
6565

66-
The training code is in the file [train-rnn.js](./train-rnn.js).
66+
### Monitoring Node.js Training in TensorBoard
67+
68+
The Node.js-based training script allows you to log the loss values from the
69+
model to TensorBoard. Relative to printing loss values to the console, which
70+
the training script performs by default, logging to tensorboard has the
71+
following advantanges:
72+
73+
1. Persistence of the loss values, so you can have a copy of the training
74+
history available even if the system crashes in the middle of the training
75+
for some reason, while logs in consoles a more ephemeral.
76+
2. Visualizing the loss values as curves makes the trends easier to see.
77+
3. You will be able to monitor the training from a remote machine by accessing
78+
the TensorBoard HTTP server.
79+
80+
To do this in this example, add the flag --logDir to the yarn train command,
81+
followed by the directory to which you want the logs to be written, e.g.,
82+
83+
```sh
84+
yarn train-rnn --gpu --logDir /tmp/jena-weather-logs-1
85+
```
86+
87+
Then install tensorboard and start it by pointing it to the log directory:
88+
89+
```sh
90+
# Skip this step if you have already installed tensorboard.
91+
pip install tensorboard
92+
93+
tensorboard --logdir /tmp/jena-weather-logs-1
94+
```
95+
96+
tensorboard will print an HTTP URL in the terminal. Open your browser and
97+
navigate to the URL to view the loss curves in the Scalar dashboard of
98+
TensorBoard.

jena-weather/data.js

+57-57
Original file line numberDiff line numberDiff line change
@@ -275,69 +275,69 @@ export class JenaWeatherData {
275275
let startIndex = minIndex + lookBack;
276276
const lookBackSlices = Math.floor(lookBack / step);
277277

278-
function nextBatchFn() {
279-
const rowIndices = [];
280-
let done = false; // Indicates whether the dataset has ended.
281-
if (shuffle) {
282-
// If `shuffle` is `true`, start from randomly chosen rows.
283-
const range = maxIndex - (minIndex + lookBack);
284-
for (let i = 0; i < batchSize; ++i) {
285-
const row = minIndex + lookBack + Math.floor(Math.random() * range);
286-
rowIndices.push(row);
287-
}
288-
} else {
289-
// If `shuffle` is `false`, the starting row indices will be sequential.
290-
let r = startIndex;
291-
for (; r < startIndex + batchSize && r < maxIndex; ++r) {
292-
rowIndices.push(r);
293-
}
294-
if (r >= maxIndex) {
295-
done = true;
278+
return {
279+
next: () => {
280+
const rowIndices = [];
281+
let done = false; // Indicates whether the dataset has ended.
282+
if (shuffle) {
283+
// If `shuffle` is `true`, start from randomly chosen rows.
284+
const range = maxIndex - (minIndex + lookBack);
285+
for (let i = 0; i < batchSize; ++i) {
286+
const row = minIndex + lookBack + Math.floor(Math.random() * range);
287+
rowIndices.push(row);
288+
}
289+
} else {
290+
// If `shuffle` is `false`, the starting row indices will be sequential.
291+
let r = startIndex;
292+
for (; r < startIndex + batchSize && r < maxIndex; ++r) {
293+
rowIndices.push(r);
294+
}
295+
if (r >= maxIndex) {
296+
done = true;
297+
}
296298
}
297-
}
298299

299-
const numExamples = rowIndices.length;
300-
startIndex += numExamples;
300+
const numExamples = rowIndices.length;
301+
startIndex += numExamples;
301302

302-
const featureLength =
303-
includeDateTime ? this.numColumns + 2 : this.numColumns;
304-
const samples = tf.buffer([numExamples, lookBackSlices, featureLength]);
305-
const targets = tf.buffer([numExamples, 1]);
306-
// Iterate over examples. Each example contains a number of rows.
307-
for (let j = 0; j < numExamples; ++j) {
308-
const rowIndex = rowIndices[j];
309-
let exampleRow = 0;
310-
// Iterate over rows in the example.
311-
for (let r = rowIndex - lookBack; r < rowIndex; r += step) {
312-
let exampleCol = 0;
313-
// Iterate over features in the row.
314-
for (let n = 0; n < featureLength; ++n) {
315-
let value;
316-
if (n < this.numColumns) {
317-
value = normalize ? this.normalizedData[r][n] : this.data[r][n];
318-
} else if (n === this.numColumns) {
319-
// Normalized day-of-the-year feature.
320-
value = this.normalizedDayOfYear[r];
321-
} else {
322-
// Normalized time-of-the-day feature.
323-
value = this.normalizedTimeOfDay[r];
303+
const featureLength =
304+
includeDateTime ? this.numColumns + 2 : this.numColumns;
305+
const samples = tf.buffer([numExamples, lookBackSlices, featureLength]);
306+
const targets = tf.buffer([numExamples, 1]);
307+
// Iterate over examples. Each example contains a number of rows.
308+
for (let j = 0; j < numExamples; ++j) {
309+
const rowIndex = rowIndices[j];
310+
let exampleRow = 0;
311+
// Iterate over rows in the example.
312+
for (let r = rowIndex - lookBack; r < rowIndex; r += step) {
313+
let exampleCol = 0;
314+
// Iterate over features in the row.
315+
for (let n = 0; n < featureLength; ++n) {
316+
let value;
317+
if (n < this.numColumns) {
318+
value = normalize ? this.normalizedData[r][n] : this.data[r][n];
319+
} else if (n === this.numColumns) {
320+
// Normalized day-of-the-year feature.
321+
value = this.normalizedDayOfYear[r];
322+
} else {
323+
// Normalized time-of-the-day feature.
324+
value = this.normalizedTimeOfDay[r];
325+
}
326+
samples.set(value, j, exampleRow, exampleCol++);
324327
}
325-
samples.set(value, j, exampleRow, exampleCol++);
326-
}
327328

328-
const value = normalize ?
329-
this.normalizedData[r + delay][this.tempCol] :
330-
this.data[r + delay][this.tempCol];
331-
targets.set(value, j, 0);
332-
exampleRow++;
329+
const value = normalize ?
330+
this.normalizedData[r + delay][this.tempCol] :
331+
this.data[r + delay][this.tempCol];
332+
targets.set(value, j, 0);
333+
exampleRow++;
334+
}
333335
}
336+
return {
337+
value: {xs: samples.toTensor(), ys: targets.toTensor()},
338+
done
339+
};
334340
}
335-
return {
336-
value: [samples.toTensor(), targets.toTensor()],
337-
done
338-
};
339-
}
340-
341-
return nextBatchFn.bind(this);
341+
};
342342
}
343343
}

jena-weather/index.js

+4-6
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
/**
1919
* Weather Prediction Example.
20-
*
20+
*
2121
* - Visualizes data using tfjs-vis.
2222
* - Trains simple models (linear regressor and MLPs) and visualizes the
2323
* training processes.
@@ -102,7 +102,7 @@ function makeTimeSeriesChart(
102102
}
103103
// NOTE(cais): On a Linux workstation running latest Chrome, the length
104104
// limit seems to be around 120k.
105-
tfvis.render.linechart({values, series: series}, chartConatiner, {
105+
tfvis.render.linechart(chartConatiner, {values, series: series}, {
106106
width: chartConatiner.offsetWidth * 0.95,
107107
height: chartConatiner.offsetWidth * 0.3,
108108
xLabel: 'Time',
@@ -141,7 +141,7 @@ function makeTimeSeriesScatterPlot(series1, series2, timeSpan, normalize) {
141141
}
142142
const series = [`${seriesLabel1} - ${seriesLabel2}`];
143143

144-
tfvis.render.scatterplot({values, series}, dataChartContainer, {
144+
tfvis.render.scatterplot(dataChartContainer, {values, series}, {
145145
width: dataChartContainer.offsetWidth * 0.7,
146146
height: dataChartContainer.offsetWidth * 0.5,
147147
xLabel: seriesLabel1,
@@ -160,7 +160,6 @@ trainModelButton.addEventListener('click', async () => {
160160
const batchSize = 128;
161161
const normalize = true;
162162
const includeDateTime = includeDateTimeSelect.checked;
163-
164163
const modelType = modelTypeSelect.value;
165164

166165
console.log('Creating model...');
@@ -177,10 +176,9 @@ trainModelButton.addEventListener('click', async () => {
177176

178177
console.log('Starting model training...');
179178
const epochs = +epochsInput.value;
180-
const displayEvery = 100;
181179
await trainModel(
182180
model, jenaWeatherData, normalize, includeDateTime,
183-
lookBack, step, delay, batchSize, epochs, displayEvery,
181+
lookBack, step, delay, batchSize, epochs,
184182
tfvis.show.fitCallbacks(trainingSurface, ['loss', 'val_loss'], {
185183
callbacks: ['onBatchEnd', 'onEpochEnd']
186184
}));

jena-weather/models.js

+17-57
Original file line numberDiff line numberDiff line change
@@ -210,68 +210,28 @@ export function buildModel(modelType, numTimeSteps, numFeatures) {
210210
* for.
211211
* @param {number} batchSize batchSize for training.
212212
* @param {number} epochs Number of training epochs.
213-
* @param {number} displayEvery Log info to console every _ batches.
214-
* @param {number} customCallbacks Optional callback args to invoke at the
215-
* end of every epoch. Can optionally have `onBatchEnd` and `onEpochEnd`
216-
* fields.
213+
* @param {tf.Callback | tf.CustomCallbackArgs} customCallback Optional callback
214+
* to invoke at the end of every epoch. Can optionally have `onBatchEnd` and
215+
* `onEpochEnd` fields.
217216
*/
218217
export async function trainModel(
219218
model, jenaWeatherData, normalize, includeDateTime, lookBack, step, delay,
220-
batchSize, epochs, displayEvery = 100, customCallbacks) {
221-
const shuffle = true;
219+
batchSize, epochs, customCallback) {
220+
const trainShuffle = true;
221+
const trainDataset = tf.data.generator(
222+
() => jenaWeatherData.getNextBatchFunction(
223+
trainShuffle, lookBack, delay, batchSize, step, TRAIN_MIN_ROW,
224+
TRAIN_MAX_ROW, normalize, includeDateTime)).prefetch(8);
225+
const evalShuffle = false;
226+
const valDataset = tf.data.generator(
227+
() => jenaWeatherData.getNextBatchFunction(
228+
evalShuffle, lookBack, delay, batchSize, step, VAL_MIN_ROW,
229+
VAL_MAX_ROW, normalize, includeDateTime));
222230

223-
const trainNextBatchFn = jenaWeatherData.getNextBatchFunction(
224-
shuffle, lookBack, delay, batchSize, step, TRAIN_MIN_ROW, TRAIN_MAX_ROW,
225-
normalize, includeDateTime);
226-
const trainDataset = tf.data.generator(trainNextBatchFn).prefetch(8);
227-
228-
const batchesPerEpoch = 500;
229-
let t0;
230-
let currentEpoch;
231231
await model.fitDataset(trainDataset, {
232-
batchesPerEpoch,
232+
batchesPerEpoch: 500,
233233
epochs,
234-
callbacks: {
235-
onEpochBegin: async (epoch) => {
236-
currentEpoch = epoch;
237-
t0 = tf.util.now();
238-
},
239-
onBatchEnd: async (batch, logs) => {
240-
if ((batch + 1) % displayEvery === 0) {
241-
const t = tf.util.now();
242-
const millisPerBatch = (t - t0) / (batch + 1);
243-
console.log(
244-
`epoch ${currentEpoch + 1}/${epochs} ` +
245-
`batch ${batch + 1}/${batchesPerEpoch}: ` +
246-
`loss=${logs.loss.toFixed(6)} ` +
247-
`(${millisPerBatch.toFixed(1)} ms/batch)`);
248-
if (customCallbacks && customCallbacks.onBatchEnd) {
249-
customCallbacks.onBatchEnd(batch, logs);
250-
}
251-
}
252-
},
253-
onEpochEnd: async (epoch, logs) => {
254-
const valNextBatchFn = jenaWeatherData.getNextBatchFunction(
255-
false, lookBack, delay, batchSize, step, VAL_MIN_ROW, VAL_MAX_ROW,
256-
normalize, includeDateTime);
257-
const valDataset = tf.data.generator(valNextBatchFn);
258-
console.log(`epoch ${epoch + 1}/${epochs}: Performing validation...`);
259-
// TODO(cais): Remove the second arg (empty object), when the bug is
260-
// fixed:
261-
// https://github.com/tensorflow/tfjs/issues/1096
262-
const evalOut = await model.evaluateDataset(valDataset, {});
263-
logs.val_loss = (await evalOut.data())[0];
264-
tf.dispose(evalOut);
265-
console.log(
266-
`epoch ${epoch + 1}/${epochs}: ` +
267-
`trainLoss=${logs.loss.toFixed(6)}; ` +
268-
`valLoss=${logs.val_loss.toFixed(6)}`);
269-
if (customCallbacks && customCallbacks.onEpochEnd) {
270-
customCallbacks.onEpochEnd(epoch, logs);
271-
}
272-
}
273-
}
234+
callbacks: customCallback,
235+
validationData: valDataset
274236
});
275-
276-
return model;
277237
}

jena-weather/package.json

+5-5
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@
99
"node": ">=8.9.0"
1010
},
1111
"dependencies": {
12-
"@tensorflow/tfjs": "^0.15.3",
13-
"@tensorflow/tfjs-vis": "0.4.2"
12+
"@tensorflow/tfjs": "^1.0.2",
13+
"@tensorflow/tfjs-vis": "^1.0.3"
1414
},
1515
"scripts": {
1616
"watch": "cross-env NODE_ENV=development parcel index.html --no-hmr --open",
1717
"build": "cross-env NODE_ENV=production parcel build index.html --no-minify --public-url ./",
1818
"link-local": "yalc link",
19-
"train-rnn": "babel-node train-rnn.js"
19+
"train-rnn": "babel-node --max_old_space_size=4096 train-rnn.js"
2020
},
2121
"devDependencies": {
22-
"@tensorflow/tfjs-node": "0.3.1",
23-
"@tensorflow/tfjs-node-gpu": "0.3.1",
22+
"@tensorflow/tfjs-node": "^1.0.2",
23+
"@tensorflow/tfjs-node-gpu": "^1.0.2",
2424
"argparse": "^1.0.10",
2525
"babel-cli": "^6.26.0",
2626
"babel-core": "^6.26.3",

0 commit comments

Comments
 (0)