Skip to content

[Suggestion] Explicit sync.set & sync.wait are lowered to intrinsic, but TSYNC is a better target? #647

@learning-chip

Description

@learning-chip

Summary

pto.sync.set & pto.sync.wait are lowered directly to intrinsic ffts_cross_core_sync & wait_flag_dev. But the TSync also supports cross-core sync via the IsCrossCore branch:

        if constexpr (IsCrossCore) {
            PTO_STATIC_ASSERT(CrossCoreId != 0xff,
                              "Fix: The cross-core id must be assigned by user when the event is a cross-core event.");
            wait_flag_dev(CrossCoreId);
        } else {
        if constexpr (IsCrossCore) {
            PTO_STATIC_ASSERT(CrossCoreId != 0xff,
                              "Fix: The cross-core id must be assigned by user when the event is a cross-core event.");
            ffts_cross_core_sync(srcPipe, getFFTSMsg(FFTS_MODE_VAL, CrossCoreId));
        } else {

pto.sync.set & pto.sync.wait are better lowered to TSYNC to stay on the PTO abstraction level?

Caveat: the cross-core TSync is not tested in pto-isa repo now.

Full example

Details
module {
  func.func @run_scan_kernel(%arg0: !pto.ptr<f32>, %arg1: !pto.ptr<f32>, %arg2: !pto.ptr<f32>, %arg3: i32, %arg4: memref<256xui64>) {
    pto.set_ffts %arg4 : memref<256xui64>
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c63 = arith.constant 63 : index
    %c64 = arith.constant 64 : index
    %c4096 = arith.constant 4096 : index
    %cst = arith.constant 0.000000e+00 : f32
    %c0_i64 = arith.constant 0 : i64
    %0 = arith.index_cast %arg3 : i32 to index
    %1 = arith.divsi %0, %c4096 : index
    pto.section.cube {
      %2 = pto.make_tensor_view %arg0, shape = [%c64, %c64], strides = [%c64, %c64] {layout = #pto.layout<nd>} : !pto.tensor_view<?x?xf32>
      %3 = pto.make_tensor_view %arg2, shape = [%c64, %c64], strides = [%c64, %c64] {layout = #pto.layout<nd>} : !pto.tensor_view<?x?xf32>
      %4 = pto.make_tensor_view %arg1, shape = [%c64, %c64], strides = [%c64, %c64] {layout = #pto.layout<nd>} : !pto.tensor_view<?x?xf32>
      %5 = pto.alloc_tile : !pto.tile_buf<mat, 64x64xf32, blayout=col_major, slayout=row_major>
      %6 = pto.alloc_tile : !pto.tile_buf<right, 64x64xf32, slayout=col_major>
      %7 = pto.alloc_tile : !pto.tile_buf<mat, 64x64xf32, blayout=col_major, slayout=row_major>
      %8 = pto.alloc_tile : !pto.tile_buf<left, 64x64xf32, slayout=row_major>
      %9 = pto.alloc_tile : !pto.tile_buf<acc, 64x64xf32, blayout=col_major, slayout=row_major>
      %10 = pto.partition_view %3, offsets = [%c0, %c0], sizes = [%c64, %c64] : !pto.tensor_view<?x?xf32>
      pto.tload ins(%10 : !pto.partition_tensor_view<64x64xf32>) outs(%5 : !pto.tile_buf<mat, 64x64xf32, blayout=col_major, slayout=row_major>)
      pto.record_event[#pto.sync_op_type<TLOAD>, #pto.sync_op_type<TMOV_M2L>, <EVENT_ID0>]
      pto.wait_event[#pto.sync_op_type<TLOAD>, #pto.sync_op_type<TMOV_M2L>, <EVENT_ID0>]
      pto.tmov ins(%5 : !pto.tile_buf<mat, 64x64xf32, blayout=col_major, slayout=row_major>) outs(%6 : !pto.tile_buf<right, 64x64xf32, slayout=col_major>)
      pto.record_event[#pto.sync_op_type<TMOV_M2L>, #pto.sync_op_type<TMATMUL>, <EVENT_ID0>]
      pto.wait_event[#pto.sync_op_type<TMOV_M2L>, #pto.sync_op_type<TMATMUL>, <EVENT_ID0>]
      scf.for %arg5 = %c0 to %1 step %c1 {
        %11 = arith.muli %arg5, %c64 : index
        %12 = pto.partition_view %2, offsets = [%11, %c0], sizes = [%c64, %c64] : !pto.tensor_view<?x?xf32>
        %13 = pto.partition_view %4, offsets = [%11, %c0], sizes = [%c64, %c64] : !pto.tensor_view<?x?xf32>
        pto.tload ins(%12 : !pto.partition_tensor_view<64x64xf32>) outs(%7 : !pto.tile_buf<mat, 64x64xf32, blayout=col_major, slayout=row_major>)
        pto.record_event[#pto.sync_op_type<TLOAD>, #pto.sync_op_type<TMOV_M2L>, <EVENT_ID1>]
        pto.wait_event[#pto.sync_op_type<TLOAD>, #pto.sync_op_type<TMOV_M2L>, <EVENT_ID1>]
        pto.tmov ins(%7 : !pto.tile_buf<mat, 64x64xf32, blayout=col_major, slayout=row_major>) outs(%8 : !pto.tile_buf<left, 64x64xf32, slayout=row_major>)
        pto.record_event[#pto.sync_op_type<TMOV_M2L>, #pto.sync_op_type<TMATMUL>, <EVENT_ID1>]
        pto.wait_event[#pto.sync_op_type<TMOV_M2L>, #pto.sync_op_type<TMATMUL>, <EVENT_ID1>]
        pto.tmatmul ins(%8, %6 : !pto.tile_buf<left, 64x64xf32, slayout=row_major>, !pto.tile_buf<right, 64x64xf32, slayout=col_major>) outs(%9 : !pto.tile_buf<acc, 64x64xf32, blayout=col_major, slayout=row_major>)
        pto.record_event[#pto.sync_op_type<TMATMUL>, #pto.sync_op_type<TSTORE_ACC>, <EVENT_ID1>]
        pto.wait_event[#pto.sync_op_type<TMATMUL>, #pto.sync_op_type<TSTORE_ACC>, <EVENT_ID1>]
        pto.tstore ins(%9 : !pto.tile_buf<acc, 64x64xf32, blayout=col_major, slayout=row_major>) outs(%13 : !pto.partition_tensor_view<64x64xf32>)
        pto.record_event[#pto.sync_op_type<TSTORE_ACC>, #pto.sync_op_type<TLOAD>, <EVENT_ID2>]
        pto.wait_event[#pto.sync_op_type<TSTORE_ACC>, #pto.sync_op_type<TLOAD>, <EVENT_ID2>]
        pto.sync.set <PIPE_FIX>, 0
        pto.sync.wait <PIPE_MTE3>, 1
      }
    }
    pto.section.vector {
      %2 = arith.divsi %0, %c64 : index
      %3 = pto.make_tensor_view %arg1, shape = [%2, %c64], strides = [%c64, %c1] : !pto.tensor_view<?x?xf32>
      %4 = pto.alloc_tile : !pto.tile_buf<vec, 1x64xf32>
      %5 = pto.alloc_tile : !pto.tile_buf<vec, 1x8xf32>
      pto.tsetval ins(%c0, %cst : index, f32) outs(%5 : !pto.tile_buf<vec, 1x8xf32>)
      scf.for %arg5 = %c0 to %1 step %c1 {
        pto.sync.wait <PIPE_FIX>, 0
        %6 = pto.get_subblock_idx
        %7 = arith.cmpi eq, %6, %c0_i64 : i64
        scf.if %7 {
          %8 = arith.muli %arg5, %c64 : index
          scf.for %arg6 = %c0 to %c64 step %c1 {
            %9 = arith.addi %8, %arg6 : index
            %10 = pto.partition_view %3, offsets = [%9, %c0], sizes = [%c1, %c64] : !pto.tensor_view<?x?xf32>
            pto.tload ins(%10 : !pto.partition_tensor_view<1x64xf32>) outs(%4 : !pto.tile_buf<vec, 1x64xf32>)
            pto.record_event[#pto.sync_op_type<TLOAD>, #pto.sync_op_type<TVEC>, <EVENT_ID2>]
            pto.wait_event[#pto.sync_op_type<TLOAD>, #pto.sync_op_type<TVEC>, <EVENT_ID2>]
            %11 = pto.tgetval ins(%5, %c0 : !pto.tile_buf<vec, 1x8xf32>, index) outs : f32
            pto.tadds ins(%4, %11 : !pto.tile_buf<vec, 1x64xf32>, f32) outs(%4 : !pto.tile_buf<vec, 1x64xf32>)
            pto.barrier <PIPE_ALL>
            %12 = pto.tgetval ins(%4, %c63 : !pto.tile_buf<vec, 1x64xf32>, index) outs : f32
            pto.tsetval ins(%c0, %12 : index, f32) outs(%5 : !pto.tile_buf<vec, 1x8xf32>)
            pto.record_event[#pto.sync_op_type<TVEC>, #pto.sync_op_type<TSTORE_VEC>, <EVENT_ID2>]
            pto.wait_event[#pto.sync_op_type<TVEC>, #pto.sync_op_type<TSTORE_VEC>, <EVENT_ID2>]
            pto.tstore ins(%4 : !pto.tile_buf<vec, 1x64xf32>) outs(%10 : !pto.partition_tensor_view<1x64xf32>)
            pto.record_event[#pto.sync_op_type<TSTORE_VEC>, #pto.sync_op_type<TLOAD>, <EVENT_ID3>]
            pto.wait_event[#pto.sync_op_type<TSTORE_VEC>, #pto.sync_op_type<TLOAD>, <EVENT_ID3>]
          }
          pto.record_event[#pto.sync_op_type<TLOAD>, #pto.sync_op_type<TVEC>, <EVENT_ID3>]
          pto.wait_event[#pto.sync_op_type<TLOAD>, #pto.sync_op_type<TVEC>, <EVENT_ID3>]
        }
        pto.sync.set <PIPE_MTE3>, 1
      }
    }
    return
  }
}

which generates:

#include "pto/pto-inst.hpp"
using namespace pto;

enum class PTOAutoSyncTailMode : int {
  kBarrierAll = 0,
  kSetWaitMte3ToSEvent0 = 1,
};

static AICORE inline void ptoas_auto_sync_tail(
    PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) {
  switch (mode) {
  case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0:
    set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0);
    wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0);
    break;
  case PTOAutoSyncTailMode::kBarrierAll:
  default:
    pipe_barrier(PIPE_ALL);
    break;
  }
}

__global__ AICORE void run_scan_kernel(__gm__ float* v1, __gm__ float* v2, __gm__ float* v3, int32_t v4, __gm__ uint64_t* v5) {
  unsigned v6 = 0;
  const int32_t v7 = 0;
  const int32_t v8 = 1;
  const int32_t v9 = 63;
  const int32_t v10 = 64;
  const int32_t v11 = 4096;
  const float v12 = 0.0f;
  const int64_t v13 = 0;
  const int64_t v14 = 16384;
  const int64_t v15 = 32;
  using T = float;
  size_t v16 = (size_t) v8;
  size_t v17 = (size_t) v7;
  uint64_t v18 = (uint64_t) v5;
  set_ffts_base_addr(v18);
  size_t v19 = (size_t) (v4 / v11);

  #if defined(__DAV_CUBE__)
  Tile<TileType::Mat, float, 64, 64, BLayout::ColMajor, 64, 64, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v20;
  TASSIGN(v20, v13);
  Tile<TileType::Right, float, 64, 64, BLayout::RowMajor, 64, 64, SLayout::ColMajor, 512, PadValue::Null, CompactMode::Null> v21;
  TASSIGN(v21, v13);
  Tile<TileType::Mat, float, 64, 64, BLayout::ColMajor, 64, 64, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v22;
  TASSIGN(v22, v14);
  Tile<TileType::Left, float, 64, 64, BLayout::RowMajor, 64, 64, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v23;
  TASSIGN(v23, v13);
  Tile<TileType::Acc, float, 64, 64, BLayout::ColMajor, 64, 64, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> v24;
  TASSIGN(v24, v13);
  pto::Shape<1, 1, 1, 64, 64> v25 = pto::Shape<1, 1, 1, 64, 64>();
  pto::Stride<4096, 4096, 4096, 64, 64> v26 = pto::Stride<4096, 4096, 4096, 64, 64>();
  GlobalTensor<float, pto::Shape<1, 1, 1, 64, 64>, pto::Stride<4096, 4096, 4096, 64, 64>, pto::Layout::ND> v27 = GlobalTensor<float, pto::Shape<1, 1, 1, 64, 64>, pto::Stride<4096, 4096, 4096, 64, 64>, pto::Layout::ND>(v3 + (v6 + v6 * (unsigned) v10 + v6 * (unsigned) v10), v25, v26);
  TLOAD(v20, v27);
  set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
  wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0);
  TMOV(v21, v20);
  set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
  wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0);
  for (size_t v28 = v17; v28 < v19; v28 += v16) {
    int32_t v29 = (int32_t) ((uint32_t) ((int32_t) v28) * (uint32_t) v10);
    pto::Shape<1, 1, 1, 64, 64> v30 = pto::Shape<1, 1, 1, 64, 64>();
    pto::Stride<4096, 4096, 4096, 64, 64> v31 = pto::Stride<4096, 4096, 4096, 64, 64>();
    GlobalTensor<float, pto::Shape<1, 1, 1, 64, 64>, pto::Stride<4096, 4096, 4096, 64, 64>, pto::Layout::ND> v32 = GlobalTensor<float, pto::Shape<1, 1, 1, 64, 64>, pto::Stride<4096, 4096, 4096, 64, 64>, pto::Layout::ND>(v1 + (v6 + (unsigned) v29 * (unsigned) v10 + v6 * (unsigned) v10), v30, v31);
    pto::Shape<1, 1, 1, 64, 64> v33 = pto::Shape<1, 1, 1, 64, 64>();
    pto::Stride<4096, 4096, 4096, 64, 64> v34 = pto::Stride<4096, 4096, 4096, 64, 64>();
    GlobalTensor<float, pto::Shape<1, 1, 1, 64, 64>, pto::Stride<4096, 4096, 4096, 64, 64>, pto::Layout::ND> v35 = GlobalTensor<float, pto::Shape<1, 1, 1, 64, 64>, pto::Stride<4096, 4096, 4096, 64, 64>, pto::Layout::ND>(v2 + (v6 + (unsigned) v29 * (unsigned) v10 + v6 * (unsigned) v10), v33, v34);
    TLOAD(v22, v32);
    set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
    wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1);
    TMOV(v23, v22);
    set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
    wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1);
    TMATMUL(v24, v23, v21);
    set_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
    wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1);
    TSTORE(v35, v24);
    set_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID2);
    wait_flag(PIPE_FIX, PIPE_MTE2, EVENT_ID2);
    uint16_t v36 = getFFTSMsg(FFTS_MODE_VAL, v7);
    ffts_cross_core_sync(PIPE_FIX, v36);
    wait_flag_dev(1);
  }
  #endif // __DAV_CUBE__


  #if defined(__DAV_VEC__)
  set_mask_norm();
  set_vector_mask(-1, -1);
  int32_t v37 = v4 / v10;
  Tile<TileType::Vec, float, 1, 64, BLayout::RowMajor, 1, 64, SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null> v38;
  TASSIGN(v38, v15);
  Tile<TileType::Vec, float, 1, 8, BLayout::RowMajor, 1, 8, SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null> v39;
  TASSIGN(v39, v13);
  v39.SetValue(v7, v12);
  for (size_t v40 = v17; v40 < v19; v40 += v16) {
    wait_flag_dev(0);
    int64_t v41 = get_subblockid();
    if ((int64_t) v41 == v13) {
      for (size_t v42 = v17; v42 < ((size_t) v10); v42 += v16) {
        pto::Shape<1, 1, 1, 1, 64> v43 = pto::Shape<1, 1, 1, 1, 64>();
        pto::Stride<64, 64, 64, 64, 1> v44 = pto::Stride<64, 64, 64, 64, 1>();
        GlobalTensor<float, pto::Shape<1, 1, 1, 1, 64>, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND> v45 = GlobalTensor<float, pto::Shape<1, 1, 1, 1, 64>, pto::Stride<64, 64, 64, 64, 1>, pto::Layout::ND>(v2 + (v6 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) ((int32_t) v40) * (uint32_t) v10) + (uint32_t) ((int32_t) v42)) * (unsigned) v10 + v6 * (unsigned) v8), v43, v44);
        TLOAD(v38, v45);
        set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2);
        wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2);
        float v46 = v39.GetValue(v7);
        TADDS(v38, v38, v46);
        pipe_barrier(PIPE_ALL);
        float v47 = v38.GetValue(v9);
        v39.SetValue(v7, v47);
        set_flag(PIPE_V, PIPE_MTE3, EVENT_ID2);
        wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID2);
        TSTORE(v45, v38);
        set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID3);
        wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID3);
      };
      set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3);
      wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3);
    };
    uint16_t v48 = getFFTSMsg(FFTS_MODE_VAL, v8);
    ffts_cross_core_sync(PIPE_MTE3, v48);
  }
  #endif // __DAV_VEC__

  return;
}

Generated by huawei-csl/pto-dsl#113

Motivation / use case

For explicit manual sync mode

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

Status

In Progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions