欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 美景 > 六、分布式嵌入

六、分布式嵌入

2025/4/19 17:56:58 来源:https://blog.csdn.net/Lyg970112/article/details/147217402  浏览:    关键词:六、分布式嵌入

六、分布式嵌入


文章目录

  • 六、分布式嵌入
  • 前言
  • 一、先要配置torch.distributed环境
  • 二、Distributed Embeddings
    • 2.1 EmbeddingBagCollectionSharder
    • 2.2 ShardedEmbeddingBagCollection
  • 三、Planner
  • 总结


前言

  • 我们已经使用了TorchRec的主模块:EmbeddedBagCollection。我们在上一节研究了它是如何工作的,以及数据在TorchRec中是如何表示的。然而,我们还没有探索TorchRec的主要部分之一,即分布式嵌入

一、先要配置torch.distributed环境

  • EmbeddingBagCollectionSharder 依赖于 PyTorch 的分布式通信库(torch.distributed)来管理跨进程/GPU 的分片和通信。

首先初始化分布式环境

import torch.distributed as dist# 初始化进程组
dist.init_process_group(backend="nccl",          # GPU 推荐 NCCL 后端, CPU就是 glooinit_method="env://",    # 从环境变量读取节点信息rank=rank,               # 当前进程的全局唯一标识(从 0 开始)world_size=world_size,   # 总进程数(总 GPU 数)
)pg = dist.GroupMember.WORLD

设置环境变量(多节点训练时必须)

import torch.distributed as dist# 初始化进程组
# 在每个节点上设置以下环境变量
export MASTER_ADDR="主节点IP"   # 如 "192.168.1.1"
export MASTER_PORT="66666"     # 任意未占用端口
export WORLD_SIZE=4            # 总 GPU 数
export RANK=0                  # 当前节点的全局 rank

二、Distributed Embeddings

  • 先回顾一下我们上一节的EmbeddingBagCollection module

代码演示:

print(ebc)
"""
EmbeddingBagCollection((embedding_bags): ModuleDict((product_table): EmbeddingBag(4096, 64, mode='sum')(user_table): EmbeddingBag(4096, 64, mode='sum'))
)
"""

2.1 EmbeddingBagCollectionSharder

  • 策略制定者 ,决定如何分片。
  • 决定如何将 EmbeddingBagCollection 的嵌入表(Embedding Tables)分布到多个 GPU/节点。
    核心功能 :根据配置(如 ShardingType)生成分片计划(Sharding Plan

代码演示:

from torchrec.distributed.embedding_types import ShardingType
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder# 定义分片器:指定分片策略(如按表分片)
sharder = EmbeddingBagCollectionSharder(sharding_type=ShardingType.TABLE_WISE.value,  # 每个表分配到一个 GPUkernel_type=EmbeddingComputeKernel.FUSED.value,  # 使用 fused 优化
)
  • 关键参数
    • sharding_type:分片策略,如:
      • TABLE_WISE:整个表放在一个 GPU。
      • ROW_WISE:按行分片到多个 GPU。
      • COLUMN_WISE:按列分片(适用于超大表)。
    • kernel_type:计算内核类型(如 FUSED 优化显存)

2.2 ShardedEmbeddingBagCollection

  • 策略执行者 ,实际管理分片后的嵌入表
  • 根据 EmbeddingBagCollectionSharder 生成的分片计划,实际管理分布在多设备上的嵌入表。
  • 核心功能 :在分布式环境中执行前向传播、梯度聚合和参数更新

代码演示:

from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection# 根据分片器生成分片后的模块
sharded_ebc = ShardedEmbeddingBagCollection(module=ebc,        # 原始 EmbeddingBagCollectionsharder=sharder,   # 分片策略device=device,     # 目标设备(如 GPU:0)
)

三、Planner

  • 它可以帮助我们确定最佳的分片配置。
  • Planner能够根据嵌入表的数量和GPU的数量来确定最佳配置。事实证明,这很难手动完成,工程师必须考虑大量因素来确保最佳的分片计划。
  • TorchRec在提供的这个Planner,可以帮助我们:
    • 评估硬件的内存限制
    • 将基于存储器获取的计算估计为嵌入查找
    • 解决数据特定因素
    • 考虑其他硬件细节,如带宽,以生成最佳分片计划

演示代码:

from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology# 初始化Planner
planner = EmbeddingShardingPlanner(topology=Topology(  # 硬件拓扑信息world_size=4,  # 总 GPU 数compute_device="cuda",local_world_size=2,  # 单机 GPU 数batch_size=1024,  ),constraints={  # 可选约束(如强制某些表使用特定策略)"user_id": ParameterConstraints(sharding_types=[ShardingType.TABLE_WISE]),},
)# 生成分片计划
plan = planner.collective_plan(ebc, [sharder], pg)# 分片后的模型
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollectionsharded_ebc = ShardedEmbeddingBagCollection(module=ebc,sharder=sharder,device=torch.device("cuda:0"),plan=plan,  # 应用自动生成的分片计划
)

总结

  • TorchRec中的分布式嵌入以及训练设置。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词