-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.cpp
More file actions
69 lines (55 loc) · 2.56 KB
/
main.cpp
File metadata and controls
69 lines (55 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <iostream>
#include <iomanip>
#include "RL_Lib.h"
using namespace RLLib;
int main() {
std::cout << "========================================" << std::endl;
std::cout << "RL Library v" << RL_LIB_VERSION << std::endl;
std::cout << "========================================" << std::endl << std::endl;
// 1. Initialize Environment
std::cout << "[1] Initializing GridWorld Environment..." << std::endl;
GridWorld env(5, 5, 4, 4);
std::cout << " Grid Size: " << env.getWidth() << "x" << env.getHeight() << std::endl;
std::cout << " Total States: " << env.getNumStates() << std::endl;
std::cout << " Goal Position: (" << env.getGoalPosition().first << ", "
<< env.getGoalPosition().second << ")" << std::endl << std::endl;
// 2. Initialize Agent (Q-Learner)
std::cout << "[2] Initializing Q-Learning Agent..." << std::endl;
QLearner agent(0.1, 0.99, 0.1); // alpha=0.1, gamma=0.99, epsilon=0.1
std::cout << " Learning Rate (alpha): " << agent.alpha << std::endl;
std::cout << " Discount Factor (gamma): " << agent.gamma << std::endl;
std::cout << " Exploration Rate (epsilon): " << agent.epsilon << std::endl << std::endl;
// 3. Training Loop
std::cout << "[3] Starting Training..." << std::endl;
int num_episodes = 100;
int max_steps = 50;
for (int episode = 0; episode < num_episodes; ++episode) {
env.reset();
int state = env.getState();
double episode_reward = 0.0;
for (int step = 0; step < max_steps; ++step) {
// Choose action using epsilon-greedy strategy
int action = agent.chooseAction(state, GridWorld::NUM_ACTIONS);
// Execute action and get reward
double reward = env.step(action);
int next_state = env.getState();
episode_reward += reward;
// Update Q-values
agent.update(state, action, reward, next_state, GridWorld::NUM_ACTIONS);
state = next_state;
// Check if terminal state
if (env.isTerminal()) {
break;
}
}
// Print progress every 10 episodes
if ((episode + 1) % 10 == 0) {
std::cout << " Episode " << std::setw(3) << (episode + 1)
<< " | Reward: " << std::fixed << std::setprecision(2)
<< episode_reward << std::endl;
}
}
std::cout << "\n[4] Training Complete!" << std::endl;
std::cout << "========================================" << std::endl;
return 0;
}