[MRG] Vectorize recording during integrate#561
[MRG] Vectorize recording during integrate#561ntolley wants to merge 3 commits intojaxleyverse:mainfrom
integrate#561Conversation
|
Here's some code to see how recordings impact speed: the comparison is pretty extreme but it gets the point across: import jaxley as jx
import time
from jaxley.channels import Na
from jaxley.synapses import IonotropicSynapse
from jaxley.connect import fully_connect
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
cell = jx.Cell()
cell.insert(Na())
sim_time_list, array_time_list = list(), list()
net = jx.Network([cell for _ in range(100)])
fully_connect(net, net, IonotropicSynapse())
params = net.get_parameters()
# Small number of recordings (4)
net.delete_recordings()
net.cell(range(2)).record('i_IonotropicSynapse')
start_time = time.time()
v = jx.integrate(net, params=params, t_max=10.0, delta_t=0.025)
simulate_time = time.time() - start_time
print(simulate_time)
# Huge number of recordings (10,000)
net.delete_recordings()
net.record('i_IonotropicSynapse')
start_time = time.time()
v = jx.integrate(net, params=params, t_max=10.0, delta_t=0.025)
simulate_time = time.time() - start_time
print(simulate_time) |
|
And finally here's my timing results for the two different branches. There's still a slight slowdown with more recordings, but it's an order of magnitude faster when recording from 10,000 synapses so I'd say it's an improvement 😄 Here's the results for a 10 ms simulation
|
integrateintegrate
|
@michaeldeistler @jnsbck since this is mainly a performance boost I'm not sure how it should be tested. I feel like testing the execution time directly could be very brittle for running tests locally. Unless you have some ideas, perhaps it isn't necessary? |
|
Unfortunately that performance hit does scale with time, here's the results for a 100 ms simulation
So there's still some optimizations to be made... |
While testing the recording from many many states, I noticed I was experiencing some serious performance hits during simulation. @jnsbck suggested that this may be due to a for loop over recordings that occurs during the
integratecall. This is an attempt to vectorize that update only indexing each unique state once.