13. 限制

13.1. grid 维度限制

GridDim

GCU300

X

[0, 0xffff]

Y

[0, 0xff]

Z

[0, 0xff]

若遇到RuntimeError: Kurama Error [TOps]: topsErrorInvalidConfiguration runtime 报错,在完整log里查找下是否有如下样式的信息:

EE: [tid:1034133] efdrv/src/platform/api_execution_control.cpp:187:itopsLaunchKernel_validate Grid Dims params : (0x1(max:0xffff), 0x10(max:0xff), 0x400(max:0xff))

上述信息表明grid_dim超出了限制,下面给出一个简化版的参考解法,多个grid_dim 和一个grid_dim解法一样:修改kernel 在grid超出的维度套循环,把对应grid_dim变成可控的循环步长

def add_kernel(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
               # NOTE: `constexpr` so it can be used as a shape value.
               ):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.

    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)


def add_kernel_fix_out_of_grid(x_ptr,  # *Pointer* to first input vector.
               y_ptr,  # *Pointer* to second input vector.
               output_ptr,  # *Pointer* to output vector.
               n_elements,  # Size of the vector.
               BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
               MAX_GRID_DIM: tl.constexpr,
               # NOTE: `constexpr` so it can be used as a shape value.
               ):
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    num_tile = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE
    for tile_id in tl.range(pid, num_tile, MAX_GRID_DIM, num_stages=3):
        block_start = tile_id * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(x_ptr + offsets, mask=mask)
        y = tl.load(y_ptr + offsets, mask=mask)
        output = x + y
        tl.store(output_ptr + offsets, output, mask=mask)

fix版代码:在原本kernel主体上,增加了for tile_id in tl.range(pid, num_tile, MAX_GRID_DIM, num_stages=3),并把原来用pid的地方用tile_id替换,实现了一个pid做grid_dim_origin/MAX_GRID_DIM 次任务。

triton kernel launch示例如下:

def add(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output.
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    block_size = 16384
    grid = lambda meta: ((triton.cdiv(n_elements, meta['BLOCK_SIZE'])), )
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=block_size, num_warps=1)
    return output

MAX_GRID_DIM = 12

def add_fix_out_of_grid(x: torch.Tensor, y: torch.Tensor):
    # We need to preallocate the output.
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    block_size = 16384
    grid = lambda meta: (min(MAX_GRID_DIM, (triton.cdiv(n_elements, meta['BLOCK_SIZE']))), )
    add_kernel_1D_loop[grid](x, y, output, n_elements, BLOCK_SIZE=block_size, MAX_GRID_DIM=MAX_GRID_DIM, num_warps=2)
    return output

如上,launch 的grid_dim 最大变成MAX_GRID_DIM,其中MAX_GRID_DIM根据SIP数量和对应维度的上线值酌情选择,建议值是SIP个数。

13.2. BLOCK_SIZE 限制

若见到编译报“!!!out of memory:required is xx ,hareware limis:xx ”错, 表明block_size过大,每个block申请的memory已经超过硬件上限,请调小block_size.

13.3. num_stages 限制

当前软件栈num_stages最大是3.

13.4. 无法使用原生 llvm 的优化场景

vector memory layout 问题

gcu上vector memory layout的排布在不同数据类型上是不同的, 这导致原生llvm在做truncate操作时,加载数据错误的情况。 比如使用tcle接口。由于compiler没有对llvm标准op做gcu后端的适配。 使得 燧原mlir生态栈几乎无法复用开源的llvm的优化代码。

13.5. load/store 限制

隐式broadcast限制(stride=0)

如下例: kernel_not_support_stride_0不支持stride_am传入实参0; kernel_support_stride_0支持stride_am传入实参0;

def kernel_not_support_stride_0(a_ptr, b_ptr, M, stride_am, block_shape_m,
                                stride_bm, BLOCK_M: tl.constexpr
):
    a_block_ptr = tl.make_block_ptr(
        base=a_ptr, shape=(M,), strides=(stride_am,),
        offsets=(1,), block_shape=(BLOCK_M,),
        order=(0,))
    b_block_ptr = tl.make_block_ptr(
        base=b_ptr, shape=(block_shape_m,),
        strides=(stride_bm,),
        offsets=(0,), block_shape=(BLOCK_M,), order=(0,))

    a = tl.load(a_block_ptr, boundary_check=(0,), padding_option="zero")
    tl.store(b_block_ptr, a, boundary_check=(0,))
def kernel_support_stride_0(a_ptr, b_ptr, M, stride_am: tl.constexpr,
                            block_shape_m, stride_bm: tl.constexpr,
                            BLOCK_M: tl.constexpr
):
    a_block_ptr = tl.make_block_ptr(
        base=a_ptr, shape=(M,), strides=(stride_am,),
        offsets=(1,), block_shape=(BLOCK_M,),
        order=(0,))
    b_block_ptr = tl.make_block_ptr(
        base=b_ptr, shape=(block_shape_m,),
        strides=(stride_bm,),
        offsets=(0,), block_shape=(BLOCK_M,), order=(0,))

    a = tl.load(a_block_ptr, boundary_check=(0,), padding_option="zero")
    tl.store(b_block_ptr, a, boundary_check=(0,))

同一内存多次读取限制(stride从大到小不是整数倍)

如下例: 该kernel支持传参(stride_m=8, stride_n=4, stride_k=1; 但是不支持传参(stride_m=6, stride_n=4, stride_k=1), 因为stride_m%stride_n != 0,会造成单次load时的内存数据被重复读写, 和stride=0进行broadcast类似。

  def kernel_load_support(
        a_ptr, b_ptr, M, N, K, b_M, b_N, b_K,
        stride_am, stride_an, stride_ak,
        stride_bm, stride_bn, stride_bk,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
        BLOCK_P: tl.constexpr
):
    a_block_ptr = tl.make_block_ptr(
        base=a_ptr, shape=(M, N, K), strides=(stride_am, stride_an, stride_ak),
        offsets=(0, 0, 0), block_shape=(BLOCK_M, BLOCK_N, BLOCK_K),
        order=(2, 1, 0))
    a = tl.load(a_block_ptr, boundary_check=(0, 1, 2), padding_option="zero")
    b_block_ptr = tl.make_block_ptr(
        base=b_ptr, shape=(b_M, b_N, b_K),
        strides=(stride_bm, stride_bn, stride_bk),
        offsets=(0, 0, 0), block_shape=(BLOCK_M, BLOCK_N, BLOCK_K),
        order=(2, 1, 0))
    tl.store(b_block_ptr, a, boundary_check=(0, 1, 2))

tensor最大维度限制

如下例:

  def kernel_load_support_4(
        a_ptr, b_ptr, M, N, K, P, b_M, b_N, b_K, b_P,
        stride_am, stride_an, stride_ak, stride_ap,
        stride_bm, stride_bn, stride_bk, stride_bp,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
        BLOCK_P: tl.constexpr
):
    a_block_ptr = tl.make_block_ptr(
        base=a_ptr, shape=(M, N, K, P),
        strides=(stride_am, stride_an, stride_ak, stride_ap),
        offsets=(0, 0, 0, 0), block_shape=(BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_P),
        order=(3, 2, 1, 0))
    a = tl.load(a_block_ptr, boundary_check=(0, 1, 2, 3), padding_option="zero")
    b_block_ptr = tl.make_block_ptr(
        base=b_ptr, shape=(b_M, b_N, b_K, b_P),
        strides=(stride_bm, stride_bn, stride_bk, stride_bp),
        offsets=(0, 0, 0, 0), block_shape=(BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_P),
        order=(3, 2, 1, 0))
    tl.store(b_block_ptr, a, boundary_check=(0, 1, 2, 3))
   def kernel_load_not_support_5(
        a_ptr, b_ptr, M, N, K, P, Q, b_M, b_N, b_K, b_P, b_Q,
        stride_am, stride_an, stride_ak, stride_ap, stride_aq,
        stride_bm, stride_bn, stride_bk, stride_bp, strdie_bq,
        BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
        BLOCK_P: tl.constexpr, BLOCK_Q: tl.constexpr
):
    a_block_ptr = tl.make_block_ptr(
        base=a_ptr, shape=(M, N, K, P, Q),
        strides=(stride_am, stride_an, stride_ak, stride_ap, stride_aq),
        offsets=(0, 0, 0, 0, 0),
        block_shape=(BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_P, BLOCK_Q),
        order=(4, 3, 2, 1, 0))
    a = tl.load(a_block_ptr, boundary_check=(0, 1, 2, 3, 4),
                padding_option="zero")
    b_block_ptr = tl.make_block_ptr(
        base=b_ptr, shape=(b_M, b_N, b_K, b_P, b_Q),
        strides=(stride_bm, stride_bn, stride_bk, stride_bp, strdie_bq),
        offsets=(0, 0, 0, 0, 0),
        block_shape=(BLOCK_M, BLOCK_N, BLOCK_K, BLOCK_P, BLOCK_Q),
        order=(4, 3, 2, 1, 0))
    tl.store(b_block_ptr, a, boundary_check=(0, 1, 2, 3, 4))