@@ -58,8 +58,8 @@ infiniStatus_t infiniopRoPE(
5858 size_t workspace_size,
5959 void *t,
6060 const void *pos_ids,
61- const void *sin_table,
62- const void *cos_table,
61+ const float *sin_table,
62+ const float *cos_table,
6363 void *stream
6464);
6565```
@@ -114,15 +114,15 @@ infiniStatus_t infiniopCreateRoPEDescriptor(
114114 张量必须为三维:` (seq_len, num_head, head_dim) ` 。最后一维数据必须连续,即步长为1,且长度` (head_dim) ` 为2的倍数;
115115- ` pos_ids ` - { dI | ` (seq_len) ` | (~ ) }:
116116 位置信息张量描述。张量必须为一维连续张量,长度为 ` seq_len ` 。用户需自行保证位置数据中所有数值小于 ` max_seq_len ` ;
117- - ` sin_table ` - { dT | ` (max_seq_len, head_dim/2) ` | (~ ) }:
117+ - ` sin_table ` - { float | ` (max_seq_len, head_dim/2) ` | (~ ) }:
118118 sin 值表的张量描述,二维连续张量,形状为 ` (max_seq_len, head_dim/2) ` ;
119- - ` cos_table ` - { dT | ` (max_seq_len, head_dim/2) ` | (~ ) }:
119+ - ` cos_table ` - { float | ` (max_seq_len, head_dim/2) ` | (~ ) }:
120120 cos 值表的张量描述,要求与 sin 表相同;
121121
122122参数限制:
123123
124- - ` dT ` : (` Float16 ` , ` Float32 ` ) 之一;
125- - ` dI ` : (` Int16 ` , ` Int32 ` , ` Uint16 ` , ` Uint32 ` ) 之一;
124+ - ` dT ` : (` Float16 ` , ` Float32 ` , ` Float64 ` ) 之一;
125+ - ` dI ` : (` Uint8 ` , ` Uint16 ` , ` Uint32 ` , ` Uint64 ` ) 之一;
126126
127127<div style =" background-color : lightblue ; padding : 1px ;" > 返回值:</div >
128128
0 commit comments