diff --git a/namedtensor/core.py b/namedtensor/core.py index 7486c9f..9674ecc 100644 --- a/namedtensor/core.py +++ b/namedtensor/core.py @@ -35,6 +35,10 @@ def __init__(self, tensor, names, mask=0): "Tensor has %d dim, but only %d names" % (len(self._tensor.shape), len(self._schema._names)) ) + for name in self._schema._names: + assert name.isalnum(), ( + "dim name %s must be alphanumeric" % name + ) @property def dims(self): diff --git a/namedtensor/test_core.py b/namedtensor/test_core.py index bedf8b3..f5e8ef4 100644 --- a/namedtensor/test_core.py +++ b/namedtensor/test_core.py @@ -10,6 +10,17 @@ def make_tensors(sizes, names): return [ntorch.randn(sizes, names=names)] +def test_names(): + base = torch.zeros([10, 2, 50]) + assert ntorch.tensor(base, ("alpha", "beta", "gamma")) + + +@pytest.mark.xfail +def test_old_nonzero_names(): + base = torch.zeros([10, 2]) + assert ntorch.tensor(base, ("elements_dim", "input_dims")) + + def test_shift(): for ntensor in make_tensors((10, 2, 50), ("alpha", "beta", "gamma")): # Split @@ -287,9 +298,9 @@ def test_nonzero(): # only zeros x = ntorch.zeros(10, names=("alpha",)) y = x.nonzero() - assert 0 == y.size("elements_dim") + assert 0 == y.size("elementsdim") assert x.shape == OrderedDict([("alpha", 10)]) - assert y.shape == OrderedDict([("elements_dim", 0), ("input_dims", 1)]) + assert y.shape == OrderedDict([("elementsdim", 0), ("inputdims", 1)]) # `names` length must be 2 y = x.nonzero(names=("a", "b")) @@ -299,9 +310,9 @@ def test_nonzero(): # 1d tensor x = ntorch.tensor([0, 1, 2, 0, 5], names=("dim",)) y = x.nonzero() - assert 3 == y.size("elements_dim") + assert 3 == y.size("elementsdim") assert x.shape == OrderedDict([("dim", 5)]) - assert y.shape == OrderedDict([("elements_dim", 3), ("input_dims", 1)]) + assert y.shape == OrderedDict([("elementsdim", 3), ("inputdims", 1)]) # `names` length must be 2 y = x.nonzero(names=("a", "b")) @@ -319,9 +330,9 @@ def test_nonzero(): names=("alpha", "beta"), ) y = x.nonzero() - assert 5 == y.size("elements_dim") + assert 5 == y.size("elementsdim") assert x.shape == OrderedDict([("alpha", 4), ("beta", 4)]) - assert y.shape == OrderedDict([("elements_dim", 5), ("input_dims", 2)]) + assert y.shape == OrderedDict([("elementsdim", 5), ("inputdims", 2)]) # `names` length must be 2 y = x.nonzero(names=("a", "b")) @@ -343,6 +354,93 @@ def test_nonzero_names(): assert 2 == len(y.shape) +def test_multi_index_select(): + + def _check_output(tensor, dims, indices, output): + names = tensor._schema._names + index_names = indices._schema._names + + output_names = output._schema._names + assert len(names) - len(dims) + 1 == len(output_names) + + input_elements = indices.shape[index_names[0]] + output_elements = output.shape[index_names[0]] + assert input_elements == output_elements + + remaining_dims = set(names) - set(dims) + for name in remaining_dims: + assert name in output_names + + output_element_dims = [] + remaining_element_dims = [] + for name in output_names[1:]: + output_element_dims.append(output.shape[name]) + remaining_element_dims.append(tensor.shape[name]) + assert output_element_dims == remaining_element_dims + + # 1d tensor, nonzero test + tensor = ntorch.tensor([0.6, 0.4, 0.0], names=('alpha',)) + indices = tensor.nonzero() + dims = ('alpha',) + selected_values = tensor.multi_index_select(dims, indices) + _check_output(tensor, dims, indices, selected_values) + + # 3d tensor + base = torch.cat([torch.tensor([[[0.6, 0.4, 0.0], + [2.0, 0.0, 1.2]]])] * 4, 0) + tensor = ntorch.tensor(base, names=('alpha', 'beta', 'gamma')) + + # nonzero test + indices = tensor.nonzero() + dims = ('alpha', 'beta', 'gamma') + selected_values = tensor.multi_index_select(dims, indices) + _check_output(tensor, dims, indices, selected_values) + + # one dimension + indices = ntorch.tensor(torch.tensor([[0], [1], [1]]), + names=('elementsdim', 'inputdims')) + dims = ('gamma',) + selected_values = tensor.multi_index_select(dims, indices) + _check_output(tensor, dims, indices, selected_values) + + # one dimension + indices = ntorch.tensor([[1], [2]], + names=('elementsdim', 'inputdims')) + dims = ('alpha',) + selected_values = tensor.multi_index_select(dims, indices) + _check_output(tensor, dims, indices, selected_values) + + # two transposed dimensions + indices = ntorch.tensor(torch.tensor([[0, 0], [0, 1]]), + names=('elementsdim', 'inputdims')) + dims = ('gamma', 'beta') + selected_values = tensor.multi_index_select(dims, indices) + _check_output(tensor, dims, indices, selected_values) + + # 4d tensor + base = torch.tensor([[0.6, 0.0, 0.0], + [0.0, 0.4, 0.0], + [0.0, 0.0, 1.2], + [2.0, 0.0, 0.9]]) + base = torch.cat([base.unsqueeze(0)] * 5, 0) + base = torch.cat([base.unsqueeze(0)] * 7, 0) + tensor = ntorch.tensor(base, names=('dim0', 'dim1', 'dim2', 'dim3')) + + # nonzero test + indices = tensor.nonzero() + dims = ('dim0', 'dim1', 'dim2', 'dim3') + selected_values = tensor.multi_index_select(dims, indices) + _check_output(tensor, dims, indices, selected_values) + + indices = ntorch.tensor(indices.values[:, :2], + names=('elements', 'indims')) + dims = ('dim0', 'dim1') + + # two dimensions + selected_values = tensor.multi_index_select(dims, indices) + _check_output(tensor, dims, indices, selected_values) + + # def test_scalar(): # base1 = ntorch.randn(dict(alpha=10, beta=2, gamma=50)) # base2 = base1 + 10 diff --git a/namedtensor/torch_base.py b/namedtensor/torch_base.py index 0dd42ed..784580d 100644 --- a/namedtensor/torch_base.py +++ b/namedtensor/torch_base.py @@ -97,7 +97,7 @@ def masked_select(input, mask, name): return NamedTensor(a1.values.masked_select(b1.values), name) @staticmethod - def nonzero(tensor, names=("elements_dim", "input_dims")): + def nonzero(tensor, names=("elementsdim", "inputdims")): """ Returns a tensor containing the indices of all non-zero elements. @@ -106,14 +106,69 @@ def nonzero(tensor, names=("elements_dim", "input_dims")): tensor: NamedTensor names : tuple, optional Names for the output dimensions - default value: ("elements_dim", "input_dims") - default output shape: OrderedDict([("elements_dim", number of non-zero elements), - ("input_dims", input tensor's number of dimensions)]) + default value: ("elementsdim", "inputdims") + default output shape: + OrderedDict([("elementsdim", number of non-zero elements), + ("inputdims", input tensor's number of dimensions)]) """ indices = torch.nonzero(tensor.values) return NamedTensor(tensor=indices, names=names) + @staticmethod + def multi_index_select(tensor, dims, indices): + indices_names = indices._schema._names + index_dim = indices_names[1] + if len(dims) != indices.shape[index_dim]: + raise RuntimeError( + "Size of elements in 'indices' should be %d, got %d" + % (len(dims), indices.shape[index_dim]) + ) + if len(tensor.shape) < len(dims): + raise RuntimeError( + "Size of 'dims' must be <= tensor dims (%d), got %d" + % (len(tensor.shape), len(dims)) + ) + if len(set(dims)) < len(dims): + raise RuntimeError("Tuple 'dims' must contain unique names") + names = tensor._schema._names + for dim in dims: + if dim not in names: + raise RuntimeError("%s is not a dimension name in tensor" % dim) + + values = tensor.values + names = tensor._schema._names + + # find names index in dims + match_dims = [] + for dim in dims: + dim_idx = names.index(dim) + match_dims.append(dim_idx) + + # find remaining tensor dims + remaining_dims = [] + remaining_names = [] + for i, name in enumerate(names): + if i not in match_dims: + remaining_dims.append(i) + remaining_names.append(name) + + # permute tensor values to match dims + permute_idx = match_dims + remaining_dims + values = values.permute(*permute_idx) + + # find values by idx element in indices + tensors = [] + for idx in indices.values: + indexed_value = values[tuple(idx)].unsqueeze(0) + tensors.append(indexed_value) + tensors = torch.cat(tensors) + + elements_dim = indices_names[0] + new_names = tuple([elements_dim] + remaining_names) + selecte_values = ntorch.tensor(tensors, names=new_names) + return selecte_values + @staticmethod def scatter_(input, dim, index, src, index_dim): indim = dim diff --git a/namedtensor/torch_helpers.py b/namedtensor/torch_helpers.py index 3f263be..fa863e0 100644 --- a/namedtensor/torch_helpers.py +++ b/namedtensor/torch_helpers.py @@ -61,7 +61,13 @@ def masked_select(self, mask, name): return ntorch.masked_select(self, mask, name) - def nonzero(self, names=("elements_dim", "input_dims")): + def multi_index_select(self, dims, indices): + "Index into dims names with the `indices` named tensors." + from .torch_base import ntorch + + return ntorch.multi_index_select(self, dims, indices) + + def nonzero(self, names=("elementsdim", "inputdims")): """ Returns a tensor containing the indices of all non-zero elements. @@ -69,9 +75,10 @@ def nonzero(self, names=("elements_dim", "input_dims")): ---------- names : tuple, optional Names for the output dimensions - default value: ("elements_dim", "input_dims") - default output shape: OrderedDict([("elements_dim", number of non-zero elements), - ("input_dims", input tensor's number of dimensions)]) + default value: ("elementsdim", "inputdims") + default output shape: + OrderedDict([("elementsdim", number of non-zero elements), + ("inputdims", input tensor's number of dimensions)]) """ from .torch_base import ntorch