Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RISC-V: Add vpx_comp_avg_pred_rvv #18

Open
wants to merge 1 commit into
base: riscv64_android_optimization
Choose a base branch
from

Conversation

sunmin89
Copy link
Contributor

@sunmin89 sunmin89 commented Sep 4, 2023

测试平台 D1 Nezha

root@TinaLinux:/# uname -a
Linux TinaLinux 5.4.61 #11 PREEMPT Mon Aug 28 11:36:07 UTC 2023 riscv64 GNU/Linux
root@TinaLinux:/# cat /proc/cpuinfo
processor       : 0
hart            : 0
isa             : rv64imafdcvu
mmu             : sv39

查阅c906用户手册
https://occ-oss-prod.oss-cn-hangzhou.aliyuncs.com/resource//1685946574371/%E7%8E%84%E9%93%81C906R2S1%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C%EF%BC%88occ%EF%BC%89.pdf

指令名称 指令描述 执行延时(LMUL=1)
VSADD.VV 矢量整型加法有符号取饱和指令 3
VAADD.VV 矢量整型加法取平均数指令 3

推测vsaddu和vaaddu两个函数耗时情况接近.

为了比对优化前后的性能,我们把 rvv1.0 的 intrinsic 代码改写成 rvv 0.7.1, 然后在现有的硬件平台(D1 Nezha)测试,
由于 D1 不支持vaaddu(Vector Single-Width Averaging Add)我们把它替换成 vsaddu(Vector Single-Width Saturating Add)

工具链及编译参数

~/bins/Xuantie-900-gcc-elf-newlib-x86_64-V2.2.4/bin/riscv64-unknown-elf-gcc -march=rv64gcv0p7_zfh_xtheadc -O3 test.c -o test

参考
https://github.com/sipeed/TinyMaix/blob/9487854be2ce329b89427d4a5b14ce6136cb15cf/src/arch_rv64v.h#L24

把测试程序推送到开发板,然后记录 10000000 次循环耗时

D:\bins\> adb devices
List of devices attached
20080411        device
adb push test /tmp; adb shell chmod +x /tmp/test ; adb shell time /tmp/test

width > 8
// vpx_comp_avg_pred_rvv(dst,src,20,40,ref,8);// 20.12s
// vpx_comp_avg_pred_c(dst,src,20,40,ref,8);//1m 16.88s

width = 8
// vpx_comp_avg_pred_rvv(dst,src,8,40,ref,8);//14.49s
// vpx_comp_avg_pred_c(dst,src,8,40,ref,8);// 21.37s

width = 4
// vpx_comp_avg_pred_rvv(dst,src,4,40,ref,8);// 3.38s
// vpx_comp_avg_pred_c(dst,src,4,40,ref,8);// 2.45s 

test.c 如下

//#include <assert.h>
//#include "./vpx_dsp_rtcd.h"
//#include <vpx_ports/riscv.h>
#include <time.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <riscv_vector.h>
void vpx_comp_avg_pred_rvv(uint8_t *comp_pred, const uint8_t *pred, int width,
                         int height, const uint8_t *ref, int ref_stride) {
  const size_t avl = 16;
  size_t vl;
  if (width > 8){
    int x, y = height;
    vuint8m8_t vr;
    do {
      for (x = 0; x < width; x += vl){
        vl =   vsetvl_e8m8(width - x);
        const vuint8m8_t vp =  vle8_v_u8m8(pred + x, vl);
        vr =  vle8_v_u8m8(ref + x, vl);
        //D1 not support vaaddu
        // vr =  vaaddu_vv_u8m1(vp, vr, vl);  
        vr =  vsaddu_vv_u8m8(vp, vr,vl);
        vse8_v_u8m8(comp_pred + x, vr, vl);
      }
      comp_pred += width;
      pred += width;
      ref += ref_stride;
    } while (--y);
  } else if (width == 8) {
    int i = width * height;
    size_t k;
    uint8_t *index;
    vuint8m1_t vr;
    do {
      vl =  vsetvl_e8m1(avl);
      const vuint8m1_t vp =  vle8_v_u8m1(pred, vl);
      index = (uint8_t *)malloc(vl * sizeof(uint8_t));
      memset(index, 0, vl);
      for(k = 0; k < vl; k++){
        if(k < vl / 2){
            index[k] = (uint8_t)k;
          }else{
            index[k] = (uint8_t)(ref_stride + k - 8);
          }
      }
      vuint8m1_t vindex =  vle8_v_u8m1(index, vl);
      vr =   vloxei8_v_u8m1(ref, vindex, vl);
      //D1 not support vaaddu
      // vr =  vaaddu_vv_u8m1(vp, vr, vl);  
      vr =  vsaddu_vv_u8m1(vp, vr, vl);
      vse8_v_u8m1(comp_pred, vr, vl);
      ref += 2 * ref_stride;
      pred += vl;
      comp_pred += vl;
      i -= vl;
      } while (i);
  } else {
    int i = width * height;
    //assert(width == 4);
    vuint8m1_t vr;
    vl =  vsetvl_e8m1(avl);
    do {
      vuint32m1_t a_u32;
      const vuint8m1_t vp  =  vle8_v_u8m1(pred, vl);
      a_u32 =  vlse32_v_u32m1((const uint32_t*)ref, ref_stride, vl / 4);
      vr =  vreinterpret_v_u32m1_u8m1(a_u32);
      ref += 4 * ref_stride;

      //D1 not support vaaddu
      // vr =  vaaddu_vv_u8m1(vp, vr, vl);  
      vr =  vsaddu_vv_u8m1(vp, vr, vl);  
      vse8_v_u8m1(comp_pred, vr, vl);

      pred += vl;
      comp_pred += vl;
      i -= vl;  
    } while (i);
    }
}

/* Shift down with rounding */
#define ROUND_POWER_OF_TWO(value, n) (((value) + (1 << ((n)-1))) >> (n))

void vpx_comp_avg_pred_c(uint8_t *comp_pred, const uint8_t *pred, int width,
                         int height, const uint8_t *ref, int ref_stride) {
  int i, j;
  for (i = 0; i < height; ++i) {
    for (j = 0; j < width; ++j) {
      const int tmp = pred[j] + ref[j];
      comp_pred[j] = ROUND_POWER_OF_TWO(tmp, 1);
    }
    comp_pred += width;
    pred += width;
    ref += ref_stride;
  }
}
uint8_t* getDataUint8(const int len){
    uint8_t* data = (uint8_t *)malloc(len * sizeof(uint8_t));
    for(int i = 0 ; i < len; i++){
    	data[i] = (uint8_t) (i + 1);
    }
    return data;
}

int main() {

  int w=200,h=200;
    uint8_t *src = getDataUint8(w*h),*dst = getDataUint8(w*h);
    uint8_t *ref = getDataUint8(w*h);
    int stride = 7;
    int loops = 10000000;

    time_t start,end;
    time (&start);

    for(int i = 0;i< loops;i++){
        // vpx_comp_avg_pred_rvv(dst,src,20,40,ref,8);// 20.12s
        // vpx_comp_avg_pred_c(dst,src,20,40,ref,8);//1m 16.88s

        // vpx_comp_avg_pred_rvv(dst,src,8,40,ref,8);//14.49s
        // vpx_comp_avg_pred_c(dst,src,8,40,ref,8);// 21.37s

        vpx_comp_avg_pred_rvv(dst,src,4,40,ref,8);// 3.38s
        // vpx_comp_avg_pred_c(dst,src,4,40,ref,8);// 2.45s 

        // vpx_comp_avg_pred_rvv(dst,src,8,40,ref,24);//19.09s

          // vpx_comp_avg_pred_rvv(dst,src,20,40,ref,24);// 20.13s
        // vpx_comp_avg_pred_c(dst,src,20,40,ref,24);// 20.13s

        // vpx_comp_avg_pred_rvv(dst,src,4,40,ref,24);// 3.39s

    }

  time (&end);
  double dif = difftime (end,start);
  printf ("Elasped time is %.2lf seconds.\n", dif);
  return 0;
}

Test environment: D1 Nezha with rvv-0.7.1 enabled.

For comparison, we define the test data, and execute 10000000 loops each for
vpx_comp_avg_pred_rvv vpx_comp_avg_pred_c and record the time elapsed.

Width           RVV Impl  C Impl
if width > 8    20.12s    1m 16.88s
if width == 8   14.49s    21.37s
if width == 4   3.38s     2.45s
@sunmin89 sunmin89 assigned unicornx and Trumeet and unassigned unicornx and Trumeet Sep 5, 2023
@sunmin89 sunmin89 requested review from unicornx and Trumeet September 5, 2023 01:25
@Trumeet
Copy link
Contributor

Trumeet commented Sep 12, 2023

好像没看懂这个实现唉 ... 能不能稍微解释一下?(?

vpx_comp_avg_pred_c 我看了一下,粗略理解是,本质是三个 uint8_t 二维数组 comp_pred pred refpred 进行 unit strided load 读取 width * height 个 elements, ref 进行 strided load,然后每个 element 求平均再 round,最后 unit strided store 进 comp_pred 。不知道能不能通过 vle vlse vaaddu.vv vse 简单实现?

@sunmin89
Copy link
Contributor Author

好像没看懂这个实现唉 ... 能不能稍微解释一下?(?

vpx_comp_avg_pred_c 我看了一下,粗略理解是,本质是三个 uint8_t 二维数组 comp_pred pred refpred 进行 unit strided load 读取 width * height 个 elements, ref 进行 strided load,然后每个 element 求平均再 round,最后 unit strided store 进 comp_pred

我是按照 neon 的逻辑做的,没有仔细看 c 的逻辑

不知道能不能通过 vle vlse vaaddu.vv vse 简单实现?

我确实用了 vle vlse vaaddu.vv vse,但是有三个 if,对应 neon 的优化逻辑, 所以看起来复杂了,目前我没有想到一个办法,把三个 if 合并起来

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants