You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[ENH] Add gradient computation control and improve memory management (#31)
### TL;DR
Added memory optimization features to improve performance when evaluating large models.
### What changed?
- Added a `grads` parameter to the backend tensor configuration to control PyTorch gradient computation
- Implemented proper context management for PyTorch's `no_grad()` mode
- Added garbage collection calls during chunked evaluation to prevent memory buildup
- Optimized memory usage in the evaluation kernel by immediately deleting tensors after use
- Improved error handling in the evaluation function
### How to test?
1. Test with large models that previously caused memory issues
2. Compare memory usage before and after these changes
3. Verify that model evaluation still produces correct results
4. Test with both gradient computation enabled and disabled
### Why make this change?
These optimizations address memory leaks and excessive memory usage during model evaluation, particularly for large models. By properly managing PyTorch's gradient computation and implementing strategic garbage collection, we can significantly reduce memory footprint without sacrificing performance. The immediate cleanup of tensors after use prevents memory buildup during evaluation of large datasets.
0 commit comments