@@ -16,13 +16,24 @@ namespace refactor::kernel {
1616#ifndef USE_BANG
1717 return nullptr ;
1818#endif
19- return std::make_unique<K>(decltype (info) {
20- inputs[1 ].get ().dataType ,
21- inputs[0 ].get ().shape ,
22- inputs[1 ].get ().shape ,
23- inputs[2 ].get ().shape ,
24- outputs[0 ].get ().shape ,
25- });
19+ std::vector<int > cDim (inputs[0 ].get ().shape .begin (), inputs[0 ].get ().shape .end ()),
20+ xDim (inputs[1 ].get ().shape .begin (), inputs[1 ].get ().shape .end ()),
21+ yDim (inputs[2 ].get ().shape .begin (), inputs[2 ].get ().shape .end ()),
22+ ansDim (outputs[0 ].get ().shape .begin (), outputs[0 ].get ().shape .end ());
23+ if (ansDim.size () == 0 ) {
24+ ansDim.push_back (1 );
25+ }
26+ if (xDim.size () == 0 ) {
27+ xDim.push_back (1 );
28+ }
29+ if (yDim.size () == 0 ) {
30+ yDim.push_back (1 );
31+ }
32+ if (cDim.size () == 0 ) {
33+ cDim.push_back (1 );
34+ }
35+ return std::make_unique<K>(decltype (info){
36+ inputs[1 ].get ().dataType , cDim, xDim, yDim, ansDim});
2637 }
2738 auto K::typeId () noexcept -> size_t {
2839 static uint8_t ID = 1 ;
@@ -44,11 +55,10 @@ namespace refactor::kernel {
4455
4556 struct Descriptors {
4657 cnnlTensorDescriptor_t cond, x, y, ans;
47- bool f32 ;
4858
49- explicit Descriptors (decltype ( f32 ) f32_ )
59+ explicit Descriptors ()
5060 : cond(nullptr ), x(nullptr ), y(nullptr ),
51- ans(nullptr ), f32(f32_) {
61+ ans(nullptr ) {
5262 CNNL_ASSERT (cnnlCreateTensorDescriptor (&cond));
5363 CNNL_ASSERT (cnnlCreateTensorDescriptor (&x));
5464 CNNL_ASSERT (cnnlCreateTensorDescriptor (&y));
@@ -64,49 +74,35 @@ namespace refactor::kernel {
6474 Descriptors (const Descriptors &) = delete ;
6575 Descriptors (Descriptors &&) = delete ;
6676 };
67- auto d = std::make_shared<Descriptors>(info.dataType != DT::F64);
68-
69- std::vector<int > cDim (info.condDim .begin (), info.condDim .end ()),
70- xDim (info.thenDim .begin (), info.thenDim .end ()),
71- yDim (info.elseDim .begin (), info.elseDim .end ()),
72- ansDim (info.outputDim .begin (), info.outputDim .end ());
73-
74- auto rightAlign = [](std::vector<int > &dim, uint32_t targetLength) {
75- if (dim.size () < targetLength) {
76- dim.insert (dim.begin (), targetLength - dim.size (), 1 );
77- }
78- };
79- if (ansDim.size () == 0 ) {
80- ansDim.push_back (1 );
81- }
82- rightAlign (cDim, ansDim.size ());
83- rightAlign (xDim, ansDim.size ());
84- rightAlign (yDim, ansDim.size ());
85-
86- CNNL_ASSERT (cnnlSetTensorDescriptor (d->cond , CNNL_LAYOUT_NCHW, cnnlDataTypeConvert (DT::Bool), cDim.size (), cDim.data ()));
87- CNNL_ASSERT (cnnlSetTensorDescriptor (d->x , CNNL_LAYOUT_NCHW, cnnlDataTypeConvert (info.dataType ), xDim.size (), xDim.data ()));
88- CNNL_ASSERT (cnnlSetTensorDescriptor (d->y , CNNL_LAYOUT_NCHW, cnnlDataTypeConvert (info.dataType ), yDim.size (), yDim.data ()));
89- CNNL_ASSERT (cnnlSetTensorDescriptor (d->ans , CNNL_LAYOUT_NCHW, cnnlDataTypeConvert (info.dataType ), ansDim.size (), ansDim.data ()));
77+ auto d = std::make_shared<Descriptors>();
78+
79+ CNNL_ASSERT (cnnlSetTensorDescriptor (
80+ d->cond , CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert (DT::Bool),
81+ info.condDim .size (), info.condDim .data ()));
82+ CNNL_ASSERT (cnnlSetTensorDescriptor (
83+ d->x , CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert (info.dataType ),
84+ info.thenDim .size (), info.thenDim .data ()));
85+ CNNL_ASSERT (cnnlSetTensorDescriptor (
86+ d->y , CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert (info.dataType ),
87+ info.elseDim .size (), info.elseDim .data ()));
88+ CNNL_ASSERT (cnnlSetTensorDescriptor (
89+ d->ans , CNNL_LAYOUT_ARRAY, cnnlDataTypeConvert (info.dataType ),
90+ info.outputDim .size (), info.outputDim .data ()));
9091
9192 auto handle = res.fetchOrStore <CnnlContext>()->handle ;
9293 size_t workspaceSize;
9394 CNNL_ASSERT (cnnlGetSelectV2WorkspaceSize (handle, d->cond , d->x , d->y , &workspaceSize));
9495
9596 res.fetchOrStore <CnnlContext>();
9697 auto routine = [d = std::move (d), workspaceSize](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
97- // fetch cnnl handle from resources
98- auto handle = res.fetchOrStore <CnnlContext>()->handle ;
99- auto cond = inputs[0 ],
100- x = inputs[1 ],
101- y = inputs[2 ];
102- auto ans = outputs[0 ];
10398
10499 CNNL_ASSERT (cnnlSelectV2 (
105- handle, d->cond , cond, d->x , x,
106- d->y , y, workspace, workspaceSize,
107- d->ans , ans));
100+ res.fetchOrStore <CnnlContext>()->handle ,
101+ d->cond , inputs[0 ], d->x , inputs[1 ],
102+ d->y , inputs[2 ], workspace, workspaceSize,
103+ d->ans , outputs[0 ]));
108104
109- cnrtQueueSync ( res.fetchOrStore <CnnlContext>()->queue );
105+ res.fetchOrStore <CnnlContext>()->queueSync ( );
110106 };
111107
112108 return {std::move (routine), workspaceSize};
0 commit comments