-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy path01_linear_regression_sol.py
64 lines (52 loc) · 1.93 KB
/
01_linear_regression_sol.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
Simple linear regression example in TensorFlow
This program tries to predict the number of thefts from
the number of fire in the city of Chicago
"""
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import csv
DATA_FILE = 'data/fire_theft.csv'
# Step 1: read data
with open(DATA_FILE, 'r') as f:
data = []
reader = csv.reader(f, delimiter=',')
for i, row in enumerate(reader):
if i == 0:
continue
data.append(row)
n_samples = len(data)
data = np.asarray(data, dtype='float32')
# Step 2: create placeholders for input X (number of fire) and label Y (number of theft)
X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32, name='Y')
# Step 3: create weight and bias, initialized to 0
w = tf.Variable(0.0, name='weights')
b = tf.Variable(0.0, name='bias')
# Step 4: build model to predict Y
Y_predicted = X * w + b
# Step 5: use the square error as the loss function
loss = tf.reduce_mean(tf.square(Y - Y_predicted, name='loss'))
# Step 6: using gradient descent with learning rate of 0.001 to minimize loss
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
train = optimizer.minimize(loss)
with tf.Session() as sess:
# Step 7: initialize the necessary variables, in this case, w and b
sess.run(tf.global_variables_initializer())
# Step 8: train the model
for i in range(100): # train the model 100 times
total_loss = 0
for x, y in data:
# Session runs train_op and fetch values of loss
_, l = sess.run([train, loss], feed_dict={X: x, Y: y})
total_loss += l
print('Epoch {0}: {1}'.format(i, total_loss / n_samples))
# Step 9: output the values of w and b
w_value, b_value = sess.run([w, b])
# plot the results
X, Y = data.T[0], data.T[1]
plt.plot(X, Y, 'bo', label='Real data')
plt.plot(X, X * w_value + b_value, 'r', label='Predicted data')
plt.legend()
plt.show()