-
Notifications
You must be signed in to change notification settings - Fork 12
Add SparseLatticedTensor #484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-new-engine
Are you sure you want to change the base?
Conversation
72b382a to
95f9490
Compare
| def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: | ||
| t1_, t2_ = prepare_for_elementwise_op(t1, t2) | ||
| t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.strides) | ||
| all_dims = list(range(t1_.ndim)) | ||
| return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is problematic, we do not divide by a sparse tensor, or at least by a non-dense tensor. The line
t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.strides)Is essentially the pointwise function t2. Note that it can be applied if and only if it is dense. Still, if t2 is dense (and even if it is a sst), we are good, so I am not sure how to solve this.
| def prepare_for_elementwise_op( | ||
| t1: Tensor | int | float, t2: Tensor | int | float | ||
| ) -> tuple[StructuredSparseTensor, StructuredSparseTensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is something fishy about the responsibility of this function. It seems to be a mix between tricking einsum into handling non sst Tensors and handling non-tensor inputs. The thing is we don't want to use einsum for these when we have non-tensor inputs (it is too much machinery).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I believe that the einsum implementation of most of those based on int or float are actually wrong. For instance mul is ..., ... -> ..., but the number of dimensions might not match if it is a scalar, in which case you would want to multiply the physical by the scalar.
src/torchjd/sparse/linalg.py
Outdated
| def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None: | ||
| """ | ||
| Solve A X = B where A, B and X have integer dtype. | ||
| Return X if such a matrix exists and otherwise None. | ||
| """ | ||
|
|
||
| A_ = A.to(torch.float64) | ||
| B_ = B.to(torch.float64) | ||
|
|
||
| try: | ||
| X = torch.linalg.solve(A_, B_) | ||
| except RuntimeError: | ||
| return None | ||
|
|
||
| X_rounded = X.round() | ||
| if not torch.all(torch.isclose(X, X_rounded, atol=tol)): | ||
| return None | ||
|
|
||
| # TODO: Verify that the round operation cannot fail | ||
| return X_rounded.to(torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the short term this seems good enough, but in the long term I think we'd want to either use another library to solve integer linear equations (SymPy, maybe others) or to implement our own solver in torch.
The reason is that computations with floating point numbers may be wrong with large stride values, and it's probably faster to solve something with the extra knowledge that it's integer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll probably end up coding that in C.
ab50cd5 to
3271b3a
Compare
* Note that in python, x % 0 raises ZeroDivisionError. The implementation of mod_c matches this behavior when t2 is the zero vector.
3271b3a to
3a8e684
Compare
dc5740b to
3e9e7d4
Compare
…prove documentation.
* It was unused and I think it will be replaced by functions that find divisors of the basis
…nctions and concatenate)
* Otherwise when specifying dim=0 the dispatcher calls the function without dim arg (this is very weird but it seems that this is necessary)
Co-authored-by: Pierre Quinton <[email protected]>
TODO:
hnf_decompositionalgorithm, with a matrix of shape[5, 7], rank3and values in range[-50, 51[, there was overflow (See test prior to ba9bf21)