@@ -47,10 +47,10 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
4747 # input sequence should be "data", "starts", "ends", "axes", "steps"
4848 attr = {}
4949 data = self .convert_to_input (kwargs .pop ("data" ))
50- starts = self .convert_to_input (kwargs .pop ("starts" ))
51- ends = self .convert_to_input (kwargs .pop ("ends" ))
52- axes = self .convert_to_input (kwargs .pop ("axes" , None ), is_optional = True )
53- steps = self .convert_to_input (kwargs .pop ("steps" , None ), is_optional = True )
50+ starts = self .convert_to_input (kwargs .pop ("starts" ), dtype = np . int64 )
51+ ends = self .convert_to_input (kwargs .pop ("ends" ), dtype = np . int64 )
52+ axes = self .convert_to_input (kwargs .pop ("axes" , None ), is_optional = True , dtype = np . int64 )
53+ steps = self .convert_to_input (kwargs .pop ("steps" , None ), is_optional = True , dtype = np . int64 )
5454 inputs = [data , starts , ends , axes , steps ]
5555
5656 # pro-process inputs and attr
@@ -78,7 +78,7 @@ def make_slice(self, kwargs, name=None, shapes=None, dtypes=None):
7878 return self .graph .make_node (op_type = "Slice" , inputs = inputs , attr = attr , name = name ,
7979 outputs = outputs , shapes = shapes , dtypes = dtypes ).output [0 ]
8080
81- def convert_to_input (self , tensor , is_optional = False ):
81+ def convert_to_input (self , tensor , is_optional = False , dtype = None ):
8282 """in ONNX, input shold come from node, so it must be a string"""
8383 if is_optional and tensor is None :
8484 return None
@@ -87,7 +87,7 @@ def convert_to_input(self, tensor, is_optional=False):
8787
8888 res = tensor
8989 if isinstance (tensor , list ):
90- res = self .graph .make_const (utils .make_name ("const_slice" ), np .array (tensor )).output [0 ]
90+ res = self .graph .make_const (utils .make_name ("const_slice" ), np .array (tensor , dtype )).output [0 ]
9191
9292 utils .make_sure (isinstance (res , str ), "input is a dynamic input, so a str is needed" )
9393
0 commit comments