Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,37 @@
"max_bars_comment": "Максимальная длина эпизода в барах",
"reward_scaling": 1.0,
"reward_scaling_comment": "Множитель награды",
"risk_fraction": 0.01,
"risk_fraction": 0.0034473841356203925,
"risk_fraction_comment": "Доля капитала под риск на сделку",
"max_alloc_per_trade": 0.3,
"max_alloc_per_trade": 0.4206833982135664,
"max_alloc_per_trade_comment": "Макс. доля капитала в одной сделке",
"min_notional": 1.0,
"min_notional_comment": "Минимальная сумма открытия ордера",
"penalize_no_trade_steps": true,
"penalize_no_trade_steps_comment": "Штрафовать за бездействие",
"no_trade_penalty": 100.0,
"no_trade_penalty": 16.322930372713817,
"no_trade_penalty_comment": "Размер штрафа за бездействие",
"consecutive_no_trade_allowed": 10,
"consecutive_no_trade_allowed_comment": "Допустимое число шагов без действий",
"train_timesteps": 500000,
"train_timesteps_comment": "Число шагов обучения",
"learn_timesteps": 10000,
"learn_timesteps_comment": "Шаги обучения на одну пару",
"learning_rate": 1.1736906626201254e-05,
"gamma": 0.9797937313436991,
"batch_size": 512,
"n_steps": 512,
"ent_coef": 0.023636801086752424,
"clip_range": 0.29121737990065244,
"gae_lambda": 0.828550617451298,
"vf_coef": 0.7392096696563757,
"max_grad_norm": 0.7781116322105397,
"target_kl": 0.041564999642783235,
"n_epochs": 10,
"policy_net_arch": "256-128",
"force_new_model": true,
"per_symbol_timesteps": 10000,
"checkpoint_freq": 20000,
"logging_level": "INFO",
"logging_level_comment": "Уровень логирования",
"data_top_n": 100,
Expand Down
60 changes: 51 additions & 9 deletions train_rl.py → train_rl_all_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@

from env.hourly_trading_env import HourlyTradingEnv # твоя среда

def _parse_arch(s: str):
"""'256-128' -> dict(pi=[256,128], vf=[256,128]) для SB3 PPO."""
try:
parts = [int(x) for x in str(s).strip().split("-") if x]
return dict(pi=parts, vf=parts)
except Exception:
return dict(pi=[256, 128], vf=[256, 128])


# ----------------- CONFIG -----------------
with open("config.json", "r") as f:
Expand All @@ -54,6 +62,24 @@
for d in (LOG_DIR, TB_DIR, CHECKPOINT_DIR):
os.makedirs(d, exist_ok=True)

# ---------- PPO hyperparams (из config.json) ----------
BEST_HPS = {
"learning_rate": float(config.get("learning_rate", 1.1736906626201254e-05)),
"gamma": float(config.get("gamma", 0.9797937313436991)),
"batch_size": int(config.get("batch_size", 512)),
"n_steps": int(config.get("n_steps", 512)),
"ent_coef": float(config.get("ent_coef", 0.023636801086752424)),
"clip_range": float(config.get("clip_range", 0.29121737990065244)),
"gae_lambda": float(config.get("gae_lambda", 0.828550617451298)),
"vf_coef": float(config.get("vf_coef", 0.7392096696563757)),
"max_grad_norm": float(config.get("max_grad_norm", 0.7781116322105397)),
"target_kl": float(config.get("target_kl", 0.041564999642783235)),
"n_epochs": int(config.get("n_epochs", 10)),
}
NET_ARCH_STR = str(config.get("policy_net_arch", "256-128"))
POLICY_KWARGS = {"net_arch": _parse_arch(NET_ARCH_STR)}
FORCE_NEW_MODEL = bool(config.get("force_new_model", True)) # True — безопасно при смене признаков

PER_SYMBOL_TIMESTEPS = int(config.get("per_symbol_timesteps", config.get("learn_timesteps", 10_000)))
SLEEP_BETWEEN = float(config.get("sleep_between_symbols_sec", 2))
CHECKPOINT_FREQ = int(config.get("checkpoint_freq", 20_000))
Expand Down Expand Up @@ -395,17 +421,33 @@ def main():
# первая пара — чтобы инициализировать модель/логгер
first_df = load_one_symbol_csv(CSV_PATH, symbols[0])
vec_env = make_env_from_df(first_df, training=True)
if os.path.isfile(VECNORM_PATH):
vec_env = VecNormalize.load(VECNORM_PATH, vec_env)
vec_env.training = True
if os.path.isfile(VECNORM_PATH) and not FORCE_NEW_MODEL:
try:
vec_env = VecNormalize.load(VECNORM_PATH, vec_env)
vec_env.training = True
except Exception as e:
print(f"Не удалось загрузить VecNormalize ({e}). Стартуем с нуля.")

# создаём/грузим модель, без спама в консоль
if os.path.isfile(MODEL_PATH):
print("Загружаю сохранённую модель…")
model = PPO.load(MODEL_PATH, env=vec_env, verbose=0, tensorboard_log=TB_DIR)
else:
print("Создаю новую модель…")
model = PPO(policy="MlpPolicy", env=vec_env, verbose=0, tensorboard_log=TB_DIR)
model = None
if os.path.isfile(MODEL_PATH) and not FORCE_NEW_MODEL:
try:
print("Загружаю сохранённую модель…")
model = PPO.load(MODEL_PATH, env=vec_env, verbose=0, tensorboard_log=TB_DIR)
except Exception as e:
print(f"Не удалось загрузить модель ({e}). Создаю новую…")
if model is None:
print("Создаю новую модель (с лучшими гиперпараметрами)…")
model = PPO(
policy="MlpPolicy",
env=vec_env,
verbose=0,
tensorboard_log=TB_DIR,
policy_kwargs=POLICY_KWARGS,
**BEST_HPS,
)
print(f"[PPO] policy_net_arch={NET_ARCH_STR} | force_new={FORCE_NEW_MODEL}")
print(f"[PPO] HParams: {BEST_HPS}")

# аккуратный логгер: CSV + TensorBoard
run_name = time.strftime("%Y%m%d-%H%M%S")
Expand Down