Skip to content

Commit 43a9d6f

Browse files
bwastifacebook-github-bot
authored andcommitted
[TorchScript] Support user defined classes as constants (pytorch#5062)
Summary: Pull Request resolved: pytorch/glow#5062 Pull Request resolved: pytorch#45556 User defined classes can be used as constants. This is useful when freezing and removing the module from the graph. Test Plan: waitforsadcastle Reviewed By: eellison Differential Revision: D23994974 fbshipit-source-id: 5b4a5c91158aa7f22df39d71f2658afce1d29317
1 parent 3611d26 commit 43a9d6f

17 files changed

+144
-26
lines changed

aten/src/ATen/core/ivalue.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,9 @@ std::ostream& IValue::repr(
508508
return out << enum_holder->qualifiedClassName() << "." <<
509509
enum_holder->name();
510510
}
511+
case IValue::Tag::Object: {
512+
TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind(), ". Perhaps you've frozen a module with custom classes?");
513+
}
511514
default:
512515
TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind());
513516
}

test/cpp/jit/test_custom_class.cpp

+43
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <gtest/gtest.h>
22

33
#include <test/cpp/jit/test_custom_class_registrations.h>
4+
#include <torch/csrc/jit/passes/freeze_module.h>
45
#include <torch/custom_class.h>
56
#include <torch/script.h>
67

@@ -86,5 +87,47 @@ TEST(CustomClassTest, TestDocString) {
8687
method_doc_string);
8788
}
8889

90+
TEST(CustomClassTest, Serialization) {
91+
script::Module m("m");
92+
93+
// test make_custom_class API
94+
auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
95+
std::vector<std::string>{"foo", "bar"});
96+
m.register_attribute(
97+
"s",
98+
custom_class_obj.type(),
99+
custom_class_obj,
100+
/*is_parameter=*/false);
101+
m.define(R"(
102+
def forward(self):
103+
return self.s.return_a_tuple()
104+
)");
105+
106+
auto test_with_obj = [](script::Module& mod) {
107+
auto res = mod.run_method("forward");
108+
auto tup = res.toTuple();
109+
AT_ASSERT(tup->elements().size() == 2);
110+
auto i = tup->elements()[1].toInt();
111+
AT_ASSERT(i == 123);
112+
};
113+
114+
auto frozen_m = torch::jit::freeze_module(m.clone());
115+
116+
test_with_obj(m);
117+
test_with_obj(frozen_m);
118+
119+
std::ostringstream oss;
120+
m.save(oss);
121+
std::istringstream iss(oss.str());
122+
caffe2::serialize::IStreamAdapter adapter{&iss};
123+
auto loaded_module = torch::jit::load(iss, torch::kCPU);
124+
125+
std::ostringstream oss_frozen;
126+
frozen_m.save(oss_frozen);
127+
std::istringstream iss_frozen(oss_frozen.str());
128+
caffe2::serialize::IStreamAdapter adapter_frozen{&iss_frozen};
129+
auto loaded_frozen_module = torch::jit::load(iss_frozen, torch::kCPU);
130+
}
131+
89132
} // namespace jit
90133
} // namespace torch

test/quantization/test_quantize_jit.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2674,6 +2674,7 @@ def forward(self, x):
26742674
num_quantize_per_tensor = 1 # for output
26752675
for num_quant, num_op in num_op_by_num_quant.items():
26762676
num_quantize_per_tensor += num_op * num_quant
2677+
num_quantize_per_tensor -= 4 # constant propagation removes some prepacks
26772678
FileCheck().check_count("aten::quantize_per_tensor(", num_quantize_per_tensor, exactly=True) \
26782679
.run(m1.graph)
26792680

@@ -2997,8 +2998,7 @@ def forward(self, x):
29972998
m = torch.jit.script(M())
29982999
m = quantize_dynamic_jit(m, {'': float16_dynamic_qconfig})
29993000

3000-
FileCheck().check("quantized::linear_prepack_fp16") \
3001-
.check_next("quantized::linear_dynamic_fp16") \
3001+
FileCheck().check("quantized::linear_dynamic_fp16") \
30023002
.check_not("aten::linear") \
30033003
.check_not("aten::dequantize") \
30043004
.check_not("aten::quantize") \
@@ -3076,7 +3076,7 @@ def forward(self, indices1, offsets1, indices2, offsets2):
30763076
m = prepare_jit(m, {'embedding1' : int4_qconfig, 'embedding2' : int8_qconfig})
30773077
m = convert_jit(m)
30783078
FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \
3079-
.check_next("quantized::embedding_bag_byte_rowwise_offsets") \
3079+
.check("quantized::embedding_bag_byte_rowwise_offsets") \
30803080
.run(m.graph)
30813081
m(*dummy_inputs)
30823082

test/test_mobile_optimizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def _quant_script_and_optimize(model):
340340

341341
m, m_optim = _quant_script_and_optimize(Standalone())
342342
FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
343-
.check_count("_jit_pass_hoist_conv_packed_params", 2, exactly=True) \
343+
.check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
344344
.run(m_optim.graph)
345345
self.assertFalse(hasattr(m_optim, "conv1"))
346346
self.assertFalse(hasattr(m_optim, "conv2"))
@@ -354,7 +354,7 @@ def _quant_script_and_optimize(model):
354354

355355
m, m_optim = _quant_script_and_optimize(Parent())
356356
FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
357-
.check_count("_jit_pass_hoist_conv_packed_params", 2, exactly=True) \
357+
.check_count("__torch__.torch.classes.quantized.Conv2dPackedParamsBase = prim::Constant", 2, exactly=True) \
358358
.run(m_optim.graph)
359359
self.assertFalse(hasattr(m_optim, "conv1"))
360360
self.assertFalse(hasattr(m_optim, "child"))

torch/_C/__init__.pyi.in

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule',
179179
def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule',
180180
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
181181
def _jit_pass_inline(Graph) -> None: ...
182+
def _jit_pass_constant_propagation(Graph) -> None: ...
182183
def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
183184
def _jit_can_fuse_on_cpu() -> _bool: ...
184185
def _jit_can_fuse_on_gpu() -> _bool: ...

torch/csrc/jit/ir/constants.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,9 @@ c10::optional<Value*> tryInsertConstant(
124124
n->destroy();
125125
return c10::nullopt;
126126
};
127-
} else if (val.isGenericDict() && insertableIValue(val)) {
128-
n->ival_(attr::value, val);
129-
n->output()->setType(val.type());
130-
} else if (val.isEnum()) {
127+
} else if (
128+
(val.isGenericDict() && insertableIValue(val)) || (val.isEnum()) ||
129+
(val.isObject() && !val.toObjectRef().type()->is_module())) {
131130
n->ival_(attr::value, val);
132131
n->output()->setType(val.type());
133132
} else {
@@ -191,6 +190,9 @@ c10::optional<IValue> toIValue(const Value* v) {
191190
} else if (type->cast<EnumType>()) {
192191
const auto& enum_val = node->ival(attr::value);
193192
return enum_val;
193+
} else if (type->cast<ClassType>() && !type->is_module()) {
194+
const auto& class_val = node->ival(attr::value);
195+
return class_val;
194196
} else {
195197
std::stringstream ss;
196198
ss << "constant literal not supported for: " << type->str();

torch/csrc/jit/ir/ir.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ static void printAttribute(std::ostream& out, const IValue& ival) {
133133
} else if (input.isTensorList()) {
134134
ss << "[<Tensors>]";
135135
return true;
136+
} else if (input.isObject() && !input.type()->is_module()) {
137+
ss << "object(" << &input.toObjectRef() << ")";
138+
return true;
136139
}
137140
return false;
138141
};

torch/csrc/jit/ir/node_hashing.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ bool ivaluesEqual(const IValue& a1, const IValue& a2) {
126126
if (a1.isEnum()) {
127127
return a1.toEnumHolder() == a2.toEnumHolder();
128128
}
129+
if (a1.isObject()) {
130+
return &a1.toObjectRef() == &a2.toObjectRef();
131+
}
129132
TORCH_INTERNAL_ASSERT(false);
130133
}
131134

torch/csrc/jit/passes/constant_propagation.cpp

+28-8
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
namespace torch {
1717
namespace jit {
1818

19-
c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(const Node* n) {
19+
c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
20+
const Node* n,
21+
bool ignore_custom_classes) {
2022
Stack stack;
2123
for (auto input : n->inputs()) {
2224
if (auto ival = toIValue(input)) {
@@ -25,6 +27,7 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(const Node* n) {
2527
return c10::nullopt;
2628
}
2729
}
30+
2831
switch (n->kind()) {
2932
case prim::ListUnpack: {
3033
if (stack.back().toList().size() != n->outputs().size()) {
@@ -80,6 +83,12 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(const Node* n) {
8083
return c10::nullopt;
8184
}
8285
}
86+
// Weak form of const propagation
87+
if (ignore_custom_classes) {
88+
if (v.isCustomClass()) {
89+
return c10::nullopt;
90+
}
91+
}
8392
}
8493
return stack;
8594
}
@@ -104,33 +113,40 @@ std::unordered_set<Symbol> skip_list = {
104113
struct ConstantPropagator {
105114
// Runs constant propagation with an aliasing db and checks if inputs or
106115
// outputs might be mutated in the graph
107-
static ConstantPropagator WithAliasDb(std::shared_ptr<Graph> graph) {
108-
return ConstantPropagator(graph, true);
116+
static ConstantPropagator WithAliasDb(
117+
std::shared_ptr<Graph> graph,
118+
bool ignore_custom_classes) {
119+
return ConstantPropagator(std::move(graph), true, ignore_custom_classes);
109120
}
110121

111122
// Runs constant propagation only on ops that clearly do not have aliased
112123
// inputs or outputs without computing aliasing information
113124
static ConstantPropagator NoAliasDb(std::shared_ptr<Graph> graph) {
114-
return ConstantPropagator(graph, false);
125+
return ConstantPropagator(std::move(graph), false, false);
115126
}
116127

117128
void run() {
118129
ConstantPropagation(graph_->block());
119130
}
120131

121132
private:
122-
ConstantPropagator(std::shared_ptr<Graph> graph, bool aliasing_types)
133+
ConstantPropagator(
134+
std::shared_ptr<Graph> graph,
135+
bool aliasing_types,
136+
bool ignore_custom_classes)
123137
: graph_(std::move(graph)) {
124138
if (aliasing_types) {
125139
aliasDb_ = torch::make_unique<AliasDb>(graph_);
126140
} else {
127141
aliasDb_ = nullptr;
128142
}
143+
ignore_custom_classes_ = ignore_custom_classes;
129144
}
130145

131146
void propagateNode(Node* n) {
132147
std::vector<IValue> outputs;
133-
if (auto outputs_opt = runNodeIfInputsAreConstant(n)) {
148+
if (auto outputs_opt =
149+
runNodeIfInputsAreConstant(n, ignore_custom_classes_)) {
134150
outputs = std::move(outputs_opt.value());
135151
} else {
136152
// The op failed to run, so we cannot continue constant-prop for it.
@@ -353,11 +369,15 @@ struct ConstantPropagator {
353369

354370
std::shared_ptr<Graph> graph_;
355371
std::unique_ptr<AliasDb> aliasDb_;
372+
bool ignore_custom_classes_;
356373
};
357374
} // anonymous namespace
358375

359-
void ConstantPropagation(std::shared_ptr<Graph>& graph) {
360-
ConstantPropagator cp = ConstantPropagator::WithAliasDb(graph);
376+
void ConstantPropagation(
377+
std::shared_ptr<Graph>& graph,
378+
bool ignore_custom_classes) {
379+
ConstantPropagator cp =
380+
ConstantPropagator::WithAliasDb(graph, ignore_custom_classes);
361381
cp.run();
362382
EliminateDeadCode(graph);
363383
GRAPH_DUMP("After ConstantPropagation: ", graph);

torch/csrc/jit/passes/constant_propagation.h

+13-3
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,25 @@
55
namespace torch {
66
namespace jit {
77

8-
TORCH_API void ConstantPropagation(std::shared_ptr<Graph>& graph);
8+
// Runs constant propagation on all objects unless ignore_custom_classes is
9+
// specified as true, in which case user defined classes are skipped. This is
10+
// useful to prevent early fusion of packing operations, which end up lowering
11+
// away information about their constructors (e.g. packed::linear_clamp_prepack
12+
// and prepacked::conv2d_clamp_prepack)
13+
TORCH_API void ConstantPropagation(
14+
std::shared_ptr<Graph>& graph,
15+
bool ignore_custom_classes = false);
916

1017
// runs constant propagation only on ops that have non-aliasing inputs & outputs
1118
TORCH_API void ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph);
1219

1320
// Runs the node if its inputs are constants. Callers of this function must
1421
// make their own determination if constant prop is appropriate - for example
15-
// non-deterministic ops or ops with side effects
16-
TORCH_API c10::optional<Stack> runNodeIfInputsAreConstant(const Node* node);
22+
// non-deterministic ops or ops with side effects. If ignore_custom_classes is
23+
// specified, nodes that output user defined classes are not run.
24+
TORCH_API c10::optional<Stack> runNodeIfInputsAreConstant(
25+
const Node* node,
26+
bool ignore_custom_classes = false);
1727

1828
} // namespace jit
1929
} // namespace torch

torch/csrc/jit/passes/freeze_module.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ class AttributePropagator {
8282
ClearProfilingInformation(subgraph);
8383
};
8484
auto applyOptimizations = [](std::shared_ptr<Graph>& subgraph) {
85-
runOptimization(subgraph, /* unroll? */ false);
85+
runOptimization(
86+
subgraph, /* unroll? */ false, /* const_prop_user_classes? */ false);
8687
};
8788

8889
for (auto function : preservedMethods_) {
@@ -315,6 +316,15 @@ class AttributePropagator {
315316
val = overrideGradient(val);
316317
}
317318
attr = std::move(dict);
319+
} else if (attr.isObject() && !attr.toObjectRef().type()->is_module()) {
320+
auto obj_type = attr.type()->expect<ClassType>();
321+
auto obj_value = std::move(attr).toObject();
322+
auto sub_attributes = obj_type->getAttributes();
323+
for (const auto& sub_attr : sub_attributes) {
324+
auto sub_attr_val = obj_value->getAttr(sub_attr.getName());
325+
sub_attr_val = overrideGradient(sub_attr_val);
326+
}
327+
return obj_value;
318328
}
319329

320330
return attr;

torch/csrc/jit/passes/xnnpack_rewrite.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <torch/csrc/jit/ir/ir.h>
55
#include <torch/csrc/jit/ir/subgraph_matcher.h>
66
#include <torch/csrc/jit/passes/constant_pooling.h>
7+
#include <torch/csrc/jit/passes/constant_propagation.h>
78
#include <torch/csrc/jit/passes/fold_conv_bn.h>
89
#include <torch/csrc/jit/passes/freeze_module.h>
910
#include <torch/csrc/jit/passes/fuse_linear.h>
@@ -334,6 +335,9 @@ void fusePrePackedLinearConvWithClamp(script::Module& module) {
334335
auto graph = module.get_method("forward").graph();
335336
fuseReluWithPackedOps(graph);
336337
fuseHardtanhWithPackedOps(graph);
338+
339+
// Ignore user defined classes for later passes
340+
ConstantPropagation(graph, true);
337341
}
338342

339343
void FoldPrePackingOps(script::Module& m) {
@@ -348,6 +352,9 @@ void FoldPrePackingOps(script::Module& m) {
348352
"prepacked::conv2d_transpose_clamp_prepack"));
349353
};
350354
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
355+
auto graph = m.get_method("forward").graph();
356+
// Folding requires a const propagation through user defined classes
357+
ConstantPropagation(graph, false);
351358
}
352359

353360
script::Module optimizeForMobile(

torch/csrc/jit/python/init.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,8 @@ void initJITBindings(PyObject* module) {
436436
})
437437
.def(
438438
"_jit_pass_constant_propagation",
439-
[](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); })
439+
[](std::shared_ptr<Graph>& g) { return ConstantPropagation(g); },
440+
py::arg("graph"))
440441
.def("_jit_pass_erase_shape_information", EraseShapeInformation)
441442
.def(
442443
"_jit_pass_create_autodiff_subgraphs",

torch/csrc/jit/runtime/graph_executor.cpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,10 @@ void runNondiffOptimization(
865865
"After customPostPassses (end of runNondiffOptimization)\n", *graph);
866866
}
867867

868-
void runOptimization(std::shared_ptr<Graph>& graph, bool unroll) {
868+
void runOptimization(
869+
std::shared_ptr<Graph>& graph,
870+
bool unroll,
871+
bool const_prop_user_classes) {
869872
// Basic graph preprocessing to eliminate noise.
870873
GRAPH_DEBUG(
871874
"Before EliminateDeadCode (beginning of runOptimization)\n", *graph);
@@ -878,8 +881,14 @@ void runOptimization(std::shared_ptr<Graph>& graph, bool unroll) {
878881

879882
PeepholeOptimize(graph);
880883
GRAPH_DEBUG("After PeepholeOptimize, before ConstantPropagation\n", *graph);
881-
ConstantPropagation(graph);
884+
885+
if (const_prop_user_classes) {
886+
ConstantPropagation(graph);
887+
} else {
888+
ConstantPropagation(graph, true);
889+
}
882890
GRAPH_DEBUG("After ConstantPropagation, before ConstantPooling\n", *graph);
891+
883892
ConstantPooling(graph);
884893
GRAPH_DEBUG("After ConstantPooling\n", *graph);
885894

torch/csrc/jit/runtime/graph_executor_impl.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ namespace jit {
3131

3232
void packGradient(const Gradient& gradient, Node* dnode);
3333
bool needsGradient(const std::shared_ptr<const Graph>& graph);
34-
void runOptimization(std::shared_ptr<Graph>& graph, bool unroll = true);
34+
void runOptimization(
35+
std::shared_ptr<Graph>& graph,
36+
bool unroll = true,
37+
bool const_prop_user_classes = true);
3538
void runNondiffOptimization(
3639
std::shared_ptr<Graph>& graph,
3740
bool strict_fuser_check = false);

0 commit comments

Comments
 (0)