Skip to content

Commit

Permalink
Add regression vs classification comparison sample from blog post
Browse files Browse the repository at this point in the history
  • Loading branch information
sararob committed Feb 22, 2018
1 parent 3c5e06a commit f3656de
Show file tree
Hide file tree
Showing 6 changed files with 6,821 additions and 0 deletions.
344 changes: 344 additions & 0 deletions extras/regression-classification/classification.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,344 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TensorFlow version: 1.5.0\n"
]
}
],
"source": [
"import os\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"# Check that we have correct TensorFlow version installed\n",
"tf_version = tf.__version__\n",
"print(\"TensorFlow version: {}\".format(tf_version))\n",
"assert \"1.5\" <= tf_version, \"TensorFlow r1.5 or later is needed\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"tf.logging.set_verbosity(tf.logging.INFO)\n",
"\n",
"train_file = \"classify-train.csv\"\n",
"test_file = \"classify-test.csv\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"numerical_feature_names = [\n",
" 'PctUnder18',\n",
" 'PctOver65',\n",
" 'PctFemale',\n",
" 'PctWhite',\n",
" 'PctBachelors',\n",
" 'PctDem',\n",
" 'PctGop'\n",
"]\n",
"\n",
"feature_columns = [tf.feature_column.numeric_column(k) for k in numerical_feature_names]\n",
"\n",
"def my_input_fn(file_path, repeat_count=200):\n",
" def decode_csv(line):\n",
" parsed_line = tf.decode_csv(line, [[0.],[0.],[0.],[0.],[0.],[0.],[0.],[0.]])\n",
" label = parsed_line[-1] # Last element is the label\n",
" features = parsed_line[:-1] # Everything but last elements are the features\n",
" d = dict(zip(numerical_feature_names, features)), label\n",
" return d\n",
"\n",
" dataset = (tf.data.TextLineDataset(file_path) # Read text file\n",
" .map(decode_csv)) # Transform each elem by applying decode_csv fn\n",
" dataset = dataset.shuffle(buffer_size=256)\n",
" dataset = dataset.repeat(repeat_count) # Repeats dataset this # times\n",
" dataset = dataset.batch(8) # Batch size to use\n",
" return dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"WARNING:tensorflow:Using temporary folder as model directory: /var/folders/1h/g9jk9_kx67d6g0_gyfnvk1n4008m_k/T/tmpk44nbcf2\n",
"INFO:tensorflow:Using config: {'_model_dir': '/var/folders/1h/g9jk9_kx67d6g0_gyfnvk1n4008m_k/T/tmpk44nbcf2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x120a0afd0>, '_task_type': 'worker', '_task_id': 0, '_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Saving checkpoints for 1 into /var/folders/1h/g9jk9_kx67d6g0_gyfnvk1n4008m_k/T/tmpk44nbcf2/model.ckpt.\n",
"INFO:tensorflow:loss = 5.5451775, step = 1\n",
"INFO:tensorflow:global_step/sec: 570.865\n",
"INFO:tensorflow:loss = 0.29437244, step = 101 (0.176 sec)\n",
"INFO:tensorflow:global_step/sec: 844.545\n",
"INFO:tensorflow:loss = 0.2726082, step = 201 (0.118 sec)\n",
"INFO:tensorflow:global_step/sec: 891.569\n",
"INFO:tensorflow:loss = 3.9231257, step = 301 (0.112 sec)\n",
"INFO:tensorflow:global_step/sec: 728.251\n",
"INFO:tensorflow:loss = 0.51435304, step = 401 (0.137 sec)\n",
"INFO:tensorflow:global_step/sec: 749.616\n",
"INFO:tensorflow:loss = 3.1930523, step = 501 (0.133 sec)\n",
"INFO:tensorflow:global_step/sec: 874.407\n",
"INFO:tensorflow:loss = 1.5969841, step = 601 (0.114 sec)\n",
"INFO:tensorflow:global_step/sec: 819.868\n",
"INFO:tensorflow:loss = 0.35173976, step = 701 (0.122 sec)\n",
"INFO:tensorflow:global_step/sec: 764.753\n",
"INFO:tensorflow:loss = 3.9147635, step = 801 (0.131 sec)\n",
"INFO:tensorflow:global_step/sec: 813.749\n",
"INFO:tensorflow:loss = 2.1076026, step = 901 (0.123 sec)\n",
"INFO:tensorflow:global_step/sec: 860.919\n",
"INFO:tensorflow:loss = 2.8540645, step = 1001 (0.116 sec)\n",
"INFO:tensorflow:global_step/sec: 840.817\n",
"INFO:tensorflow:loss = 0.58483046, step = 1101 (0.119 sec)\n",
"INFO:tensorflow:global_step/sec: 851.179\n",
"INFO:tensorflow:loss = 1.315016, step = 1201 (0.117 sec)\n",
"INFO:tensorflow:global_step/sec: 882.614\n",
"INFO:tensorflow:loss = 1.2451642, step = 1301 (0.113 sec)\n",
"INFO:tensorflow:global_step/sec: 827.021\n",
"INFO:tensorflow:loss = 2.5050027, step = 1401 (0.121 sec)\n",
"INFO:tensorflow:global_step/sec: 755.812\n",
"INFO:tensorflow:loss = 2.521564, step = 1501 (0.132 sec)\n",
"INFO:tensorflow:global_step/sec: 796.628\n",
"INFO:tensorflow:loss = 0.8701288, step = 1601 (0.125 sec)\n",
"INFO:tensorflow:global_step/sec: 849.192\n",
"INFO:tensorflow:loss = 0.68960416, step = 1701 (0.118 sec)\n",
"INFO:tensorflow:global_step/sec: 812.176\n",
"INFO:tensorflow:loss = 1.3478332, step = 1801 (0.123 sec)\n",
"INFO:tensorflow:global_step/sec: 832.119\n",
"INFO:tensorflow:loss = 2.1637692, step = 1901 (0.120 sec)\n",
"INFO:tensorflow:global_step/sec: 815.56\n",
"INFO:tensorflow:loss = 0.41932052, step = 2001 (0.123 sec)\n",
"INFO:tensorflow:global_step/sec: 799.355\n",
"INFO:tensorflow:loss = 2.4380984, step = 2101 (0.125 sec)\n",
"INFO:tensorflow:global_step/sec: 810.78\n",
"INFO:tensorflow:loss = 0.1496887, step = 2201 (0.123 sec)\n",
"INFO:tensorflow:global_step/sec: 852.31\n",
"INFO:tensorflow:loss = 1.1216372, step = 2301 (0.117 sec)\n",
"INFO:tensorflow:global_step/sec: 794.951\n",
"INFO:tensorflow:loss = 0.6908283, step = 2401 (0.126 sec)\n",
"INFO:tensorflow:global_step/sec: 835.584\n",
"INFO:tensorflow:loss = 1.3392761, step = 2501 (0.120 sec)\n",
"INFO:tensorflow:global_step/sec: 786.533\n",
"INFO:tensorflow:loss = 0.7310655, step = 2601 (0.127 sec)\n",
"INFO:tensorflow:global_step/sec: 848.776\n",
"INFO:tensorflow:loss = 0.8060344, step = 2701 (0.118 sec)\n",
"INFO:tensorflow:global_step/sec: 837.765\n",
"INFO:tensorflow:loss = 0.13385934, step = 2801 (0.119 sec)\n",
"INFO:tensorflow:global_step/sec: 782.093\n",
"INFO:tensorflow:loss = 0.22636321, step = 2901 (0.128 sec)\n",
"INFO:tensorflow:global_step/sec: 849.018\n",
"INFO:tensorflow:loss = 0.8238576, step = 3001 (0.118 sec)\n",
"INFO:tensorflow:global_step/sec: 796.554\n",
"INFO:tensorflow:loss = 2.2260704, step = 3101 (0.125 sec)\n",
"INFO:tensorflow:global_step/sec: 736.251\n",
"INFO:tensorflow:loss = 2.3282278, step = 3201 (0.136 sec)\n",
"INFO:tensorflow:global_step/sec: 776.194\n",
"INFO:tensorflow:loss = 0.28495857, step = 3301 (0.129 sec)\n",
"INFO:tensorflow:global_step/sec: 768.958\n",
"INFO:tensorflow:loss = 3.2070298, step = 3401 (0.130 sec)\n",
"INFO:tensorflow:global_step/sec: 832.881\n",
"INFO:tensorflow:loss = 1.5284773, step = 3501 (0.120 sec)\n",
"INFO:tensorflow:global_step/sec: 798.871\n",
"INFO:tensorflow:loss = 0.3559987, step = 3601 (0.125 sec)\n",
"INFO:tensorflow:global_step/sec: 829.696\n",
"INFO:tensorflow:loss = 1.2014079, step = 3701 (0.121 sec)\n",
"INFO:tensorflow:global_step/sec: 804.724\n",
"INFO:tensorflow:loss = 2.766306, step = 3801 (0.124 sec)\n",
"INFO:tensorflow:global_step/sec: 800.264\n",
"INFO:tensorflow:loss = 1.0233307, step = 3901 (0.125 sec)\n",
"INFO:tensorflow:global_step/sec: 828.555\n",
"INFO:tensorflow:loss = 1.0765572, step = 4001 (0.121 sec)\n",
"INFO:tensorflow:global_step/sec: 834.161\n",
"INFO:tensorflow:loss = 0.65881246, step = 4101 (0.120 sec)\n",
"INFO:tensorflow:global_step/sec: 810.123\n",
"INFO:tensorflow:loss = 1.2197807, step = 4201 (0.123 sec)\n",
"INFO:tensorflow:global_step/sec: 826.337\n",
"INFO:tensorflow:loss = 2.2669864, step = 4301 (0.121 sec)\n",
"INFO:tensorflow:global_step/sec: 715.687\n",
"INFO:tensorflow:loss = 1.821552, step = 4401 (0.140 sec)\n",
"INFO:tensorflow:global_step/sec: 769.864\n",
"INFO:tensorflow:loss = 2.4519372, step = 4501 (0.130 sec)\n",
"INFO:tensorflow:global_step/sec: 741.284\n",
"INFO:tensorflow:loss = 1.1228801, step = 4601 (0.135 sec)\n",
"INFO:tensorflow:global_step/sec: 785.533\n",
"INFO:tensorflow:loss = 1.2760054, step = 4701 (0.127 sec)\n",
"INFO:tensorflow:global_step/sec: 814.087\n",
"INFO:tensorflow:loss = 1.1107497, step = 4801 (0.123 sec)\n",
"INFO:tensorflow:global_step/sec: 830.842\n",
"INFO:tensorflow:loss = 0.5630184, step = 4901 (0.120 sec)\n",
"INFO:tensorflow:global_step/sec: 832.863\n",
"INFO:tensorflow:loss = 1.6846699, step = 5001 (0.120 sec)\n",
"INFO:tensorflow:global_step/sec: 836.812\n",
"INFO:tensorflow:loss = 0.897588, step = 5101 (0.119 sec)\n",
"INFO:tensorflow:global_step/sec: 807.005\n",
"INFO:tensorflow:loss = 1.0144025, step = 5201 (0.124 sec)\n",
"INFO:tensorflow:global_step/sec: 928.506\n",
"INFO:tensorflow:loss = 0.19305001, step = 5301 (0.108 sec)\n",
"INFO:tensorflow:global_step/sec: 763.539\n",
"INFO:tensorflow:loss = 0.6202412, step = 5401 (0.131 sec)\n",
"INFO:tensorflow:global_step/sec: 838.088\n",
"INFO:tensorflow:loss = 1.4105525, step = 5501 (0.119 sec)\n",
"INFO:tensorflow:global_step/sec: 860.675\n",
"INFO:tensorflow:loss = 0.6199516, step = 5601 (0.116 sec)\n",
"INFO:tensorflow:global_step/sec: 796.527\n",
"INFO:tensorflow:loss = 0.034956243, step = 5701 (0.126 sec)\n",
"INFO:tensorflow:global_step/sec: 796.019\n",
"INFO:tensorflow:loss = 0.9802186, step = 5801 (0.126 sec)\n",
"INFO:tensorflow:global_step/sec: 804.448\n",
"INFO:tensorflow:loss = 1.5616608, step = 5901 (0.124 sec)\n",
"INFO:tensorflow:global_step/sec: 811.477\n",
"INFO:tensorflow:loss = 0.26409894, step = 6001 (0.123 sec)\n",
"INFO:tensorflow:global_step/sec: 826.153\n",
"INFO:tensorflow:loss = 0.759922, step = 6101 (0.121 sec)\n",
"INFO:tensorflow:global_step/sec: 819.021\n",
"INFO:tensorflow:loss = 0.23053291, step = 6201 (0.122 sec)\n",
"INFO:tensorflow:Saving checkpoints for 6250 into /var/folders/1h/g9jk9_kx67d6g0_gyfnvk1n4008m_k/T/tmpk44nbcf2/model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 0.61512256.\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.estimator.canned.linear.LinearClassifier at 0x120a15eb8>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"classifier = tf.estimator.LinearClassifier(feature_columns=feature_columns)\n",
"\n",
"# Run training for 20 epochs (20 times through our entire dataset)\n",
"# You can experiment with this value for your own dataset\n",
"classifier.train(\n",
" input_fn=lambda: my_input_fn(train_file, 20))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Starting evaluation at 2018-02-22-18:06:33\n",
"INFO:tensorflow:Restoring parameters from /var/folders/1h/g9jk9_kx67d6g0_gyfnvk1n4008m_k/T/tmpk44nbcf2/model.ckpt-6250\n",
"INFO:tensorflow:Finished evaluation at 2018-02-22-18:06:34\n",
"INFO:tensorflow:Saving dict for global step 6250: accuracy = 0.9705882, accuracy_baseline = 0.8888889, auc = 0.99298507, auc_precision_recall = 0.9434167, average_loss = 0.1134358, global_step = 6250, label/mean = 0.11111111, loss = 0.90159357, prediction/mean = 0.14078975\n",
"accuracy: 0.9705882\n",
"accuracy_baseline: 0.8888889\n",
"auc: 0.99298507\n",
"auc_precision_recall: 0.9434167\n",
"average_loss: 0.1134358\n",
"global_step: 6250\n",
"label/mean: 0.11111111\n",
"loss: 0.90159357\n",
"prediction/mean: 0.14078975\n"
]
}
],
"source": [
"results = classifier.evaluate(input_fn=lambda: my_input_fn(test_file, 1))\n",
"\n",
"for key in sorted(results):\n",
" print('%s: %s' % (key, results[key]))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Generate predictions on 3 counties\n",
"prediction_input = {\n",
" 'PctUnder18': [23.9, 25.7, 10.6],\n",
" 'PctOver65': [17.6,24.7,15.8],\n",
" 'PctFemale': [50.0,48.5,53.5],\n",
" 'PctWhite':[0.965, 0.97, 0.75],\n",
" 'PctBachelors':[12.7, 17.0, 49.8],\n",
" 'PctDem': [0.3227832512315271, 0.09475032010243278, 0.6346801346801347],\n",
" 'PctGop': [0.6545566502463054, 0.8911651728553138, 0.3468013468013468]\n",
"}\n",
"\n",
"def test_input_fn():\n",
" dataset = tf.data.Dataset.from_tensors(prediction_input)\n",
" return dataset\n",
"\n",
"# Predict all our prediction_input\n",
"pred_results = classifier.predict(input_fn=test_input_fn)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Restoring parameters from /var/folders/1h/g9jk9_kx67d6g0_gyfnvk1n4008m_k/T/tmpk44nbcf2/model.ckpt-6250\n",
"[0.9905778 0.00942222]\n",
"[9.9971288e-01 2.8714165e-04]\n",
"[0.04889648 0.95110345]\n"
]
}
],
"source": [
"# Actual values for the raw prediction data:\n",
"# 1) 23% Clinton (class of 0 for Trump)\n",
"# 2) 5% Clinton (class of 0)\n",
"# 3) 69% Clinton (class of 1)\n",
"\n",
"# Iterate over predictions on raw data\n",
"for pred in enumerate(pred_results):\n",
" print(pred[1]['probabilities'])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit f3656de

Please sign in to comment.