-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
46 lines (37 loc) · 1.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
'''
Author: Hongliang Lu, [email protected]
Date: 2024-06-27 13:43:03
LastEditTime: 2024-06-27 13:58:37
FilePath: /DQN for Stock Trading/main.py
Description:
Organization: College of Engineering,Peking University.
'''
from dqn_agent import Agent
from model import QNetwork
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
from StockExchange import StockExchange
STATE_SIZE = 12 # 状态空间大小
EPISODE_COUNT = 100 # episode 数量
data_dir = "StockData"
filename = "600967.SS.csv"
file_dir = data_dir + "/" + filename
# 使用 pandas 读取 CSV 文件
try:
df = pd.read_csv(file_dir, encoding='utf-8')
except UnicodeDecodeError:
try:
df = pd.read_csv(file_dir, encoding='gbk')
except UnicodeDecodeError:
df = pd.read_csv(file_dir, encoding='iso-8859-1')
# 提取收盘价
stockData = list(df['Close'].values)
# 初始化 agent
agent = Agent(state_size=STATE_SIZE, action_size=3)
l = len(stockData) - 1
stock_agent = StockExchange(stockData, agent, STATE_SIZE)
stock_agent.train(episodes=EPISODE_COUNT,filename=filename)
stock_agent.test()
stock_agent.plot_result()