12
12
// --- Constant expressions
13
13
// =============================================================================
14
14
15
- constexpr auto get_dtype (torch_data_t dtype) {
15
+ // Mapping from FTorch device_data_t to libtorch Dtype
16
+ constexpr auto get_libtorch_dtype (torch_data_t dtype) {
16
17
switch (dtype) {
17
18
case torch_kUInt8:
18
19
std::cerr << " [WARNING]: uint8 not supported in Fortran" << std::endl;
@@ -40,7 +41,37 @@ constexpr auto get_dtype(torch_data_t dtype) {
40
41
}
41
42
}
42
43
43
- const auto get_device (torch_device_t device_type, int device_index) {
44
+ // Mapping from libtorch Dtype to FTorch device_data_t
45
+ torch_data_t get_ftorch_dtype (caffe2::TypeMeta dtype) {
46
+ if (dtype == torch::kUInt8 ) {
47
+ std::cerr << " [WARNING]: uint8 not supported in Fortran" << std::endl;
48
+ // See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html
49
+ exit (EXIT_FAILURE);
50
+ } else if (dtype == torch::kInt8 ) {
51
+ return torch_kInt8;
52
+ } else if (dtype == torch::kInt16 ) {
53
+ return torch_kInt16;
54
+ } else if (dtype == torch::kInt32 ) {
55
+ return torch_kInt32;
56
+ } else if (dtype == torch::kInt64 ) {
57
+ return torch_kInt64;
58
+ } else if (dtype == torch::kFloat16 ) {
59
+ std::cerr << " [WARNING]: float16 not supported in Fortran" << std::endl;
60
+ // See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html
61
+ exit (EXIT_FAILURE);
62
+ } else if (dtype == torch::kFloat32 ) {
63
+ return torch_kFloat32;
64
+ } else if (dtype == torch::kFloat64 ) {
65
+ return torch_kFloat64;
66
+ } else {
67
+ std::cerr << " [ERROR]: data type " << dtype << " not supported in Fortran"
68
+ << std::endl;
69
+ exit (EXIT_FAILURE);
70
+ }
71
+ }
72
+
73
+ // Mapping from FTorch device_type_t to libtorch DeviceType
74
+ const auto get_libtorch_device (torch_device_t device_type, int device_index) {
44
75
switch (device_type) {
45
76
case torch_kCPU:
46
77
if (device_index != -1 ) {
@@ -65,6 +96,20 @@ const auto get_device(torch_device_t device_type, int device_index) {
65
96
}
66
97
}
67
98
99
+ // Mapping from libtorch DeviceType to FTorch device_type_t
100
+ const torch_device_t get_ftorch_device (torch::DeviceType device_type) {
101
+ switch (device_type) {
102
+ case torch::kCPU :
103
+ return torch_kCPU;
104
+ case torch::kCUDA :
105
+ return torch_kCUDA;
106
+ default :
107
+ std::cerr << " [ERROR]: device type " << device_type << " not implemented in FTorch"
108
+ << std::endl;
109
+ exit (EXIT_FAILURE);
110
+ }
111
+ }
112
+
68
113
// =============================================================================
69
114
// --- Functions for constructing tensors
70
115
// =============================================================================
@@ -78,8 +123,8 @@ torch_tensor_t torch_empty(int ndim, const int64_t *shape, torch_data_t dtype,
78
123
// This doesn't throw if shape and dimensions are incompatible
79
124
c10::IntArrayRef vshape (shape, ndim);
80
125
tensor = new torch::Tensor;
81
- *tensor = torch::empty (vshape, torch::dtype (get_dtype (dtype)))
82
- .to (get_device (device_type, device_index));
126
+ *tensor = torch::empty (vshape, torch::dtype (get_libtorch_dtype (dtype)))
127
+ .to (get_libtorch_device (device_type, device_index));
83
128
} catch (const torch::Error &e) {
84
129
std::cerr << " [ERROR]: " << e.msg () << std::endl;
85
130
delete tensor;
@@ -101,8 +146,8 @@ torch_tensor_t torch_zeros(int ndim, const int64_t *shape, torch_data_t dtype,
101
146
// This doesn't throw if shape and dimensions are incompatible
102
147
c10::IntArrayRef vshape (shape, ndim);
103
148
tensor = new torch::Tensor;
104
- *tensor = torch::zeros (vshape, torch::dtype (get_dtype (dtype)))
105
- .to (get_device (device_type, device_index));
149
+ *tensor = torch::zeros (vshape, torch::dtype (get_libtorch_dtype (dtype)))
150
+ .to (get_libtorch_device (device_type, device_index));
106
151
} catch (const torch::Error &e) {
107
152
std::cerr << " [ERROR]: " << e.msg () << std::endl;
108
153
delete tensor;
@@ -124,8 +169,8 @@ torch_tensor_t torch_ones(int ndim, const int64_t *shape, torch_data_t dtype,
124
169
// This doesn't throw if shape and dimensions are incompatible
125
170
c10::IntArrayRef vshape (shape, ndim);
126
171
tensor = new torch::Tensor;
127
- *tensor = torch::ones (vshape, torch::dtype (get_dtype (dtype)))
128
- .to (get_device (device_type, device_index));
172
+ *tensor = torch::ones (vshape, torch::dtype (get_libtorch_dtype (dtype)))
173
+ .to (get_libtorch_device (device_type, device_index));
129
174
} catch (const torch::Error &e) {
130
175
std::cerr << " [ERROR]: " << e.msg () << std::endl;
131
176
delete tensor;
@@ -152,8 +197,9 @@ torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape,
152
197
c10::IntArrayRef vshape (shape, ndim);
153
198
c10::IntArrayRef vstrides (strides, ndim);
154
199
tensor = new torch::Tensor;
155
- *tensor = torch::from_blob (data, vshape, vstrides, torch::dtype (get_dtype (dtype)))
156
- .to (get_device (device_type, device_index));
200
+ *tensor = torch::from_blob (data, vshape, vstrides,
201
+ torch::dtype (get_libtorch_dtype (dtype)))
202
+ .to (get_libtorch_device (device_type, device_index));
157
203
158
204
} catch (const torch::Error &e) {
159
205
std::cerr << " [ERROR]: " << e.msg () << std::endl;
@@ -214,11 +260,6 @@ void torch_tensor_print(const torch_tensor_t tensor) {
214
260
std::cout << *t << std::endl;
215
261
}
216
262
217
- int torch_tensor_get_device_index (const torch_tensor_t tensor) {
218
- auto t = reinterpret_cast <torch::Tensor *>(tensor);
219
- return t->device ().index ();
220
- }
221
-
222
263
int torch_tensor_get_rank (const torch_tensor_t tensor) {
223
264
auto t = reinterpret_cast <torch::Tensor *>(tensor);
224
265
return t->sizes ().size ();
@@ -236,6 +277,21 @@ const long long int *torch_tensor_get_sizes(const torch_tensor_t tensor) {
236
277
}
237
278
#endif
238
279
280
+ torch_data_t torch_tensor_get_dtype (const torch_tensor_t tensor) {
281
+ auto t = reinterpret_cast <torch::Tensor *>(tensor);
282
+ return get_ftorch_dtype (t->dtype ());
283
+ }
284
+
285
+ torch_device_t torch_tensor_get_device_type (const torch_tensor_t tensor) {
286
+ auto t = reinterpret_cast <torch::Tensor *>(tensor);
287
+ return get_ftorch_device (t->device ().type ());
288
+ }
289
+
290
+ int torch_tensor_get_device_index (const torch_tensor_t tensor) {
291
+ auto t = reinterpret_cast <torch::Tensor *>(tensor);
292
+ return t->device ().index ();
293
+ }
294
+
239
295
// =====================================================================================
240
296
// --- Functions for deallocating tensors
241
297
// =====================================================================================
@@ -350,7 +406,8 @@ torch_jit_script_module_t torch_jit_load(const char *filename,
350
406
torch::jit::script::Module *module = nullptr ;
351
407
try {
352
408
module = new torch::jit::script::Module;
353
- *module = torch::jit::load (filename, get_device (device_type, device_index));
409
+ *module =
410
+ torch::jit::load (filename, get_libtorch_device (device_type, device_index));
354
411
} catch (const torch::Error &e) {
355
412
std::cerr << " [ERROR]: " << e.msg () << std::endl;
356
413
delete module;
0 commit comments