14. 移植开源Triton算子到TritonGCU¶
14.1. 移植步骤¶
在Triton算子python文件的开头,加入
import triton_gcu.triton:import triton import triton.language as tl import triton_gcu.triton
在Triton算子python文件的开头,加入
from torch_gcu import transfer_to_gcu:import torch from torch_gcu import transfer_to_gcu
按照注意事项-a-name移植开源Triton算子到gcu-注意事项a 列出的条目,对Triton算子进行必要的修改。
完成移植。
14.2. 注意事项¶
数据类型
int64/uint64对 ``int64`` / ``uint64`` 计算有强需求的Triton算子,无法移植到gcu300。
对 ``int64`` / ``uint64`` 计算没有强需求的Triton算子,在gcu300上可以改用 ``int32`` / ``uint32`` 数据类型:
在host代码里,pytorch数据类型
torch.int64/torch.uint64,替换成torch.int32/torch.uint32。在kernel代码里,Triton数据类型
tl.int64/tl.uint64,替换成tl.int32/tl.uint32。备注:pytorch创建int类型tensor的函数(例如,
torch.randint()),默认使用数据类型torch.int64。因此,在调用这类函数的时候,必须显性地指定数据类型(例如,torch.randint(..., dtype=torch.int32))。
数据类型
float64对 ``float64`` 计算有强需求的Triton算子,无法移植到gcu300。
对 ``float64`` 计算没有强需求的Triton算子,在gcu300上可以改用 ``float32`` 数据类型:
在host代码里,pytorch数据类型
torch.float64,替换成torch.float32。在kernel代码里,Triton数据类型
tl.float64,替换成tl.float32。备注:pytorch创建float类型tensor的函数(例如,
torch.rand()),默认使用数据类型torch.float64。因此,在调用这类函数的时候,必须显性地指定数据类型(例如,torch.randint(..., dtype=torch.float32))。
数据类型
float8对 ``float8`` 计算有强需求的Triton算子,无法移植到gcu300。
对 ``float8`` 计算没有强需求的Triton算子,在gcu300上可以改用 ``bfloat16`` 数据类型:
在host代码里,pytorch数据类型
torch.float8_e5m2/torch.float8_e4m3fn/torch.float8_e5m2fnuz/torch.float8_e4m3fnuz,替换成torch.bfloat16。在kernel代码里,Triton数据类型
tl.float8e5/tl.float8e5b16/tl.float8e4nv/tl.float8e4b8/tl.float8e4b15,替换成tl.bfloat16。