@@ -1340,7 +1340,49 @@ def sort(x, axis=-1):
1340
1340
1341
1341
1342
1342
def split (x , indices_or_sections , axis = 0 ):
1343
- raise NotImplementedError ("`split` is not supported with openvino backend" )
1343
+ x = get_ov_output (x )
1344
+ axis_tensor = ov_opset .constant (axis , dtype = Type .i32 ).output (0 )
1345
+
1346
+ shape_tensor = ov_opset .shape_of (x )
1347
+ axis_i32 = ov_opset .constant ([axis ], dtype = Type .i32 )
1348
+ dim_at_axis_tensor = ov_opset .gather (
1349
+ shape_tensor , axis_i32 , ov_opset .constant (0 , dtype = Type .i32 )
1350
+ )
1351
+
1352
+ if isinstance (indices_or_sections , int ):
1353
+ num_splits = indices_or_sections
1354
+ splits = ov_opset .split (x , axis_tensor , num_splits = num_splits )
1355
+ result = []
1356
+ for i in range (num_splits ):
1357
+ result .append (OpenVINOKerasTensor (splits .output (i )))
1358
+ return result
1359
+
1360
+ if isinstance (indices_or_sections , (list , tuple , np .ndarray )):
1361
+ indices = list (indices_or_sections )
1362
+ split_lengths = []
1363
+ split_lengths .append (indices [0 ])
1364
+ for i in range (1 , len (indices )):
1365
+ split_lengths .append (indices [i ] - indices [i - 1 ])
1366
+
1367
+ last_index_tensor = ov_opset .constant (indices [- 1 ], dtype = Type .i64 )
1368
+ remaining_length_tensor = ov_opset .subtract (
1369
+ dim_at_axis_tensor , last_index_tensor
1370
+ )
1371
+
1372
+ length_parts = []
1373
+ length_parts .append (ov_opset .constant (split_lengths , dtype = Type .i64 ))
1374
+ length_parts .append (remaining_length_tensor )
1375
+ length_tensor = ov_opset .concat (length_parts , axis = 0 )
1376
+
1377
+ splits = ov_opset .variadic_split (x , axis_tensor , length_tensor )
1378
+ result = []
1379
+ for i in range (len (split_lengths ) + 1 ):
1380
+ result .append (OpenVINOKerasTensor (splits .output (i )))
1381
+ return result
1382
+
1383
+ raise TypeError (
1384
+ f"unsupported type of indices_or_sections: { type (indices_or_sections )} "
1385
+ )
1344
1386
1345
1387
1346
1388
def stack (x , axis = 0 ):
0 commit comments