Skip to content

[MRG] Vectorize recording during integrate#561

Open
ntolley wants to merge 3 commits intojaxleyverse:mainfrom
ntolley:record_speedup
Open

[MRG] Vectorize recording during integrate#561
ntolley wants to merge 3 commits intojaxleyverse:mainfrom
ntolley:record_speedup

Conversation

@ntolley
Copy link
Copy Markdown
Contributor

@ntolley ntolley commented Dec 20, 2024

While testing the recording from many many states, I noticed I was experiencing some serious performance hits during simulation. @jnsbck suggested that this may be due to a for loop over recordings that occurs during the integrate call. This is an attempt to vectorize that update only indexing each unique state once.

Comment thread jaxley/integrate.py Outdated
@ntolley
Copy link
Copy Markdown
Contributor Author

ntolley commented Dec 20, 2024

Here's some code to see how recordings impact speed: the comparison is pretty extreme but it gets the point across:

import jaxley as jx
import time
from jaxley.channels import Na
from jaxley.synapses import IonotropicSynapse
from jaxley.connect import fully_connect
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

cell = jx.Cell()
cell.insert(Na())

sim_time_list, array_time_list = list(), list()

net = jx.Network([cell for _ in range(100)])
fully_connect(net, net, IonotropicSynapse())

params = net.get_parameters()

# Small number of recordings (4)
net.delete_recordings()
net.cell(range(2)).record('i_IonotropicSynapse')

start_time = time.time()
v = jx.integrate(net, params=params, t_max=10.0, delta_t=0.025)
simulate_time = time.time() - start_time
print(simulate_time)

# Huge number of recordings (10,000)
net.delete_recordings()
net.record('i_IonotropicSynapse')

start_time = time.time()
v = jx.integrate(net, params=params, t_max=10.0, delta_t=0.025)
simulate_time = time.time() - start_time
print(simulate_time)

@ntolley
Copy link
Copy Markdown
Contributor Author

ntolley commented Dec 20, 2024

And finally here's my timing results for the two different branches. There's still a slight slowdown with more recordings, but it's an order of magnitude faster when recording from 10,000 synapses so I'd say it's an improvement 😄

Here's the results for a 10 ms simulation

main record_speedup
4 recordings 0.51 s 0.50 s
10,000 recordings 10.9 s 1.15 s

@ntolley ntolley changed the title WIP: Vectorize recording during integrate [MRG] Vectorize recording during integrate Dec 20, 2024
@ntolley
Copy link
Copy Markdown
Contributor Author

ntolley commented Dec 20, 2024

@michaeldeistler @jnsbck since this is mainly a performance boost I'm not sure how it should be tested. I feel like testing the execution time directly could be very brittle for running tests locally. Unless you have some ideas, perhaps it isn't necessary?

@ntolley
Copy link
Copy Markdown
Contributor Author

ntolley commented Dec 20, 2024

Unfortunately that performance hit does scale with time, here's the results for a 100 ms simulation

main record_speedup
4 recordings 0.55 s 0.63 s
10,000 recordings 16.33 s 6.85 s

So there's still some optimizations to be made...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant