9. 编程接口¶
本章节在当前版本主要以Triton官方原生接口为基础,更新了在TritonGCU有潜在变化的接口描述。
9.1. 内存/指针操作¶
load¶
triton.language.load(pointer, mask=None, other=None, boundary_check=(),
padding_option='', cache_modifier='', eviction_policy='',
volatile=False)
limit to gcu
不支持动态stride传入0值,如必须传入0值,需要对该参数添加常量约束条件tl.constexpr。见示例:隐式broadcast限制(stride=0)
多维tensor动态stride的情况下:stride从大到小必须是倍数关系(即一次load中内存的每个地址只能被访问1次)。见示例:同一内存多次读取限制(stride从大到小不是整数倍)
tensor的维度不能超过4维。见示例:tensor最大维度限制
store¶
triton.language.store(pointer, value, mask=None, boundary_check=(),
cache_modifier='', eviction_policy='')
limit to gcu
不支持动态stride传入0值,如必须传入0值,需要对该参数添加常量约束条件tl.constexpr。见示例:隐式broadcast限制(stride=0)
多维tensor动态stride的情况下:stride从大到小必须是倍数关系。见示例:同一内存多次读取限制(stride从大到小不是整数倍)
tensor的维度不能超过4维。见示例:tensor最大维度限制
make_block_ptr¶
triton.language.make_block_ptr(base: tensor, shape, strides, offsets,
block_shape, order)
limit to gcu
shape描述tensor的维度不能超过4维。见示例:tensor最大维度限制
多维tensor动态stride的情况下:stride从大到小必须是倍数关系。见示例:同一内存多次读取限制(stride从大到小不是整数倍)
strides里的值不能全是0
9.2. 数学操作¶
abs¶
triton.language.abs(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0
cdiv¶
triton.language.cdiv(x, div)
精度 n/a
ceil¶
triton.language.ceil(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0
clamp¶
triton.language.clamp(x, min, max,
propagate_nan: ~triton.language.core.constexpr = <PROPAGATE_NAN.NONE: 0>)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0
cos¶
triton.language.cos(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.68e+7 非规格化输入的当前精度 (max_diff: ulp): 0
div_rn¶
triton.language.div_rn(x, y)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.68e+7 非规格化输入的当前精度 (max_diff: ulp): 2
erf¶
triton.language.erf(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.43e+7 非规格化输入的当前精度 (max_diff: ulp): 9.47e+6
exp¶
triton.language.exp(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 3.94e+3 非规格化输入的当前精度 (max_diff: ulp): 0
exp2¶
triton.language.exp2(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 4.19e+6 非规格化输入的当前精度 (max_diff: ulp): 0
fdiv¶
triton.language.fdiv(x, y, ieee_rounding=False)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.68e+7 非规格化输入的当前精度 (max_diff: ulp): 2
floor¶
triton.language.floor(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0
fma¶
triton.language.fma(x, y, z)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0
log¶
triton.language.log(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 2.05e+3 非规格化输入的当前精度 (max_diff: ulp): 3.4e+38
log2¶
triton.language.log2(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.48e+3 非规格化输入的当前精度 (max_diff: ulp): 3.4e+38
maximum¶
triton.language.maximum(x, y,
propagate_nan: ~triton.language.core.constexpr = <PROPAGATE_NAN.NONE: 0>)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0
minimum¶
triton.language.minimum(x, y,
propagate_nan: ~triton.language.core.constexpr = <PROPAGATE_NAN.NONE: 0>)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 0 非规格化输入的当前精度 (max_diff: ulp): 0
rsqrt¶
triton.language.rsqrt(x)
计算元素 x 的逆平方根。
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1 非规格化输入的当前精度 (max_diff: ulp): 3.09e+26
sigmoid¶
triton.language.sigmoid(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 4.5e+0 非规格化输入的当前精度 (max_diff: ulp): 3.94e+3
sin¶
triton.language.sin(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 1.68e+7 非规格化输入的当前精度 (max_diff: ulp): 0
softmax¶
triton.language.softmax(x, ieee_rounding=False)
精度 n/a
sqrt¶
triton.language.sqrt(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 2 非规格化输入的当前精度 (max_diff: ulp): 1.68e+7
sqrt_rn¶
triton.language.sqrt_rn(x)
精度 单精度(float32)结果: 规格化输入的当前精度 (max_diff: ulp): 2 非规格化输入的当前精度 (max_diff: ulp): 1.68e+7
umulhi¶
triton.language.umulhi(x, y)
精度 n/a
9.3. 规约操作¶
argmax¶
triton.language.argmax(input, axis, tie_break_left=True, keep_dims=False)
返回沿指定 axis 的 input 张量中所有元素的最大索引
Parameters:
input (Tensor) – 输入值
axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
tie_break_left (bool) – 如果为 true,在出现平局的情况下(即多个元素具有相同的最大索引值),对于非 NaN 的值返回最左边的索引
argmin¶
triton.language.argmin(input, axis, tie_break_left=True, keep_dims=False)
返回沿指定 axis 的 input 张量中所有元素的最小索引
Parameters:
input (Tensor) – 输入值
axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
tie_break_left (bool) – 如果为 true,在出现平局的情况下(即多个元素具有相同的最小索引值),对于非 NaN 的值返回最左边的索引
max¶
triton.language.max(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False)
返回沿指定 axis 轴上 input 张量中所有元素的最大值
Parameters:
input (Tensor) – 输入值
axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
return_indices (bool) – 如果为 true,则返回对应最大值的索引
return_indices_tie_break_left (bool) – 如果为 true,在出现平局的情况下(即多个元素具有相同的最大值),对于非 NaN 的值返回最左边的索引
min¶
triton.language.min(input, axis=None, return_indices=False, return_indices_tie_break_left=True, keep_dims=False)
返回沿指定 axis 轴上 input 张量中所有元素的最小值
Parameters:
input (Tensor) – 输入值
axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
keep_dims (bool) – 如果为 true,则保留长度为 1 的归约维度
return_indices (bool) – 如果为 true,则返回对应最小值的索引
return_indices_tie_break_left (bool) – 如果为 true,在出现平局的情况下(即多个元素具有相同的最小值),对于非 NaN 的值返回最左边的索引
reduce¶
triton.language.reduce(input, axis, combine_fn, keep_dims=False)
将 combine_fn 应用于沿指定 axis 轴上 input 张量中的所有元素
Parameters:
input (Tensor) – 输入张量,或张量的元组
axis (int | None) – 要进行归约操作的维度。如果为 None,则归约所有维度
combine_fn (Callable) – 1 个用于组合 2 组标量张量的函数(必须使用 @triton.jit 标记)
keep_dims (bool) – 如果为 true,保留长度为 1 的归约维度
sum¶
triton.language.sum(input, axis=None, keep_dims=False, dtype: constexpr | None = None)
返回 input 张量中,沿指定 axis 的所有元素的总和
Parameters:
input (Tensor) – 输入值
axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
keep_dims (bool) – 如果为 true,保留长度为 1 的归约维度
dtype (tl.dtype) – 张量返回值所需的数据类型。 如果指定,则在执行 sum 操作之前将输入张量进行强制转换,这对于防止数据溢出非常有用。 如果未指定,则 integer 和 bool 类型将隐式向上转换为 tl.int32 ,float 类型将向上转换为tl.float32。
xor_sum¶
triton.language.xor_sum(input, axis=None, keep_dims=False)
沿指定 axis 的 input 张量中所有元素的异或和
Parameters:
input (Tensor) – 输入值
axis (int) – 要进行归约操作的维度。如果为 None,则归约所有维度
keep_dims (bool) – 如果为 true,保留长度为 1 的归约维度
9.4. 扫描/排序操作¶
associative_scan¶
triton.language.associative_scan(input, axis, combine_fn, reverse=False)
沿指定 axis 将 combine_fn 应用于 input 张量的每个元素携带的值,并更新携带的值
Parameters:
input (Tensor) – 输入张量,或张量的元组
axis (int) – 要进行归约操作的维度
combine_fn (Callable) – 1 个用于组合 2 组标量张量的函数(必须使用 @triton.jit 标记)
reverse (bool) – 是否沿着轴进行反向关联扫描
cumprod¶
triton.language.cumprod(input, axis=0, reverse=False)
返回沿指定 axis 的 input 张量中所有元素的累积乘积
Parameters:
input (Tensor) – 输入值
axis (int) – 应进行扫描的维度
cumsum¶
triton.language.cumsum(input, axis=0, reverse=False)
返回沿指定 axis 的 input 张量中所有元素的累积和
Parameters:
input (Tensor) – 输入值
axis (int) – 应进行扫描的维度
histogram¶
triton.language.histogram(input, num_bins)
# computes an histogram based on input tensor with num_bins bins,
# the bins have a width of 1 and start at 0.
统计每个bin内的元素数量,bin外的元素按照torch的处理不进行统计。
限制:
input 必须为 1D int8/int16/int32。
支持的最大shape(int32):M = 131072, N = 8192
GCU300 |
GCU400 |
|
---|---|---|
支持情况 |
支持 |
暂时使用gcu300的兼容实现 |
最大shape(i32) |
M = 262144, N = 262144(不保证在gcu400, gcu500上兼容) |
M = 131072, N = 8192 |
数据类型为i32,当N > 65536时,本算法的精度高于torch.histc的精度。
sort¶
triton.language.sort(x, dim: constexpr | None = None, descending: constexpr = constexpr[0])
沿着指定维度对张量进行排序。
Parameters:
x (Tensor) – 要排序的输入张量。
dim (int, optional) – 用于对张量进行排序的维度。如果为 None,则沿张量的最后一个维度进行排序。目前仅支持按最后一个维度排序
descending (bool, optional) – 如果设置为 True,则张量按降序排序。如果设置为 False,则张量按升序排序。