-
Notifications
You must be signed in to change notification settings - Fork 169
define data attribution for AnalogContext #717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
2ef4d52
to
586450f
Compare
Signed-off-by: Zhaoxian-Wu <[email protected]>
Signed-off-by: Zhaoxian-Wu <[email protected]>
Signed-off-by: Zhaoxian-Wu <[email protected]>
Signed-off-by: Zhaoxian-Wu <[email protected]>
Signed-off-by: Zhaoxian-Wu <[email protected]>
Signed-off-by: Zhaoxian-Wu <[email protected]>
Signed-off-by: Zhaoxian-Wu <[email protected]>
Signed-off-by: Zhaoxian-Wu <[email protected]>
Signed-off-by: Zhaoxian-Wu <[email protected]>
@Zhaoxian-Wu can you please update and sync the branch with master? We are gonna proceed with the review of it. Thank you for the enhancement! |
Hi @PabloCarmona , I sync the latest update, feel free to let any comment or discussion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks for the contribution, @Zhaoxian-Wu !
However, I don't quite see the need for these low-level changes and I am not convinced that all parts of the code base will actually still work. Considering all the various cases for the RPUCuda part of the code is a bit tricky, since the weight is stored in the C++ part and only a copy is returned to the user when get_weights
is called. This in part is by design, as we do not encourage users to fiddle with the analog weights, which are not accessible directly in reality. While some tiles will work with your formulation, the ctx.data
will be out-of-sync for other and this would be more confusing. The user should not be changing anything related to analog_ctx
.
If you just want to make the size of the parameter of analog_ctx
reflecting the shape of the analog time, why not simply add a "shape" property or custom size
method to the analog context that returns the weight size rather than storing the full weights, which will then become out-of-sync with the actual trained weights?
|
||
# Recreate the tile. | ||
self.tile = self._recreate_simulator_tile(x_size, d_size, self.rpu_config) | ||
self.analog_ctx.data = self.tile.get_weights() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the logic here? Note that this will only the a copy of the current weights. So if you update the weights (using RPUCuda) the analog_ctx.data
will not be synchronized correctly with the actual weight. Of course the size of the weight will not change, but it will be more confusing of one maintains two different version of the weight which are not synced, or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely, there could be an out-of-sync concern here. Therefore, I also change the definition of self.tile.get_weights()
. So far, the tile will return an original weight instead of a detached tensor here. Since the data
here and the actual weight tenser are the same object in essence, there is no sync issue here.
def get_weights(self) -> Tensor:
"""Get the tile weights.
matrix; and the second item is either the ``[out_size]`` bias vector
or ``None`` if the tile is set not to use bias.
"""
return self.weight.data
bias = from_numpy(array(bias)) | ||
|
||
self.bias.data[:] = bias[:].clone().detach().to(self.get_dtype()).to(self.bias.device) | ||
self.bias.data.copy_(bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this allow for setting the bias with the data type defined in the tile? While it is correct that torch defines the data type of a layer solely by the data type of the weight tensor, I found it more convenient to handle all the specialized code to have a dtype property on tile level (as this is essentially the "analog tensor" ). Do you suggest that the d_type should be removed from the tile, but now determined by the ctx.data tensor dtype?
or ``None`` if the tile is set not to use bias. | ||
""" | ||
return self.weight.data.detach().cpu() | ||
return self.weight.data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the current convention is that get_weights always returns CPU weights. If you want to change this, the RPUCuda get_weights call need to change as well, as they producing CPU weights by default. Moreover, get_weights will always product a copy (without the backward trace) by design, to avoid implicit things that cannot be done with analog weights. Of course, hardware aware training is a special case, but for that we have a separate tile.
Hi @maljoras-sony , thanks for your careful review and comments.
The motivation for the enhancement mainly comes from the need for convenient weight access, but the dummy tensor in the It may be a solution to define a series of attributions like One example is the issue from aihwkit.nn import AnalogLinear
model = AnalogLinear(6, 4, False)
analog_ctx = next(model.parameters())
print(analog_ctx.size()) # torch.Size([])
print(analog_ctx.nonzero()) # tensor([], size=(1, 0), dtype=torch.int64)
print(analog_ctx > 10) # false, expected a boolean array
print(analog_ctx.norm()) # tensor(1., grad_fn=<LinalgVectorNormBackward0>) Therefore, I believe it will be a straightforward way to adopt a pytorch-style convention. I.e., a
I completely understand your concern. Therefore, I use a relatively safe approach to do that. In the self.analog_ctx.data = self.tile.get_weights() Note that I also make sure the def get_weights(self) -> Tensor:
"""Get the tile weights.
matrix; and the second item is either the ``[out_size]`` bias vector
or ``None`` if the tile is set not to use bias.
"""
return self.weight.data By this way, even though the weight is updated, the weight can remind synchronous with the underlying weights. |
I found there are still issues with my solution so far in the RPUCuda part, as you pointed out correctly. I may need some time to fix it. I will also consider your concern that the user should not have access to the actual weight unless necessary. However, I still want to retain the flexibility of accessing weight by AnalogContext. My preliminary idea is to
|
Related issues
NA
Description
In the current version, the
data
attribution is replaced by a dummy tensor, leading to confusiing behavior (related code). For example, when accessing the size of analog context (which is a subclass of PytorchParameter
andTensor
), the output is an empty size, while we expect it outputs the correct size.Details
To ensure all related components work correctly, I adapt a serious of code.
Furthermore, I adopt some conventions to make the style uniformly across the library and try to follow the PyTorch style.
1. AnalogContext is a valid
torch.nn.Parameter
To support the correct behavior, I bind the weight matrix from the binding analog tile by using
self.data = analog_tile.tile.get_weights()
.I understand that this feature introduces an additional degree of freedom for programmers and it is different from the real weights reading mechanism. So I add some comments to encourage the users to read and write the weights in "realistic" way in the comments.
2. [convention]
get_weights
andset_weights
inplaceIn the
TileModuleArray
class, we adapt the convention that all the reading and writing are done in-place.In
get_weights
, we don't convert theTensor
back tonp.Array
automatically or move back to cpu, i.e., we doinstead of
Since the movement across devices is supposed to done explicitly by user to avoid confusion.
In
set_weights
, we doinstead of
3. [convention] Run-time computing of
device
andis_cuda
For a analog tile class (any subclass of
SimulatorTileWrapper
), thedevice
andis_cuda
is in essence the one of itsanalog_ctx
attribute.In case someone replaces the tile in
analog_ctx
without remembering to update them, I introduce two properties function to ensure the consistency.Let me know if I miss anything from the library design perspective.