@@ -28,17 +28,27 @@ namespace tensorflow {
2828
2929typedef Eigen::ThreadPoolDevice CPUDevice;
3030typedef Eigen::GpuDevice GPUDevice;
31+ #ifdef TENSORFLOW_USE_SYCL
32+ typedef Eigen::SyclDevice SYCLDevice;
33+ #endif // TENSORFLOW_USE_SYCL
3134
3235// Partial specialization for a CPUDevice, that uses the Eigen implementation
3336// from SoftmaxEigenImpl.
3437namespace functor {
35- template <typename T>
36- struct SoftmaxFunctor <CPUDevice, T> {
37- void operator ()(const CPUDevice & d, typename TTypes<T>::ConstMatrix logits,
38+ template <typename Device, typename T>
39+ struct SoftmaxFunctorBase {
40+ void operator ()(const Device & d, typename TTypes<T>::ConstMatrix logits,
3841 typename TTypes<T>::Matrix softmax, const bool log) {
39- SoftmaxEigenImpl<CPUDevice , T>::Compute (d, logits, softmax, log);
42+ SoftmaxEigenImpl<Device , T>::Compute (d, logits, softmax, log);
4043 }
4144};
45+ template <typename T>
46+ struct SoftmaxFunctor <CPUDevice, T> : SoftmaxFunctorBase<CPUDevice, T> {};
47+
48+ #ifdef TENSORFLOW_USE_SYCL
49+ template <typename T>
50+ struct SoftmaxFunctor <SYCLDevice, T> : SoftmaxFunctorBase<SYCLDevice, T> {};
51+ #endif // TENSORFLOW_USE_SYCL
4252} // namespace functor
4353
4454#define REGISTER_CPU (T ) \
@@ -76,4 +86,10 @@ REGISTER_KERNEL_BUILDER(
7686 SoftmaxOp<GPUDevice, float>);
7787#endif // GOOGLE_CUDA
7888
89+ #ifdef TENSORFLOW_USE_SYCL
90+ REGISTER_KERNEL_BUILDER (
91+ Name (" Softmax" ).Device(DEVICE_SYCL).TypeConstraint<float>(" T" ),
92+ SoftmaxOp<SYCLDevice, float>);
93+ #endif // TENSORFLOW_USE_SYCL
94+
7995} // namespace tensorflow
0 commit comments