Skip to content

Commit

Permalink
add learning of event function example
Browse files Browse the repository at this point in the history
  • Loading branch information
rtqichen committed Dec 27, 2022
1 parent 19a8c70 commit 08007fb
Show file tree
Hide file tree
Showing 3 changed files with 339 additions and 22 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=ode

The solve is terminated at an event time `t` and state `y` when an element of `event_fn(t, y)` is equal to zero. Multiple outputs from `event_fn` can be used to specify multiple event functions, of which the first to trigger will terminate the solve.

Both the event time and final state are returned from `odeint_event`, and can be differentiated. Gradients will be backpropagated through the event function.
Both the event time and final state are returned from `odeint_event`, and can be differentiated. Gradients will be backpropagated through the event function. **NOTE**: parameters for the event function must be in the state itself to obtain gradients.

The numerical precision for the event time is determined by the `atol` argument.

See example of simulating and differentiating through a bouncing ball in [`examples/bouncing_ball.py`](./examples/bouncing_ball.py).
See example of simulating and differentiating through a bouncing ball in [`examples/bouncing_ball.py`](./examples/bouncing_ball.py). See example code for learning a simple event function in [`examples/learn_physics.py`](./examples/learn_physics.py).

<p align="center">
<img align="middle" src="./assets/bouncing_ball.png" alt="Bouncing Ball" width="500" height="250" />
Expand Down
73 changes: 53 additions & 20 deletions examples/bouncing_ball.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class BouncingBallExample(nn.Module):

def __init__(self, radius=0.2, gravity=9.8, adjoint=False):
super().__init__()
self.gravity = nn.Parameter(torch.as_tensor([gravity]))
Expand All @@ -39,9 +38,11 @@ def get_initial_state(self):
return self.t0, state

def state_update(self, state):
""" Updates state based on an event (collision)."""
"""Updates state based on an event (collision)."""
pos, vel, log_radius = state
pos = pos + 1e-7 # need to add a small eps so as not to trigger the event function immediately.
pos = (
pos + 1e-7
) # need to add a small eps so as not to trigger the event function immediately.
vel = -vel * (1 - self.absorption)
return (pos, vel, log_radius)

Expand All @@ -52,7 +53,16 @@ def get_collision_times(self, nbounces=1):
t0, state = self.get_initial_state()

for i in range(nbounces):
event_t, solution = odeint_event(self, state, t0, event_fn=self.event_fn, reverse_time=False, atol=1e-8, rtol=1e-8, odeint_interface=self.odeint)
event_t, solution = odeint_event(
self,
state,
t0,
event_fn=self.event_fn,
reverse_time=False,
atol=1e-8,
rtol=1e-8,
odeint_interface=self.odeint,
)
event_times.append(event_t)

state = self.state_update(tuple(s[-1] for s in solution))
Expand All @@ -69,18 +79,25 @@ def simulate(self, nbounces=1):
velocity = [state[1][None]]
times = [t0.reshape(-1)]
for event_t in event_times:
tt = torch.linspace(float(t0), float(event_t), int((float(event_t) - float(t0)) * 50))[1:-1]
tt = torch.linspace(
float(t0), float(event_t), int((float(event_t) - float(t0)) * 50)
)[1:-1]
tt = torch.cat([t0.reshape(-1), tt, event_t.reshape(-1)])
solution = odeint(self, state, tt, atol=1e-8, rtol=1e-8)

trajectory.append(solution[0])
velocity.append(solution[1])
times.append(tt)
trajectory.append(solution[0][1:])
velocity.append(solution[1][1:])
times.append(tt[1:])

state = self.state_update(tuple(s[-1] for s in solution))
t0 = event_t

return torch.cat(times), torch.cat(trajectory, dim=0).reshape(-1), torch.cat(velocity, dim=0).reshape(-1), event_times
return (
torch.cat(times),
torch.cat(trajectory, dim=0).reshape(-1),
torch.cat(velocity, dim=0).reshape(-1),
event_times,
)


def gradcheck(nbounces):
Expand Down Expand Up @@ -124,7 +141,9 @@ def gradcheck(nbounces):
fd = fd_grads[var]
if torch.norm(analytical - fd) > 1e-4:
success = False
print(f"Got analytical grad {analytical.item()} for {var} param but finite difference is {fd.item()}")
print(
f"Got analytical grad {analytical.item()} for {var} param but finite difference is {fd.item()}"
)

if not success:
raise Exception("Gradient check failed.")
Expand Down Expand Up @@ -152,10 +171,20 @@ def gradcheck(nbounces):

# Event locations.
for event_t in event_times:
plt.plot(event_t, 0.0, color="C0", marker="o", markersize=7, fillstyle='none', linestyle="")

vel, = plt.plot(times, velocity, color="C1", alpha=0.7, linestyle="--", linewidth=2.0)
pos, = plt.plot(times, trajectory, color="C0", linewidth=2.0)
plt.plot(
event_t,
0.0,
color="C0",
marker="o",
markersize=7,
fillstyle="none",
linestyle="",
)

(vel,) = plt.plot(
times, velocity, color="C1", alpha=0.7, linestyle="--", linewidth=2.0
)
(pos,) = plt.plot(times, trajectory, color="C0", linewidth=2.0)

plt.hlines(0, 0, 100)
plt.xlim([times[0], times[-1]])
Expand All @@ -164,16 +193,20 @@ def gradcheck(nbounces):
plt.xlabel("Time", fontsize=13)
plt.legend([pos, vel], ["Position", "Velocity"], fontsize=16)

plt.gca().xaxis.set_tick_params(direction='in', which='both') # The bottom will maintain the default of 'out'
plt.gca().yaxis.set_tick_params(direction='in', which='both') # The bottom will maintain the default of 'out'
plt.gca().xaxis.set_tick_params(
direction="in", which="both"
) # The bottom will maintain the default of 'out'
plt.gca().yaxis.set_tick_params(
direction="in", which="both"
) # The bottom will maintain the default of 'out'

# Hide the right and top spines
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["top"].set_visible(False)

# Only show ticks on the left and bottom spines
plt.gca().yaxis.set_ticks_position('left')
plt.gca().xaxis.set_ticks_position('bottom')
plt.gca().yaxis.set_ticks_position("left")
plt.gca().xaxis.set_ticks_position("bottom")

plt.tight_layout()
plt.savefig("bouncing_ball.png")
Loading

0 comments on commit 08007fb

Please sign in to comment.