diff --git a/PredictStockPricesRNN.py b/PredictStockPricesRNN.py index 215c050..1d23ec5 100644 --- a/PredictStockPricesRNN.py +++ b/PredictStockPricesRNN.py @@ -55,16 +55,24 @@ def train_rnn(X_train, X_test): # function to evaluate RNN def evaluate_rnn(model, X_test): + """Evaluate a trained RNN model using mean squared error.""" + # reshape test data to match training shape + X_test = X_test.reshape((X_test.shape[0], 1, 1)) + # make predictions with RNN predictions = model.predict(X_test) # calculate mean squared error - mse = np.mean((predictions - X_test)**2) + mse = np.mean((predictions - X_test) ** 2) return mse # function to make predictions with RNN def predict_with_rnn(model, X_test): + """Generate predictions from a trained RNN model.""" + # reshape test data to match training shape + X_test = X_test.reshape((X_test.shape[0], 1, 1)) + # make predictions with RNN predictions = model.predict(X_test)