vLLM竞赛入门案例-含完整代码

↑↑↑关注后"星标"kaggle竞赛宝典
  kaggle竞赛宝典  作者:Chris

 vLLM竞赛入门案例-含完整代码!

简介

本文介绍Kaggle竞赛中使用vLLM进行模型预测的Baseline,对此感兴趣的朋友可以去下面链接中参赛https://www.kaggle.com/competitions/lmsys-chatbot-arena 跟着大神一起进步学习。

案例

%%time
# 安装工具包
!pip uninstall -y torch
!pip install -U --no-index --find-links=/kaggle/input/vllm-whl -U vllm
!pip install -U --upgrade /kaggle/input/vllm-t4-fix/grpcio-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
!pip install -U --upgrade /kaggle/input/vllm-t4-fix/ray-2.11.0-cp310-cp310-manylinux2014_x86_64.whl
import os, math, numpy as np
# 指定gpu
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

1. 加载带有 vLLM 的 34B 量化模型

import vllm

llm = vllm.LLM(
    "/kaggle/input/bagel-v3-343",
    quantization="awq",
    tensor_parallel_size=2
    gpu_memory_utilization=0.95
    trust_remote_code=True,
    dtype="half"
    enforce_eager=True,
    max_model_len=1024,
    #distributed_executor_backend="ray",
)
tokenizer = llm.get_tokenizer()

2.加载测试数据

import pandas as pd
VALIDATE = 128

test = pd.read_csv("/kaggle/input/lmsys-chatbot-arena/test.csv"
if len(test)==3:
    test = pd.read_csv("/kaggle/input/lmsys-chatbot-arena/train.csv")
    test = test.iloc[:VALIDATE]

3.Engineer Prompt

from typing import Any, Dict, List
from transformers import LogitsProcessor
import torch

choices = ["A","B","tie"]

KEEP = []
for x in choices:
    c = tokenizer.encode(x,add_special_tokens=False)[0]
    KEEP.append(c)
print(f"Force predictions to be tokens {KEEP} which are {choices}.")

class DigitLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer):
        self.allowed_ids = KEEP
        
    def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
        scores[self.allowed_ids] += 100
        return scores
SS = "#"*25 + "\n"
all_prompts = []
for index,row in test.iterrows():
    
    a = " ".join(eval(row.prompt, {"null"""}))
    b = " ".join(eval(row.response_a, {"null"""}))
    c = " ".join(eval(row.response_b, {"null"""}))
    
    prompt = f"{SS}PROMPT: "+a+f"\n\n{SS}RESPONSE A: "+b+f"\n\n{SS}RESPONSE B: "+c+"\n\n"
    
    formatted_sample = sys_prompt + "\n\n" + prompt
    
    all_prompts.append( formatted_sample )

4. 预测

此处使用快速vLLM进行测试推断。我们要求vLLM输出认为应该在第一个标记中被预测的前5个标记的概率。我们还将预测限制在1个标记,以增加推理速度。

根据推断128个训练样本所需的速度,我们可以推断推断25,000个测试样本将需要多长时间。

from time import time
start = time()

logits_processors = [DigitLogitsProcessor(tokenizer)]
responses = llm.generate(
    all_prompts,
    vllm.SamplingParams(
        n=1,  # Number of output sequences to return for each prompt.
        top_p=0.9,  # Float that controls the cumulative probability of the top tokens to consider.
        temperature=0,  # randomness of the sampling
        seed=777# Seed for reprodicibility
        skip_special_tokens=True,  # Whether to skip special tokens in the output.
        max_tokens=1,  # Maximum number of tokens to generate per output sequence.
        logits_processors=logits_processors,
        logprobs = 5
    ),
    use_tqdm = True
)

end = time()
elapsed = (end-start)/60. #minutes
print(f"Inference of {VALIDATE} samples took {elapsed} minutes!")

5.预测概率

results = []
errors = 0

for i,response in enumerate(responses):
    try:
        x = response.outputs[0].logprobs[0]
        logprobs = []
        for k in KEEP:
            if k in x:
                logprobs.append( math.exp(x[k].logprob) )
            else:
                logprobs.append( 0 )
                print(f"bad logits {i}")
        logprobs = np.array( logprobs )
        logprobs /= logprobs.sum()
        results.append( logprobs )
    except:
        #print(f"error {i}")
        results.append( np.array([1/3.1/3.1/3.]) )
        errors += 1
        
print(f"There were {errors} inference errors out of {i+1} inferences")
results = np.vstack(results)


7.结果提交

sub = pd.read_csv("/kaggle/input/lmsys-chatbot-arena/sample_submission.csv")

if len(test)!=VALIDATE:
    sub[["winner_model_a","winner_model_b","winner_tie"]] = results
    
sub.to_csv("submission.csv",index=False)
sub.head()

参考文献

  1. https://www.kaggle.com/code/cdeotte/infer-34b-with-vllm
  2. https://www.kaggle.com/competitions/lmsys-chatbot-arena


相关推荐

  • [开源]Java生态下企业级AIGC项目解决方案,集成AIGC大模型功能
  • 如何在实际项目中优雅运用设计模式?
  • 为什么我们相信英伟达能到 5 万亿 | AGIX 投什么
  • 如何向数学小白解释PCA(主成分分析)算法
  • 原来支持OPC UA的PLC真么牛!!!
  • 2.2K Star精美监控!!!运维用了,在公司横着走
  • Spring Boot集成screw实现数据库文档生成
  • 启动资金5000块,2个人如何在抖音带货100w/月?
  • 解锁转转门店业务灵活性:如何利用MVEL引擎优化结算流程
  • 【云原生|K8S系列】不再迷茫!跟随这份攻略,10分钟了解K8S持久化存储!
  • 最高贴息1000万!杭州14条重磅AI新政发布,每年发2.5亿元算力券
  • 扎克伯格深度专访:中美AI竞争完全错误,美国别想长期领先中国
  • 云巨头大暴走,自研CPU落地200万张!新一轮芯片洗牌开始了
  • 基于大模型 + 知识库的 Code Review 实践
  • AI编程哪家强?一次对比4大编程助手
  • “开源模型是智商税” v.s. “开源AI是前进的道路”
  • 苹果iPhone 15 Pro跑起精简版Win11
  • 成都最硬核的公司,陆奇投了
  • 美团后端日常实习面试,轻松拿捏了!
  • 深入分析 C++ 错误处理:哪种策略的性能最强?