-
Notifications
You must be signed in to change notification settings - Fork 108
Handle aliasing of viewed input tensors of varying shapes #2760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Fixes #2756 |
shino16
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thank you!
|
Your test, Prologue# Constructed by Transform for execution (took 10 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def prologue(*args, **kwargs):
# args: "Any"
check_len(args, 3)
# prims.check_len(args, 3)
# kwargs: "Any"
check_len(kwargs, 0)
# prims.check_len(kwargs, 0)
x: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=3, static=CONSTRAINT.CONSTRAINABLE]]" = args[0]
y: "cuda:0 f32[[IntegerProxy name=i2, value=3, static=CONSTRAINT.CONSTRAINABLE]]" = args[1]
z: "cuda:0 f32[[IntegerProxy name=i3, value=3, static=CONSTRAINT.CONSTRAINABLE]]" = args[2]
(i0, i1) = shape(x)
# (i0, i1) = prims.shape(x)
check_tensor_metadata(x, (i0, i1), 'cuda:0', torch.float32, False)
# prims.check_tensor_shape_and_metadata(x, (i0, i1), 'cuda:0', torch.float32, False)
(i2,) = shape(y)
# (i2,) = prims.shape(y)
check_tensor_metadata(y, (i2,), 'cuda:0', torch.float32, False)
# prims.check_tensor_shape_and_metadata(y, (i2,), 'cuda:0', torch.float32, False)
(i3,) = shape(z)
# (i3,) = prims.shape(z)
check_tensor_metadata(z, (i3,), 'cuda:0', torch.float32, False)
# prims.check_tensor_shape_and_metadata(z, (i3,), 'cuda:0', torch.float32, False)
cache_info: "Any" = thunder._get_cache_info()
cache_info_default_dtype: "<class 'torch.dtype'>" = cache_info['default_dtype']
check_literal_like(cache_info_default_dtype, torch.float32)
# prims.check_literal_like(cache_info_default_dtype, torch.float32)
cache_info_default_device: "<class 'torch.device'>" = cache_info['default_device']
check_literal_like(cache_info_default_device, torch.device("cpu"))
# prims.check_literal_like(cache_info_default_device, torch.device("cpu"))
cache_info_is_autocast_enabled: "bool False" = cache_info['is_autocast_enabled']
check_number_type_and_value(cache_info_is_autocast_enabled, False)
# prims.check_number_type_and_value(cache_info_is_autocast_enabled, False)
cache_info_alias_tensor_indices: "str" = cache_info['alias_tensor_indices']
check_string_value(cache_info_alias_tensor_indices, '0,1,2')
# prims.check_string_value(cache_info_alias_tensor_indices, '0,1,2')
cache_info_is_grad_enabled: "bool True" = cache_info['is_grad_enabled']
check_number_type_and_value(cache_info_is_grad_enabled, True)
# prims.check_number_type_and_value(cache_info_is_grad_enabled, True)
cache_info_no_grad_sync: "bool False" = cache_info['no_grad_sync']
check_number_type_and_value(cache_info_no_grad_sync, False)
# prims.check_number_type_and_value(cache_info_no_grad_sync, False)
return ((x, y, z), ())Computation trace# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x, y, z):
# x: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=3, static=CONSTRAINT.CONSTRAINABLE]]"
# y: "cuda:0 f32[[IntegerProxy name=i2, value=3, static=CONSTRAINT.CONSTRAINABLE]]"
# z: "cuda:0 f32[[IntegerProxy name=i3, value=3, static=CONSTRAINT.CONSTRAINABLE]]"
(t27, t28, t29) = update_aliases((x, y, z))
del x, y, z
# /opt/pytorch/lightning-thunder/tmp/main.py:8: return x + 2, y.add_(z)
t35 = torch.add(t27, 2, alpha=1) # t35: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=3, static=CONSTRAINT.CONSTRAINABLE]]"
# t35 = ltorch.add(t27, 2, alpha=1) # t35: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=3, static=CONSTRAINT.CONSTRAINABLE]]"
# (i0, i1) = prims.shape(t27)
# (i0, i1) = prims.shape(t27)
# i31 = prims.eq(i1, 1) # i31: "bool False"
# i32 = prims.eq(i1, i1) # i32: "bool True"
# i33 = prims.eq(i0, 1) # i33: "bool False"
# i34 = prims.eq(i0, i0) # i34: "bool True"
# (i0, i1) = prims.shape(t27)
# (i0, i1) = prims.shape(t27)
# t35 = prims.add(t27, 2.0) # t35: "cuda:0 f32[[IntegerProxy name=i0, value=2, static=CONSTRAINT.CONSTRAINABLE], [IntegerProxy name=i1, value=3, static=CONSTRAINT.CONSTRAINABLE]]"
(t36, t37, t38) = update_aliases((t27, t28, t29))
del t27, t28, t29
# /opt/pytorch/lightning-thunder/tmp/main.py:8: return x + 2, y.add_(z)
t49 = torch.add(t37, t38, alpha=1) # t49: "cuda:0 f32[[IntegerProxy name=i2, value=3, static=CONSTRAINT.CONSTRAINABLE]]"
# t49 = ltorch.add(t37, t38, alpha=1) # t49: "cuda:0 f32[[IntegerProxy name=i2, value=3, static=CONSTRAINT.CONSTRAINABLE]]"
# (i2,) = prims.shape(t37)
# (i2,) = prims.shape(t37)
# (i3,) = prims.shape(t38)
# (i3,) = prims.shape(t38)
# i42 = prims.eq(i2, 1) # i42: "bool False"
# i43 = prims.eq(i2, i2) # i43: "bool True"
# i44 = prims.eq(i2, 1) # i44: "bool False"
# i45 = prims.eq(i3, 1) # i45: "bool False"
# i46 = prims.eq(i2, i3) # i46: "bool True"
# (i2,) = prims.shape(t37)
# (i2,) = prims.shape(t37)
# (i3,) = prims.shape(t38)
# (i3,) = prims.shape(t38)
# i47 = prims.eq(i3, i2) # i47: "bool True"
# t49 = prims.add(t37, t38) # t49: "cuda:0 f32[[IntegerProxy name=i2, value=3, static=CONSTRAINT.CONSTRAINABLE]]"
[t17] = nvFusion0(t49, t37)
# t17 = prims.copy_(t49, t37, grad_enabled=True) # t17: "cuda:0 f32[[IntegerProxy name=i2, value=3, static=CONSTRAINT.CONSTRAINABLE]]"
del t49, t37
return (t35, t17)Other than so many |
|
Thanks for checking!. I have another PR in review to clean up the symbolic traces. |
mattteochen
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great,thank you for addressing this Beverly!
|
I've just realized this doesn't solve the problem, but I'm not seeing a fix. Consider the following program: This fails on this PR. All the expected |
|
Hmm... Looks like both and are both returning nonsense when jitted on the main branch even. So the underlying issue is how we are communicating data sharing to nvFuser. |
|
I find the code here sketchy, particularly lightning-thunder/thunder/core/update_aliases.py Lines 147 to 153 in 9d24da4
Compare lightning-thunder/thunder/core/update_aliases.py Lines 195 to 197 in 9d24da4
This shortcut is no longer justified when Take your example: def f(a, _):
b = a.view(5,5)
return a.exp_() * b.tanh_()These are the traces before and after # Constructed by Remove context manager prims
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a):
# a: "cuda:0 f32[5, 5]"
# /opt/pytorch/lightning-thunder/tmp/main.py:5: b = a.view(5,5)
b = ltorch.view(a, 5, 5) # b: "cuda:0 f32[5, 5]"
# b = ltorch.reshape(a, (5, 5)) # b: "cuda:0 f32[5, 5]"
# b = prims.shallow_copy(a) # b: "cuda:0 f32[5, 5]"
# /opt/pytorch/lightning-thunder/tmp/main.py:6: return a.exp_() * b.tanh_()
t2 = ltorch.exp_(a) # t2: "cuda:0 f32[5, 5]"
# t1 = ltorch.exp(a) # t1: "cuda:0 f32[5, 5]"
# t1 = prims.exp(a) # t1: "cuda:0 f32[5, 5]"
# t2 = prims.copy_(t1, a, grad_enabled=True) # t2: "cuda:0 f32[5, 5]"
t4 = ltorch.tanh_(b) # t4: "cuda:0 f32[5, 5]"
# t3 = ltorch.tanh(b) # t3: "cuda:0 f32[5, 5]"
# t3 = prims.tanh(b) # t3: "cuda:0 f32[5, 5]"
# t4 = prims.copy_(t3, b, grad_enabled=True) # t4: "cuda:0 f32[5, 5]"
t5 = ltorch.mul(t2, t4) # t5: "cuda:0 f32[5, 5]"
# t5 = prims.mul(t2, t4) # t5: "cuda:0 f32[5, 5]"
return {'output': (t5,), 'flat_args': [a]}Trace after
|
shino16
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably a separate issue.
thunder/core/update_aliases.py
Outdated
| inplace_inputs.add(in_tensor) | ||
| out_tensors = set() | ||
| for group in view_groups: | ||
| for group in intermediate_view_groups: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should iterate for groups in input_view_groups too, otherwise the following makes {x, y} and {y, y2} as separate view groups.
import torch, thunder
def f(x, y):
y2 = y.view(-1)
x.exp_()
return y2.sin()
jf = thunder.jit(f)
x = torch.randn(3, device="cuda")
x_ref = x.clone().detach()
z = jf(x, x[0])
z_ref = f(x_ref, x_ref[0])
print(thunder.last_traces(jf)[-1])
torch.testing.assert_close(z, z_ref) # AssertionErrorThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also thought that we should merge overlapping groups to avoid what you pointed out. But I wasn't able to break the overall trace correctness. I guess because I tested on CPU and not on CUDA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is a good point.
I have been able to execute your problematic functions with something like (built over your branch): Patchdiff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py
index d2db3e32..62e6c1c2 100644
--- a/thunder/core/update_aliases.py
+++ b/thunder/core/update_aliases.py
@@ -66,6 +66,62 @@ def _can_be_reshaped(arg, arg_to_replace):
arg_to_replace_numel = arg_to_replace.numel
return arg_numel == arg_to_replace_numel
+def _merge_overlapping_groups(groups: list[set]) -> list[set]:
+ """
+ Merge overlapping sets in a list of sets.
+
+ When tensors share storage transitively (e.g., a→b→c), the initial grouping
+ may create overlapping sets like [{a,b}, {b,c}]. This function merges them
+ into [{a,b,c}] to preserve transitive relationships.
+
+ Args:
+ groups: List of sets, potentially with overlaps
+
+ Returns:
+ List of sets with all overlapping groups merged
+
+ Example:
+ >>> _merge_overlapping_groups([{1, 2}, {2, 3}, {4, 5}])
+ [{1, 2, 3}, {4, 5}]
+ """
+ if not groups:
+ return []
+
+ merged = []
+ for group in groups:
+ # Check if this group overlaps with any existing merged group
+ found_overlap = False
+ for existing in merged:
+ if group.intersection(existing):
+ # Merge into existing group
+ existing.update(group)
+ found_overlap = True
+ break
+
+ if not found_overlap:
+ # No overlap found, add as new group
+ merged.append(group.copy())
+
+ # Keep merging until no more overlaps exist (handles transitive overlaps)
+ # Example: [{1,2}, {2,3}, {3,4}] needs multiple passes
+ changed = True
+ while changed:
+ changed = False
+ new_merged = []
+ for group in merged:
+ found_overlap = False
+ for existing in new_merged:
+ if group.intersection(existing):
+ existing.update(group)
+ found_overlap = True
+ changed = True
+ break
+ if not found_overlap:
+ new_merged.append(group)
+ merged = new_merged
+
+ return merged
+
def replace_args_with_alias_map(
computation_trace: Trace,
@@ -150,7 +206,6 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
out_tensors = set(map(variableify, filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_outs)))
if _is_inplace_op(bsym):
inplace_inputs.add(in_tensor)
- out_tensors = set()
for group in intermediate_view_groups:
if in_tensor in group:
group.update(out_tensors)
@@ -158,12 +213,14 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
else:
intermediate_view_groups.append(out_tensors.union({in_tensor}))
- # filter out view groups that don't have any tensors involved in inplace ops
- input_view_groups = [group for group in input_view_groups if len(group.intersection(inplace_inputs)) != 0]
- intermediate_view_groups = [
- group for group in intermediate_view_groups if len(group.intersection(inplace_inputs)) != 0
- ]
+ # Merge overlapping groups first to handle transitive relationships
+ # (e.g., if x aliases y, and y.view() creates y2, then {x,y} and {y,y2} should merge to {x,y,y2})
view_groups = input_view_groups + intermediate_view_groups
+ view_groups = _merge_overlapping_groups(view_groups)
+
+ # Filter out view groups that don't have any tensors involved in inplace ops
+ # This must happen AFTER merging so we don't discard groups that are transitively related
+ view_groups = [group for group in view_groups if len(group.intersection(inplace_inputs)) != 0]
viewed = set(reduce(set.union, view_groups, set()))
encountered = set(reduce(set.union, input_view_groups, set()))
@@ -183,15 +240,29 @@ def insert_alias_updates(computation_trace: Trace, alias_tensor_indices: list[li
bsyms.append(bsym.from_bsym_swap_proxies(swap_map, skip_output=True))
continue
- new_aliases = _get_new_aliases(views_encountered, computation_trace)
+ # For view creation ops, just create the view with swapped inputs
+ # Don't insert update_aliases - the view output will be included in the next update_aliases
+ # (which happens before the inplace op)
+ if _is_view_creation_op(bsym):
+ new_bsym = bsym.from_bsym_swap_proxies(swap_map)
+ bsyms.append(new_bsym)
+ # Mark the output as encountered so it can be included in future update_aliases
+ encountered.update(out_tensors)
+ else:
+ # For inplace ops or ops involving viewed args, insert update_aliases before the op
+ if views_encountered:
+ new_aliases = _get_new_aliases(views_encountered, computation_trace)
- update_bsym, swap_map = _get_update_bsym(views_encountered, swap_map, new_aliases)
- new_bsym = bsym.from_bsym_swap_proxies(swap_map)
- if has_tags(bsym, {BoundSymbolTag.BACKWARD}):
- update_bsym.tags.add(BoundSymbolTag.BACKWARD)
- bsyms.append(update_bsym)
- encountered.update(out_tensors)
- bsyms.append(new_bsym)
+ update_bsym, swap_map = _get_update_bsym(views_encountered, swap_map, new_aliases)
+ new_bsym = bsym.from_bsym_swap_proxies(swap_map)
+ if has_tags(bsym, {BoundSymbolTag.BACKWARD}):
+ update_bsym.tags.add(BoundSymbolTag.BACKWARD)
+ bsyms.append(update_bsym)
+ encountered.update(out_tensors)
+ bsyms.append(new_bsym)
+ else:
+ # No compatible views to update, just process the operation
+ bsyms.append(bsym.from_bsym_swap_proxies(swap_map))
if _is_inplace_op(bsym) and len(out_tensors) == 1 and len(in_tensors) == 1:
# This relies on these being one element sets (ltorch.setitem_ yields no outs).
swap_map = _update_swap_map(swap_map, in_tensors.pop(), unvariableify(out_tensors.pop()))Test reproimport torch, thunder
passed = {}
print_traces = False
def f(x, y):
y2 = y.view(-1)
x.exp_()
return y2.sin()
jf = thunder.jit(f)
x = torch.randn(3, device="cuda")
x_ref = x.clone().detach()
z = jf(x, x[0])
z_ref = f(x_ref, x_ref[0])
if print_traces:
print(thunder.last_traces(jf)[-1])
try:
torch.testing.assert_close(z, z_ref)
passed["test_basic_alias"] = True
except Exception as e:
passed["test_basic_alias"] = False
def f(a, _):
b = a.view(5,5)
return a.exp_().view(5,5) * b.tanh_() # adding view after exp_() due to broadcast issue
jf = thunder.jit(f)
x = torch.randn(25, device="cuda")
x_ref = x.clone().detach()
z = jf(x, x[0]) # b input irrelevant
z_ref = f(x_ref, x_ref[0])
if print_traces:
print(thunder.last_traces(jf)[-1])
try:
torch.testing.assert_close(z, z_ref)
passed["test_alias_view_reshape"] = True
except Exception as e:
passed["test_alias_view_reshape"] = False
def f(a, _):
b = a[0,0]
return a.exp_() * b.tanh_()
jf = thunder.jit(f)
x = torch.randn(4, 4, device="cuda")
x_ref = x.clone().detach()
z = jf(x, x[0]) # b input irrelevant
z_ref = f(x_ref, x_ref[0])
if print_traces:
print(thunder.last_traces(jf)[-1])
try:
torch.testing.assert_close(z, z_ref)
passed["test_indexing_alias"] = True
except Exception as e:
passed["test_indexing_alias"] = False
print("TEST PASSED!" if all(passed.values()) else "TEST FAILED!")This was partially Gemini-generated, so take it with a grain of salt. What I tried to add:
|
I'm going to state some definitions for clarity's sake, to make sure that we are on the same page. Consider the following modification of the given example It produces the post-update_aliases trace of: Notice that the alias
I'm confused by this comment. Does the name of the variable suggest something different from what I described above? Do you have a suggestion for a better variable name? |
|
@mattteochen I just pushed a commit that incorporates the view groups coming from input with the collection of the view groups coming from the intermediate tensors. I understand how defining For the other part of your patch following the comment |
|
@beverlylytle Running the repro from #2760 (comment) fails with: Execution traces are attached in that comment. My understanding is that However, in this example we already have the view group Adding mappings like |
|
Sorry, I mistakenly re-requested @mattteochen's review. |
|
def func(a, b):
# c is a view of a.
c = a.view(a.shape)
# Modify 'c' in place.
c.add_(1.0)
# Return b. It should reflect the update to c (via a).
return bThunder last trace: |
|
I think what I said here could answer your question in #2760 (comment), but it seems more like a separate issue that exists on |
I think it's dealt against in this transform: lightning-thunder/thunder/__init__.py Line 485 in 656656b
This should appear as the first trace in |
|
Consider the two versions of the same function We would expect these two functions to produce essentially the same trace. However, on the main branch, and results in a tensor-likes not close error when compare to the pytorch result, whereas and does produce the same result as PyTorch. Compare the last fusions in each trace. They both represent tanh-copy_-mul, but one is provided two inputs, while the other has only one. This is because in the second version, all instances of That being said, it is possible that view groups should include more than they do. Do you (@shino16) have an example where an alias, when included in a view group, produces a correct trace, but an incorrect trace when excluded? |
I found Thunder (with the patch above) producing incorrect results in a context like: def func(a, b):
# c is a view of a
c = a.view(a.shape)
# Inplace update on c.
# Since c is a view of a, and b is a slice of a, this MUST update b.
c.add_(1.0)
# Operation on b. Should see the +1.0 update.
d = b * 2
return dif The input alias should be The goal of that function is to unify repro:import torch
import thunder
def main():
def func(a, b):
# c is a view of a
c = a.view(a.shape)
# Inplace update on c, which affects a, and thus should affect b
c.add_(1.0)
# Operation on b. Should see the 1.0.
# If b is not updated, this might run before the add or use old values.
d = b * 2
return d
device = 'cuda'
a = torch.zeros(10, device=device)
b = a[:5] # b aliases a
a_ref = a.clone().detach()
b_ref = a_ref[:5]
print(f"Input 'b' (should be 0): {b[:3].tolist()}")
jf = thunder.jit(func)
try:
res = jf(a, b)
res_ref = func(a_ref, b_ref)
print(f"Thunder result (should be 1.0): {a[:3].tolist()}")
print(f"PyTorch result: {a_ref[:3].tolist()}")
print(f"Thunder result (should be 2.0): {res[:3].tolist()}")
print(f"PyTorch result: {res_ref[:3].tolist()}")
print("\nLast Trace:")
print(thunder.last_traces(jf)[-1])
torch.testing.assert_close(res, res_ref)
print("\n[PASSED] Slice alias 'b' correctly saw the update!")
except Exception as e:
print(f"\n[FAILED] Slice alias check failed.")
# print(e)
raise e
if __name__ == "__main__":
main()
|
import torch
import thunder
import traceback
def main():
def func(a, b):
# c is a view of a.
c = a.view(a.shape)
# Modify 'c' in place.
c.add_(1.0)
# Return b. It should reflect the update to c (via a).
return b
# Setup inputs
device = 'cuda' if torch.cuda.is_available() else 'cpu'
a = torch.zeros(10, device=device)
b = a.view(a.shape) # b aliases a
a_ref = a.clone().detach()
b_ref = a_ref.view(a_ref.shape)
print(f"Input 'b' (first 3): {b[:3].tolist()}")
# Compile with Thunder
jf = thunder.jit(func)
try:
# Run Thunder
# Note: we return b.
res = jf(a, b)
# Run PyTorch Reference
res_ref = func(a_ref, b_ref)
# Compare
print(f"Thunder result (first 3): {res[:3].tolist()}")
print(f"PyTorch result (first 3): {res_ref[:3].tolist()}")
print("\nLast Trace:")
print(thunder.last_traces(jf)[-1])
torch.testing.assert_close(res, res_ref)
print("\n[PASSED] Transitive alias 'b' was correctly updated!")
except Exception as e:
print(f"\n[FAILED] Transitive alias check failed.")
print("Likely cause: 'b' did not receive the inplace update from 'c'.")
# traceback.print_exc()
raise e
if __name__ == "__main__":
main() |
I think I'm not properly understanding your concern. The script in #2760 (comment) is exactly the trace that gives the correct result when |
|
This PR is exploding. I agree that we should address different subsections separately. I've investigated the proposed patch to address those cases where we mix This PR doesn't need to handle all the edge cases, I think. |
If code doesn't directly contribute to the output of the function that is jitted, Thunder cuts it out. I think this is a good assumption for Thunder to make. We don't want to be spinning wheels on dead code. If |
Oh yes, if we assume that we don't cover "dead code", I agree with you. I don't have a concrete example right now as this came to mind by trying some nested ops working on a shared memory location. I saw this case as something in the middle between dead and not dead code, due to the shared memory. |
Sorry @beverlylytle , this test seems to be correct on your branch (this PR). I was testing this with the patch applied, and the This introduced an incorrect overall result (caused by the skipping policy) on cases of this type. Checking out to an earlier commit, we had runtime errors: RuntimeError: Attempting to reshape a.shape=(4,) to shape=(2,), but a.numel=4 is different from the number of elements in shape, 2 |
In an off-github discussion, we resolved that when |
shino16
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's merge this! The issue with in-place ops on aliased memory exists even on main, so it will be discussed separately in #2766.
|
@KaelanDt Could you take a look? |
Consider the following function
provided the input
x = torch.randn(2,2)andy = x[0,:]. When the function is called, the in-place changes toyalso affectxas they share storage. Ifxandyhad the same number of elements, we would replace all instances ofyin the program body with the output ofprims.reshape(x, y.shape). Since they do not have the same number of elements, this doesn't work. Currently this situation is skipped, and, in the case above, no special handling ofxoccurs. This PR addresses that case by definingview_groupsnot only with the data collected from view creation operations within the program body but also with with data coming from the input tensors.