7. 示例代码

注意: 用户如有需要可以从hugging face自行下载安装使用开源数据集,燧原不对该开源数据集做任何承诺,使用该开源数据集产生的一切后果和风险由用户自行承担。

7.1. 量化样例

量化过程

量化示例代码:basic_quant.py

import argparse

from topscompressor.quantization.quantize import quantize, save_quantized_model, accelerate_gptq_pack_model
from topscompressor.quantization.config import QuantConfig

def main(args):
    if args.method in ['gptq', 'w8a16']:
        accelerate_gptq_pack_model()

    quant_config = QuantConfig.create_config(args.method)
    calib_data_name = 'wikitext'
    calib_data_config = {
        'name': 'wikitext-2-raw-v1',
        'split': 'validation',
    }
    model = quantize(
        args.model_name_or_path,
        quant_config,
        calib_data_name,
        calib_data_load_fn_kwargs=calib_data_config,
        calib_data_max_len=512,
        n_samples=args.nsamples,
        device=args.device
    )

    save_quantized_model(model, quant_config, args.save_dir)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="quantize model sample.")
    parser.add_argument("--model_name_or_path",
                        type=str,
                        required=True,)
    parser.add_argument("--method",
                        type=str,
                        choices=['awq', 'gptq', 'w8a16'],
                        required=True,)
    parser.add_argument("--save_dir",
                        type=str,
                        required=True,)
    parser.add_argument("--device",
                        type=str,
                        choices=['gcu', 'cuda'],
                        default='gcu',)
    parser.add_argument("--nsamples", type=int, default=128)
    args = parser.parse_args()
    main(args)

执行方式:

# --method 可取 awq,gptq,w8a16
python3 basic_quant.py --model_name_or_path model_path --method w8a16 --save_dir save_path

量化后模型ppl计算

量化模型ppl计算示例代码:basic_eval_ppl.py

注意: 以下方式需要使用gpu测试

import argparse

import torch
import torch.nn as nn

from topscompressor.quantization.model.auto import AutoQuantForCausalML

def get_wikitext2(tokenizer_path):
    from datasets import load_dataset
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    from transformers import AutoTokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False, trust_remote_code=True)
    except:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True)
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
    return testenc

@torch.no_grad()
def eval(model, testenv, seqlen=2048):
    print("Evaluating ... ")

    use_cache = model.config.use_cache
    model.config.use_cache = False

    testenv = testenv.input_ids
    nsamples = testenv.numel() // seqlen

    nlls = []
    for i in range(nsamples):
        batch = testenv[:, (i*seqlen):((i+1)*seqlen)].cuda()
        lm_logits = model(batch).logits

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = testenv[:, (i * seqlen):((i + 1) * seqlen)][:, 1:].cuda()
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * seqlen
        nlls.append(neg_log_likelihood)
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen))
    print(ppl.item())

    model.config.use_cache = use_cache

def main(args):
    testenc = get_wikitext2(args.tokenizer_path)
    # 此处使用AutoQuantForCausalML,可以支持量化模型推理,对于原始transformers模型此处使用transformers库load原始模型即可
    quant_model = AutoQuantForCausalML.from_quantized(args.quantized_model_dir, device_map="auto")
    eval(quant_model.model, testenc)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Benchmark the throughput.")
    parser.add_argument("--tokenizer_path",
                        type=str,
                        required=True,)
    parser.add_argument("--quantized_model_dir",
                        type=str,
                        required=True,)
    args = parser.parse_args()
    main(args)

执行方式:

python3 basic_eval_ppl.py --tokenizer_path tokenizer_path --quantized_model_dir quant_model_path

int8 kvcache量化过程

量化示例代码:int8_kv_quant.py

import argparse

from topscompressor.quantization.quantize import quantize, save_kvcache_params, accelerate_gptq_pack_model
from topscompressor.quantization.config import QuantConfig

def main(args):
    accelerate_gptq_pack_model()
    kv_config = QuantConfig.create_config('int8_kv')
    kv_config.sym_quant_kv = args.sym_quant_kv
    # 传入模型可以是量化后模型
    kv_config.is_quantized_model = not args.is_pretrained_model
    calib_data_name = 'wikitext'
    calib_data_config = {
        'name': 'wikitext-2-raw-v1',
        'split': 'validation',
    }
    model = quantize(
        args.model_name_or_path,
        kv_config,
        calib_data_name,
        calib_data_load_fn_kwargs=calib_data_config,
        n_samples=args.nsamples,
        device=args.device
    )

    save_kvcache_params(model, kv_config, args.save_dir)

if __name__ == '__main__':    
    parser = argparse.ArgumentParser(description="quantize model sample.")
    parser.add_argument("--model_name_or_path",
                        type=str,
                        required=True,)
    parser.add_argument("--save_dir",
                        type=str,
                        required=True,)
    parser.add_argument('--sym_quant_kv', action='store_true')
    parser.add_argument('--is_pretrained_model', action='store_true')
    parser.add_argument("--device",
                        type=str,
                        choices=['gcu', 'cuda'],
                        default='gcu',)
    parser.add_argument("--nsamples", type=int, default=128)
    args = parser.parse_args()
    main(args)

执行方式:

python3 int8_kv_cache.py --model_name_or_path model_path --sym_quant_kv --save_dir save_path --device cuda

int8 kvcache模型ppl计算

量化模型ppl计算示例代码:int8_kv_eval_ppl.py

注意: 以下方式需要使用gpu测试

import argparse

import torch
import torch.nn as nn
from transformers import AutoConfig
from topscompressor.quantization.model.auto import AutoQuantForCausalML
from topscompressor.quantization.quantizer.kv_cache.auto_kv import replace_kv_linears

def get_wikitext2(tokenizer_path):
    from datasets import load_dataset
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    
    from transformers import AutoTokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_path,
            use_fast=False,
            trust_remote_code=True
        )
    except:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_path,
            use_fast=True,
            trust_remote_code=True
        )
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
    return testenc

@torch.no_grad()
def eval(model, testenv, seqlen=2048):
    print("Evaluating ... ")

    use_cache = model.config.use_cache
    model.config.use_cache = False

    testenv = testenv.input_ids
    nsamples = testenv.numel() // seqlen

    nlls = []
    for i in range(nsamples):
        print(f"eval sample at {i}")
        batch = testenv[:, (i*seqlen):((i+1)*seqlen)].cuda()
        lm_logits = model(batch).logits

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = testenv[:, (i * seqlen):((i + 1) * seqlen)][:, 1:].cuda()
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * seqlen
        nlls.append(neg_log_likelihood)
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen))
    print(f"ppl is: {ppl.item()}")

    model.config.use_cache = use_cache

def main(args):
    testenc = get_wikitext2(args.tokenizer_path)
    seqlen = 2048
    model_config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
    if args.is_pretrained_model:
        from transformers import AutoModelForCausalLM
        tgt_model = AutoModelForCausalLM.from_pretrained(args.model_dir,
                                                         torch_dtype=model_config.torch_dtype,
                                                     device_map="auto", 
                                                     trust_remote_code=True)

    else:
        quant_model = AutoQuantForCausalML.from_quantized(args.model_dir, 
                                                          device_map="auto")
        if hasattr(quant_model.config, "max_position_embeddings"):
            seqlen = quant_model.config.max_position_embeddings
        tgt_model = quant_model.model

    int8_kv_cache = False
    if args.int8_kv_file != "":
        replace_kv_linears(tgt_model, args.int8_kv_file)
        int8_kv_cache = True

    eval(tgt_model, testenc, seqlen=seqlen)

    if int8_kv_cache and hasattr(tgt_model, "kv_quantizer"):
        tgt_model.kv_quantizer.restore_rope_hook()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Evaluate the ppl of the model.")
    parser.add_argument("--tokenizer_path",
                        type=str,
                        required=True,)
    parser.add_argument("--model_dir",
                        type=str,
                        required=True,)
    parser.add_argument('--int8_kv_file',
                        type=str,
                        default="",)
    parser.add_argument('--is_pretrained_model', action='store_true')
    args = parser.parse_args()
    main(args)

执行方式:

python3 int8_kv_eval_ppl.py --tokenizer_path tokenizer_path --model_dir model_path --int8_kv_file int8_kv_file --is_pretrained_model