Skip to content

Commit 515ec24

Browse files
committed
Add get_dtype and get_device_type methods for torch_tensor (#251)
* Add dtype and device_type attrs for torch_tensor; implement getters * Rename get_<rank/shape> as torch_tensor_get_<rank/shape> for consistency * Make torch_tensor_get_device_index a class method * Add unit test for torch_tensor_get_device_type on CPU * Add unit test for torch_tensor_get_device_type on CUDA device * Add unit test for torch_tensor_get_dtype * Make use of getters for device type and index * Alias methods to be less verbose * Implement get_device_type on C++ side; introduce get_ftorch_device * Implement get_dtype on C++ side; introduce get_ftorch_dtype * Drop dtype/device type attributes
1 parent 3b67ae8 commit 515ec24

6 files changed

+408
-152
lines changed

src/ctorch.cpp

+73-16
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
// --- Constant expressions
1313
// =============================================================================
1414

15-
constexpr auto get_dtype(torch_data_t dtype) {
15+
// Mapping from FTorch device_data_t to libtorch Dtype
16+
constexpr auto get_libtorch_dtype(torch_data_t dtype) {
1617
switch (dtype) {
1718
case torch_kUInt8:
1819
std::cerr << "[WARNING]: uint8 not supported in Fortran" << std::endl;
@@ -40,7 +41,37 @@ constexpr auto get_dtype(torch_data_t dtype) {
4041
}
4142
}
4243

43-
const auto get_device(torch_device_t device_type, int device_index) {
44+
// Mapping from libtorch Dtype to FTorch device_data_t
45+
torch_data_t get_ftorch_dtype(caffe2::TypeMeta dtype) {
46+
if (dtype == torch::kUInt8) {
47+
std::cerr << "[WARNING]: uint8 not supported in Fortran" << std::endl;
48+
// See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html
49+
exit(EXIT_FAILURE);
50+
} else if (dtype == torch::kInt8) {
51+
return torch_kInt8;
52+
} else if (dtype == torch::kInt16) {
53+
return torch_kInt16;
54+
} else if (dtype == torch::kInt32) {
55+
return torch_kInt32;
56+
} else if (dtype == torch::kInt64) {
57+
return torch_kInt64;
58+
} else if (dtype == torch::kFloat16) {
59+
std::cerr << "[WARNING]: float16 not supported in Fortran" << std::endl;
60+
// See https://gcc.gnu.org/onlinedocs/gfortran/ISO_005fFORTRAN_005fENV.html
61+
exit(EXIT_FAILURE);
62+
} else if (dtype == torch::kFloat32) {
63+
return torch_kFloat32;
64+
} else if (dtype == torch::kFloat64) {
65+
return torch_kFloat64;
66+
} else {
67+
std::cerr << "[ERROR]: data type " << dtype << " not supported in Fortran"
68+
<< std::endl;
69+
exit(EXIT_FAILURE);
70+
}
71+
}
72+
73+
// Mapping from FTorch device_type_t to libtorch DeviceType
74+
const auto get_libtorch_device(torch_device_t device_type, int device_index) {
4475
switch (device_type) {
4576
case torch_kCPU:
4677
if (device_index != -1) {
@@ -65,6 +96,20 @@ const auto get_device(torch_device_t device_type, int device_index) {
6596
}
6697
}
6798

99+
// Mapping from libtorch DeviceType to FTorch device_type_t
100+
const torch_device_t get_ftorch_device(torch::DeviceType device_type) {
101+
switch (device_type) {
102+
case torch::kCPU:
103+
return torch_kCPU;
104+
case torch::kCUDA:
105+
return torch_kCUDA;
106+
default:
107+
std::cerr << "[ERROR]: device type " << device_type << " not implemented in FTorch"
108+
<< std::endl;
109+
exit(EXIT_FAILURE);
110+
}
111+
}
112+
68113
// =============================================================================
69114
// --- Functions for constructing tensors
70115
// =============================================================================
@@ -78,8 +123,8 @@ torch_tensor_t torch_empty(int ndim, const int64_t *shape, torch_data_t dtype,
78123
// This doesn't throw if shape and dimensions are incompatible
79124
c10::IntArrayRef vshape(shape, ndim);
80125
tensor = new torch::Tensor;
81-
*tensor = torch::empty(vshape, torch::dtype(get_dtype(dtype)))
82-
.to(get_device(device_type, device_index));
126+
*tensor = torch::empty(vshape, torch::dtype(get_libtorch_dtype(dtype)))
127+
.to(get_libtorch_device(device_type, device_index));
83128
} catch (const torch::Error &e) {
84129
std::cerr << "[ERROR]: " << e.msg() << std::endl;
85130
delete tensor;
@@ -101,8 +146,8 @@ torch_tensor_t torch_zeros(int ndim, const int64_t *shape, torch_data_t dtype,
101146
// This doesn't throw if shape and dimensions are incompatible
102147
c10::IntArrayRef vshape(shape, ndim);
103148
tensor = new torch::Tensor;
104-
*tensor = torch::zeros(vshape, torch::dtype(get_dtype(dtype)))
105-
.to(get_device(device_type, device_index));
149+
*tensor = torch::zeros(vshape, torch::dtype(get_libtorch_dtype(dtype)))
150+
.to(get_libtorch_device(device_type, device_index));
106151
} catch (const torch::Error &e) {
107152
std::cerr << "[ERROR]: " << e.msg() << std::endl;
108153
delete tensor;
@@ -124,8 +169,8 @@ torch_tensor_t torch_ones(int ndim, const int64_t *shape, torch_data_t dtype,
124169
// This doesn't throw if shape and dimensions are incompatible
125170
c10::IntArrayRef vshape(shape, ndim);
126171
tensor = new torch::Tensor;
127-
*tensor = torch::ones(vshape, torch::dtype(get_dtype(dtype)))
128-
.to(get_device(device_type, device_index));
172+
*tensor = torch::ones(vshape, torch::dtype(get_libtorch_dtype(dtype)))
173+
.to(get_libtorch_device(device_type, device_index));
129174
} catch (const torch::Error &e) {
130175
std::cerr << "[ERROR]: " << e.msg() << std::endl;
131176
delete tensor;
@@ -152,8 +197,9 @@ torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape,
152197
c10::IntArrayRef vshape(shape, ndim);
153198
c10::IntArrayRef vstrides(strides, ndim);
154199
tensor = new torch::Tensor;
155-
*tensor = torch::from_blob(data, vshape, vstrides, torch::dtype(get_dtype(dtype)))
156-
.to(get_device(device_type, device_index));
200+
*tensor = torch::from_blob(data, vshape, vstrides,
201+
torch::dtype(get_libtorch_dtype(dtype)))
202+
.to(get_libtorch_device(device_type, device_index));
157203

158204
} catch (const torch::Error &e) {
159205
std::cerr << "[ERROR]: " << e.msg() << std::endl;
@@ -214,11 +260,6 @@ void torch_tensor_print(const torch_tensor_t tensor) {
214260
std::cout << *t << std::endl;
215261
}
216262

217-
int torch_tensor_get_device_index(const torch_tensor_t tensor) {
218-
auto t = reinterpret_cast<torch::Tensor *>(tensor);
219-
return t->device().index();
220-
}
221-
222263
int torch_tensor_get_rank(const torch_tensor_t tensor) {
223264
auto t = reinterpret_cast<torch::Tensor *>(tensor);
224265
return t->sizes().size();
@@ -236,6 +277,21 @@ const long long int *torch_tensor_get_sizes(const torch_tensor_t tensor) {
236277
}
237278
#endif
238279

280+
torch_data_t torch_tensor_get_dtype(const torch_tensor_t tensor) {
281+
auto t = reinterpret_cast<torch::Tensor *>(tensor);
282+
return get_ftorch_dtype(t->dtype());
283+
}
284+
285+
torch_device_t torch_tensor_get_device_type(const torch_tensor_t tensor) {
286+
auto t = reinterpret_cast<torch::Tensor *>(tensor);
287+
return get_ftorch_device(t->device().type());
288+
}
289+
290+
int torch_tensor_get_device_index(const torch_tensor_t tensor) {
291+
auto t = reinterpret_cast<torch::Tensor *>(tensor);
292+
return t->device().index();
293+
}
294+
239295
// =====================================================================================
240296
// --- Functions for deallocating tensors
241297
// =====================================================================================
@@ -350,7 +406,8 @@ torch_jit_script_module_t torch_jit_load(const char *filename,
350406
torch::jit::script::Module *module = nullptr;
351407
try {
352408
module = new torch::jit::script::Module;
353-
*module = torch::jit::load(filename, get_device(device_type, device_index));
409+
*module =
410+
torch::jit::load(filename, get_libtorch_device(device_type, device_index));
354411
} catch (const torch::Error &e) {
355412
std::cerr << "[ERROR]: " << e.msg() << std::endl;
356413
delete module;

src/ctorch.h

+21-7
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,6 @@ EXPORT_C void *torch_to_blob(const torch_tensor_t tensor, const torch_data_t dty
120120
*/
121121
EXPORT_C void torch_tensor_print(const torch_tensor_t tensor);
122122

123-
/**
124-
* Function to determine the device index of a Torch Tensor
125-
* @param Torch Tensor to determine the device index of
126-
* @return device index of the Torch Tensor
127-
*/
128-
EXPORT_C int torch_tensor_get_device_index(const torch_tensor_t tensor);
129-
130123
/**
131124
* Function to determine the rank of a Torch Tensor
132125
* @param Torch Tensor to determine the rank of
@@ -145,6 +138,27 @@ EXPORT_C const long int *torch_tensor_get_sizes(const torch_tensor_t tensor);
145138
EXPORT_C const long long int *torch_tensor_get_sizes(const torch_tensor_t tensor);
146139
#endif
147140

141+
/**
142+
* Function to determine the data type of a Torch Tensor
143+
* @param Torch Tensor to determine the data type of
144+
* @return data type of the Torch Tensor represented as an enum
145+
*/
146+
EXPORT_C torch_data_t torch_tensor_get_dtype(const torch_tensor_t tensor);
147+
148+
/**
149+
* Function to determine the device type of a Torch Tensor
150+
* @param Torch Tensor to determine the device type of
151+
* @return device type of the Torch Tensor represented as an enum
152+
*/
153+
EXPORT_C torch_device_t torch_tensor_get_device_type(const torch_tensor_t tensor);
154+
155+
/**
156+
* Function to determine the device index of a Torch Tensor
157+
* @param Torch Tensor to determine the device index of
158+
* @return device index of the Torch Tensor
159+
*/
160+
EXPORT_C int torch_tensor_get_device_index(const torch_tensor_t tensor);
161+
148162
// =============================================================================
149163
// --- Functions for deallocating tensors
150164
// =============================================================================

0 commit comments

Comments
 (0)