9. 编程接口

本章节在当前版本主要以Triton官方原生接口为基础,更新了在TritonGCU有潜在变化的接口描述。

9.1. 内存/指针操作

load

triton.language.load(pointer, mask=None, other=None, boundary_check=(),
                     padding_option='', cache_modifier='', eviction_policy='',
                     volatile=False)

limit to gcu

store

triton.language.store(pointer, value, mask=None, boundary_check=(),
                      cache_modifier='', eviction_policy='')

limit to gcu

make_block_ptr

triton.language.make_block_ptr(base: tensor, shape, strides, offsets,
                               block_shape, order)

limit to gcu

9.2. 数学操作

abs

triton.language.abs(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0

cdiv

triton.language.cdiv(x, div)
  • 精度 n/a

ceil

triton.language.ceil(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0

clamp

triton.language.clamp(x, min, max,
    propagate_nan: ~triton.language.core.constexpr = <PROPAGATE_NAN.NONE: 0>)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0

cos

triton.language.cos(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.68e+7 非规格化输入的当前精度 (max_diff: ulp): 0

div_rn

triton.language.div_rn(x, y)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.68e+7 非规格化输入的当前精度 (max_diff: ulp): 2

erf

triton.language.erf(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.43e+7 非规格化输入的当前精度 (max_diff: ulp): 9.47e+6

exp

triton.language.exp(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 3.94e+3 非规格化输入的当前精度 (max_diff: ulp): 0

exp2

triton.language.exp2(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 4.19e+6 非规格化输入的当前精度 (max_diff: ulp): 0

fdiv

triton.language.fdiv(x, y, ieee_rounding=False)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.68e+7 非规格化输入的当前精度 (max_diff: ulp): 2

floor

triton.language.floor(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0

fma

triton.language.fma(x, y, z)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0

log

triton.language.log(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 2.05e+3 非规格化输入的当前精度 (max_diff: ulp): 3.4e+38

log2

triton.language.log2(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.48e+3 非规格化输入的当前精度 (max_diff: ulp): 3.4e+38

maximum

triton.language.maximum(x, y,
    propagate_nan: ~triton.language.core.constexpr = <PROPAGATE_NAN.NONE: 0>)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0

minimum

triton.language.minimum(x, y,
    propagate_nan: ~triton.language.core.constexpr = <PROPAGATE_NAN.NONE: 0>)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0

rsqrt

triton.language.rsqrt(x)

计算元素 x 的逆平方根。

  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1 非规格化输入的当前精度 (max_diff: ulp): 3.09e+26

sigmoid

triton.language.sigmoid(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 4.5e+0 非规格化输入的当前精度 (max_diff: ulp): 3.94e+3

sin

triton.language.sin(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.68e+7 非规格化输入的当前精度 (max_diff: ulp): 0

softmax

triton.language.softmax(x, ieee_rounding=False)
  • 精度 n/a

sqrt

triton.language.sqrt(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 2 非规格化输入的当前精度 (max_diff: ulp): 1.68e+7

sqrt_rn

triton.language.sqrt_rn(x)
  • 精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 2 非规格化输入的当前精度 (max_diff: ulp): 1.68e+7

umulhi

triton.language.umulhi(x, y)
  • 精度 n/a

9.3. 规约操作

argmax

triton.language.argmax(input, axis, tie_break_left=True, keep_dims=False)
  返回沿指定 axis 的 input 张量中所有元素的最大索引
Parameters:
  input (Tensor) – 输入值
  axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
  keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
  tie_break_left (bool) – 如果为 true,在出现平局的情况下(即多个元素具有相同的最大索引值),对于非 NaN 的值返回最左边的索引

argmin

triton.language.argmin(input, axis, tie_break_left=True, keep_dims=False)
  返回沿指定 axis 的 input 张量中所有元素的最小索引
Parameters:
  input (Tensor) – 输入值
  axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
  keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
  tie_break_left (bool) – 如果为 true,在出现平局的情况下(即多个元素具有相同的最小索引值),对于非 NaN 的值返回最左边的索引

max

triton.language.max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False)
  返回沿指定 axis 轴上 input 张量中所有元素的最大值
Parameters:
  input (Tensor) – 输入值
  axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
  keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
  return_indices (bool) – 如果为 true,则返回对应最大值的索引
  return_indices_tie_break_left (bool) – 如果为 true,在出现平局的情况下(即多个元素具有相同的最大值),对于非 NaN 的值返回最左边的索引

min

triton.language.min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False)
  返回沿指定 axis 轴上 input 张量中所有元素的最小值
Parameters:
  input (Tensor) – 输入值
  axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
  keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
  return_indices (bool) – 如果为 true,则返回对应最小值的索引
  return_indices_tie_break_left (bool) – 如果为 true,在出现平局的情况下(即多个元素具有相同的最小值),对于非 NaN 的值返回最左边的索引

reduce

triton.language.reduce(input, axis, combine_fn, keep_dims=False)
  将 combine_fn 应用于沿指定 axis 轴上 input 张量中的所有元素
Parameters:
  input (Tensor) – 输入张量,或张量的元组
  axis (int | None) – 要进行归约操作的维度。如果为 None,则归约所有维度
  combine_fn (Callable) – 1 个用于组合 2 组标量张量的函数(必须使用 @triton.jit 标记)
  keep_dims (bool) – 如果为 true,保留长度为 1 的归约维度

sum

triton.language.sum(input, axis=None, keep_dims=False, dtype: constexpr | None = None)
  返回 input 张量中,沿指定 axis 的所有元素的总和
Parameters:
  input (Tensor) – 输入值
  axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
  keep_dims (bool) – 如果为 true,保留长度为 1 的归约维度
  dtype (tl.dtype) – 张量返回值所需的数据类型。 如果指定,则在执行 sum 操作之前将输入张量进行强制转换,这对于防止数据溢出非常有用。 如果未指定,则 integer 和 bool 类型将隐式向上转换为 tl.int32 ,float 类型将向上转换为tl.float32。

xor_sum

triton.language.xor_sum(input, axis=None, keep_dims=False)
  沿指定 axis 的 input 张量中所有元素的异或和
Parameters:
  input (Tensor) – 输入值
  axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
  keep_dims (bool) – 如果为 true,保留长度为 1 的归约维度

9.4. 扫描/排序操作

associative_scan

triton.language.associative_scan(input, axis, combine_fn, reverse=False)
  沿指定 axis 将 combine_fn 应用于 input 张量的每个元素携带的值,并更新携带的值
Parameters:
  input (Tensor) – 输入张量,或张量的元组
  axis (int) – 要进行归约操作的维度
  combine_fn (Callable) – 1 个用于组合 2 组标量张量的函数(必须使用 @triton.jit 标记)
  reverse (bool) – 是否沿着轴进行反向关联扫描

cumprod

triton.language.cumprod(input, axis=0, reverse=False)
  返回沿指定 axis 的 input 张量中所有元素的累积乘积
Parameters:
  input (Tensor) – 输入值
  axis (int) – 应进行扫描的维度

cumsum

triton.language.cumsum(input, axis=0, reverse=False)
  返回沿指定 axis 的 input 张量中所有元素的累积和
Parameters:
  input (Tensor) – 输入值
  axis (int) – 应进行扫描的维度

histogram

triton.language.histogram(input, num_bins)
# computes an histogram based on input tensor with num_bins bins,
# the bins have a width of 1 and start at 0.

统计每个bin内的元素数量,bin外的元素按照torch的处理不进行统计。

限制:

input 必须为 1D int8/int16/int32。

支持的最大shape(int32):M = 131072, N = 8192

GCU300

GCU400

支持情况

支持

暂时使用gcu300的兼容实现

最大shape(i32)

M = 262144, N = 262144(不保证在gcu400, gcu500上兼容)

M = 131072, N = 8192

数据类型为i32,当N > 65536时,本算法的精度高于torch.histc的精度。

sort

triton.language.sort(x, dim: constexpr | None = None, descending: constexpr = constexpr[0])
  沿着指定维度对张量进行排序。
Parameters:
  x (Tensor) – 要排序的输入张量。
  dim (int, optional) – 用于对张量进行排序的维度。如果为 None,则沿张量的最后一个维度进行排序。目前仅支持按最后一个维度排序
  descending (bool, optional) – 如果设置为 True,则张量按降序排序。如果设置为 False,则张量按升序排序。