Skip to content

Commit a29e47a

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
refactor train_pipeline_tests for tracing (#3314)
Summary: Pull Request resolved: #3314 # context * previous work in torchrec train_pipeline to refactoring the utils file * tracing was separated from the utils * here we also move the corresponding tests into a new file. Reviewed By: iamzainhuda Differential Revision: D80882439 fbshipit-source-id: 572a52f651f381b084be369314fce4bead6e853d
1 parent 1c20cfb commit a29e47a

File tree

2 files changed

+119
-90
lines changed

2 files changed

+119
-90
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from typing import List, Optional
12+
from unittest.mock import MagicMock
13+
14+
import parameterized
15+
16+
import torch
17+
from torch import nn
18+
19+
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
20+
21+
from torchrec.distributed.train_pipeline.tracing import (
22+
_get_leaf_module_names,
23+
ArgInfo,
24+
ArgInfoStepFactory,
25+
CallArgs,
26+
NodeArgsHelper,
27+
PipelinedPostproc,
28+
Tracer,
29+
)
30+
from torchrec.distributed.types import NullShardedModuleContext, ShardedModule
31+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
32+
33+
34+
class TestNodeArg(unittest.TestCase):
35+
36+
@parameterized.parameterized.expand(
37+
[
38+
(
39+
CallArgs(
40+
args=[],
41+
kwargs={
42+
"id_list_features": ArgInfo(steps=[ArgInfoStepFactory.noop()]),
43+
# Empty attrs to ignore any attr based logic.
44+
"id_score_list_features": ArgInfo(
45+
steps=[ArgInfoStepFactory.noop()]
46+
),
47+
},
48+
),
49+
0,
50+
["id_list_features", "id_score_list_features"],
51+
),
52+
(
53+
CallArgs(
54+
args=[
55+
# Empty attrs to ignore any attr based logic.
56+
ArgInfo(steps=[ArgInfoStepFactory.noop()]),
57+
ArgInfo(steps=[]),
58+
],
59+
kwargs={},
60+
),
61+
2,
62+
[],
63+
),
64+
(
65+
CallArgs(
66+
args=[
67+
# Empty attrs to ignore any attr based logic.
68+
ArgInfo(
69+
steps=[ArgInfoStepFactory.noop()],
70+
)
71+
],
72+
kwargs={"id_score_list_features": ArgInfo(steps=[])},
73+
),
74+
1,
75+
["id_score_list_features"],
76+
),
77+
]
78+
)
79+
def test_build_args_kwargs(
80+
self,
81+
fwd_args: CallArgs,
82+
args_len: int,
83+
kwarges_keys: List[str],
84+
) -> None:
85+
args, kwargs = fwd_args.build_args_kwargs("initial_input")
86+
self.assertEqual(len(args), args_len)
87+
self.assertEqual(list(kwargs.keys()), kwarges_keys)
88+
89+
def test_get_node_args_helper_call_module_kjt(self) -> None:
90+
graph = torch.fx.Graph()
91+
kjt_args = []
92+
93+
kjt_args.append(
94+
torch.fx.Node(graph, "values", "placeholder", "torch.Tensor", (), {})
95+
)
96+
kjt_args.append(
97+
torch.fx.Node(graph, "lengths", "placeholder", "torch.Tensor", (), {})
98+
)
99+
kjt_args.append(
100+
torch.fx.Node(
101+
graph, "weights", "call_module", "PositionWeightedModule", (), {}
102+
)
103+
)
104+
105+
kjt_node = torch.fx.Node(
106+
graph,
107+
"keyed_jagged_tensor",
108+
"call_function",
109+
KeyedJaggedTensor,
110+
tuple(kjt_args),
111+
{},
112+
)
113+
114+
node_args_helper = NodeArgsHelper(MagicMock(), TrainPipelineContext(), False)
115+
116+
_, num_found = node_args_helper.get_node_args(kjt_node)
117+
118+
# Weights is call_module node, so we should only find 2 args unmodified
119+
self.assertEqual(num_found, len(kjt_args) - 1)

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

Lines changed: 0 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,8 @@
1010
import copy
1111
import enum
1212
import unittest
13-
from typing import List
1413
from unittest.mock import MagicMock
1514

16-
import parameterized
17-
1815
import torch
1916

2017
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
@@ -260,90 +257,3 @@ def test_restore_from_snapshot(self) -> None:
260257
]
261258
for source_model_type, recipient_model_type in variants:
262259
self._test_restore_from_snapshot(source_model_type, recipient_model_type)
263-
264-
@parameterized.parameterized.expand(
265-
[
266-
(
267-
CallArgs(
268-
args=[],
269-
kwargs={
270-
"id_list_features": ArgInfo(steps=[ArgInfoStepFactory.noop()]),
271-
# Empty attrs to ignore any attr based logic.
272-
"id_score_list_features": ArgInfo(
273-
steps=[ArgInfoStepFactory.noop()]
274-
),
275-
},
276-
),
277-
0,
278-
["id_list_features", "id_score_list_features"],
279-
),
280-
(
281-
CallArgs(
282-
args=[
283-
# Empty attrs to ignore any attr based logic.
284-
ArgInfo(steps=[ArgInfoStepFactory.noop()]),
285-
ArgInfo(steps=[]),
286-
],
287-
kwargs={},
288-
),
289-
2,
290-
[],
291-
),
292-
(
293-
CallArgs(
294-
args=[
295-
# Empty attrs to ignore any attr based logic.
296-
ArgInfo(
297-
steps=[ArgInfoStepFactory.noop()],
298-
)
299-
],
300-
kwargs={"id_score_list_features": ArgInfo(steps=[])},
301-
),
302-
1,
303-
["id_score_list_features"],
304-
),
305-
]
306-
)
307-
def test_build_args_kwargs(
308-
self,
309-
fwd_args: CallArgs,
310-
args_len: int,
311-
kwarges_keys: List[str],
312-
) -> None:
313-
args, kwargs = fwd_args.build_args_kwargs("initial_input")
314-
self.assertEqual(len(args), args_len)
315-
self.assertEqual(list(kwargs.keys()), kwarges_keys)
316-
317-
318-
class TestUtils(unittest.TestCase):
319-
def test_get_node_args_helper_call_module_kjt(self) -> None:
320-
graph = torch.fx.Graph()
321-
kjt_args = []
322-
323-
kjt_args.append(
324-
torch.fx.Node(graph, "values", "placeholder", "torch.Tensor", (), {})
325-
)
326-
kjt_args.append(
327-
torch.fx.Node(graph, "lengths", "placeholder", "torch.Tensor", (), {})
328-
)
329-
kjt_args.append(
330-
torch.fx.Node(
331-
graph, "weights", "call_module", "PositionWeightedModule", (), {}
332-
)
333-
)
334-
335-
kjt_node = torch.fx.Node(
336-
graph,
337-
"keyed_jagged_tensor",
338-
"call_function",
339-
KeyedJaggedTensor,
340-
tuple(kjt_args),
341-
{},
342-
)
343-
344-
node_args_helper = NodeArgsHelper(MagicMock(), TrainPipelineContext(), False)
345-
346-
_, num_found = node_args_helper.get_node_args(kjt_node)
347-
348-
# Weights is call_module node, so we should only find 2 args unmodified
349-
self.assertEqual(num_found, len(kjt_args) - 1)

0 commit comments

Comments
 (0)