Skip to content

Commit

Permalink
Added sys.path manipulation to make imports from ex2 easier
Browse files Browse the repository at this point in the history
  • Loading branch information
jtlowery committed Mar 15, 2017
1 parent a373b3f commit dba2be6
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
7 changes: 4 additions & 3 deletions ex3/ex3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
from matplotlib import use
use('TkAgg')
import sys
sys.path.append('../ex2/')

from oneVsAll import oneVsAll
from predictOneVsAll import predictOneVsAll
Expand All @@ -13,7 +15,7 @@
#
# This file contains code that helps you get started on the
# linear exercise. You will need to complete the following functions
# in this exericse:
# in this exercise:
#
# lrCostFunction.m (logistic regression cost function)
# oneVsAll.m
Expand All @@ -37,7 +39,7 @@
# Load Training Data
print('Loading and Visualizing Data ...')

data = scipy.io.loadmat('ex3data1.mat') # training data stored in arrays X, y
data = scipy.io.loadmat('ex3data1.mat') # training data stored in arrays X, y
X = data['X']
y = data['y']
m, _ = X.shape
Expand Down Expand Up @@ -68,7 +70,6 @@
# After ...

pred = predictOneVsAll(all_theta, X)

accuracy = np.mean(np.double(pred == np.squeeze(y))) * 100
print('\nTraining Set Accuracy: %f\n' % accuracy)

2 changes: 2 additions & 0 deletions ex3/ex3_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('../ex2/')

from displayData import displayData
from predict import predict
Expand Down
4 changes: 3 additions & 1 deletion ex3/submit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
import sys
sys.path.append('../ex2/')

from Submission import Submission
from Submission import sprintf
Expand Down Expand Up @@ -44,7 +46,7 @@ def output(part_id):
t2 = np.cos(np.array(range(1, 40, 2)).reshape(5, 4).T)

fname = srcs[part_id - 1].rsplit('.', 1)[0]
mod = __import__(fname, fromlist=[fname], level=1)
mod = __import__(fname, fromlist=[fname], level=0)
func = getattr(mod, fname)

if part_id == 1:
Expand Down

0 comments on commit dba2be6

Please sign in to comment.