Skip to content

Commit 957cba3

Browse files
committed
star_nln
1 parent 610ff06 commit 957cba3

File tree

1 file changed

+394
-0
lines changed

1 file changed

+394
-0
lines changed

star_nln/star_nln.py

+394
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,394 @@
1+
import numpy as np
2+
from scipy.integrate import odeint
3+
import matplotlib.pyplot as plt
4+
import itertools
5+
from scipy.optimize import linprog
6+
import copy
7+
8+
def vanderpol(s,t):
9+
mu = 1
10+
x1, y1, x2, y2, b = s
11+
x1_dot = y1
12+
y1_dot = mu*(1-x1**2)*y1+b*(x2-x1)-x1
13+
x2_dot = y2
14+
y2_dot = mu*(1-x2**2)*y2-b*(x2-x1)-x2
15+
b_dot = 0
16+
return [x1_dot, y1_dot, x2_dot, y2_dot, b_dot]
17+
18+
def laubloomis(x,t):
19+
pass
20+
21+
class Star:
22+
def __init__(
23+
self,
24+
center: np.ndarray,
25+
basis: np.ndarray,
26+
C: np.ndarray,
27+
g: np.ndarray
28+
):
29+
self.center = center
30+
self.basis = basis
31+
self.C = C
32+
self.g = g
33+
34+
def in_star(self, point):
35+
m = self.n_basis()
36+
c = np.zeros(m)
37+
A_eq = self.basis.T
38+
b_eq = point.squeeze() - self.center.squeeze()
39+
A_ub = self.C
40+
b_ub = self.g
41+
bounds = [(-np.inf, np.inf)]*(m)
42+
res = linprog(
43+
c = c,
44+
A_eq = A_eq,
45+
b_eq = b_eq,
46+
A_ub = A_ub,
47+
b_ub = b_ub,
48+
bounds = bounds
49+
)
50+
return res.success
51+
52+
def n_basis(self):
53+
return self.basis.shape[0]
54+
55+
def n_dim(self):
56+
return len(self.center)
57+
58+
def n_pred(self):
59+
return len(self.g)
60+
61+
def visualize(self, fig: plt.axes, dim1, dim2) -> plt.axes:
62+
xs, ys = self.get_verts(dim1, dim2)
63+
#print(verts)
64+
fig.plot(xs, ys)
65+
# plt.show()
66+
return fig
67+
68+
# def normalize(self) -> Star:
69+
# tmp_center = self.center
70+
# tmp_basis = self.basis
71+
# tmp_C = self.C
72+
# tmp_g = self.g
73+
# norm_factor = np.linalg.norm(tmp_basis, axis=1)
74+
# tmp_basis = tmp_basis/norm_factor
75+
76+
@classmethod
77+
def from_rect(cls, rect: np.ndarray):
78+
n = rect.shape[1]
79+
center = (rect[0,:]+rect[1,:])/2.0
80+
basis = np.diag((rect[1,:]-rect[0,:])/2.0)
81+
C = np.zeros((n*2,n))
82+
g = np.ones((n*2,1))
83+
for i in range(n):
84+
C[i*2,i] = -1
85+
C[i*2+1,i] = 1
86+
return cls(center, basis, C, g)
87+
88+
#stanley bak code
89+
def get_verts(self, dim1, dim2):
90+
"""get the vertices of the stateset"""
91+
#TODO: generalize for n dimensional
92+
verts = []
93+
x_pts = []
94+
y_pts = []
95+
extra_dims_ct = len(self.center) - 2
96+
zeros = []
97+
for i in range(0, extra_dims_ct):
98+
zeros.append([0])
99+
# sample the angles from 0 to 2pi, 100 samples
100+
for angle in np.linspace(0, 2*np.pi, 100):
101+
x_component = np.cos(angle)
102+
y_component = np.sin(angle)
103+
#TODO: needs to work for 3d and any dim of non-graphed state
104+
direction = [[x_component], [y_component]]
105+
direction.extend(zeros)
106+
direction = np.array(direction)
107+
#for i in range(0, extra_dims_ct):
108+
# direction.append([0])
109+
#direction.extend(zeros)
110+
111+
pt = self.maximize(direction)
112+
113+
verts.append(pt)
114+
#print(pt)
115+
x_pts.append(pt[0][0])
116+
#print(pt[0][0])
117+
y_pts.append(pt[0][1])
118+
#print(pt[1][0])
119+
x_pts.append(x_pts[0])
120+
y_pts.append(y_pts[0])
121+
return (x_pts, y_pts)
122+
123+
def min_max(self, dim):
124+
n = self.n_dim()
125+
m = self.n_basis()
126+
p = self.n_pred()
127+
# A_ub = np.zeros((n+p, m+m+n))
128+
# b_ub = np.zeros((n+p, 1))
129+
# A_ub[:n,:n] = np.eye(n)
130+
# A_ub[:n,n:n+m] = self.basis.T
131+
# A_ub[n:,m+n:] = self.C
132+
# b_ub[:n,:] = center.reshape((-1,1))
133+
# b_ub[n:,:] = self.g
134+
# # Constraints enforcing equality of alphas
135+
# A_eq = np.zeros((m, n+m+m))
136+
# b_eq = np.zeros((m, 1))
137+
# A_eq[:,n:n+m] = np.eye(m)
138+
# A_eq[:,n+m:n+m+m] = -np.eye(m)
139+
A_eq = np.zeros((n, m+n))
140+
b_eq = np.zeros((n, 1))
141+
A_eq[:,:n] = np.eye(n)
142+
A_eq[:,n:] = self.basis.T
143+
b_eq = self.center.reshape((-1,1))
144+
A_ub = np.zeros((p, n+m))
145+
b_ub = np.zeros((p,1))
146+
A_ub[:,n:] = self.C
147+
b_ub = self.g.reshape((-1,1))
148+
c = np.zeros(n+m)
149+
c[dim] = 1
150+
bounds = [(-np.inf, np.inf)]*(n+m)
151+
res_min = linprog(
152+
c = c,
153+
A_ub = A_ub,
154+
b_ub = b_ub,
155+
A_eq = A_eq,
156+
b_eq = b_eq,
157+
bounds = bounds
158+
)
159+
res_max = linprog(
160+
c = -c,
161+
A_ub = A_ub,
162+
b_ub = b_ub,
163+
A_eq = A_eq,
164+
b_eq = b_eq,
165+
bounds = bounds
166+
)
167+
168+
return (res_min.x[dim], res_max.x[dim])
169+
170+
def maximize(self, opt_direction):
171+
"""return the point that maximizes the direction"""
172+
173+
opt_direction *= -1
174+
175+
# convert opt_direction to domain
176+
#print(self.basis)
177+
domain_direction = opt_direction.T @ self.basis
178+
179+
180+
# use LP to optimize the constraints
181+
res = linprog(c=domain_direction,
182+
A_ub=self.C,
183+
b_ub=self.g,
184+
bounds=(None, None))
185+
186+
# convert point back to range
187+
#print(res.x)
188+
domain_pt = res.x.reshape((res.x.shape[0], 1))
189+
#if domain_direction[0][0] == -1 and domain_direction[0][1] == 0:
190+
# print(domain_pt)
191+
#range_pt = self.center + self.basis @ domain_pt
192+
range_pt = self.center + domain_pt.T @ self.basis
193+
194+
#if domain_direction[0][0] == -1 and domain_direction[0][1] == 0:
195+
# print(self.basis)
196+
# print(range_pt)
197+
# return the point
198+
return range_pt
199+
200+
def find_alpha(star:Star, point:np.ndarray):
201+
point = point.reshape((-1,1))
202+
p = star.n_pred()
203+
n = star.n_dim()
204+
m = star.n_basis()
205+
tmp = np.linalg.det(star.basis)
206+
# if tmp<1e-8:
207+
# tmp = np.random.randint(0,2,(5,5))*2-1
208+
# star.basis = star.basis*tmp
209+
A_eq = np.zeros((n, 2*m))
210+
A_eq[:,:m] = star.basis.T
211+
b_eq = point - star.center.reshape((-1,1))
212+
A_ub = np.zeros((2*m,2*m))
213+
A_ub[:m,:m] = np.eye(m)
214+
A_ub[:m,m:2*m] = -np.eye(m)
215+
A_ub[m:2*m,:m] = -np.eye(m)
216+
A_ub[m:2*m,m:2*m] = -np.eye(m)
217+
b_ub = np.zeros((2*m, 1))
218+
c = np.zeros(2*m)
219+
c[m:] = 1
220+
bounds = [(-1000,1000)]*2*m
221+
res = linprog(
222+
c = c,
223+
A_ub = A_ub,
224+
b_ub = b_ub,
225+
A_eq = A_eq,
226+
b_eq = b_eq,
227+
bounds = bounds,
228+
method= 'highs-ipm'
229+
)
230+
return res.x[:m]
231+
232+
def sample_rect(rect, n=1):
233+
res = np.random.uniform(rect[0], rect[1], (n, len(rect[0])))
234+
return res
235+
236+
if __name__ == "__main__":
237+
init_rect = [
238+
[1.25, 2.35, 1.25, 2.35, 1],
239+
[1.55, 2.45, 1.55, 2.45, 3]
240+
]
241+
242+
init = sample_rect(init_rect, 100)
243+
t = np.arange(0, 7, 0.01)
244+
fig0 = plt.figure(0)
245+
plt.plot([0,0],[1.25,1.55])
246+
fig1 = plt.figure(1)
247+
plt.plot([0,0],[2.35,2.45])
248+
fig2 = plt.figure(2)
249+
plt.plot([0,0],[1.25,1.55])
250+
fig3 = plt.figure(3)
251+
plt.plot([0,0],[2.35,2.45])
252+
traces = []
253+
for i in range(init.shape[0]):
254+
trace = odeint(vanderpol, init[i,:], t)
255+
plt.figure(0)
256+
plt.plot(t, trace[:,0],'r')
257+
plt.figure(1)
258+
plt.plot(t, trace[:,1],'r')
259+
plt.figure(2)
260+
plt.plot(t, trace[:,2],'r')
261+
plt.figure(3)
262+
plt.plot(t, trace[:,3],'r')
263+
traces.append(trace)
264+
265+
inits = itertools.product([1.25,1.55],[2.35,2.45],[1.25,1.55],[2.35,2.45],[1.0,3.0])
266+
267+
# corner_traces = []
268+
# for init in inits:
269+
# trace = odeint(vanderpol, init, t)
270+
# plt.figure(0)
271+
# plt.plot(t, trace[:,0],'b')
272+
# plt.figure(1)
273+
# plt.plot(t, trace[:,1],'b')
274+
# plt.figure(2)
275+
# plt.plot(t, trace[:,2],'b')
276+
# plt.figure(3)
277+
# plt.plot(t, trace[:,3],'b')
278+
# corner_traces.append(trace)
279+
# plt.show()
280+
281+
# Compute star set for each axis assuming linearality
282+
init_star = Star.from_rect(np.array(init_rect))
283+
num_basis = init_star.n_basis()
284+
center = init_star.center
285+
center_trace = odeint(vanderpol, center, t)
286+
basis_trace = []
287+
for i in range(num_basis):
288+
x0 = center + init_star.basis[i,:]
289+
tmp = odeint(vanderpol, x0, t)
290+
basis_trace.append(tmp)
291+
# star_vertices =
292+
star_list = []
293+
for i in range(center_trace.shape[0]):
294+
tmp_center = center_trace[i]
295+
tmp_basis = np.zeros(init_star.basis.shape)
296+
for j in range(len(basis_trace)):
297+
tmp_basis[j,:] = basis_trace[j][i,:] - tmp_center
298+
tmp_C = init_star.C
299+
tmp_g = init_star.g
300+
star_list.append(Star(tmp_center, tmp_basis, tmp_C, tmp_g))
301+
302+
trace0 = []
303+
trace1 = []
304+
trace2 = []
305+
trace3 = []
306+
for i in range(len(star_list)):
307+
star = star_list[i]
308+
trace0.append(star.min_max(0))
309+
trace1.append(star.min_max(1))
310+
trace2.append(star.min_max(2))
311+
trace3.append(star.min_max(3))
312+
trace0 = np.array(trace0)
313+
trace1 = np.array(trace1)
314+
trace2 = np.array(trace2)
315+
trace3 = np.array(trace3)
316+
plt.figure(0)
317+
plt.plot(t, trace0[:,0], 'g')
318+
plt.plot(t, trace0[:,1], 'g')
319+
plt.figure(1)
320+
plt.plot(t, trace1[:,0], 'g')
321+
plt.plot(t, trace1[:,1], 'g')
322+
plt.figure(2)
323+
plt.plot(t, trace2[:,0], 'g')
324+
plt.plot(t, trace2[:,1], 'g')
325+
plt.figure(3)
326+
plt.plot(t, trace3[:,0], 'g')
327+
plt.plot(t, trace3[:,1], 'g')
328+
# plt.show()
329+
330+
# Bloat each stars
331+
# 1) Collect simulation trajectories
332+
# 2) Normalize basis for all the stars (skip)
333+
# 3) At each time step, find alpha for each trajectories
334+
new_star_list = []
335+
for i in range(len(t)):
336+
print(i)
337+
if i==305:
338+
print("stop here")
339+
curr_star = star_list[i]
340+
# curr_star.normalize()
341+
alpha_list = []
342+
for traj in traces:
343+
alpha = find_alpha(curr_star, traj[i,:])
344+
alpha_list.append(alpha)
345+
# 4) Find g such that all the alphas at that time step satisfy Calpha\leq g
346+
g = copy.deepcopy(curr_star.g)
347+
for alpha in alpha_list:
348+
tmp_g = curr_star.C@alpha.reshape((-1,1))
349+
g = np.maximum(g, tmp_g)
350+
new_star = Star(curr_star.center, curr_star.basis, curr_star.C, g)
351+
new_star_list.append(new_star)
352+
353+
trace0 = []
354+
trace1 = []
355+
trace2 = []
356+
trace3 = []
357+
for i in range(len(new_star_list)):
358+
star = new_star_list[i]
359+
trace0.append(star.min_max(0))
360+
trace1.append(star.min_max(1))
361+
trace2.append(star.min_max(2))
362+
trace3.append(star.min_max(3))
363+
trace0 = np.array(trace0)
364+
trace1 = np.array(trace1)
365+
trace2 = np.array(trace2)
366+
trace3 = np.array(trace3)
367+
plt.figure(0)
368+
plt.plot(t, trace0[:,0], 'y')
369+
plt.plot(t, trace0[:,1], 'y')
370+
plt.figure(1)
371+
plt.plot(t, trace1[:,0], 'y')
372+
plt.plot(t, trace1[:,1], 'y')
373+
plt.figure(2)
374+
plt.plot(t, trace2[:,0], 'y')
375+
plt.plot(t, trace2[:,1], 'y')
376+
plt.figure(3)
377+
plt.plot(t, trace3[:,0], 'y')
378+
plt.plot(t, trace3[:,1], 'y')
379+
380+
fig4 = plt.figure(4)
381+
ax4 = fig4.gca()
382+
# for i in range(new_star_list):
383+
star = new_star_list[305]
384+
ax4 = star.visualize(ax4, None, None)
385+
plt.plot(star.center[0], star.center[1], 'g*')
386+
for trace in traces:
387+
plt.plot(trace[305,0], trace[305,1], 'r*')
388+
for i in range(star.basis.shape[0]):
389+
base = star.basis[i,:]*50
390+
x = [star.center[0], star.center[0]+base[0]]
391+
y = [star.center[1], star.center[1]+base[1]]
392+
plt.plot(copy.deepcopy(x),copy.deepcopy(y),'b')
393+
plt.show()
394+
print("here")

0 commit comments

Comments
 (0)