Skip to content

Commit 02d6f65

Browse files
committed
Add debug mode and logger
Add debug flag to run in debug mode (useful for gdb) and use logger instead of print.
1 parent e7b2794 commit 02d6f65

File tree

5 files changed

+101
-17
lines changed

5 files changed

+101
-17
lines changed

main.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import argparse
44
import os
55
import sys
6+
import math
67

78
import torch
89
import torch.optim as optim
@@ -13,8 +14,11 @@
1314
from model import ActorCritic
1415
from train import train
1516
from test import test
17+
from utils import logger
1618
import my_optim
1719

20+
logger = logger.getLogger('main')
21+
1822
# Based on
1923
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
2024
# Training settings
@@ -37,13 +41,16 @@
3741
help='environment to train on (default: PongDeterministic-v3)')
3842
parser.add_argument('--no-shared', default=False, metavar='O',
3943
help='use an optimizer without shared momentum.')
44+
parser.add_argument('--max-iters', type=int, default=math.inf,
45+
help='maximum iterations per process.')
4046

47+
parser.add_argument('--debug', action='store_true', default=False,
48+
help='run in a way its easier to debug')
4149

4250
if __name__ == '__main__':
4351
args = parser.parse_args()
4452

4553
torch.manual_seed(args.seed)
46-
4754
env = create_atari_env(args.env_name)
4855
shared_model = ActorCritic(
4956
env.observation_space.shape[0], env.action_space)
@@ -55,15 +62,19 @@
5562
optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr)
5663
optimizer.share_memory()
5764

58-
processes = []
59-
60-
p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
61-
p.start()
62-
processes.append(p)
65+
66+
if not args.debug:
67+
processes = []
6368

64-
for rank in range(0, args.num_processes):
65-
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
69+
p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
6670
p.start()
6771
processes.append(p)
68-
for p in processes:
69-
p.join()
72+
for rank in range(0, args.num_processes):
73+
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
74+
p.start()
75+
processes.append(p)
76+
for p in processes:
77+
p.join()
78+
else: ## debug is enabled
79+
# run only one process in a main, easier to debug
80+
train(0, args, shared_model, optimizer)

model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ def __init__(self, num_inputs, action_space):
4444
self.lstm = nn.LSTMCell(32 * 3 * 3, 256)
4545

4646
num_outputs = action_space.n
47+
4748
self.critic_linear = nn.Linear(256, 1)
4849
self.actor_linear = nn.Linear(256, num_outputs)
50+
#self.critic_linear = nn.Linear(288, 1)
51+
#self.actor_linear = nn.Linear(288, num_outputs)
4952

5053
self.apply(weights_init)
5154
self.actor_linear.weight.data = normalized_columns_initializer(
@@ -66,7 +69,6 @@ def forward(self, inputs):
6669
x = F.elu(self.conv2(x))
6770
x = F.elu(self.conv3(x))
6871
x = F.elu(self.conv4(x))
69-
7072
x = x.view(-1, 32 * 3 * 3)
7173
hx, cx = self.lstm(x, (hx, cx))
7274
x = hx

test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from torchvision import datasets, transforms
1212
import time
1313
from collections import deque
14+
from utils import logger
1415

16+
logger = logger.getLogger('test')
1517

1618
def test(rank, args, shared_model):
1719
torch.manual_seed(args.seed + rank)
@@ -59,7 +61,7 @@ def test(rank, args, shared_model):
5961
done = True
6062

6163
if done:
62-
print("Time {}, episode reward {}, episode length {}".format(
64+
logger.info("Time {}, episode reward {}, episode length {}".format(
6365
time.strftime("%Hh %Mm %Ss",
6466
time.gmtime(time.time() - start_time)),
6567
reward_sum, episode_length))

train.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import math
22
import os
33
import sys
4+
import resource
5+
import gc
46

57
import torch
68
import torch.nn.functional as F
@@ -9,15 +11,16 @@
911
from model import ActorCritic
1012
from torch.autograd import Variable
1113
from torchvision import datasets, transforms
14+
from utils import logger
1215

16+
logger = logger.getLogger('main')
1317

1418
def ensure_shared_grads(model, shared_model):
1519
for param, shared_param in zip(model.parameters(), shared_model.parameters()):
1620
if shared_param.grad is not None:
1721
return
1822
shared_param._grad = param.grad
1923

20-
2124
def train(rank, args, shared_model, optimizer=None):
2225
torch.manual_seed(args.seed + rank)
2326

@@ -36,8 +39,29 @@ def train(rank, args, shared_model, optimizer=None):
3639
done = True
3740

3841
episode_length = 0
42+
43+
iteration = 0
44+
3945
while True:
46+
47+
values = []
48+
log_probs = []
49+
rewards = []
50+
entropies = []
51+
52+
if iteration == args.max_iters:
53+
logger.info('Max iteration {} reached..'.format(args.max_iters))
54+
break
55+
56+
if iteration % 200 == 0 and rank == 0:
57+
mem_used = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
58+
mem_used_mb = mem_used / 1024
59+
logger.info('Memory usage of one proc: {} (mb)'.format(mem_used_mb))
60+
61+
62+
iteration += 1
4063
episode_length += 1
64+
4165
# Sync with the shared model
4266
model.load_state_dict(shared_model.state_dict())
4367
if done:
@@ -47,10 +71,6 @@ def train(rank, args, shared_model, optimizer=None):
4771
cx = Variable(cx.data)
4872
hx = Variable(hx.data)
4973

50-
values = []
51-
log_probs = []
52-
rewards = []
53-
entropies = []
5474

5575
for step in range(args.num_steps):
5676
value, logit, (hx, cx) = model(

utils/logger.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# -*- coding: utf-8 -*-
2+
import os
3+
import logging
4+
import logging.config
5+
6+
7+
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO')
8+
LOGGING = {
9+
'version': 1,
10+
'disable_existing_loggers': True,
11+
'formatters': {
12+
'verbose': {
13+
'format': "[%(asctime)s] %(levelname)s " \
14+
"[%(threadName)s:%(lineno)s] %(message)s",
15+
'datefmt': "%Y-%m-%d %H:%M:%S"
16+
},
17+
'simple': {
18+
'format': '%(levelname)s %(message)s'
19+
},
20+
},
21+
'handlers': {
22+
'console': {
23+
'level': LOG_LEVEL,
24+
'class': 'logging.StreamHandler',
25+
'formatter': 'verbose'
26+
},
27+
'file': {
28+
'level': LOG_LEVEL,
29+
'class': 'logging.handlers.RotatingFileHandler',
30+
'formatter': 'verbose',
31+
'filename': 'rl.log',
32+
'maxBytes': 10*10**6,
33+
'backupCount': 3
34+
}
35+
},
36+
'loggers': {
37+
'': {
38+
'handlers': ['console', 'file'],
39+
'level': LOG_LEVEL,
40+
},
41+
}
42+
}
43+
44+
45+
logging.config.dictConfig(LOGGING)
46+
47+
def getLogger(name):
48+
49+
return logging.getLogger(name)

0 commit comments

Comments
 (0)