forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTestOps.cpp
128 lines (108 loc) · 3.95 KB
/
TestOps.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// Copyright 2004-present Facebook. All Rights Reserved.
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/FunctionalInverses.h>
#include <ATen/ScalarOps.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_test_ambiguous_defaults_native.h>
#include <ATen/ops/_test_autograd_multiple_dispatch_native.h>
#include <ATen/ops/_test_autograd_multiple_dispatch_view_native.h>
#include <ATen/ops/_test_check_tensor_native.h>
#include <ATen/ops/_test_optional_filled_intlist_native.h>
#include <ATen/ops/_test_optional_floatlist_native.h>
#include <ATen/ops/_test_optional_intlist_native.h>
#include <ATen/ops/_test_string_default_native.h>
#include <ATen/ops/_test_warn_in_autograd_native.h>
#include <ATen/ops/empty_like.h>
#endif
#include <c10/util/irange.h>
namespace at {
namespace native {
/// If addends is nullopt, return values.
/// Else, return a new tensor containing the elementwise sums.
Tensor _test_optional_intlist(
const Tensor& values,
at::OptionalIntArrayRef addends) {
if (!addends) {
return values;
}
TORCH_CHECK(values.dim() == 1);
Tensor output = at::empty_like(values);
auto inp = values.accessor<int,1>();
auto out = output.accessor<int,1>();
for (const auto i : c10::irange(values.size(0))) {
out[i] = inp[i] + addends->at(i);
}
return output;
}
/// If addends is nullopt, return values.
/// Else, return a new tensor containing the elementwise sums.
Tensor _test_optional_floatlist(
const Tensor& values,
c10::optional<ArrayRef<double>> addends) {
if (!addends) {
return values;
}
TORCH_CHECK(values.dim() == 1);
Tensor output = at::empty_like(values);
auto inp = values.accessor<float,1>();
auto out = output.accessor<float,1>();
for (const auto i : c10::irange(values.size(0))) {
out[i] = inp[i] + addends->at(i);
}
return output;
}
// Test default strings can handle escape sequences properly (although commas are broken)
Tensor _test_string_default(const Tensor& dummy, c10::string_view a, c10::string_view b) {
const c10::string_view expect = "\"'\\";
TORCH_CHECK(a == expect, "Default A failed");
TORCH_CHECK(b == expect, "Default B failed");
return dummy;
}
// Test that overloads with ambiguity created by defaulted parameters work.
// The operator declared first should have priority always
// Overload a
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, int64_t b) {
TORCH_CHECK(a == 1);
TORCH_CHECK(b == 1);
return c10::scalar_to_tensor(1);
}
// Overload b
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, c10::string_view b) {
TORCH_CHECK(a == 2);
TORCH_CHECK(b == "2");
return c10::scalar_to_tensor(2);
}
Tensor _test_warn_in_autograd(const Tensor &self) {
return self.clone();
}
// Test registration of per-dispatch-key derivatives in derivatives.yaml.
// See derivatives.yaml for dummy registrations.
Tensor _test_autograd_multiple_dispatch_fullcoverage(const Tensor &self) {
return self.clone();
}
Tensor _test_autograd_multiple_dispatch_ntonly(const Tensor &self, bool b) {
return self.clone();
}
// Test derivative dispatch registration for view_copy ops
Tensor _test_autograd_multiple_dispatch_view(const Tensor &self) {
return self.view(-1);
}
Tensor _test_check_tensor(const Tensor& self) {
TORCH_CHECK_TENSOR_ALL(self, "Test message for TORCH_CHECK_TENSOR_ALL");
return self.clone();
}
} // namespace native
namespace functionalization {
// view_copy ops must have a functional inverse registered
Tensor FunctionalInverses::_test_autograd_multiple_dispatch_view_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) {
TORCH_INTERNAL_ASSERT(false,
"Attempted to call _test_autograd_multiple_dispatch_view_copy_inverse() during the functionalization pass. ",
"This function is for testing only and should never be called.");
return Tensor();
}
} // namespace functionalization
} // namespace at