@@ -26,9 +26,10 @@ namespace refactor::kernel {
2626 // !a.dataType.isFloat() ||
2727 !ARTHIMETIC.contains (op) ||
2828 // At least one of a,b should have the same shape as c
29- (a.shape != c.shape && b.shape != c.shape ) ||
29+ (a.shape != c.shape && b.shape != c.shape )
3030 // Sub only supports brocasting b
31- (a.shape != c.shape && op == Op::Sub)) {
31+ // (a.shape != c.shape && op == Op::Sub)
32+ ) {
3233 return nullptr ;
3334 }
3435
@@ -122,18 +123,13 @@ namespace refactor::kernel {
122123
123124 auto handle = res.fetchOrStore <CnnlContext>()->handle ;
124125 size_t workspaceSize;
125- if (aDims != cDims) {
126- CNNL_ASSERT (cnnlGetBinaryWorkspaceSize (handle, d->bDesc ,
127- d->aDesc , d->cDesc ,
128- &workspaceSize));
129- } else {
130- CNNL_ASSERT (cnnlGetBinaryWorkspaceSize (handle, d->aDesc ,
126+ CNNL_ASSERT (cnnlGetBinaryWorkspaceSize (handle, d->aDesc ,
131127 d->bDesc , d->cDesc ,
132128 &workspaceSize));
133- }
129+
134130
135131 res.fetchOrStore <CnnlContext>();
136- auto routine = [swap = aDims != cDims, d ,
132+ auto routine = [d = std::move (d) ,
137133 workspaceSize, cnnlLogicOP,
138134 op = this ->opType ](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
139135 auto handle = res.fetchOrStore <CnnlContext>()->handle ;
@@ -151,20 +147,11 @@ namespace refactor::kernel {
151147 beta = d->f32
152148 ? factor<fp32_t >(0 )
153149 : factor<fp64_t >(0 );
154-
155- if (swap) {
156- CNNL_ASSERT (cnnlOpTensor (handle, d->opDesc ,
157- &alphaB, d->bDesc , b,
158- &alphaA, d->aDesc , a,
159- workspace, workspaceSize,
160- &beta, d->cDesc , c));
161- } else {
162150 CNNL_ASSERT (cnnlOpTensor (handle, d->opDesc ,
163151 &alphaA, d->aDesc , a,
164152 &alphaB, d->bDesc , b,
165153 workspace, workspaceSize,
166154 &beta, d->cDesc , c));
167- }
168155 } else if (op == SimpleBinaryType::Div) {
169156 CNNL_ASSERT (cnnlDiv_v2 (handle,
170157 CNNL_COMPUTATION_HIGH_PRECISION,
0 commit comments