20
20
21
21
import metaStrategy as mt
22
22
23
- from stable_baselines3 import PPO
23
+ from stable_baselines3 import PPO , DQN
24
24
25
25
import numpy as np
26
26
from enum import Enum
27
27
28
28
import pandas as pd
29
+ import pandas_ta as ta
29
30
30
31
from custom_indicators import BollingerBandsBandwitch
31
32
@@ -40,10 +41,14 @@ class AiStableBaselinesModel(mt.MetaStrategy):
40
41
41
42
params = (
42
43
('model' , "" ), # Model name
43
- ('tradeSize' , 5000 ),
44
+ ('tradeSize' , 10.0 ),
45
+ ('use_ATR_SL' , True ),
44
46
('atrperiod' , 14 ), # ATR Period (standard)
45
- ('atrdist_SL' , 3 ), # ATR distance for stop price
46
- ('atrdist_TP' , 5 ), # ATR distance for take profit price
47
+ ('atrdist_SL' , 3.0 ), # ATR distance for stop price
48
+ ('atrdist_TP' , 5.0 ), # ATR distance for take profit price
49
+ ('use_Fixed_SL' , False ),
50
+ ('fixed_SL' , 50.0 ), # Fixed distance Stop Loss
51
+ ('fixed_TP' , 100.0 ), # Fixed distance Take Profit
47
52
)
48
53
49
54
def notify_order (self , order ):
@@ -60,7 +65,7 @@ def __init__(self, *argv):
60
65
super ().__init__ (argv [0 ])
61
66
62
67
# Ichi indicator
63
- self .ichimoku = bt .ind .Ichimoku ()
68
+ self .ichimoku = bt .ind .Ichimoku (self . data )
64
69
65
70
# To set the stop price
66
71
self .atr = bt .indicators .ATR (self .data , period = self .p .atrperiod )
@@ -86,71 +91,78 @@ def __init__(self, *argv):
86
91
87
92
self .macd = bt .ind .MACDHisto (self .data )
88
93
94
+ self .normalization_bounds_df_min = self .loadNormalizationBoundsCsv ( "C:/perso/AI/AI_Framework/prepared_data/normalization_bounds_min.csv" ).transpose ()
95
+ self .normalization_bounds_df_max = self .loadNormalizationBoundsCsv ( "C:/perso/AI/AI_Framework/prepared_data/normalization_bounds_max.csv" ).transpose ()
96
+
89
97
pass
90
98
91
99
def start (self ):
92
100
self .order = None # sentinel to avoid operrations on pending order
93
101
94
102
# Load the model
95
103
self .model = PPO .load (self .p .model )
104
+ #self.model = DQN.load(self.p.model)
96
105
pass
97
106
98
107
def next (self ):
99
108
100
109
self .obseravation = self .next_observation ()
101
110
111
+ # Normalize observation
112
+ #self.normalizeObservations()
113
+
102
114
# Do nothing if a parameter is not valid yet (basically wait for all idicators to be loaded)
115
+
103
116
if pd .isna (self .obseravation ).any ():
104
117
print ("Waiting indicators" )
105
118
return
106
-
119
+
107
120
# Prepare data for Model
108
121
action , _states = self .model .predict (self .obseravation ) # deterministic=True
109
122
110
123
if not self .position : # not in the market
111
124
125
+ # TP & SL calculation
126
+ loss_dist = self .atr [0 ] * self .p .atrdist_SL if self .p .use_ATR_SL else self .p .fixed_SL
127
+ profit_dist = self .atr [0 ] * self .p .atrdist_TP if self .p .use_ATR_SL else self .p .fixed_TP
128
+
112
129
if action == Action .SELL .value :
113
130
self .order = self .sell (size = self .p .tradeSize )
114
- ldist = self .atr [0 ] * self .p .atrdist_SL
115
- self .lstop = self .data .close [0 ] + ldist
116
- pdist = self .atr [0 ] * self .p .atrdist_TP
117
- self .take_profit = self .data .close [0 ] - pdist
131
+ self .lstop = self .data .close [0 ] + loss_dist
132
+ self .take_profit = self .data .close [0 ] - profit_dist
118
133
119
134
elif action == Action .BUY .value :
120
135
self .order = self .buy (size = self .p .tradeSize )
121
- ldist = self .atr [0 ] * self .p .atrdist_SL
122
- self .lstop = self .data .close [0 ] - ldist
123
- pdist = self .atr [0 ] * self .p .atrdist_TP
124
- self .take_profit = self .data .close [0 ] + pdist
136
+ self .lstop = self .data .close [0 ] - loss_dist
137
+ self .take_profit = self .data .close [0 ] + profit_dist
125
138
126
139
else : # in the market
127
140
pclose = self .data .close [0 ]
128
- pstop = self .lstop # seems to be the bug
129
141
130
- if (not ((pstop < pclose < self .take_profit )| (pstop > pclose > self .take_profit ))):
142
+ if (not ((self . lstop < pclose < self .take_profit )| (self . lstop > pclose > self .take_profit ))):
131
143
self .close () # Close position
132
144
133
145
pass
134
146
135
147
# Here you have to transform self object price and indicators into a np.array input for AI Model
136
148
# How you do it depend on your AI Model inputs
137
149
# Strategy is in the data preparation for AI :D
138
- def next_observation (self ):
139
-
140
- # https://stackoverflow.com/questions/53979199/tensorflow-keras-returning-multiple-predictions-while-expecting-one
150
+ # https://stackoverflow.com/questions/53979199/tensorflow-keras-returning-multiple-predictions-while-expecting-one
141
151
142
- # Ichimoku
143
- #inputs = [ self.ichimoku.tenkan[0], self.ichimoku.kijun[0], self.ichimoku.senkou[0], self.ichimoku.senkou_lead[0], self.ichimoku.chikou[0] ]
152
+ def next_observation (self ):
144
153
145
154
# OHLCV
146
155
inputs = [ self .data .open [0 ],self .data .high [0 ],self .data .low [0 ],self .data .close [0 ],self .data .volume [0 ] ]
147
156
148
- # Stochastic
149
- inputs = inputs + [self .stochastic . percK [0 ], self .stochastic . percD [0 ]]
157
+ # Ichimoku
158
+ inputs = inputs + [ self .ichimoku . senkou_span_a [0 ], self .ichimoku . senkou_span_b [0 ], self . ichimoku . tenkan_sen [ 0 ], self . ichimoku . kijun_sen [ 0 ], self . ichimoku . chikou_span [ 0 ] ]
150
159
151
160
# Rsi
152
161
inputs = inputs + [self .rsi .rsi [0 ]]
153
162
163
+ # Stochastic
164
+ inputs = inputs + [self .stochastic .percK [0 ], self .stochastic .percD [0 ]]
165
+
154
166
# bbands
155
167
inputs = inputs + [self .bbands .bot [0 ]] # BBL
156
168
inputs = inputs + [self .bbands .mid [0 ]] # BBM
@@ -180,4 +192,87 @@ def next_observation(self):
180
192
181
193
return np .array (inputs )
182
194
183
- pass
195
+ pass
196
+
197
+
198
+ def normalizeObservations (self ):
199
+
200
+ MIN_BITCOIN_VALUE = 10_000.0
201
+ MAX_BITCOIN_VALUE = 80_000.0
202
+
203
+ self .obseravation_normalized = np .empty (len (self .obseravation ))
204
+ try :
205
+ # Normalize data
206
+ for index , value in enumerate (self .obseravation ):
207
+ if value < 100.0 :
208
+ self .obseravation_normalized [index ] = (value - 0.0 ) / (100.0 - 0.0 )
209
+ else :
210
+ self .obseravation_normalized [index ] = (value - MIN_BITCOIN_VALUE ) / (MAX_BITCOIN_VALUE - MIN_BITCOIN_VALUE )
211
+
212
+ #self.obseravation_normalized = (self.obseravation-self.normalization_bounds_df_min) / (self.normalization_bounds_df_max-self.normalization_bounds_df_min)
213
+
214
+
215
+ except ValueError as err :
216
+ return None , "ValueError error:" + str (err )
217
+ except AttributeError as err :
218
+ return None , "AttributeError error:" + str (err )
219
+ except IndexError as err :
220
+ return None , "IndexError error:" + str (err )
221
+ except :
222
+ aie = 1
223
+
224
+ pass
225
+
226
+ def loadNormalizationBoundsCsv (self , filePath ):
227
+
228
+ # Try importing data file
229
+ # We should code a widget that ask for options as : separators, date format, and so on...
230
+ try :
231
+
232
+ # Python contains
233
+ if pd .__version__ < '2.0.0' :
234
+ df = pd .read_csv (filePath ,
235
+ sep = ";" ,
236
+ parse_dates = None ,
237
+ date_parser = lambda x : pd .to_datetime (x , format = "" ),
238
+ skiprows = 0 ,
239
+ header = None ,
240
+ index_col = 0 )
241
+ else :
242
+ df = pd .read_csv (filePath ,
243
+ sep = ";" ,
244
+ parse_dates = None ,
245
+ date_format = "" ,
246
+ skiprows = 0 ,
247
+ header = None ,
248
+ index_col = 0 )
249
+
250
+ return df
251
+
252
+ except ValueError as err :
253
+ return None , "ValueError error:" + str (err )
254
+ except AttributeError as err :
255
+ return None , "AttributeError error:" + str (err )
256
+ except IndexError as err :
257
+ return None , "IndexError error:" + str (err )
258
+
259
+ pass
260
+
261
+
262
+ # https://stackoverflow.com/questions/53321608/extract-dataframe-from-pandas-datafeed-in-backtrader
263
+ def __bt_to_pandas__ (self , btdata , len ):
264
+ get = lambda mydata : mydata .get (ago = 0 , size = len )
265
+
266
+ fields = {
267
+ 'open' : get (btdata .open ),
268
+ 'high' : get (btdata .high ),
269
+ 'low' : get (btdata .low ),
270
+ 'close' : get (btdata .close ),
271
+ 'volume' : get (btdata .volume )
272
+ }
273
+ time = [btdata .num2date (x ) for x in get (btdata .datetime )]
274
+
275
+ return pd .DataFrame (data = fields , index = time )
276
+
277
+ pass
278
+
0 commit comments