Skip to content

Commit

Permalink
remove torch_scatter dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp committed Nov 3, 2021
1 parent 0e6fc7e commit 4fb4f67
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ Most recent change on the bottom.
### Added
- Option to mask out values, with correct counting, using NaN

### Removed
- Dependency on `torch_scatter`

## [0.1.0] - 2021-05-28
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Notable features:
- Arbitrary sample shapes beyond single scalars
- Reduction over arbitrary dimensions of each sample
- "Batched"/"binned" reduction into multiple running tallies using a per-sample bin index.
This can be useful, for example, in accumulating statistics over samples by some kind of "type" index or for accumulating statistics per-graph in a `pytorch_geometric`-like [batching scheme](https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html). (This feature uses and is similar to [`torch_scatter`](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html).)
This can be useful, for example, in accumulating statistics over samples by some kind of "type" index or for accumulating statistics per-graph in a `pytorch_geometric`-like [batching scheme](https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html). (This feature is similar to [`torch_scatter`](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html).)
- Option to ignore NaN values with correct sample counting.

**Note:** the implementations currently heavily uses in-place operations for peformance and memory efficiency. This probably doesn't play nice with the autograd engine — this is currently likely the wrong library for accumulating running statistics you want to backward through. (See [TorchMetrics](https://torchmetrics.readthedocs.io/en/latest/) for a possible alternative.)
Expand All @@ -18,7 +18,7 @@ For more information, please see [the docs](https://pytorch-runstats.readthedocs

## Install

`torch_runstats` requires PyTorch and [`torch_scatter`](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html), but neither is specified in `install_requires` for `pip` since both require manual installation for correct CUDA compatability.
`torch_runstats` requires PyTorch.

The library can be installed from PyPI:
```bash
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# ones.
extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_rtd_theme"]

autodoc_mock_imports = ["torch", "torch_scatter"]
autodoc_mock_imports = ["torch"]

# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pytorch_runstats

* Arbitrary sample shapes beyond single scalars
* Reduction over arbitrary dimensions of each sample
* "Batched"/"binned" reduction into multiple running tallies using a per-sample bin index. This can be useful, for example, in accumulating statistics over samples by some kind of "type" index or for accumulating statistics per-graph in a ``pytorch_geometric``-like `batching scheme <https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html>`_ . (This feature uses and is similar to `torch_scatter <https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html>`_ .)
* "Batched"/"binned" reduction into multiple running tallies using a per-sample bin index. This can be useful, for example, in accumulating statistics over samples by some kind of "type" index or for accumulating statistics per-graph in a ``pytorch_geometric``-like `batching scheme <https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html>`_ . (This feature is similar to `torch_scatter <https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html>`_ .)
* Option to ignore NaN values with correct sample counting

.. note::
Expand Down
28 changes: 19 additions & 9 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import random

import torch
from torch_scatter import scatter

from torch_runstats import RunningStats, Reduction

Expand Down Expand Up @@ -53,7 +52,7 @@ def current_result(self):

if len(average) < self._n_bins:
N_to_add = self._n_bins - len(average)
average = torch.cat((average, torch.zeros((N_to_add,)+average.shape[1:])))
average = torch.cat((average, torch.zeros((N_to_add,) + average.shape[1:])))

return average

Expand All @@ -78,8 +77,12 @@ def current_result(self):
def test_runstats(dim, reduce_dims, nan_attrs, reduction, do_accumulate_by, allclose):

n_batchs = (random.randint(1, 4), random.randint(1, 4))
truth_obj = StatsTruth(dim=dim, reduction=reduction, reduce_dims=reduce_dims, ignore_nan=nan_attrs)
runstats = RunningStats(dim=dim, reduction=reduction, reduce_dims=reduce_dims, ignore_nan=nan_attrs)
truth_obj = StatsTruth(
dim=dim, reduction=reduction, reduce_dims=reduce_dims, ignore_nan=nan_attrs
)
runstats = RunningStats(
dim=dim, reduction=reduction, reduce_dims=reduce_dims, ignore_nan=nan_attrs
)

for n_batch in n_batchs:
for _ in range(n_batch):
Expand All @@ -104,6 +107,7 @@ def test_runstats(dim, reduce_dims, nan_attrs, reduction, do_accumulate_by, allc
truth_obj.reset(reset_n_bins=True)
runstats.reset(reset_n_bins=True)


@pytest.mark.parametrize("do_accumulate_by", [True, False])
@pytest.mark.parametrize("nan_attrs", [True, False])
def test_batching(do_accumulate_by, nan_attrs, allclose):
Expand All @@ -115,29 +119,35 @@ def test_batching(do_accumulate_by, nan_attrs, allclose):

# generate reference data
data = torch.randn((n_samples,) + dim)
accumulate_by = torch.randint(0, 5, size=(data.shape[0],)) if do_accumulate_by else None
accumulate_by = (
torch.randint(0, 5, size=(data.shape[0],)) if do_accumulate_by else None
)
if nan_attrs:
ids = torch.randperm(n_samples)[:10]
for idx in ids:
data.view(-1)[idx] = float("NaN")

# compute ground truth
truth_obj = StatsTruth(dim=dim, reduction=reduction, reduce_dims=reduce_dims, ignore_nan=nan_attrs)
truth_obj = StatsTruth(
dim=dim, reduction=reduction, reduce_dims=reduce_dims, ignore_nan=nan_attrs
)
truth_obj.accumulate_batch(data, accumulate_by=accumulate_by)
truth = truth_obj.current_result()
del truth_obj

runstats = RunningStats(dim=dim, reduction=reduction, reduce_dims=reduce_dims, ignore_nan=nan_attrs)
runstats = RunningStats(
dim=dim, reduction=reduction, reduce_dims=reduce_dims, ignore_nan=nan_attrs
)

for stride in [1, 3, 5, 7, 13, 100]:
n_batch = n_samples // stride
if n_batch*stride < n_samples:
if n_batch * stride < n_samples:
n_batch += 1
count = 0
for idx in range(n_batch):

loid = count
hiid = count+stride
hiid = count + stride
hiid = n_samples if hiid > n_samples else hiid
count += stride

Expand Down
3 changes: 2 additions & 1 deletion torch_runstats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._runstats import RunningStats, Reduction
from . import scatter

__all__ = ["Reduction", "RunningStats"]
__all__ = ["Reduction", "RunningStats", "scatter"]
22 changes: 14 additions & 8 deletions torch_runstats/_runstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numbers

import torch
from torch_scatter import scatter
from .scatter import scatter


def _prod(x):
Expand Down Expand Up @@ -161,15 +161,15 @@ def batch_result(

else:

new_sum = new_sum.sum(dim=(0, ), keepdim=True)
new_sum = new_sum.sum(dim=(0,), keepdim=True)
# since all types are 0, the first dimension should be 1
N = (
torch.as_tensor([batch.shape[0]], dtype=torch.long, device=device)
* self._reduction_factor
)

if len(N.shape) < len(new_sum.shape):
N = N.reshape(N.shape+(1,)*(len(new_sum.shape)-len(N.shape)))
N = N.reshape(N.shape + (1,) * (len(new_sum.shape) - len(N.shape)))

else:

Expand All @@ -195,7 +195,9 @@ def batch_result(
# reduce along the first (batch) dimension using accumulate_by
new_sum = scatter(new_sum, accumulate_by, dim=0)

N = torch.bincount(accumulate_by).reshape((-1,)+(1,)*(len(new_sum.shape)-1))
N = torch.bincount(accumulate_by).reshape(
(-1,) + (1,) * (len(new_sum.shape) - 1)
)

# Each sample is now a reduction over _reduction_factor samples
N *= self._reduction_factor
Expand Down Expand Up @@ -237,10 +239,15 @@ def accumulate_batch(

# time to expand
self._state = torch.cat(
(self._state, self._state.new_zeros((N_to_add,) + self._state.shape[1:])),
(
self._state,
self._state.new_zeros((N_to_add,) + self._state.shape[1:]),
),
dim=0,
)
self._n = torch.cat((self._n, self._n.new_zeros((N_to_add, )+self._n.shape[1:])), dim=0)
self._n = torch.cat(
(self._n, self._n.new_zeros((N_to_add,) + self._n.shape[1:])), dim=0
)

# assert self._state.shape == (self._n_bins + N_to_add,) + self._dim
self._n_bins += N_to_add
Expand All @@ -252,7 +259,6 @@ def accumulate_batch(
)
N = torch.cat((N, N.new_zeros((-N_to_add,) + N.shape[1:])), dim=0)


self._state += (new_sum - N * self._state) / (self._n + N)
self._n += N

Expand All @@ -274,7 +280,7 @@ def reset(self, reset_n_bins: bool = False) -> None:
self._n.fill_(0)
else:
self._n_bins = 1
self._n = torch.zeros((self._n_bins,)+self._dim, dtype=torch.long)
self._n = torch.zeros((self._n_bins,) + self._dim, dtype=torch.long)
self._state = torch.zeros((self._n_bins,) + self._dim)

def to(self, device=None, dtype=None) -> None:
Expand Down
82 changes: 82 additions & 0 deletions torch_runstats/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""basic scatter operations from torch_scatter
Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency.
"""

from typing import Optional

import torch


def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src


@torch.jit.script
def scatter(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
) -> torch.Tensor:
index = _broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)


@torch.jit.script
def scatter_std(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
unbiased: bool = True,
) -> torch.Tensor:

if out is not None:
dim_size = out.size(dim)

if dim < 0:
dim = src.dim() + dim

count_dim = dim
if index.dim() <= dim:
count_dim = index.dim() - 1

ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter(ones, index, count_dim, dim_size=dim_size)

index = _broadcast(index, src, dim)
tmp = scatter(src, index, dim, dim_size=dim_size)
count = _broadcast(count, tmp, dim).clamp(1)
mean = tmp.div(count)

var = src - mean.gather(dim, index)
var = var * var
out = scatter(var, index, dim, out, dim_size)

if unbiased:
count = count.sub(1).clamp_(1)
out = out.div(count + 1e-6).sqrt()

return out

0 comments on commit 4fb4f67

Please sign in to comment.