From dba2be6b22b802a872c6c22b2ddf1f55597ce6b1 Mon Sep 17 00:00:00 2001 From: Joel Lowery Date: Wed, 15 Mar 2017 02:35:55 -0500 Subject: [PATCH] Added sys.path manipulation to make imports from ex2 easier --- ex3/ex3.py | 7 ++++--- ex3/ex3_nn.py | 2 ++ ex3/submit.py | 4 +++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/ex3/ex3.py b/ex3/ex3.py index b303bd9..947f36e 100644 --- a/ex3/ex3.py +++ b/ex3/ex3.py @@ -2,6 +2,8 @@ import numpy as np from matplotlib import use use('TkAgg') +import sys +sys.path.append('../ex2/') from oneVsAll import oneVsAll from predictOneVsAll import predictOneVsAll @@ -13,7 +15,7 @@ # # This file contains code that helps you get started on the # linear exercise. You will need to complete the following functions -# in this exericse: +# in this exercise: # # lrCostFunction.m (logistic regression cost function) # oneVsAll.m @@ -37,7 +39,7 @@ # Load Training Data print('Loading and Visualizing Data ...') -data = scipy.io.loadmat('ex3data1.mat') # training data stored in arrays X, y +data = scipy.io.loadmat('ex3data1.mat') # training data stored in arrays X, y X = data['X'] y = data['y'] m, _ = X.shape @@ -68,7 +70,6 @@ # After ... pred = predictOneVsAll(all_theta, X) - accuracy = np.mean(np.double(pred == np.squeeze(y))) * 100 print('\nTraining Set Accuracy: %f\n' % accuracy) diff --git a/ex3/ex3_nn.py b/ex3/ex3_nn.py index 34ea9a4..5574b1d 100644 --- a/ex3/ex3_nn.py +++ b/ex3/ex3_nn.py @@ -3,6 +3,8 @@ import scipy.io import numpy as np import matplotlib.pyplot as plt +import sys +sys.path.append('../ex2/') from displayData import displayData from predict import predict diff --git a/ex3/submit.py b/ex3/submit.py index 98ff7b2..c3deea0 100644 --- a/ex3/submit.py +++ b/ex3/submit.py @@ -1,4 +1,6 @@ import numpy as np +import sys +sys.path.append('../ex2/') from Submission import Submission from Submission import sprintf @@ -44,7 +46,7 @@ def output(part_id): t2 = np.cos(np.array(range(1, 40, 2)).reshape(5, 4).T) fname = srcs[part_id - 1].rsplit('.', 1)[0] - mod = __import__(fname, fromlist=[fname], level=1) + mod = __import__(fname, fromlist=[fname], level=0) func = getattr(mod, fname) if part_id == 1: