欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 建筑 > 【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码逐行深度解析(上)

【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码逐行深度解析(上)

2025/2/22 16:58:04 来源:https://blog.csdn.net/sinat_16020825/article/details/145716399  浏览:    关键词:【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码逐行深度解析(上)

目录

  • 4 GRPO源码分析
    • 4.1 数据类 `GRPOScriptArguments`
    • 4.2 系统提示字符串 `SYSTEM_PROMPT`
    • 4.3 奖励函数
      • 4.3.1 accuracy_reward函数
      • 4.3.2 verify函数
      • 4.3.3 format_reward函数
    • 4.4 将数据集格式化为对话形式
    • 4.5 初始化GRPO Trainer


【复现DeepSeek-R1之Open R1实战】系列3:SFT和GRPO源码逐行深度解析(上)
【复现DeepSeek-R1之Open R1实战】系列5:SFT和GRPO源码逐行深度解析(中)

4 GRPO源码分析

前面两篇博文已经详细介绍了一些基础知识和SFT源码,本文继续解读GRPO源码。与SFT源码差不多的部分,我们就不展开细说了,这里只解析GRPO独特的部分。

4.1 数据类 GRPOScriptArguments

该类使用了 Python 的 dataclass 装饰器,这是一种简化类定义的方式,特别是对于那些主要用来存储数据的类。它继承自 ScriptArguments 类。

  • reward_funcs: 这是一个列表,包含了一系列可能的奖励函数名称,默认值为 ["accuracy", "format"]。这些奖励函数可能是用于评估模型性能的不同标准。

    reward_funcs: list[str] = field(default_factory=lambda: ["accuracy", "format"],metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"},
    )
    
  • cosine_min_value_wrongcosine_max_value_wrong: 分别表示错误答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.0-0.5

  • cosine_min_value_correctcosine_max_value_correct: 分别表示正确答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.51.0

  • cosine_max_len: 表示余弦相似度尺度的最大长度,默认值为 1000

  • repetition_n_grams: 表示用于重复惩罚奖励的n-gram数量,默认值为 3

  • repetition_max_penalty: 表示重复惩罚奖励的最大负值,默认值为 -1.0

每个字段都使用了 field() 函数来定义其默认值和元数据(如帮助信息)。这有助于工具和库更好地理解和处理这些字段,例如生成命令行解析器时。

4.2 系统提示字符串 SYSTEM_PROMPT

SYSTEM_PROMPT = ("A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant ""first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning ""process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., ""<think> reasoning process here </think><answer> answer here </answer>"
)

字符串描述了一个对话场景,用户先提问,助手首先思考推理过程,然后提供答案。推理过程和答案分别用 <think><answer> 标签包裹,这种格式化有助于区分和识别不同的部分,和DeepSeek-R1的思考过程格式一致。

4.3 奖励函数

奖励函数的定义如下,GRPO默认用到了accuracy_reward和format_reward这两个函数。

# Get reward functionsREWARD_FUNCS_REGISTRY = {"accuracy": accuracy_reward,"format": format_reward,"reasoning_steps": reasoning_steps_reward,"cosine": get_cosine_scaled_reward(min_value_wrong=script_args.cosine_min_value_wrong,max_value_wrong=script_args.cosine_max_value_wrong,min_value_correct=script_args.cosine_min_value_correct,max_value_correct=script_args.cosine_max_value_correct,max_len=script_args.cosine_max_len,),"repetition_penalty": get_repetition_penalty_reward(ngram_size=script_args.repetition_n_grams,max_penalty=script_args.repetition_max_penalty,),"length": len_reward,}reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

这段代码定义了一个奖励函数注册表 REWARD_FUNCS_REGISTRY,并根据用户提供的配置动态生成一个奖励函数列表 reward_funcs。每个奖励函数用于评估模型输出的不同方面,如准确性、格式、推理步骤等。

  1. 注册表定义
  • accuracy: 使用 accuracy_reward 函数评估模型输出的准确性。
  • format: 使用 format_reward 函数评估模型输出的格式。
  • reasoning_steps: 使用 reasoning_steps_reward 函数评估模型输出的推理步骤。
  • cosine: 使用 get_cosine_scaled_reward 函数计算余弦相似度奖励,参数包括:
    • min_value_wrong: 错误情况下的最小值。
    • max_value_wrong: 错误情况下的最大值。
    • min_value_correct: 正确情况下的最小值。
    • max_value_correct: 正确情况下的最大值。
    • max_len: 最大长度。
  • repetition_penalty: 使用 get_repetition_penalty_reward 函数计算重复惩罚奖励,参数包括:
    • ngram_size: n-gram 的大小。
    • max_penalty: 最大惩罚值。
  • length: 使用 len_reward 函数评估模型输出的长度。
  1. 动态生成奖励函数列表
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
  • 根据 script_args.reward_funcs 中指定的奖励函数名称,从 REWARD_FUNCS_REGISTRY 中获取相应的奖励函数,并生成一个列表 reward_funcs

4.3.1 accuracy_reward函数

该函数用于计算模型生成的补全与真实答案之间的准确性奖励。它通过解析和验证生成的内容与真实答案来确定奖励值。

def accuracy_reward(completions, solution, **kwargs):"""Reward function that checks if the completion is the same as the ground truth."""contents = [completion[0]["content"] for completion in completions]rewards = []for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)if len(gold_parsed) != 0:# We require the answer to be provided in correct latex (no malformed operators)answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)# Reward 1 if the content is the same as the ground truth, 0 otherwisereward = float(verify(answer_parsed, gold_parsed))else:# If the gold solution is not parseable, we reward 1 to skip this examplereward = 1.0print("Failed to parse gold solution: ", sol)rewards.append(reward)return rewards
  • completions (list): 包含多个补全结果的列表,每个补全结果是一个包含内容的字典列表。
  • solution (list): 真实答案的列表。
  • kwargs: 其他可选参数(在本函数中未使用)。
  1. 提取补全内容

    contents = [completion[0]["content"] for completion in completions]
    
    • completions 列表中提取每个补全的第一个内容(假设每个补全是单个元素的列表),形成一个新的 contents 列表。
  2. 初始化奖励列表

    rewards = []
    
  3. 遍历每个补全和对应的真实答案

    for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)
    
    • 使用 zip 函数将 contentssolution 配对。
    • 对于每一对补全内容和真实答案,首先解析真实答案 sol,使用 parse 函数提取其中的信息。
  4. 处理解析结果

    if len(gold_parsed) != 0:answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)
    
    • 如果解析得到的真实答案 gold_parsed 非空,则继续解析生成的补全内容 content
    • 使用 LatexExtractionConfigNormalizationConfig 进行详细配置,确保解析过程中考虑了各种格式要求(如方程、单位等)。
  5. 计算奖励

    reward = float(verify(answer_parsed, gold_parsed))
    
    • 使用 verify 函数比较生成的补全解析结果和真实答案的解析结果。
    • 如果两者匹配,则返回 1.0,否则返回 0.0
  6. 处理无法解析的情况

    else:reward = 1.0print("Failed to parse gold solution: ", sol)
    
    • 如果真实答案无法解析,则默认给予奖励 1.0 并打印警告信息。
  7. 添加奖励到列表

    rewards.append(reward)
    
  8. 返回所有奖励

    return rewards
    

4.3.2 verify函数

该函数用于验证目标表达式是否与参考表达式匹配,它通过多种比较策略来处理不同的数学对象(如数字、表达式、集合、矩阵等),并提供灵活的配置选项以适应不同的需求。

def verify(gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, float_rounding: int=6,numeric_precision: int=15,strict: bool=True,timeout_seconds: int=3
) -> bool:
  • gold: 参考或正确的表达式,可以是单个 SymPy 表达式(BasicMatrixBase)、字符串或这些类型的列表。
  • target: 需要验证的表达式,类型同 gold
  • float_rounding: 浮点数舍入的小数位数,默认为 6。
  • numeric_precision: 数值比较时考虑的小数位数,默认为 15。
  • strict: 是否启用严格比较模式,默认为 True
    • 在严格模式下:变量很重要,集合不可与元组比较。
    • 在非严格模式下:变量按位置匹配,集合可与元组比较。
  • timeout_seconds: 单次比较操作的最大超时时间(秒),默认为 3 秒。
  1. 定义内部比较函数 compare_single_extraction

    @timeout(timeout_seconds=timeout_seconds)
    def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool:...
    
    • 使用装饰器 @timeout 设置超时保护,默认超时时间为 timeout_seconds
    • 比较两个表达式:
      • 如果两者都是 SymPy 表达式(BasicMatrixBase),则调用 sympy_expr_eq 进行比较。
      • 如果两者都是字符串,则进行简单的字符串比较。
  2. 定义包装函数 compare_single_extraction_wrapper

    def compare_single_extraction_wrapper(g, t):try:return compare_single_extraction(g, t)except Exception as e:logger.exception(f"Error comparing {g} and {t}")return False
    
    • 包装 compare_single_extraction,捕获并记录任何异常,返回 False 以避免程序中断。
  3. 处理输入列表

    if not isinstance(gold, list):gold = [gold]
    if not isinstance(target, list):target = [target]
    
    • 如果 goldtarget 不是列表,则将其转换为单元素列表,以便统一处理。
  4. 组合所有可能的比较

    return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target))
    
    • 使用 itertools.product 生成所有可能的 goldtarget 组合。
    • 对每个组合调用 compare_single_extraction_wrapper,如果任意一对匹配成功,则返回 True

4.3.3 format_reward函数

函数用于检查给定的完成文本是否符合特定的格式,它验证完成文本是否包含 <think><answer> 标签,并且这两个标签的内容是非空的。

def format_reward(completions, **kwargs):"""Reward function that checks if the completion has a specific format."""pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"completion_contents = [completion[0]["content"] for completion in completions]matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]return [1.0 if match else 0.0 for match in matches]
  • completions: 这是一个列表,其中每个元素都是一个包含完成内容的对象(通常是字典)。假设每个完成对象的第一个元素包含一个键 "content",其值是需要检查的文本。
  • kwargs: 其他关键字参数,这里没有使用,但可以为未来的扩展提供灵活性。
  1. 正则表达式模式定义

    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    
    • 这个正则表达式用于匹配字符串是否以 <think> 开始,紧接着是任意字符(非贪婪匹配),然后是 </think>,接着可能有任意数量的空白字符(包括换行符),最后是以 <answer> 开始并以 </answer> 结束。
    • .*? 是非贪婪匹配,确保尽可能少地匹配字符。
    • \s* 匹配零个或多个空白字符(包括换行符)。
    • re.DOTALL | re.MULTILINE 标志允许点号 . 匹配所有字符(包括换行符),并且使多行文本中的每一行都可以独立匹配。
  2. 提取完成内容

    completion_contents = [completion[0]["content"] for completion in completions]
    
    • 这里通过列表推导式从 completions 列表中提取每个完成对象的第一个元素的 "content" 字段,形成一个新的列表 completion_contents
  3. 匹配正则表达式

    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    
    • 使用 re.match 函数对 completion_contents 中的每个内容应用正则表达式模式。
    • matches 列表将包含 re.Match 对象(如果匹配成功)或 None(如果匹配失败)。
  4. 生成奖励分数

    return [1.0 if match else 0.0 for match in matches]
    
    • 最后一步是根据匹配结果生成奖励分数。如果匹配成功(即 match 不是 None),则返回 1.0;否则返回 0.0

示例代码:

completions = [[{"content": "<think>This is reasoning.</think><answer>This is answer.</answer>"}],[{"content": "<think>This is reasoning.</think>"}],[{"content": "<answer>This is answer.</answer>"}],[{"content": "This does not match."}]
]reward_scores = format_reward(completions)
print(reward_scores)  # 输出: [1.0, 0.0, 0.0, 0.0]

在这个例子中:

  • 第一个完成内容完全匹配正则表达式,因此得分为 1.0
  • 后三个完成内容不符合要求,因此得分均为 0.0

4.4 将数据集格式化为对话形式

# Format into conversationdef make_conversation(example):return {"prompt": [{"role": "system", "content": SYSTEM_PROMPT},{"role": "user", "content": example["problem"]},],}dataset = dataset.map(make_conversation)for split in dataset:if "messages" in dataset[split].column_names:dataset[split] = dataset[split].remove_columns("messages")

将一个数据集中的每个示例转换为对话格式,并确保数据集中没有多余的列(如 messages)。

  • 输入example 是一个字典,包含单个数据样本的信息,其中 "problem" 键对应的值是用户的问题或任务描述。
  • 输出:返回一个新的字典,包含一个 "prompt" 键,其值是一个对话列表:
    • 第一条消息是系统消息,内容由 SYSTEM_PROMPT 定义。
    • 第二条消息是用户消息,内容是 example["problem"]
  • dataset.map(make_conversation):使用 map 方法将 make_conversation 函数应用到数据集的每个示例上,生成新的对话格式。
  • 移除多余列:遍历数据集的每个拆分(split),如果存在 "messages" 列,则将其移除。

4.5 初始化GRPO Trainer

trainer = GRPOTrainer(model=model_args.model_name_or_path,reward_funcs=reward_funcs,args=training_args,train_dataset=dataset[script_args.dataset_train_split],eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,peft_config=get_peft_config(model_args),callbacks=get_callbacks(training_args, model_args),)

篇幅有限,训练部分的代码我们放到下一篇博文详细解读!

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词