Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/online/main_dmc_offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"td7": TD7Agent,
"sdac": SDACAgent,
"dpmd": DPMDAgent,
"td3_crtl": Ctrl_TD3_Agent,
}

class OffPolicyTrainer():
Expand Down Expand Up @@ -116,7 +117,8 @@ def train(self):
if self.global_frame < cfg.warmup_frames:
train_metrics = {}
else:
batch, indices = self.buffer.sample(batch_size=cfg.batch_size)
batch_size = cfg.batch_size + getattr(cfg, "aug_batch_size", 0)
batch, indices = self.buffer.sample(batch_size=batch_size)
train_metrics = self.agent.train_step(batch, step=self.global_frame)
if self.use_lap_buffer:
new_priorities = train_metrics.pop("priority")
Expand Down
2 changes: 2 additions & 0 deletions flowrl/agent/online/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .sdac import SDACAgent
from .td3 import TD3Agent
from .td7.td7 import TD7Agent
from .crtl.crtl import Ctrl_TD3_Agent

__all__ = [
"BaseAgent",
Expand All @@ -14,4 +15,5 @@
"SDACAgent",
"DPMDAgent",
"PPOAgent",
"Ctrl_TD3_Agent",
]
Loading