六、分布式嵌入
文章目录
- 六、分布式嵌入
- 前言
- 一、先要配置torch.distributed环境
- 二、Distributed Embeddings
- 2.1 EmbeddingBagCollectionSharder
- 2.2 ShardedEmbeddingBagCollection
- 三、Planner
- 总结
前言
- 我们已经使用了TorchRec的主模块:EmbeddedBagCollection。我们在上一节研究了它是如何工作的,以及数据在TorchRec中是如何表示的。然而,我们还没有探索TorchRec的主要部分之一,即分布式嵌入
一、先要配置torch.distributed环境
- EmbeddingBagCollectionSharder 依赖于 PyTorch 的分布式通信库(torch.distributed)来管理跨进程/GPU 的分片和通信。
首先初始化分布式环境
import torch.distributed as dist# 初始化进程组
dist.init_process_group(backend="nccl", # GPU 推荐 NCCL 后端, CPU就是 glooinit_method="env://", # 从环境变量读取节点信息rank=rank, # 当前进程的全局唯一标识(从 0 开始)world_size=world_size, # 总进程数(总 GPU 数)
)pg = dist.GroupMember.WORLD
设置环境变量(多节点训练时必须)
import torch.distributed as dist# 初始化进程组
# 在每个节点上设置以下环境变量
export MASTER_ADDR="主节点IP" # 如 "192.168.1.1"
export MASTER_PORT="66666" # 任意未占用端口
export WORLD_SIZE=4 # 总 GPU 数
export RANK=0 # 当前节点的全局 rank
二、Distributed Embeddings
- 先回顾一下我们上一节的EmbeddingBagCollection module
代码演示:
print(ebc)
"""
EmbeddingBagCollection((embedding_bags): ModuleDict((product_table): EmbeddingBag(4096, 64, mode='sum')(user_table): EmbeddingBag(4096, 64, mode='sum'))
)
"""
2.1 EmbeddingBagCollectionSharder
- 策略制定者 ,决定如何分片。
- 决定如何将 EmbeddingBagCollection 的嵌入表(Embedding Tables)分布到多个 GPU/节点。
核心功能 :根据配置(如 ShardingType)生成分片计划(Sharding Plan)
代码演示:
from torchrec.distributed.embedding_types import ShardingType
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder# 定义分片器:指定分片策略(如按表分片)
sharder = EmbeddingBagCollectionSharder(sharding_type=ShardingType.TABLE_WISE.value, # 每个表分配到一个 GPUkernel_type=EmbeddingComputeKernel.FUSED.value, # 使用 fused 优化
)
- 关键参数
- sharding_type:分片策略,如:
- TABLE_WISE:整个表放在一个 GPU。
- ROW_WISE:按行分片到多个 GPU。
- COLUMN_WISE:按列分片(适用于超大表)。
- kernel_type:计算内核类型(如 FUSED 优化显存)
- sharding_type:分片策略,如:
2.2 ShardedEmbeddingBagCollection
- 策略执行者 ,实际管理分片后的嵌入表
- 根据 EmbeddingBagCollectionSharder 生成的分片计划,实际管理分布在多设备上的嵌入表。
- 核心功能 :在分布式环境中执行前向传播、梯度聚合和参数更新
代码演示:
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection# 根据分片器生成分片后的模块
sharded_ebc = ShardedEmbeddingBagCollection(module=ebc, # 原始 EmbeddingBagCollectionsharder=sharder, # 分片策略device=device, # 目标设备(如 GPU:0)
)
三、Planner
- 它可以帮助我们确定最佳的分片配置。
- Planner能够根据嵌入表的数量和GPU的数量来确定最佳配置。事实证明,这很难手动完成,工程师必须考虑大量因素来确保最佳的分片计划。
- TorchRec在提供的这个Planner,可以帮助我们:
- 评估硬件的内存限制
- 将基于存储器获取的计算估计为嵌入查找
- 解决数据特定因素
- 考虑其他硬件细节,如带宽,以生成最佳分片计划
演示代码:
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology# 初始化Planner
planner = EmbeddingShardingPlanner(topology=Topology( # 硬件拓扑信息world_size=4, # 总 GPU 数compute_device="cuda",local_world_size=2, # 单机 GPU 数batch_size=1024, ),constraints={ # 可选约束(如强制某些表使用特定策略)"user_id": ParameterConstraints(sharding_types=[ShardingType.TABLE_WISE]),},
)# 生成分片计划
plan = planner.collective_plan(ebc, [sharder], pg)# 分片后的模型
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollectionsharded_ebc = ShardedEmbeddingBagCollection(module=ebc,sharder=sharder,device=torch.device("cuda:0"),plan=plan, # 应用自动生成的分片计划
)
总结
- TorchRec中的分布式嵌入以及训练设置。