@@ -9,48 +9,44 @@ namespace conversion {
9
9
namespace converters {
10
10
namespace impl {
11
11
12
-
13
- auto bitwisenot TORCHTRT_UNUSED =
14
- RegisterNodeConversionPatterns ()
15
- .pattern({" aten::bitwise_not(Tensor self) -> Tensor" ,
16
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17
- auto in = args[0 ].ITensorOrFreeze (ctx);
18
- nvinfer1::ILayer* out;
19
-
20
- if (in->getType () == nvinfer1::DataType::kINT32 ) {
21
- // Integer case, using ~x = -x - 1
22
- auto neg_one = torch::tensor ({-1 }, util::TRTDataTypeToScalarType (in->getType ()));
23
- auto neg_one_const = tensor_to_const (ctx, neg_one);
24
- auto neg = add_elementwise (
25
- ctx,
26
- nvinfer1::ElementWiseOperation::kPROD ,
27
- in,
28
- neg_one_const,
29
- util::node_info (n) + std::string (" _Negation" ));
30
- TORCHTRT_CHECK (neg, " Unable to create prod layer from node: " << *n);
31
- out = add_elementwise (
32
- ctx,
33
- nvinfer1::ElementWiseOperation::kSUM ,
34
- neg->getOutput (0 ),
35
- neg_one_const,
36
- util::node_info (n) + std::string (" _SubOne" ));
37
- TORCHTRT_CHECK (out, " Unable to create sum layer from node: " << *n);
38
- } else if (in->getType () == nvinfer1::DataType::kBOOL ) {
39
- // Boolean case
40
- out = ctx->net ->addUnary (*in, nvinfer1::UnaryOperation::kNOT );
41
- TORCHTRT_CHECK (out, " Unable to create logical not layer from node: " << *n);
42
- } else {
43
- LOG_ERROR (" Input tensor must be 32 bit integer or boolean" );
44
- return false ;
45
- }
46
-
47
- out->setName (util::node_info (n).c_str ());
48
- auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out->getOutput (0 ));
49
- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
50
-
51
- return true ;
52
- }});
53
-
12
+ auto bitwise_not_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
13
+ {" aten::bitwise_not(Tensor self) -> Tensor" , [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
14
+ auto in = args[0 ].ITensorOrFreeze (ctx);
15
+ nvinfer1::ILayer* out;
16
+
17
+ if (in->getType () == nvinfer1::DataType::kINT32 ) {
18
+ // Integer case, using ~x = -x - 1
19
+ auto neg_one = torch::tensor ({-1 }, util::TRTDataTypeToScalarType (in->getType ()));
20
+ auto neg_one_const = tensor_to_const (ctx, neg_one);
21
+ auto neg = add_elementwise (
22
+ ctx,
23
+ nvinfer1::ElementWiseOperation::kPROD ,
24
+ in,
25
+ neg_one_const,
26
+ util::node_info (n) + std::string (" _Negation" ));
27
+ TORCHTRT_CHECK (neg, " Unable to create prod layer from node: " << *n);
28
+ out = add_elementwise (
29
+ ctx,
30
+ nvinfer1::ElementWiseOperation::kSUM ,
31
+ neg->getOutput (0 ),
32
+ neg_one_const,
33
+ util::node_info (n) + std::string (" _SubOne" ));
34
+ TORCHTRT_CHECK (out, " Unable to create sum layer from node: " << *n);
35
+ } else if (in->getType () == nvinfer1::DataType::kBOOL ) {
36
+ // Boolean case
37
+ out = ctx->net ->addUnary (*in, nvinfer1::UnaryOperation::kNOT );
38
+ TORCHTRT_CHECK (out, " Unable to create logical not layer from node: " << *n);
39
+ } else {
40
+ LOG_ERROR (" Input tensor must be 32 bit integer or boolean" );
41
+ return false ;
42
+ }
43
+
44
+ out->setName (util::node_info (n).c_str ());
45
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out->getOutput (0 ));
46
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
47
+
48
+ return true ;
49
+ }});
54
50
55
51
} // namespace impl
56
52
} // namespace converters
0 commit comments