Skip to content

Commit 4a0d17b

Browse files
swolchokfacebook-github-bot
authored andcommitted
[PyTorch][codemod] Replace immediately-dereferenced expect calls w/expectRef (pytorch#50228)
Summary: Pull Request resolved: pytorch#50228 `fastmod -m 'expect(<((at|c10)::)?\w+Type>\(\)\s*)->' 'expectRef${1}.'` Presuming it builds, this is a safe change: the result of `expect()` wasn't being saved anywhere, so we didn't need it, so we can take a reference instead of a new `shared_ptr`. ghstack-source-id: 119782961 Test Plan: CI Reviewed By: SplitInfinity Differential Revision: D25837374 fbshipit-source-id: 86757b70b1520e3dbaa141001e7976400cdd3b08
1 parent c6cb632 commit 4a0d17b

27 files changed

+67
-65
lines changed

aten/src/ATen/core/ivalue.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ std::ostream& printMaybeAnnotatedList(
417417
std::ostream& out,
418418
const IValue& the_list,
419419
IValueFormatter formatter) {
420-
auto list_elem_type = the_list.type()->expect<ListType>()->getElementType();
420+
auto list_elem_type = the_list.type()->expectRef<ListType>().getElementType();
421421
if (the_list.toListRef().size() == 0 ||
422422
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
423423
out << "annotate(" << the_list.type()->annotation_str() << ", ";

aten/src/ATen/core/type.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ c10::optional<TypePtr> unifyTypesImpl(const TypePtr& t1, const TypePtr& t2) {
252252

253253
// Handle non-container types which do not subtype each other and unify
254254
if (t1->kind() == TensorType::Kind && t2->kind() == TensorType::Kind) {
255-
return t1->expect<TensorType>()->merge(*t2->expect<TensorType>());
255+
return t1->expectRef<TensorType>().merge(*t2->expect<TensorType>());
256256
}
257257

258258
if (t1->isSubtypeOf(NoneType::get()) && !t2->isSubtypeOf(NoneType::get())) {
@@ -1317,7 +1317,7 @@ size_t ClassType::addAttribute(
13171317
TORCH_CHECK(
13181318
(type->kind() == TensorType::Kind) ||
13191319
(type->kind() == OptionalType::Kind &&
1320-
type->expect<OptionalType>()->getElementType()->kind() ==
1320+
type->expectRef<OptionalType>().getElementType()->kind() ==
13211321
TensorType::Kind) ||
13221322
(type->kind() == NoneType::Kind),
13231323
"Expecting parameter or buffer to have either None, Tensor or Optional[Tensor] type, but got: ",

test/cpp/jit/test_jit_type.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ TEST(JitTypeTest, UnifyTypes) {
1818
TORCH_INTERNAL_ASSERT(!tensor->isSubtypeOf(opt_bool_tensor));
1919
auto unified = unifyTypes(opt_bool_tensor, tensor);
2020
TORCH_INTERNAL_ASSERT(unified);
21-
auto elem = (*unified)->expect<OptionalType>()->getElementType();
21+
auto elem = (*unified)->expectRef<OptionalType>().getElementType();
2222
TORCH_INTERNAL_ASSERT(elem->isSubtypeOf(TensorType::get()));
2323

2424
auto opt_tuple_none_int = OptionalType::create(

test/cpp/jit/test_misc.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -504,18 +504,18 @@ TEST(SchemaParserTest, NestedArrays) {
504504
ASSERT_TRUE(IntType::get()->isSubtypeOf(s.arguments()
505505
.at(0)
506506
.type()
507-
->expect<ListType>()
508-
->getElementType()
509-
->expect<ListType>()
510-
->getElementType()));
507+
->expectRef<ListType>()
508+
.getElementType()
509+
->expectRef<ListType>()
510+
.getElementType()));
511511
auto s2 = parseSchema("at::what(int[][] foo) -> ()");
512512
ASSERT_TRUE(IntType::get()->isSubtypeOf(s2.arguments()
513513
.at(0)
514514
.type()
515-
->expect<ListType>()
516-
->getElementType()
517-
->expect<ListType>()
518-
->getElementType()));
515+
->expectRef<ListType>()
516+
.getElementType()
517+
->expectRef<ListType>()
518+
.getElementType()));
519519
}
520520

521521
TEST(SchemaParserTest, NamedReturns) {
@@ -531,7 +531,7 @@ TEST(SchemaParserTest, Futures) {
531531
// futures
532532
auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
533533
ASSERT_TRUE(IntType::get()->isSubtypeOf(
534-
s4.arguments().at(0).type()->expect<FutureType>()->getElementType()));
534+
s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
535535
}
536536

537537
TEST(SchemaParserTest, AnnotatedAliasSets) {
@@ -1751,7 +1751,7 @@ TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
17511751
});
17521752
ASSERT_NE(guard, nodes.end());
17531753
ASSERT_EQ(
1754-
guard->input()->type()->expect<TensorType>()->sizes().size(),
1754+
guard->input()->type()->expectRef<TensorType>().sizes().size(),
17551755
c10::nullopt);
17561756
checkShape(*guard, {2, 3}, false);
17571757
auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };

torch/csrc/jit/codegen/cuda/partition.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ static c10::optional<c10::Device> getDevice(const Value* value) {
2020
// not tensor type, return false as the op is not outputing scalar.
2121
return c10::nullopt;
2222
}
23-
return value->type()->expect<TensorType>()->device();
23+
return value->type()->expectRef<TensorType>().device();
2424
}
2525

2626
static c10::optional<c10::Device> getDevice(const Node* node) {

torch/csrc/jit/codegen/fuser/codegen.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ static std::string variableType(const std::shared_ptr<c10::Type>& t) {
9191
return "double";
9292
} else if (t->kind() == TypeKind::BoolType) {
9393
return "bool";
94-
} else if (auto scalar_type = t->expect<TensorType>()->scalarType()) {
94+
} else if (auto scalar_type = t->expectRef<TensorType>().scalarType()) {
9595
return calcScalarTypeName(*scalar_type);
9696
}
9797
// something went wrong with the type analysis during shape propagation
@@ -118,7 +118,7 @@ static std::string typeCastedValueName(
118118
} else if (t->kind() == TypeKind::NoneType) {
119119
// Support None value for optional arguments like memory format
120120
return vn;
121-
} else if (auto scalar_type = t->expect<TensorType>()->scalarType()) {
121+
} else if (auto scalar_type = t->expectRef<TensorType>().scalarType()) {
122122
if (*scalar_type != outtype) {
123123
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
124124
}
@@ -261,7 +261,7 @@ static std::string encodeRHS(const Node* n) {
261261
} else {
262262
size_t i = 0;
263263

264-
auto outtype = n->output()->type()->expect<TensorType>()->scalarType();
264+
auto outtype = n->output()->type()->expectRef<TensorType>().scalarType();
265265
TORCH_INTERNAL_ASSERT(outtype);
266266

267267
for (auto in : n->inputs()) {

torch/csrc/jit/codegen/fuser/compiler.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ std::shared_ptr<FusedKernel> compileKernel(
260260
sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size();
261261
}
262262

263-
auto scalar_type = o->type()->expect<TensorType>()->scalarType();
263+
auto scalar_type = o->type()->expectRef<TensorType>().scalarType();
264264
TORCH_INTERNAL_ASSERT(scalar_type);
265265
auto type = TensorType::createContiguous(*scalar_type, device, sizes);
266266
output_desc.emplace_back(type);

torch/csrc/jit/frontend/ir_emitter.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -2976,7 +2976,7 @@ struct to_ir {
29762976
return std::make_shared<SimpleValue>(
29772977
graph
29782978
->insertNode(graph->createList(
2979-
type->expect<ListType>()->getElementType(), {}))
2979+
type->expectRef<ListType>().getElementType(), {}))
29802980
->output());
29812981
}
29822982
// list(iter) desugars to [_elem for _elem in iter]
@@ -3376,7 +3376,7 @@ struct to_ir {
33763376
TypePtr elem_type = TensorType::get();
33773377
if (type_hint) {
33783378
if (type_hint->kind() == TypeKind::ListType) {
3379-
elem_type = type_hint->expect<ListType>()->getElementType();
3379+
elem_type = type_hint->expectRef<ListType>().getElementType();
33803380
} else {
33813381
// If the type hint was not a List[T] throw an error
33823382
throw ErrorReport(tree)

torch/csrc/jit/frontend/schema_matching.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Value* tryConvertToType(
7272
if (convertibleToList(value->type(), unwrapOptional(concrete_type))) {
7373
auto unpacked = createTupleUnpack(value);
7474
auto elem_type =
75-
unwrapOptional(concrete_type)->expect<ListType>()->getElementType();
75+
unwrapOptional(concrete_type)->expectRef<ListType>().getElementType();
7676
value = graph.insertNode(graph.createList(elem_type, unpacked))->output();
7777
}
7878

@@ -340,8 +340,9 @@ static c10::optional<MatchedSchema> tryMatchSchema(
340340
// The actual cannot already be a list
341341
if (actual_type->kind() != TypeKind::ListType &&
342342
!convertibleToList(actual_type, unwrapOptional(arg.type()))) {
343-
auto formal_type =
344-
unwrapOptional(arg.type())->expect<ListType>()->getElementType();
343+
auto formal_type = unwrapOptional(arg.type())
344+
->expectRef<ListType>()
345+
.getElementType();
345346

346347
Value* list = tryCreateList(
347348
formal_type,

torch/csrc/jit/ir/alias_analysis.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class MutableTypePtrHelper {
6363
}
6464
case TypeKind::TupleType: {
6565
std::vector<TypePtr> mutable_types;
66-
for (const auto& elem : type->expect<TupleType>()->elements()) {
66+
for (const auto& elem : type->expectRef<TupleType>().elements()) {
6767
if (auto mut_elem = getMutableType(elem)) {
6868
mutable_types.push_back(*mut_elem);
6969
}
@@ -511,7 +511,7 @@ void AliasDb::analyzeImpl(Node* node) {
511511
case prim::GetAttr:
512512
if (isFrozen_ && node->kind() == prim::GetAttr) {
513513
auto& ty = node->input()->type();
514-
if (ty->expect<ClassType>()->is_module()) {
514+
if (ty->expectRef<ClassType>().is_module()) {
515515
return analyzeCreator(node);
516516
}
517517
}

torch/csrc/jit/passes/clear_undefinedness.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ void clearUndefinedness(Value* o) {
1010
o->setType(TensorType::get());
1111
} else if (
1212
o->type()->kind() == ListType::Kind &&
13-
o->type()->expect<ListType>()->getElementType()->kind() ==
13+
o->type()->expectRef<ListType>().getElementType()->kind() ==
1414
TensorType::Kind) {
1515
o->setType(ListType::create(TensorType::get()));
1616
}

torch/csrc/jit/passes/decompose_ops.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ bool isDecomposableNorm(Node* normalize_op) {
3939
if (!input->type()->isSubtypeOf(TensorType::get())) {
4040
return false;
4141
}
42-
auto device = input->type()->expect<TensorType>()->device();
42+
auto device = input->type()->expectRef<TensorType>().device();
4343
// As of now, we do the decomposition for batchnorm/layernorm on GPU device
4444
// only
4545
if (!device || (*device).is_cpu()) {

torch/csrc/jit/passes/freeze_module.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class AttributePropagator {
156156
Module& attrModule,
157157
std::shared_ptr<Graph>& graph) {
158158
if (!input->type()->cast<InterfaceType>() &&
159-
!input->type()->expect<ClassType>()->is_module()) {
159+
!input->type()->expectRef<ClassType>().is_module()) {
160160
return false;
161161
}
162162

@@ -425,7 +425,7 @@ class AttributePropagator {
425425
if (!findConstantAttr(input, name, attrModule, graph)) {
426426
GRAPH_DEBUG(
427427
input->type()->cast<InterfaceType>() ||
428-
input->type()->expect<ClassType>()->is_module()
428+
input->type()->expectRef<ClassType>().is_module()
429429
? "attribute: " + name + " is mutable."
430430
: "");
431431
continue;

torch/csrc/jit/passes/graph_fuser.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ struct GraphFuser {
177177
if (!v->type()->isSubtypeOf(TensorType::get())) {
178178
return true;
179179
}
180-
auto device = v->type()->expect<TensorType>()->device();
180+
auto device = v->type()->expectRef<TensorType>().device();
181181
if (!device) {
182182
return !strict_fuser_check;
183183
}

torch/csrc/jit/passes/graph_rewrite_helper.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace jit {
99
namespace graph_rewrite_helper {
1010

1111
std::string getFuncName(Value* func_value) {
12-
auto func = func_value->type()->expect<FunctionType>()->function();
12+
auto func = func_value->type()->expectRef<FunctionType>().function();
1313
const auto& qname = func->qualname();
1414
const auto& name = qname.qualifiedName();
1515
auto rdot_idx = name.rfind('.');

torch/csrc/jit/passes/guard_elimination.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ struct GuardElimination {
242242
size_t i = 0;
243243
for (auto input : n->inputs()) {
244244
if ((input->node()->kind() == prim::Guard &&
245-
!input->type()->expect<TensorType>()->isSummarized()) ||
245+
!input->type()->expectRef<TensorType>().isSummarized()) ||
246246
input->node()->kind() == prim::Constant ||
247247
(allow_numbers && input->type()->isSubtypeOf(NumberType::get())) ||
248248
except.count(i) != 0) {
@@ -377,7 +377,7 @@ struct GuardElimination {
377377
case aten::conv3d:
378378
return checkInputs(n, std::unordered_set<size_t>{2, 6}, false);
379379
case aten::slice:
380-
return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
380+
return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
381381
// check that the dimension argument is constant
382382
n->input(1)->node()->kind() == prim::Constant &&
383383
// the start offset is constant
@@ -389,7 +389,7 @@ struct GuardElimination {
389389
case aten::max_pool1d:
390390
case aten::max_pool2d:
391391
case aten::max_pool3d:
392-
return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
392+
return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
393393
// check that the kernel size is constant
394394
n->input(1)->node()->kind() == prim::Constant &&
395395
// check that the stride is constant
@@ -402,7 +402,7 @@ struct GuardElimination {
402402
n->input(5)->node()->kind() == prim::Constant;
403403
case aten::unsqueeze:
404404
// check that the dimension argument is constant
405-
return !n->input(0)->type()->expect<TensorType>()->isSummarized() &&
405+
return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
406406
n->input(1)->node()->kind() == prim::Constant;
407407
case aten::cat:
408408
// check that the dimension argument is constant
@@ -427,8 +427,8 @@ struct GuardElimination {
427427
// aten::size is effectively a constant
428428
if (asize->input()
429429
->type()
430-
->expect<TensorType>()
431-
->sizes()
430+
->expectRef<TensorType>()
431+
.sizes()
432432
.concrete_sizes()) {
433433
return true;
434434
}

torch/csrc/jit/passes/onnx/peephole.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,14 @@ void fuseBroadcast(Block* b) {
138138
// Not all broadcasts are supported by ONNX broadcast.
139139
c10::optional<size_t> axis = fusibleExpandTo(
140140
unexpanded_input->type()
141-
->expect<TensorType>()
142-
->sizes()
141+
->expectRef<TensorType>()
142+
.sizes()
143143
.concrete_sizes()
144144
.value(), // from
145145
n->output()
146146
->type()
147-
->expect<TensorType>()
148-
->sizes()
147+
->expectRef<TensorType>()
148+
.sizes()
149149
.concrete_sizes()
150150
.value()); // to
151151
if (axis == c10::nullopt)

torch/csrc/jit/passes/peephole.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -425,11 +425,11 @@ void FuseAddMM(Block* block) {
425425

426426
// Attempts to find a matrix with a defined scalar type to type as
427427
auto* type_as_mat = mat1;
428-
if (!type_as_mat->type()->expect<TensorType>()->scalarType()) {
428+
if (!type_as_mat->type()->expectRef<TensorType>().scalarType()) {
429429
type_as_mat = mat2;
430430
}
431431
auto mat_scalar_type =
432-
type_as_mat->type()->expect<TensorType>()->scalarType();
432+
type_as_mat->type()->expectRef<TensorType>().scalarType();
433433

434434
// we can't use type_as if we don't know the target type (mm), the
435435
// bias needs to be coerced to

torch/csrc/jit/passes/quantization/helper.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ bool useQuantizable(const Use& use, QuantType quant_type) {
532532

533533
std::shared_ptr<Graph> getCallFunctionGraph(Node* n) {
534534
auto* func_node = n->input(0)->node();
535-
auto func = func_node->output()->type()->expect<FunctionType>()->function();
535+
auto func = func_node->output()->type()->expectRef<FunctionType>().function();
536536
TORCH_CHECK(
537537
func->isGraphFunction(), "Quantization only works for graph function");
538538
return func->graph();

torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ void UpdateDifferentiableGraphRequiresGrad(
2020
n->ty_(
2121
attr::profiled_type,
2222
n->ty(attr::profiled_type)
23-
->expect<TensorType>()
24-
->withRequiresGrad(new_requires_grad));
23+
->expectRef<TensorType>()
24+
.withRequiresGrad(new_requires_grad));
2525
}
2626
for (Block* b : n->blocks()) {
2727
UpdateDifferentiableGraphRequiresGrad(b, new_requires_grad);

torch/csrc/jit/python/pybind_utils.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
7878
return static_cast<int64_t>(stream->cdata);
7979
}
8080
case TypeKind::ListType: {
81-
const auto& elem_type = type->expect<ListType>()->getElementType();
81+
const auto& elem_type = type->expectRef<ListType>().getElementType();
8282
switch (elem_type->kind()) {
8383
// allows single int/float to be broadcasted to a fixed size list
8484
case TypeKind::IntType:
@@ -127,7 +127,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional<int32_t> N) {
127127
// return an IValue() to denote a NoneType
128128
return {};
129129
}
130-
return toIValue(obj, type->expect<OptionalType>()->getElementType());
130+
return toIValue(obj, type->expectRef<OptionalType>().getElementType());
131131
}
132132
case TypeKind::ClassType: {
133133
auto classType = type->expect<ClassType>();

torch/csrc/jit/python/python_custom_class.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ void initPythonCustomClassBindings(PyObject* module) {
3131
py::class_<ScriptClass>(m, "ScriptClass")
3232
.def("__call__", &ScriptClass::__call__)
3333
.def_property_readonly("__doc__", [](const ScriptClass& self) {
34-
return self.class_type_.type_->expect<ClassType>()->doc_string();
34+
return self.class_type_.type_->expectRef<ClassType>().doc_string();
3535
});
3636

3737
// This function returns a ScriptClass that wraps the constructor

0 commit comments

Comments
 (0)