Skip to content

Commit fea6cfb

Browse files
manrajgrovercaisq
authored andcommitted
Neural Network Regression Example (#111)
* Neural Network Regression: Base neural network regression completed
1 parent 48b0c58 commit fea6cfb

File tree

3 files changed

+92
-26
lines changed

3 files changed

+92
-26
lines changed

boston-housing/index.html

+21-4
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,32 @@
2323
textarea {
2424
border: none;
2525
resize: none;
26+
font-size: 20px;
27+
text-align: center;
28+
}
29+
.container {
30+
width: 80%;
31+
margin-left: auto;
32+
margin-right: auto;
33+
text-align: center;
34+
}
35+
#plot {
36+
width: auto;
37+
height: auto;
2638
}
2739
</style>
2840
</head>
2941

3042
<body>
31-
<h3>TensorFlow.js: Train Multivariate Linear Regression with the Core API</h3>
32-
33-
<div id="plot"></div>
34-
<textarea rows="3" cols="60" readonly="true" id="status">Loading data...</textarea>
43+
<div class="container">
44+
<h3>TensorFlow.js: Train Multivariate Regression with the Layers API</h3>
45+
<div id="plot"></div>
46+
<textarea rows="3" cols="60" readonly="true" id="status">Loading data...</textarea>
47+
<div id="buttons">
48+
<button id="simple-mlr">Train Linear Regressor</button>
49+
<button id="nn-mlr">Train Neural Network Regressor</button>
50+
</div>
51+
</div>
3552
<script src="index.js"></script>
3653
</body>
3754

boston-housing/index.js

+49-11
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,58 @@ import * as ui from './ui';
2222

2323
// Some hyperparameters for model training.
2424
const NUM_EPOCHS = 250;
25-
const BATCH_SIZE = 50;
25+
const BATCH_SIZE = 40;
2626
const LEARNING_RATE = 0.01;
2727

2828
const data = new BostonHousingDataset();
29-
data.loadData().then(async () => {
29+
30+
/**
31+
* Builds and returns Linear Regression Model.
32+
*
33+
* @returns {tf.Sequential} The linear regression model.
34+
*/
35+
export const linearRegressionModel = () => {
36+
const model = tf.sequential();
37+
model.add(tf.layers.dense({inputShape: [data.numFeatures], units: 1}));
38+
39+
return model;
40+
};
41+
42+
/**
43+
* Builds and returns Multi Layer Perceptron Regression Model
44+
* with 2 hidden layers, each with 10 units activated by sigmoid.
45+
*
46+
* @returns {tf.Sequential} The multi layer perceptron regression model.
47+
*/
48+
export const multiLayerPerceptronRegressionModel = () => {
49+
const model = tf.sequential();
50+
model.add(tf.layers.dense(
51+
{inputShape: [data.numFeatures], units: 50, activation: 'sigmoid'}));
52+
model.add(tf.layers.dense({units: 50, activation: 'sigmoid'}));
53+
model.add(tf.layers.dense({units: 1}));
54+
55+
return model;
56+
};
57+
58+
/**
59+
* Fetches training and testing data, compiles `model`, trains the model
60+
* using train data and runs model against test data.
61+
*
62+
* @param {tf.Sequential} model Model to be trained.
63+
*/
64+
export const run = async (model) => {
3065
await ui.updateStatus('Getting training and testing data...');
3166
const trainData = data.getTrainData();
3267
const testData = data.getTestData();
3368

34-
await ui.updateStatus('Building model...');
35-
const model = tf.sequential();
36-
model.add(tf.layers.dense({inputShape: [data.numFeatures], units: 1}));
37-
model.compile({
38-
optimizer: tf.train.sgd(LEARNING_RATE),
39-
loss: 'meanSquaredError'
40-
});
69+
await ui.updateStatus('Compiling model...');
70+
71+
model.compile(
72+
{optimizer: tf.train.sgd(LEARNING_RATE), loss: 'meanSquaredError'});
4173

4274
let trainLoss;
4375
let valLoss;
44-
await ui.updateStatus('Training starting...');
76+
await ui.updateStatus('Starting training process...');
4577
await model.fit(trainData.data, trainData.target, {
4678
batchSize: BATCH_SIZE,
4779
epochs: NUM_EPOCHS,
@@ -69,4 +101,10 @@ data.loadData().then(async () => {
69101
`Final train-set loss: ${trainLoss.toFixed(4)}\n` +
70102
`Final validation-set loss: ${valLoss.toFixed(4)}\n` +
71103
`Test-set loss: ${testLoss.toFixed(4)}`);
72-
});
104+
};
105+
106+
document.addEventListener('DOMContentLoaded', async () => {
107+
await data.loadData();
108+
await ui.updateStatus('Data loaded!');
109+
await ui.setup();
110+
}, false);

boston-housing/ui.js

+22-11
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,34 @@
1717

1818
import renderChart from 'vega-embed';
1919

20+
import {linearRegressionModel, multiLayerPerceptronRegressionModel, run} from './index';
21+
2022
const statusElement = document.getElementById('status');
2123
export const updateStatus = (message) => {
2224
statusElement.value = message;
2325
};
2426

25-
const losses = [];
27+
export const setup = async () => {
28+
const trainSimpleLinearRegression = document.getElementById('simple-mlr');
29+
const trainNeuralNetworkLinearRegression = document.getElementById('nn-mlr');
30+
31+
trainSimpleLinearRegression.addEventListener('click', async (e) => {
32+
const model = linearRegressionModel();
33+
losses = [];
34+
await run(model);
35+
}, false);
36+
37+
trainNeuralNetworkLinearRegression.addEventListener('click', async (e) => {
38+
const model = multiLayerPerceptronRegressionModel();
39+
losses = [];
40+
await run(model);
41+
}, false);
42+
};
43+
44+
let losses = [];
2645
export const plotData = async (epoch, trainLoss, valLoss) => {
27-
losses.push({
28-
'epoch': epoch,
29-
'loss': trainLoss,
30-
'split': 'Train Loss'
31-
});
32-
losses.push({
33-
'epoch': epoch,
34-
'loss': valLoss,
35-
'split': 'Validation Loss'
36-
});
46+
losses.push({'epoch': epoch, 'loss': trainLoss, 'split': 'Train Loss'});
47+
losses.push({'epoch': epoch, 'loss': valLoss, 'split': 'Validation Loss'});
3748

3849
const spec = {
3950
'$schema': 'https://vega.github.io/schema/vega-lite/v2.json',

0 commit comments

Comments
 (0)