14. 移植开源Triton算子到TritonGCU

14.1. 移植步骤

  1. 在Triton算子python文件的开头,加入 import triton_gcu.triton

    import triton
    import triton.language as tl
    import triton_gcu.triton
    
  2. 在Triton算子python文件的开头,加入 from torch_gcu import transfer_to_gcu

    import torch
    from torch_gcu import transfer_to_gcu
    
  3. 按照注意事项-a-name移植开源Triton算子到gcu-注意事项a 列出的条目,对Triton算子进行必要的修改。

  4. 完成移植。

14.2. 注意事项

  1. 数据类型 int64 / uint64

    • 对 ``int64`` / ``uint64`` 计算有强需求的Triton算子,无法移植到gcu300

    • 对 ``int64`` / ``uint64`` 计算没有强需求的Triton算子,在gcu300上可以改用 ``int32`` / ``uint32`` 数据类型

      • 在host代码里,pytorch数据类型 torch.int64 / torch.uint64,替换成 torch.int32 / torch.uint32

      • 在kernel代码里,Triton数据类型 tl.int64 / tl.uint64,替换成 tl.int32 / tl.uint32

      • 备注:pytorch创建int类型tensor的函数(例如, torch.randint()),默认使用数据类型 torch.int64。因此,在调用这类函数的时候,必须显性地指定数据类型(例如,torch.randint(..., dtype=torch.int32))。

  2. 数据类型 float64

    • 对 ``float64`` 计算有强需求的Triton算子,无法移植到gcu300

    • 对 ``float64`` 计算没有强需求的Triton算子,在gcu300上可以改用 ``float32`` 数据类型

      • 在host代码里,pytorch数据类型 torch.float64,替换成 torch.float32

      • 在kernel代码里,Triton数据类型 tl.float64,替换成 tl.float32

      • 备注:pytorch创建float类型tensor的函数(例如, torch.rand()),默认使用数据类型 torch.float64。因此,在调用这类函数的时候,必须显性地指定数据类型(例如,torch.randint(..., dtype=torch.float32))。

  3. 数据类型 float8

    • 对 ``float8`` 计算有强需求的Triton算子,无法移植到gcu300

    • 对 ``float8`` 计算没有强需求的Triton算子,在gcu300上可以改用 ``bfloat16`` 数据类型

      1. 在host代码里,pytorch数据类型 torch.float8_e5m2 / torch.float8_e4m3fn / torch.float8_e5m2fnuz / torch.float8_e4m3fnuz,替换成 torch.bfloat16

      2. 在kernel代码里,Triton数据类型 tl.float8e5 / tl.float8e5b16 / tl.float8e4nv / tl.float8e4b8 / tl.float8e4b15,替换成 tl.bfloat16