forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUpSample.cpp
37 lines (33 loc) · 1.43 KB
/
UpSample.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// Copyright 2004-present Facebook. All Rights Reserved.
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/UpSample.h>
#include <c10/util/irange.h>
#include <c10/util/TypeCast.h>
namespace at {
namespace native {
namespace upsample {
TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
c10::IntArrayRef input_size, // Full input tensor size.
at::OptionalIntArrayRef output_size,
c10::optional<c10::ArrayRef<double>> scale_factors) {
const auto spatial_dimensions = static_cast<int64_t>(input_size.size()) - 2;
if (output_size) {
TORCH_CHECK(!scale_factors, "Must specify exactly one of output_size and scale_factors");
TORCH_CHECK(static_cast<int64_t>(output_size->size()) == spatial_dimensions);
return {output_size->data(), output_size->data() + output_size->size()};
}
if (scale_factors) {
TORCH_CHECK(!output_size, "Must specify exactly one of output_size and scale_factors");
TORCH_CHECK(static_cast<int64_t>(scale_factors->size()) == spatial_dimensions);
c10::SmallVector<int64_t, 3> ret;
for (const auto i : c10::irange(spatial_dimensions)) {
const double odim = static_cast<double>(input_size[i+2]) * scale_factors.value()[i];
ret.push_back(c10::checked_convert<int64_t>(odim, "int64_t"));
}
return ret;
}
TORCH_CHECK(false, "Must specify exactly one of output_size and scale_factors");
}
} // namespace upsample
} // namespace native
} // namespace at