1
+ import numpy as np
2
+ from scipy .integrate import odeint
3
+ import matplotlib .pyplot as plt
4
+ import itertools
5
+ from scipy .optimize import linprog
6
+ import copy
7
+
8
+ def vanderpol (s ,t ):
9
+ mu = 1
10
+ x1 , y1 , x2 , y2 , b = s
11
+ x1_dot = y1
12
+ y1_dot = mu * (1 - x1 ** 2 )* y1 + b * (x2 - x1 )- x1
13
+ x2_dot = y2
14
+ y2_dot = mu * (1 - x2 ** 2 )* y2 - b * (x2 - x1 )- x2
15
+ b_dot = 0
16
+ return [x1_dot , y1_dot , x2_dot , y2_dot , b_dot ]
17
+
18
+ def laubloomis (x ,t ):
19
+ pass
20
+
21
+ class Star :
22
+ def __init__ (
23
+ self ,
24
+ center : np .ndarray ,
25
+ basis : np .ndarray ,
26
+ C : np .ndarray ,
27
+ g : np .ndarray
28
+ ):
29
+ self .center = center
30
+ self .basis = basis
31
+ self .C = C
32
+ self .g = g
33
+
34
+ def in_star (self , point ):
35
+ m = self .n_basis ()
36
+ c = np .zeros (m )
37
+ A_eq = self .basis .T
38
+ b_eq = point .squeeze () - self .center .squeeze ()
39
+ A_ub = self .C
40
+ b_ub = self .g
41
+ bounds = [(- np .inf , np .inf )]* (m )
42
+ res = linprog (
43
+ c = c ,
44
+ A_eq = A_eq ,
45
+ b_eq = b_eq ,
46
+ A_ub = A_ub ,
47
+ b_ub = b_ub ,
48
+ bounds = bounds
49
+ )
50
+ return res .success
51
+
52
+ def n_basis (self ):
53
+ return self .basis .shape [0 ]
54
+
55
+ def n_dim (self ):
56
+ return len (self .center )
57
+
58
+ def n_pred (self ):
59
+ return len (self .g )
60
+
61
+ def visualize (self , fig : plt .axes , dim1 , dim2 ) -> plt .axes :
62
+ xs , ys = self .get_verts (dim1 , dim2 )
63
+ #print(verts)
64
+ fig .plot (xs , ys )
65
+ # plt.show()
66
+ return fig
67
+
68
+ # def normalize(self) -> Star:
69
+ # tmp_center = self.center
70
+ # tmp_basis = self.basis
71
+ # tmp_C = self.C
72
+ # tmp_g = self.g
73
+ # norm_factor = np.linalg.norm(tmp_basis, axis=1)
74
+ # tmp_basis = tmp_basis/norm_factor
75
+
76
+ @classmethod
77
+ def from_rect (cls , rect : np .ndarray ):
78
+ n = rect .shape [1 ]
79
+ center = (rect [0 ,:]+ rect [1 ,:])/ 2.0
80
+ basis = np .diag ((rect [1 ,:]- rect [0 ,:])/ 2.0 )
81
+ C = np .zeros ((n * 2 ,n ))
82
+ g = np .ones ((n * 2 ,1 ))
83
+ for i in range (n ):
84
+ C [i * 2 ,i ] = - 1
85
+ C [i * 2 + 1 ,i ] = 1
86
+ return cls (center , basis , C , g )
87
+
88
+ #stanley bak code
89
+ def get_verts (self , dim1 , dim2 ):
90
+ """get the vertices of the stateset"""
91
+ #TODO: generalize for n dimensional
92
+ verts = []
93
+ x_pts = []
94
+ y_pts = []
95
+ extra_dims_ct = len (self .center ) - 2
96
+ zeros = []
97
+ for i in range (0 , extra_dims_ct ):
98
+ zeros .append ([0 ])
99
+ # sample the angles from 0 to 2pi, 100 samples
100
+ for angle in np .linspace (0 , 2 * np .pi , 100 ):
101
+ x_component = np .cos (angle )
102
+ y_component = np .sin (angle )
103
+ #TODO: needs to work for 3d and any dim of non-graphed state
104
+ direction = [[x_component ], [y_component ]]
105
+ direction .extend (zeros )
106
+ direction = np .array (direction )
107
+ #for i in range(0, extra_dims_ct):
108
+ # direction.append([0])
109
+ #direction.extend(zeros)
110
+
111
+ pt = self .maximize (direction )
112
+
113
+ verts .append (pt )
114
+ #print(pt)
115
+ x_pts .append (pt [0 ][0 ])
116
+ #print(pt[0][0])
117
+ y_pts .append (pt [0 ][1 ])
118
+ #print(pt[1][0])
119
+ x_pts .append (x_pts [0 ])
120
+ y_pts .append (y_pts [0 ])
121
+ return (x_pts , y_pts )
122
+
123
+ def min_max (self , dim ):
124
+ n = self .n_dim ()
125
+ m = self .n_basis ()
126
+ p = self .n_pred ()
127
+ # A_ub = np.zeros((n+p, m+m+n))
128
+ # b_ub = np.zeros((n+p, 1))
129
+ # A_ub[:n,:n] = np.eye(n)
130
+ # A_ub[:n,n:n+m] = self.basis.T
131
+ # A_ub[n:,m+n:] = self.C
132
+ # b_ub[:n,:] = center.reshape((-1,1))
133
+ # b_ub[n:,:] = self.g
134
+ # # Constraints enforcing equality of alphas
135
+ # A_eq = np.zeros((m, n+m+m))
136
+ # b_eq = np.zeros((m, 1))
137
+ # A_eq[:,n:n+m] = np.eye(m)
138
+ # A_eq[:,n+m:n+m+m] = -np.eye(m)
139
+ A_eq = np .zeros ((n , m + n ))
140
+ b_eq = np .zeros ((n , 1 ))
141
+ A_eq [:,:n ] = np .eye (n )
142
+ A_eq [:,n :] = self .basis .T
143
+ b_eq = self .center .reshape ((- 1 ,1 ))
144
+ A_ub = np .zeros ((p , n + m ))
145
+ b_ub = np .zeros ((p ,1 ))
146
+ A_ub [:,n :] = self .C
147
+ b_ub = self .g .reshape ((- 1 ,1 ))
148
+ c = np .zeros (n + m )
149
+ c [dim ] = 1
150
+ bounds = [(- np .inf , np .inf )]* (n + m )
151
+ res_min = linprog (
152
+ c = c ,
153
+ A_ub = A_ub ,
154
+ b_ub = b_ub ,
155
+ A_eq = A_eq ,
156
+ b_eq = b_eq ,
157
+ bounds = bounds
158
+ )
159
+ res_max = linprog (
160
+ c = - c ,
161
+ A_ub = A_ub ,
162
+ b_ub = b_ub ,
163
+ A_eq = A_eq ,
164
+ b_eq = b_eq ,
165
+ bounds = bounds
166
+ )
167
+
168
+ return (res_min .x [dim ], res_max .x [dim ])
169
+
170
+ def maximize (self , opt_direction ):
171
+ """return the point that maximizes the direction"""
172
+
173
+ opt_direction *= - 1
174
+
175
+ # convert opt_direction to domain
176
+ #print(self.basis)
177
+ domain_direction = opt_direction .T @ self .basis
178
+
179
+
180
+ # use LP to optimize the constraints
181
+ res = linprog (c = domain_direction ,
182
+ A_ub = self .C ,
183
+ b_ub = self .g ,
184
+ bounds = (None , None ))
185
+
186
+ # convert point back to range
187
+ #print(res.x)
188
+ domain_pt = res .x .reshape ((res .x .shape [0 ], 1 ))
189
+ #if domain_direction[0][0] == -1 and domain_direction[0][1] == 0:
190
+ # print(domain_pt)
191
+ #range_pt = self.center + self.basis @ domain_pt
192
+ range_pt = self .center + domain_pt .T @ self .basis
193
+
194
+ #if domain_direction[0][0] == -1 and domain_direction[0][1] == 0:
195
+ # print(self.basis)
196
+ # print(range_pt)
197
+ # return the point
198
+ return range_pt
199
+
200
+ def find_alpha (star :Star , point :np .ndarray ):
201
+ point = point .reshape ((- 1 ,1 ))
202
+ p = star .n_pred ()
203
+ n = star .n_dim ()
204
+ m = star .n_basis ()
205
+ tmp = np .linalg .det (star .basis )
206
+ # if tmp<1e-8:
207
+ # tmp = np.random.randint(0,2,(5,5))*2-1
208
+ # star.basis = star.basis*tmp
209
+ A_eq = np .zeros ((n , 2 * m ))
210
+ A_eq [:,:m ] = star .basis .T
211
+ b_eq = point - star .center .reshape ((- 1 ,1 ))
212
+ A_ub = np .zeros ((2 * m ,2 * m ))
213
+ A_ub [:m ,:m ] = np .eye (m )
214
+ A_ub [:m ,m :2 * m ] = - np .eye (m )
215
+ A_ub [m :2 * m ,:m ] = - np .eye (m )
216
+ A_ub [m :2 * m ,m :2 * m ] = - np .eye (m )
217
+ b_ub = np .zeros ((2 * m , 1 ))
218
+ c = np .zeros (2 * m )
219
+ c [m :] = 1
220
+ bounds = [(- 1000 ,1000 )]* 2 * m
221
+ res = linprog (
222
+ c = c ,
223
+ A_ub = A_ub ,
224
+ b_ub = b_ub ,
225
+ A_eq = A_eq ,
226
+ b_eq = b_eq ,
227
+ bounds = bounds ,
228
+ method = 'highs-ipm'
229
+ )
230
+ return res .x [:m ]
231
+
232
+ def sample_rect (rect , n = 1 ):
233
+ res = np .random .uniform (rect [0 ], rect [1 ], (n , len (rect [0 ])))
234
+ return res
235
+
236
+ if __name__ == "__main__" :
237
+ init_rect = [
238
+ [1.25 , 2.35 , 1.25 , 2.35 , 1 ],
239
+ [1.55 , 2.45 , 1.55 , 2.45 , 3 ]
240
+ ]
241
+
242
+ init = sample_rect (init_rect , 100 )
243
+ t = np .arange (0 , 7 , 0.01 )
244
+ fig0 = plt .figure (0 )
245
+ plt .plot ([0 ,0 ],[1.25 ,1.55 ])
246
+ fig1 = plt .figure (1 )
247
+ plt .plot ([0 ,0 ],[2.35 ,2.45 ])
248
+ fig2 = plt .figure (2 )
249
+ plt .plot ([0 ,0 ],[1.25 ,1.55 ])
250
+ fig3 = plt .figure (3 )
251
+ plt .plot ([0 ,0 ],[2.35 ,2.45 ])
252
+ traces = []
253
+ for i in range (init .shape [0 ]):
254
+ trace = odeint (vanderpol , init [i ,:], t )
255
+ plt .figure (0 )
256
+ plt .plot (t , trace [:,0 ],'r' )
257
+ plt .figure (1 )
258
+ plt .plot (t , trace [:,1 ],'r' )
259
+ plt .figure (2 )
260
+ plt .plot (t , trace [:,2 ],'r' )
261
+ plt .figure (3 )
262
+ plt .plot (t , trace [:,3 ],'r' )
263
+ traces .append (trace )
264
+
265
+ inits = itertools .product ([1.25 ,1.55 ],[2.35 ,2.45 ],[1.25 ,1.55 ],[2.35 ,2.45 ],[1.0 ,3.0 ])
266
+
267
+ # corner_traces = []
268
+ # for init in inits:
269
+ # trace = odeint(vanderpol, init, t)
270
+ # plt.figure(0)
271
+ # plt.plot(t, trace[:,0],'b')
272
+ # plt.figure(1)
273
+ # plt.plot(t, trace[:,1],'b')
274
+ # plt.figure(2)
275
+ # plt.plot(t, trace[:,2],'b')
276
+ # plt.figure(3)
277
+ # plt.plot(t, trace[:,3],'b')
278
+ # corner_traces.append(trace)
279
+ # plt.show()
280
+
281
+ # Compute star set for each axis assuming linearality
282
+ init_star = Star .from_rect (np .array (init_rect ))
283
+ num_basis = init_star .n_basis ()
284
+ center = init_star .center
285
+ center_trace = odeint (vanderpol , center , t )
286
+ basis_trace = []
287
+ for i in range (num_basis ):
288
+ x0 = center + init_star .basis [i ,:]
289
+ tmp = odeint (vanderpol , x0 , t )
290
+ basis_trace .append (tmp )
291
+ # star_vertices =
292
+ star_list = []
293
+ for i in range (center_trace .shape [0 ]):
294
+ tmp_center = center_trace [i ]
295
+ tmp_basis = np .zeros (init_star .basis .shape )
296
+ for j in range (len (basis_trace )):
297
+ tmp_basis [j ,:] = basis_trace [j ][i ,:] - tmp_center
298
+ tmp_C = init_star .C
299
+ tmp_g = init_star .g
300
+ star_list .append (Star (tmp_center , tmp_basis , tmp_C , tmp_g ))
301
+
302
+ trace0 = []
303
+ trace1 = []
304
+ trace2 = []
305
+ trace3 = []
306
+ for i in range (len (star_list )):
307
+ star = star_list [i ]
308
+ trace0 .append (star .min_max (0 ))
309
+ trace1 .append (star .min_max (1 ))
310
+ trace2 .append (star .min_max (2 ))
311
+ trace3 .append (star .min_max (3 ))
312
+ trace0 = np .array (trace0 )
313
+ trace1 = np .array (trace1 )
314
+ trace2 = np .array (trace2 )
315
+ trace3 = np .array (trace3 )
316
+ plt .figure (0 )
317
+ plt .plot (t , trace0 [:,0 ], 'g' )
318
+ plt .plot (t , trace0 [:,1 ], 'g' )
319
+ plt .figure (1 )
320
+ plt .plot (t , trace1 [:,0 ], 'g' )
321
+ plt .plot (t , trace1 [:,1 ], 'g' )
322
+ plt .figure (2 )
323
+ plt .plot (t , trace2 [:,0 ], 'g' )
324
+ plt .plot (t , trace2 [:,1 ], 'g' )
325
+ plt .figure (3 )
326
+ plt .plot (t , trace3 [:,0 ], 'g' )
327
+ plt .plot (t , trace3 [:,1 ], 'g' )
328
+ # plt.show()
329
+
330
+ # Bloat each stars
331
+ # 1) Collect simulation trajectories
332
+ # 2) Normalize basis for all the stars (skip)
333
+ # 3) At each time step, find alpha for each trajectories
334
+ new_star_list = []
335
+ for i in range (len (t )):
336
+ print (i )
337
+ if i == 305 :
338
+ print ("stop here" )
339
+ curr_star = star_list [i ]
340
+ # curr_star.normalize()
341
+ alpha_list = []
342
+ for traj in traces :
343
+ alpha = find_alpha (curr_star , traj [i ,:])
344
+ alpha_list .append (alpha )
345
+ # 4) Find g such that all the alphas at that time step satisfy Calpha\leq g
346
+ g = copy .deepcopy (curr_star .g )
347
+ for alpha in alpha_list :
348
+ tmp_g = curr_star .C @alpha .reshape ((- 1 ,1 ))
349
+ g = np .maximum (g , tmp_g )
350
+ new_star = Star (curr_star .center , curr_star .basis , curr_star .C , g )
351
+ new_star_list .append (new_star )
352
+
353
+ trace0 = []
354
+ trace1 = []
355
+ trace2 = []
356
+ trace3 = []
357
+ for i in range (len (new_star_list )):
358
+ star = new_star_list [i ]
359
+ trace0 .append (star .min_max (0 ))
360
+ trace1 .append (star .min_max (1 ))
361
+ trace2 .append (star .min_max (2 ))
362
+ trace3 .append (star .min_max (3 ))
363
+ trace0 = np .array (trace0 )
364
+ trace1 = np .array (trace1 )
365
+ trace2 = np .array (trace2 )
366
+ trace3 = np .array (trace3 )
367
+ plt .figure (0 )
368
+ plt .plot (t , trace0 [:,0 ], 'y' )
369
+ plt .plot (t , trace0 [:,1 ], 'y' )
370
+ plt .figure (1 )
371
+ plt .plot (t , trace1 [:,0 ], 'y' )
372
+ plt .plot (t , trace1 [:,1 ], 'y' )
373
+ plt .figure (2 )
374
+ plt .plot (t , trace2 [:,0 ], 'y' )
375
+ plt .plot (t , trace2 [:,1 ], 'y' )
376
+ plt .figure (3 )
377
+ plt .plot (t , trace3 [:,0 ], 'y' )
378
+ plt .plot (t , trace3 [:,1 ], 'y' )
379
+
380
+ fig4 = plt .figure (4 )
381
+ ax4 = fig4 .gca ()
382
+ # for i in range(new_star_list):
383
+ star = new_star_list [305 ]
384
+ ax4 = star .visualize (ax4 , None , None )
385
+ plt .plot (star .center [0 ], star .center [1 ], 'g*' )
386
+ for trace in traces :
387
+ plt .plot (trace [305 ,0 ], trace [305 ,1 ], 'r*' )
388
+ for i in range (star .basis .shape [0 ]):
389
+ base = star .basis [i ,:]* 50
390
+ x = [star .center [0 ], star .center [0 ]+ base [0 ]]
391
+ y = [star .center [1 ], star .center [1 ]+ base [1 ]]
392
+ plt .plot (copy .deepcopy (x ),copy .deepcopy (y ),'b' )
393
+ plt .show ()
394
+ print ("here" )
0 commit comments