1. TopsGraphBuilder

1.1. constexpr 常量

constexpr int64_t kUnknownDim = -1;
constexpr int64_t kUnknownRank = -2;
简介:
  • kUnknownDim: 动态 dim

  • kUnknownRank: rank 未知

1.2. PrintingFlags

enum PrintingFlags {
  None = 0x00000,
  /// Enable the elision of large elements attributes, by printing a '...'
  /// instead of the element data. Note: The IR generated with this option is
  /// not parsable. largeElementLimit = 16
  ElideLargeElementsAttrs = 0x00001,
  /// Enable printing of debug information. If 'prettyForm' is set to true,
  /// debug information is printed in a more readable 'pretty' form. Note: The
  /// IR generated with 'prettyForm' is not parsable.
  EnableDebugInfo = 0x00010,
  EnableDebugInfoPrettyFormFlag = 0x00100,
  /// Always print operations in the generic form.
  PrintGenericOpForm = 0x01000,
  /// Use local scope when printing the operation. This allows for using the
  /// printer in a more localized and thread-safe setting, but may not
  /// necessarily be identical to what the IR will look like when dumping
  /// the full module.
  UseLocalScope = 0x10000
};
简介:
  • 控制打印 HLIR Module IR 显示形式

1.3. Builder

Builder()

Builder();
简介:
  • Builder 构造函数

返回值:
  • Builder 类对象

CreateInput

builder::Op CreateInput(const builder::Type& type, const char* func_name = "main");
简介:
  • 给函数 func_name 设置输入 op

参数:
  • type: func_name 对应输入的 type

  • func_name: 创建 input op 对应的 func,默认为 main

返回值:
  • 创建的输入 op,shape 和 dtype 对应的类型为 type.

SetOutput

void SetOutput(const std::vector<builder::Op>& outputs, const char* func_name = "main");
简介:
  • 给函数 func_name 设置输出 op

参数:
  • outputs: 要设置为输出的 ops

  • func_name: func 函数,默认为 main

GetInputs

std::vector<builder::Op> GetInputs(const char* func_name = "main");
简介:
  • 获取函数 func_name 所有输入 op

参数:
  • func_name: 要查询的 func 函数,默认为 main

返回值:
  • 当前 func_name 对应的所有输入 op

GetModuleStr

std::shared_ptr<char> GetModuleStr(uint32_t flags = PrintingFlags::None);
简介:
  • 获取当前构造的 HLIR Module 字符串,通过 flags 参数控制 HLIR Module IR 显示形式。

参数:
  • flags: 控制输出 IR 显示形式,支持 enum PrintingFlags 定义的形式。

返回值:
  • shared_ptr 封装的 char,包含 HLIR Module IR.

GetModule

std::shared_ptr<hlir::Module> GetModule();
简介:
  • 获取当前构造的 HLIR Module, 用于 TopsGraphCompiler 编译。

返回值:
  • shared_ptr 封装的 hlir::Module 对象。

AddFunc

void AddFunc(const char* func_name);
简介:
  • 创建一个命名为 func_name 的函数

参数:
  • func_name: 函数名称

SetModuleAttribute

void SetModuleAttribute(const char* attr_name, const Attribute& attr_value);
简介:
  • 给当前 HLIR Module 设置属性。

参数:
  • attr_name: 属性名称

  • attr_value: 属性值

SetModuleAttribute

void SetModuleAttribute(const char* attr_name, const std::vector<Attribute>& attr_value);
简介:
  • 给当前 HLIR Module 设置类型为 array 的属性。

参数:
  • attr_name: 属性名称

  • attr_value: 属性值

GetModuleAttribute

Attribute GetModuleAttribute(const char* attr_name) const;
简介:
  • 获取 HLIR Module 属性名为 attr_name 的属性。

参数:
  • attr_name: 属性名称

返回值:
  • HLIR Module 上挂载的属性名为 attr_name 的属性。

Print

void Print(std::ostream& out, uint32_t flags = PrintingFlags::None) const;
简介:
  • 打印当前 HLIR Module

参数:
  • out: 输出流

  • flags: 控制 HLIR Module IR 显示形式

Dump

void Dump() const;
简介:
  • 在终端打印当前 HLIR Module, flags 选项固定为 PrintingFlags::ElideLargeElementsAttrs, 即 largeElementLimit = 16

1.4. PrimitiveType

static member function

PrimitiveType

Basic data type in C++

NONE()

none(for unkown data type)

PRED()

bool

S8()

int8_t

U8()

uint8_t

S16()

int16_t

U16()

uint16_t

F16()

half(not a standard C++ type)

BF16()

bfloat16(not a standard C++ type)

S32()

int32_t

U32()

uint32_t

F32()

float

S64()

int64_t

U64()

uint64_t

F64()

double

TUPLE()

tuple(a list of above types)

static PrimitiveType NONE();
static PrimitiveType PRED();
static PrimitiveType S8();
static PrimitiveType U8();
static PrimitiveType F16();
static PrimitiveType BF16();
static PrimitiveType S16();
static PrimitiveType U16();
static PrimitiveType F32();
static PrimitiveType S32();
static PrimitiveType U32();
static PrimitiveType S64();
static PrimitiveType U64();
static PrimitiveType F64();
static PrimitiveType TUPLE();
简介:
  • 数据类型创建函数,分别创建不同类型的 PrimitiveType

返回值:
  • PrimitiveType 类对象

比较函数

bool operator==(const PrimitiveType& prim_type) const;
bool operator!=(const PrimitiveType& prim_type) const;
简介:
  • 比较两个 PrimitiveType 是否相等

Print

void Print(std::ostream& out) const;
简介:
  • 打印当前 PrimitiveType

参数:
  • out: 输出流

Dump

void Dump() const;
简介:
  • 在终端上打印当前 PrimitiveType

1.5. Type

Type 封装了 Op 的 shape 和 data type(PrimitiveType).

Type

explicit Type(PrimitiveType primitive_type = builder::PrimitiveType::NONE());
explicit Type(const std::vector<Type*>& types);
Type(const std::vector<int64_t>& shape, const PrimitiveType& primitive_type);
Type(const std::vector<std::vector<int64_t>>& shapes, const std::vector<PrimitiveType>& primitive_types);
简介:
  • 根据不同参数类型,构造对应的 Type 对象。

  • CASE #1: 标量类型

  • CASE #2: 根据已知的 Type 对象,构造 Tuple Type

  • CASE #3: 构造 RankedTensorType

  • CASE #4: 构造 Tuple Type, 要求传入的 shapes 和 primitive_types 大小相等

返回值:
  • Type 对象

比较函数

bool operator==(const Type& type) const;
bool operator!=(const Type& type) const;
简介:
  • 比较两个 Type 对象是否完全一样

IsDynamic

bool IsDynamic() const;
简介:
  • 判断是否为 dynamic shape, 即 shape 包含 -1

IsTuple

bool IsTuple() const;
简介:
  • 判断是否为 tuple type

GetRank

int64_t GetRank() const;
简介:
  • 获取 rank

GetSize

uint64_t GetSize() const;
简介:
  • 获取元素个数

GetShape

std::vector<int64_t> GetShape() const;
简介:
  • 获取 shape

GetPrimitiveType

PrimitiveType GetPrimitiveType() const;
简介:
  • 获取 data type

Note

更多 API 详细描述,请参考第 6 章节。

1.6. Op

GetBuilder

std::shared_ptr<Builder> GetBuilder() const;
简介:
  • 通过当前 Op,获取全局 Builder 对象指针

返回值:
  • 全局 Builder 对象指针

GetType

Type GetType() const;
简介:
  • 获取当前 Op 的 type

SetAttribute

void SetAttribute(const char* attr_name, const Attribute& attr_value, const char* mode = "");
简介:
  • 给当前 Op 设置属性

参数:
  • attr_name: 属性名

  • attr_value: 属性值

  • mode: 未使能,采用默认值即可

SetAttribute

void SetAttribute(const char* attr_name, const std::vector<Attribute>& attr_value);
简介:
  • 给当前 Op 设置 array 属性

参数:
  • attr_name: 属性名

  • attr_value: 属性值

Note

更多 API 详细描述,请参考第 6 章节。

1.7. Attribute

Attribute

Attribute();
explicit Attribute(const char* value);
explicit Attribute(bool value);
explicit Attribute(int32_t value);
explicit Attribute(int64_t value);
explicit Attribute(float value);
explicit Attribute(double value);
Attribute(Type data_type, void* data);
Attribute(Type data_type, std::shared_ptr<void> data_ptr);
简介:
  • 根据不同参数类型,构造对应的 Attribute 对象。

  • CASE #1: 空属性,不包含任何有效信息

  • CASE #2: 字符串类型属性

  • CASE #3: bool 类型属性

  • CASE #4: int32_t 类型属性

  • CASE #5: int64_t 类型属性

  • CASE #6: float 类型属性

  • CASE #7: double 类型属性

  • CASE #8: 自定义类型属性,类型为 data_type, 值为 data

  • CASE #9: 自定义类型属性,类型为 data_type, 值为 data_ptr 共享指针保存的内容

返回值:
  • Attribute 对象

Note

更多 API 详细描述,请参考第 6 章节。