Skip to content

Commit 66d3aaa

Browse files
authored
Forks iris example into with and without tfjs-data versions (#203)
* Splits iris example into two versions, one making use of tfjs-data `model.fitDataset` * tfjs-data version could be simpler, pending a bug resolution * Adds comment about simplification depending on issue #1071 data disposal issue
1 parent 3c6557b commit 66d3aaa

29 files changed

+15085
-83
lines changed

iris-fitDataset/.babelrc

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
"plugins": [
16+
"transform-runtime"
17+
]
18+
}

iris-fitDataset/README.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# TensorFlow.js Example: Iris Classification using tfjs-data APIs
2+
3+
This example is identical to the
4+
[iris](https://github.com/tensorflow/tfjs-examples/tree/master/iris) example,
5+
except that it uses the [tfjs-data](https://github.com/tensorflow/tfjs-data)
6+
APIs.
7+
8+
Please see the
9+
[iris](https://github.com/tensorflow/tfjs-examples/tree/master/iris) example for
10+
details and how to run.

iris-fitDataset/build-resources.sh

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env bash
2+
3+
# Copyright 2018 Google LLC. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# =============================================================================
17+
18+
# Builds resources for the Iris demo.
19+
# Note this is not necessary to run the demo, because we already provide hosted
20+
# pre-built resources.
21+
# Usage example: do this from the 'iris' directory:
22+
# ./build-resources.sh
23+
24+
set -e
25+
26+
DEMO_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
27+
28+
TRAIN_EPOCHS=100
29+
while true; do
30+
if [[ "$1" == "--epochs" ]]; then
31+
TRAIN_EPOCHS=$2
32+
shift 2
33+
elif [[ -z "$1" ]]; then
34+
break
35+
else
36+
echo "ERROR: Unrecognized argument: $1"
37+
exit 1
38+
fi
39+
done
40+
41+
RESOURCES_ROOT="${DEMO_DIR}/dist/resources"
42+
rm -rf "${RESOURCES_ROOT}"
43+
mkdir -p "${RESOURCES_ROOT}"
44+
45+
# Run Python script to generate the pretrained model and weights files.
46+
# Make sure you install the tensorflowjs pip package first.
47+
48+
python "${DEMO_DIR}/python/iris.py" \
49+
--epochs "${TRAIN_EPOCHS}" \
50+
--artifacts_dir "${RESOURCES_ROOT}"
51+
52+
cd ${DEMO_DIR}
53+
yarn
54+
yarn build
55+
56+
echo
57+
echo "-----------------------------------------------------------"
58+
echo "Resources written to ${RESOURCES_ROOT}."
59+
echo "You can now run the demo with 'yarn watch'."
60+
echo "-----------------------------------------------------------"
61+
echo

iris-fitDataset/data.js

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
20+
export const IRIS_CLASSES =
21+
['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'];
22+
export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;
23+
export const IRIS_NUM_FEATURES = 4;
24+
25+
// Iris flowers data. Source:
26+
// https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
27+
export const IRIS_RAW_DATA = [
28+
[5.1, 3.5, 1.4, 0.2, 0], [4.9, 3.0, 1.4, 0.2, 0], [4.7, 3.2, 1.3, 0.2, 0],
29+
[4.6, 3.1, 1.5, 0.2, 0], [5.0, 3.6, 1.4, 0.2, 0], [5.4, 3.9, 1.7, 0.4, 0],
30+
[4.6, 3.4, 1.4, 0.3, 0], [5.0, 3.4, 1.5, 0.2, 0], [4.4, 2.9, 1.4, 0.2, 0],
31+
[4.9, 3.1, 1.5, 0.1, 0], [5.4, 3.7, 1.5, 0.2, 0], [4.8, 3.4, 1.6, 0.2, 0],
32+
[4.8, 3.0, 1.4, 0.1, 0], [4.3, 3.0, 1.1, 0.1, 0], [5.8, 4.0, 1.2, 0.2, 0],
33+
[5.7, 4.4, 1.5, 0.4, 0], [5.4, 3.9, 1.3, 0.4, 0], [5.1, 3.5, 1.4, 0.3, 0],
34+
[5.7, 3.8, 1.7, 0.3, 0], [5.1, 3.8, 1.5, 0.3, 0], [5.4, 3.4, 1.7, 0.2, 0],
35+
[5.1, 3.7, 1.5, 0.4, 0], [4.6, 3.6, 1.0, 0.2, 0], [5.1, 3.3, 1.7, 0.5, 0],
36+
[4.8, 3.4, 1.9, 0.2, 0], [5.0, 3.0, 1.6, 0.2, 0], [5.0, 3.4, 1.6, 0.4, 0],
37+
[5.2, 3.5, 1.5, 0.2, 0], [5.2, 3.4, 1.4, 0.2, 0], [4.7, 3.2, 1.6, 0.2, 0],
38+
[4.8, 3.1, 1.6, 0.2, 0], [5.4, 3.4, 1.5, 0.4, 0], [5.2, 4.1, 1.5, 0.1, 0],
39+
[5.5, 4.2, 1.4, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [5.0, 3.2, 1.2, 0.2, 0],
40+
[5.5, 3.5, 1.3, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [4.4, 3.0, 1.3, 0.2, 0],
41+
[5.1, 3.4, 1.5, 0.2, 0], [5.0, 3.5, 1.3, 0.3, 0], [4.5, 2.3, 1.3, 0.3, 0],
42+
[4.4, 3.2, 1.3, 0.2, 0], [5.0, 3.5, 1.6, 0.6, 0], [5.1, 3.8, 1.9, 0.4, 0],
43+
[4.8, 3.0, 1.4, 0.3, 0], [5.1, 3.8, 1.6, 0.2, 0], [4.6, 3.2, 1.4, 0.2, 0],
44+
[5.3, 3.7, 1.5, 0.2, 0], [5.0, 3.3, 1.4, 0.2, 0], [7.0, 3.2, 4.7, 1.4, 1],
45+
[6.4, 3.2, 4.5, 1.5, 1], [6.9, 3.1, 4.9, 1.5, 1], [5.5, 2.3, 4.0, 1.3, 1],
46+
[6.5, 2.8, 4.6, 1.5, 1], [5.7, 2.8, 4.5, 1.3, 1], [6.3, 3.3, 4.7, 1.6, 1],
47+
[4.9, 2.4, 3.3, 1.0, 1], [6.6, 2.9, 4.6, 1.3, 1], [5.2, 2.7, 3.9, 1.4, 1],
48+
[5.0, 2.0, 3.5, 1.0, 1], [5.9, 3.0, 4.2, 1.5, 1], [6.0, 2.2, 4.0, 1.0, 1],
49+
[6.1, 2.9, 4.7, 1.4, 1], [5.6, 2.9, 3.6, 1.3, 1], [6.7, 3.1, 4.4, 1.4, 1],
50+
[5.6, 3.0, 4.5, 1.5, 1], [5.8, 2.7, 4.1, 1.0, 1], [6.2, 2.2, 4.5, 1.5, 1],
51+
[5.6, 2.5, 3.9, 1.1, 1], [5.9, 3.2, 4.8, 1.8, 1], [6.1, 2.8, 4.0, 1.3, 1],
52+
[6.3, 2.5, 4.9, 1.5, 1], [6.1, 2.8, 4.7, 1.2, 1], [6.4, 2.9, 4.3, 1.3, 1],
53+
[6.6, 3.0, 4.4, 1.4, 1], [6.8, 2.8, 4.8, 1.4, 1], [6.7, 3.0, 5.0, 1.7, 1],
54+
[6.0, 2.9, 4.5, 1.5, 1], [5.7, 2.6, 3.5, 1.0, 1], [5.5, 2.4, 3.8, 1.1, 1],
55+
[5.5, 2.4, 3.7, 1.0, 1], [5.8, 2.7, 3.9, 1.2, 1], [6.0, 2.7, 5.1, 1.6, 1],
56+
[5.4, 3.0, 4.5, 1.5, 1], [6.0, 3.4, 4.5, 1.6, 1], [6.7, 3.1, 4.7, 1.5, 1],
57+
[6.3, 2.3, 4.4, 1.3, 1], [5.6, 3.0, 4.1, 1.3, 1], [5.5, 2.5, 4.0, 1.3, 1],
58+
[5.5, 2.6, 4.4, 1.2, 1], [6.1, 3.0, 4.6, 1.4, 1], [5.8, 2.6, 4.0, 1.2, 1],
59+
[5.0, 2.3, 3.3, 1.0, 1], [5.6, 2.7, 4.2, 1.3, 1], [5.7, 3.0, 4.2, 1.2, 1],
60+
[5.7, 2.9, 4.2, 1.3, 1], [6.2, 2.9, 4.3, 1.3, 1], [5.1, 2.5, 3.0, 1.1, 1],
61+
[5.7, 2.8, 4.1, 1.3, 1], [6.3, 3.3, 6.0, 2.5, 2], [5.8, 2.7, 5.1, 1.9, 2],
62+
[7.1, 3.0, 5.9, 2.1, 2], [6.3, 2.9, 5.6, 1.8, 2], [6.5, 3.0, 5.8, 2.2, 2],
63+
[7.6, 3.0, 6.6, 2.1, 2], [4.9, 2.5, 4.5, 1.7, 2], [7.3, 2.9, 6.3, 1.8, 2],
64+
[6.7, 2.5, 5.8, 1.8, 2], [7.2, 3.6, 6.1, 2.5, 2], [6.5, 3.2, 5.1, 2.0, 2],
65+
[6.4, 2.7, 5.3, 1.9, 2], [6.8, 3.0, 5.5, 2.1, 2], [5.7, 2.5, 5.0, 2.0, 2],
66+
[5.8, 2.8, 5.1, 2.4, 2], [6.4, 3.2, 5.3, 2.3, 2], [6.5, 3.0, 5.5, 1.8, 2],
67+
[7.7, 3.8, 6.7, 2.2, 2], [7.7, 2.6, 6.9, 2.3, 2], [6.0, 2.2, 5.0, 1.5, 2],
68+
[6.9, 3.2, 5.7, 2.3, 2], [5.6, 2.8, 4.9, 2.0, 2], [7.7, 2.8, 6.7, 2.0, 2],
69+
[6.3, 2.7, 4.9, 1.8, 2], [6.7, 3.3, 5.7, 2.1, 2], [7.2, 3.2, 6.0, 1.8, 2],
70+
[6.2, 2.8, 4.8, 1.8, 2], [6.1, 3.0, 4.9, 1.8, 2], [6.4, 2.8, 5.6, 2.1, 2],
71+
[7.2, 3.0, 5.8, 1.6, 2], [7.4, 2.8, 6.1, 1.9, 2], [7.9, 3.8, 6.4, 2.0, 2],
72+
[6.4, 2.8, 5.6, 2.2, 2], [6.3, 2.8, 5.1, 1.5, 2], [6.1, 2.6, 5.6, 1.4, 2],
73+
[7.7, 3.0, 6.1, 2.3, 2], [6.3, 3.4, 5.6, 2.4, 2], [6.4, 3.1, 5.5, 1.8, 2],
74+
[6.0, 3.0, 4.8, 1.8, 2], [6.9, 3.1, 5.4, 2.1, 2], [6.7, 3.1, 5.6, 2.4, 2],
75+
[6.9, 3.1, 5.1, 2.3, 2], [5.8, 2.7, 5.1, 1.9, 2], [6.8, 3.2, 5.9, 2.3, 2],
76+
[6.7, 3.3, 5.7, 2.5, 2], [6.7, 3.0, 5.2, 2.3, 2], [6.3, 2.5, 5.0, 1.9, 2],
77+
[6.5, 3.0, 5.2, 2.0, 2], [6.2, 3.4, 5.4, 2.3, 2], [5.9, 3.0, 5.1, 1.8, 2],
78+
];
79+
80+
/**
81+
* Converts an integer into its one-hot representation and returns
82+
* the data as a JS Array.
83+
*/
84+
export function flatOneHot(idx) {
85+
// TODO(bileschi): Remove 'Array.from' from here once tf.data supports typed
86+
// arrays https://github.com/tensorflow/tfjs/issues/1041
87+
// TODO(bileschi): Remove '.dataSync()' from here once tf.data supports
88+
// datasets built from tensors.
89+
// https://github.com/tensorflow/tfjs/issues/1046
90+
return Array.from(tf.oneHot([idx], 3).dataSync());
91+
}
92+
93+
/**
94+
* Obtains Iris data, split into training and test sets and with the label
95+
* converted into one-hot format.
96+
*
97+
* @param testSplit Fraction of the data at the end to split as test data: a
98+
* number between 0 and 1.
99+
*
100+
* @param returns A list of two datasets, [trainingData, testingData].
101+
* The datasets represent a shuffled partition of the raw IRIS data.
102+
* Elements of the yielded data will consist of [Features, Labels].
103+
* - Features as a rank-1 `Tensor` of length-4 of numbers.
104+
* - Labels as a rank-1 `Tensor` in one-hot format.
105+
*/
106+
export async function getIrisData(testSplit) {
107+
// TODO(bileschi): Update shuffle etc. to use the tf.data API calls once
108+
// it is possible to cache the results for performance and train-test split
109+
// stability across epochs. Once caching is available, perform batching first
110+
// and then map the preprocessing functions across the batches.
111+
// https://github.com/tensorflow/tfjs/issues/1025
112+
113+
// Shuffle a copy of the raw data.
114+
const shuffled = IRIS_RAW_DATA.slice();
115+
tf.util.shuffle(shuffled);
116+
// Split the data into training and testing portions.
117+
const numTestExamples = Math.round(IRIS_RAW_DATA.length * testSplit);
118+
const numTrainExamples = IRIS_RAW_DATA.length - numTestExamples;
119+
const train = shuffled.slice(0, numTrainExamples);
120+
const test = shuffled.slice(numTrainExamples);
121+
// Split the data into into X & y and apply feature mapping transformations
122+
const trainX = tf.data.array(train.map(r => r.slice(0, 4)));
123+
const testX = tf.data.array(test.map(r => r.slice(0, 4)));
124+
// TODO(we should be able to just directly use tensors built from oneHot here
125+
// instead of converting to tensor and back using datasync & Array.from.
126+
// This causes an internal disposal error however.
127+
// See https://github.com/tensorflow/tfjs/issues/1071
128+
//
129+
// const trainY = tf.data.array(train.map(r => tf.oneHot([r[4]], 3)));
130+
// const testY = tf.data.array(test.map(r => tf.oneHot([r[4]], 3)));
131+
const trainY = tf.data.array(train.map(r => flatOneHot(r[4])));
132+
const testY = tf.data.array(test.map(r => flatOneHot(r[4])));
133+
// Recombine the X and y portions of the data.
134+
const trainDataset = tf.data.zip([trainX, trainY]);
135+
const testDataset = tf.data.zip([testX, testY])
136+
return [trainDataset, testDataset];
137+
}

0 commit comments

Comments
 (0)