欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 名人名企 > PointNeXt 源码阅读 (I) —— 注册机制与参数解析

PointNeXt 源码阅读 (I) —— 注册机制与参数解析

2024/10/23 21:24:32 来源:https://blog.csdn.net/woyaomaishu2/article/details/140877037  浏览:    关键词:PointNeXt 源码阅读 (I) —— 注册机制与参数解析

Title: PointNeXt 源码阅读 (I) —— 注册机制与参数解析


文章目录

  • 前言
  • I. 注册机制
    • 1. 注册类 Registry
    • 2. 类的注册
    • 3. 注册应用
  • II. 参数解析
    • 1. 命令行解析
    • 2. 参数加载更新
    • 3. 获得的参数
  • III. 总结
    • 1. 结果
    • 2. Todo


前言

学习了部分 PointNeXt 源码, 先记录一下, 以备忘.

本篇博文分为两部分, 注册机制和参数解析, 理解的重点是注册机制.

相关注释和调试信息都是基于下面测试 session.

CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py \--cfg cfgs/s3dis/pointnext-s.yaml  mode=train

I. 注册机制

所谓注册机制是指 PointNeXt 中模块/类的注册机制, 可以实现字符串到模块/类的映射. 换而言之, 这种注册机制就可以实现读入配置文件中的参数字符串, 进而直接映射获得对应的模块/类的实例. 这部分的实现 PointNeXt 源作者参考了 mmcv 中的注册机制.

1. 注册类 Registry

注册机制本身是通过注册类 class Registry 实现的, 其中关键方法有:

方法解释
__init__()类初始化, 其中也初始化了注册模块字典 self._module_dict = dict()
get(self, key)实现从字符串到类的映射, 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key]
register_module(self, name=None, force=False, module=None)注册模块, 实现对模块/类的注册, 也用作为对模块/类进行装饰的装饰器
_register(cls)装饰器 register_module 内部的包装函数 wrapper. 适用于装饰情况下的调用, 参数 cls 就是传递进来的需要被装饰的类. 这个包装函数在 cls 类定义的基础上, 先调用_register_module(self, module_class, module_name=None, force=False) 实现了类的注册 self._module_dict[name] = module_class, 然后没有任何其他处理而直接 return 了类定义 cls
build_from_cfg(cfg, registry, default_args=None)从配置字典构建模块/类实例, 实现由字符串生成模块/类实例

openpoints/utils/registry.py 中定义了 Registry 类, 添加注释如下.

# Acknowledgement: built upon mmcv
import inspect
import warnings
from functools import partial
import copy class Registry:"""A registry to map strings to classes.Registered object could be built from registry.Example:>>> MODELS = Registry('models')>>> @MODELS.register_module()>>> class ResNet:>>>     pass>>> resnet = MODELS.build(dict(NAME='ResNet'))Please refer to https://mmcv.readthedocs.io/en/latest/registry.html foradvanced useage.Args:name (str): Registry name.build_func(func, optional): Build function to construct instance fromRegistry, func:`build_from_cfg` is used if neither ``parent`` or``build_func`` is specified. If ``parent`` is specified and``build_func`` is not given,  ``build_func`` will be inheritedfrom ``parent``. Default: None.parent (Registry, optional): Parent registry. The class registered inchildren registry could be built from parent. Default: None.scope (str, optional): The scope of registry. It is the key to searchfor children registry. If not specified, scope will be the name ofthe package where class is defined, e.g. mmdet, mmcls, mmseg.Default: None."""def __init__(self, name, build_func=None, parent=None, scope=None):self._name = nameself._module_dict = dict()self._children = dict()self._scope = self.infer_scope() if scope is None else scope# self._scope = 'openpoints'# self.build_func will be set with the following priority:# 1. build_func# 2. parent.build_func# 3. build_from_cfgif build_func is None:if parent is not None:self.build_func = parent.build_funcelse:self.build_func = build_from_cfgelse:self.build_func = build_funcif parent is not None:assert isinstance(parent, Registry)parent._add_children(self)self.parent = parentelse:self.parent = Nonedef __len__(self):return len(self._module_dict)def __contains__(self, key):return self.get(key) is not Nonedef __repr__(self):format_str = self.__class__.__name__ + \f'(name={self._name}, ' \f'items={self._module_dict})'return format_str@staticmethoddef infer_scope():"""Infer the scope of registry.The name of the package where registry is defined will be returned.Example:# in mmdet/models/backbone/resnet.py>>> MODELS = Registry('models')>>> @MODELS.register_module()>>> class ResNet:>>>     passThe scope of ``ResNet`` will be ``mmdet``.Returns:scope (str): The inferred scope name."""# inspect.stack() trace where this function is called, the index-2# indicates the frame where `infer_scope()` is calledfilename = inspect.getmodule(inspect.stack()[2][0]).__name__# filename = 'openpoints.models.build'split_filename = filename.split('.')  # ['openpoints', 'models', 'build']return split_filename[0]  # 'openpoints'@staticmethod  # 返回函数的静态方法、声明一个静态方法def split_scope_key(key):"""Split scope and key.The first scope will be split from key.Examples:>>> Registry.split_scope_key('mmdet.ResNet')'mmdet', 'ResNet'>>> Registry.split_scope_key('ResNet')None, 'ResNet'Return:scope (str, None): The first scope.key (str): The remaining key."""split_index = key.find('.')  # 如果没有检测到 key 中包含字符, 则返回 -1; 如果检测到了该字符, 则返回开始时的索引值if split_index != -1:return key[:split_index], key[split_index + 1:]else:return None, key@propertydef name(self):return self._name@propertydef scope(self):return self._scope@propertydef module_dict(self):return self._module_dict@propertydef children(self):return self._childrendef get(self, key):# 实现从字符串到类的映射# 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key]"""Get the registry record.Args:key (str): The class name in string format.Returns:class: The corresponding class."""scope, real_key = self.split_scope_key(key)  # key = BaseSeg; scope = None; real_key = BaseSegif scope is None or scope == self._scope:# get from selfif real_key in self._module_dict:return self._module_dict[real_key]else:# get from self._childrenif scope in self._children:return self._children[scope].get(real_key)else:# goto rootparent = self.parentwhile parent.parent is not None:parent = parent.parentreturn parent.get(key)def build(self, *args, **kwargs):return self.build_func(*args, **kwargs, registry=self)def _add_children(self, registry):"""Add children for a registry.The ``registry`` will be added as children based on its scope.The parent registry could build objects from children registry.Example:>>> models = Registry('models')>>> mmdet_models = Registry('models', parent=models)>>> @mmdet_models.register_module()>>> class ResNet:>>>     pass>>> resnet = models.build(dict(NAME='mmdet.ResNet'))"""assert isinstance(registry, Registry)assert registry.scope is not Noneassert registry.scope not in self.children, \f'scope {registry.scope} exists in {self.name} registry'self.children[registry.scope] = registrydef _register_module(self, module_class, module_name=None, force=False):if not inspect.isclass(module_class):raise TypeError('module must be a class, 'f'but got {type(module_class)}')if module_name is None:module_name = module_class.__name__if isinstance(module_name, str):module_name = [module_name]for name in module_name:if not force and name in self._module_dict:raise KeyError(f'{name} is already registered 'f'in {self.name}')self._module_dict[name] = module_classdef deprecated_register_module(self, cls=None, force=False):warnings.warn('The old API of register_module(module, force=False) ''is deprecated and will be removed, please use the new API ''register_module(name=None, force=False, module=None) instead.')if cls is None:return partial(self.deprecated_register_module, force=force)self._register_module(cls, force=force)return clsdef register_module(self, name=None, force=False, module=None): # 装饰器"""Register a module.A record will be added to `self._module_dict`, whose key is the classname or the specified name, and value is the class itself.It can be used as a decorator or a normal function.Example:>>> backbones = Registry('backbone')>>> @backbones.register_module()>>> class ResNet:>>>     pass>>> backbones = Registry('backbone')>>> @backbones.register_module(name='mnet')>>> class MobileNet:>>>     pass>>> backbones = Registry('backbone')>>> class ResNet:>>>     pass>>> backbones.register_module(ResNet)Args:name (str | None): The module name to be registered. If notspecified, the class name will be used.force (bool, optional): Whether to override an existing class withthe same name. Default: False.module (type): Module class to be registered."""if not isinstance(force, bool):raise TypeError(f'force must be a boolean, but got {type(force)}')# NOTE: This is a walkaround to be compatible with the old api,# while it may introduce unexpected bugs.if isinstance(name, type):return self.deprecated_register_module(name, force=force)# raise the error ahead of timeif not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):raise TypeError('name must be either of None, an instance of str or a sequence'f'  of str, but got {type(name)}')# use it as a normal method: x.register_module(module=SomeClass)# 正常调用 reister_module, 不是装饰情况if module is not None:self._register_module(module_class=module, module_name=name, force=force)return module# use it as a decorator: @x.register_module()# 这是装饰器的包装函数 wrapper# 装饰情况下的调用, cls 就是传递进来的需要被装饰的类 def _register(cls):self._register_module(module_class=cls, module_name=name, force=force)return clsreturn _register  # 装饰器返回这个包装函数def build_from_cfg(cfg, registry, default_args=None):"""Build a module from config dict.Args:cfg (edict): Config dict. It should at least contain the key "NAME".registry (:obj:`Registry`): The registry to search the type from.Returns:object: The constructed object."""if not isinstance(cfg, dict):raise TypeError(f'cfg must be a dict, but got {type(cfg)}')if 'NAME' not in cfg:if default_args is None or 'NAME' not in default_args:raise KeyError('`cfg` or `default_args` must contain the key "NAME", 'f'but got {cfg}\n{default_args}')if not isinstance(registry, Registry):raise TypeError('registry must be an mmcv.Registry object, 'f'but got {type(registry)}')if not (isinstance(default_args, dict) or default_args is None):raise TypeError('default_args must be a dict or None, 'f'but got {type(default_args)}')# if default_args is not None:#     cfg = config.merge_new_config(cfg, default_args)obj_type = cfg.get('NAME')   # 'BaseSeg'if isinstance(obj_type, str):obj_cls = registry.get(obj_type)  # <class 'openpoints.models.segmentation.base_seg.BaseSeg'># 按照名字字符串 从 self._module_dict 找出对应的 类/模块# 实现从字符串到类的映射if obj_cls is None:raise KeyError(f'{obj_type} is not in the {registry.name} registry')elif inspect.isclass(obj_type):obj_cls = obj_typeelse:raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')try:obj_cfg = copy.deepcopy(cfg)if default_args is not None:obj_cfg.update(default_args) obj_cfg.pop('NAME')  # 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项# 'NAME' 已完成对类 obj_cls 的映射return obj_cls(**obj_cfg)# 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)# 又由于 BaseSeg 加了装饰器 @MODELS.register_module() # 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg))# 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了# 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)except Exception as e:# Normal TypeError does not print class name.raise type(e)(f'{obj_cls.__name__}: {e}')

2. 类的注册

首先在 openpoints/models/build.py 中声明和定义了全局的注册类对象 MODELS, 称为注册器.

通过 Python 的导入机制 import 命令, 注册器 MODELS 会在程序初始运行时 (先于 __main__/main()) 就建立.

from openpoints.utils import registry
MODELS = registry.Registry('models')
# 创建 register.Registry 对象 (MODELS 也称为注册器), 作为全局变量
# 程序初始运行, 先于 __main__/main() 的执行, 所以程序一开始就建立了注册器 MODELSdef build_model_from_cfg(cfg, **kwargs):"""Build a model, defined by `NAME`.Args:cfg (eDICT): Returns:Model: a constructed model specified by NAME."""return MODELS.build(cfg, **kwargs)

也是因为 Python 导入机制, 在注册器 MODELS 创立后, openpoints/models 下面在类定义前装饰了 @MODELS.register_module() 的类, 一旦被 import 扫描执行到, 都将被注册到 MODELS 注册器中.

例如下面的 BaseSeg 类也会先注册到 MODELS._module_dict 中.

"""
Author: PointNeXt
"""
import copy
from typing import List
import torch
import torch.nn as nn
import logging
from ...utils import get_missing_parameters_message, get_unexpected_parameters_message
from ..build import MODELS, build_model_from_cfg
from ..layers import create_linearblock, create_convblock1d# 为类 BaseSeg 加了装饰器 MODELS.register_module
# 调用 BaseSeg() 创建对象时, 效果相当于调用 MODELS.register_module(module=BaseSeg())
# 程序初始运行对类的装饰, 先于 __main__/main(), 但晚于注册器 MODELS 的建立.
# 所以在程序初始部分, 就以完成类的注册了, 待调用 main() 时, 就能顺利利用注册器将字符串转换为类 
@MODELS.register_module()      
class BaseSeg(nn.Module):def __init__(self,encoder_args=None,decoder_args=None,cls_args=None,**kwargs):super().__init__()

调试过程中, 跟踪查看 MODELS._module_dict 可以发现已经注册了好多类.

MODELS._module_dict = {
'PointNetEncoder': <class 'openpoints.models.backbone.pointnet.PointNetEncoder'>, 
'PointPatchEmbed': <class 'openpoints.models.layers.group_embed.PointPatchEmbed'>, 
'P3Embed': <class 'openpoints.models.layers.group_embed.P3Embed'>,
'PointNet2Encoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Encoder'>, 
'PointNet2Decoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Decoder'>, 
'PointNet2PartDecoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2PartDecoder'>, 
'PointNextEncoder': <class 'openpoints.models.backbone.pointnext.PointNextEncoder'>, 
'PointNextDecoder': <class 'openpoints.models.backbone.pointnext.PointNextDecoder'>, 
'PointNextPartDecoder': <class 'openpoints.models.backbone.pointnext.PointNextPartDecoder'>, 
'DGCNN': <class 'openpoints.models.backbone.dgcnn.DGCNN'>, 
'DeepGCN': <class 'openpoints.models.backbone.deepgcn.DeepGCN'>, 
'PointMLPEncoder': <class 'openpoints.models.backbone.pointmlp.PointMLPEncoder'>, 
'PointMLP': <class 'openpoints.models.backbone.pointmlp.PointMLP'>, 
'PointViT': <class 'openpoints.models.backbone.pointvit.PointViT'>, 
'PointViTDecoder': <class 'openpoints.models.backbone.pointvit.PointViTDecoder'>, 
'PointViTPartDecoder': <class 'openpoints.models.backbone.pointvit.PointViTPartDecoder'>, 
'InvPointViT': <class 'openpoints.models.backbone.pointvit_inv.InvPointViT'>, 
'InvPointViTDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTDecoder'>, 
'InvPointViTPartDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTPartDecoder'>, 
'CurveNet': <class 'openpoints.models.backbone.curvenet.CurveNet'>, 
'MVFC': <class 'openpoints.models.backbone.simpleview.MVFC'>, 
'MVModel': <class 'openpoints.models.backbone.simpleview.MVModel'>, 
'BaseSeg': <class 'openpoints.models.segmentation.base_seg.BaseSeg'>, 
'BasePartSeg': <class 'openpoints.models.segmentation.base_seg.BasePartSeg'>, 
'VariableSeg': <class 'openpoints.models.segmentation.base_seg.VariableSeg'>, 
'SegHead': <class 'openpoints.models.segmentation.base_seg.SegHead'>, 
'VariableSegHead': <class 'openpoints.models.segmentation.base_seg.VariableSegHead'>, 
'MultiSegHead': <class 'openpoints.models.segmentation.base_seg.MultiSegHead'>, 
'BaseCls': <class 'openpoints.models.classification.cls_base.BaseCls'>, 
'DistillCls': <class 'openpoints.models.classification.cls_base.DistillCls'>, 
'ClsHead': <class 'openpoints.models.classification.cls_base.ClsHead'>, 
'MaskedTransformerDecoder': <class 'openpoints.models.reconstruction.base_recontruct.MaskedTransformerDecoder'>, 
'FoldingNet': <class 'openpoints.models.reconstruction.base_recontruct.FoldingNet'>, 
'NodeShuffle': <class 'openpoints.models.reconstruction.base_recontruct.NodeShuffle'>, 
'MaskedPointViT': <class 'openpoints.models.reconstruction.maskedpointvit.MaskedPointViT'>, 
'MaskedPoint': <class 'openpoints.models.reconstruction.maskedpoint.MaskedPoint'>, 
'MaskedPointGroup': <class 'openpoints.models.reconstruction.maskedpointgroup.MaskedPointGroup'>
}

3. 注册应用

有了注册器 MODELS, 并向其注册了各个类, 那么就可以应用其由字符串映射为类的功能, 方便地从 .yaml 文件配置实现类实例的创建.

初略时序如下图所示:

examples/segmentation/main() openpoints/models/build.py class Registry openpoints/utils/registry.py 创建实例 registry.Registry('models') __init__(), self.build_func = build_from_cfg 全局对象 MODELS (注册器) build_model_from_cfg(cfg.model) MODELS.build(cfg, **kwargs) build(self, *args, **kwargs) build_from_cfg(cfg, registry, default_args=None) return obj_cls(**obj_cfg) [相当于 BaseSeg(**obj_cfg)] model examples/segmentation/main() openpoints/models/build.py class Registry
Fig 1. 利用注册器创建类对象 (深度神经网络模型) 的时序

其中由 .yaml 文件读取获得的配置字典变量 cfg 中存在 NAME 条目, 通过 registry.get(*) 就能获得 NAME 字符串对应的已经注册了的类. 获得了对应的类后, NAME 条目完成使命, 剩下的其他配置条目将被用于 PointNeXT 中具体的深度神将网络模块/类的自动化配置构造 (这篇博文不涉及).

细节注释参看类 Registry 的方法 build_from_cfg, 重复如下:

def build_from_cfg(cfg, registry, default_args=None):"""Build a module from config dict.Args:cfg (edict): Config dict. It should at least contain the key "NAME".registry (:obj:`Registry`): The registry to search the type from.Returns:object: The constructed object."""if not isinstance(cfg, dict):raise TypeError(f'cfg must be a dict, but got {type(cfg)}')if 'NAME' not in cfg:if default_args is None or 'NAME' not in default_args:raise KeyError('`cfg` or `default_args` must contain the key "NAME", 'f'but got {cfg}\n{default_args}')if not isinstance(registry, Registry):raise TypeError('registry must be an mmcv.Registry object, 'f'but got {type(registry)}')if not (isinstance(default_args, dict) or default_args is None):raise TypeError('default_args must be a dict or None, 'f'but got {type(default_args)}')# if default_args is not None:#     cfg = config.merge_new_config(cfg, default_args)obj_type = cfg.get('NAME')   # 'BaseSeg'if isinstance(obj_type, str):obj_cls = registry.get(obj_type)  # <class 'openpoints.models.segmentation.base_seg.BaseSeg'># 按照名字字符串 从 self._module_dict 找出对应的 类/模块# 实现从字符串到类的映射if obj_cls is None:raise KeyError(f'{obj_type} is not in the {registry.name} registry')elif inspect.isclass(obj_type):obj_cls = obj_typeelse:raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')try:obj_cfg = copy.deepcopy(cfg)if default_args is not None:obj_cfg.update(default_args) obj_cfg.pop('NAME')  # 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项# 'NAME' 已完成对类 obj_cls 的映射return obj_cls(**obj_cfg)# 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)# 又由于 BaseSeg 加了装饰器 @MODELS.register_module() # 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg))# 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了# 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)

II. 参数解析

注册机制需要字符串参数的传入以构建类实例. 而参数的获得需要借助于解析过程将 .yaml 文件中的配置读入程序中.

1. 命令行解析

主程序部分先要将相关的 .yaml 文件读入并更新到 cfg 字典变量中, 注释如下.

if __name__ == "__main__":parser = argparse.ArgumentParser('Scene segmentation training/testing')# 创建解析器parser.add_argument('--cfg', type=str, required=True, help='config file')parser.add_argument('--profile', action='store_true', default=False, help='set to True to profile speed')# 添加参数args, opts = parser.parse_known_args()# CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py --cfg cfgs/s3dis/pointnext-s.yaml  mode=train# 其中 CUDA_VISIBLE_DEVICES=0,1 为环境变量, 不由 parser 解析# args = Namespace(cfg='cfgs/s3dis/pointnext-s.yaml', profile=False)# opts = ['mode=train']cfg = EasyConfig()cfg.load(args.cfg, recursive=True)  # args.cfg = cfs/s3dis/pointnext-s.yamlcfg.update(opts)  # overwrite the default arguments in yml# mode = train 更新入 cfg 字典if cfg.seed is None:cfg.seed = np.random.randint(1, 10000)# init distributed env first, since logger depends on the dist info.cfg.rank, cfg.world_size, cfg.distributed, cfg.mp = dist_utils.get_dist_info(cfg)cfg.sync_bn = cfg.world_size > 1  # debug 时, 只能单块 GPU; 正常运行时, 可以多块并行# init log dircfg.task_name = args.cfg.split('.')[-2].split('/')[-2]  # task/dataset name, \eg s3dis, modelnet40_cls# args.cfg = 'cfgs/s3dis/pointnext-s.yaml'# args.cfg.split('.')[-2] = 'cfgs/s3dis/pointnext-s'# args.cfg.split('.')[-2].split('/')[-2] = 's3dis'cfg.cfg_basename = args.cfg.split('.')[-2].split('/')[-1]  # cfg_basename, \eg pointnext-xl\# args.cfg.split('.')[-2].split('/')[-1] = 'pointnext-s'tags = [cfg.task_name,  # task name (the folder of name under ./cfgscfg.mode,cfg.cfg_basename,  # cfg file namef'ngpus{cfg.world_size}',]# tags = ['s3dis', 'train', 'pointnext-s', 'ngpus1']opt_list = [] # for checking experiment configs from logging filefor i, opt in enumerate(opts):if 'rank' not in opt and 'dir' not in opt and 'root' not in opt and 'pretrain' not in opt and 'path' not in opt and 'wandb' not in opt and '/' not in opt:opt_list.append(opt)cfg.root_dir = os.path.join(cfg.root_dir, cfg.task_name)cfg.opts = '-'.join(opt_list)  # 使用'-'作分隔符来进行joincfg.is_training = cfg.mode not in ['test', 'testing', 'val', 'eval', 'evaluation']if cfg.mode in ['resume', 'val', 'test']:resume_exp_directory(cfg, pretrained_path=cfg.pretrained_path)  # 需要命令行 加 pretrained_path=XXXcfg.wandb.tags = [cfg.mode]else:generate_exp_directory(cfg, tags, additional_id=os.environ.get('MASTER_PORT', None))cfg.wandb.tags = tagsos.environ["JOB_LOG_DIR"] = cfg.log_dircfg_path = os.path.join(cfg.run_dir, "cfg.yaml")# cfg_path = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC/cfg.yaml'with open(cfg_path, 'w') as f:yaml.dump(cfg, f, indent=2)  # cfg 写入 f 文件os.system('cp %s %s' % (args.cfg, cfg.run_dir))# args.cfg = 'cfgs/s3dis/pointnext-s.yaml'# cfg.run_dir = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC'cfg.cfg_path = cfg_path# wandb configcfg.wandb.name = cfg.run_name# cfg.run_name = 's3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC'# multi processing.if cfg.mp:port = find_free_port()cfg.dist_url = f"tcp://localhost:{port}"print('using mp spawn for distributed training')mp.spawn(main, nprocs=cfg.world_size, args=(cfg,))else:main(0, cfg)

2. 参数加载更新

配置条目的读入和更新在类 EasyConfig 中实现, 部分注释如下.

class EasyConfig(dict):def __getattr__(self, key: str) -> Any:if key not in self:raise AttributeError(key)return self[key]def __setattr__(self, key: str, value: Any) -> None:self[key] = valuedef __delattr__(self, key: str) -> None:del self[key]def load(self, fpath: str, *, recursive: bool = False) -> None:"""load cfg from yamlArgs:fpath (str): path to the yaml filerecursive (bool, optional): recursily load its parent defaul yaml files. Defaults to False."""if not os.path.exists(fpath):raise FileNotFoundError(fpath)fpaths = [fpath]# 'cfgs/s3dis/pointnext-s.yaml'if recursive:  # Trueextension = os.path.splitext(fpath)[1]   # .yamlwhile os.path.dirname(fpath) != fpath:   # 如果 fpath 是文件路径fpath = os.path.dirname(fpath)  # 去掉文件名, 返回目录, 每次脱去一级fpaths.append(os.path.join(fpath, 'default' + extension))   #  fpaths =['cfgs/s3dis/pointnext-s.yaml', #           'cfgs/s3dis/default.yaml', #           'cfgs/default.yaml', #           'default.yaml']for fpath in reversed(fpaths):   # 反转迭代器if os.path.exists(fpath):with open(fpath) as f:self.update(yaml.safe_load(f))   # 把 fpaths 中的所有 .yaml 文件中的配置条目写在一个 dict 变量中def reload(self, fpath: str, *, recursive: bool = False) -> None:self.clear()self.load(fpath, recursive=recursive)# mutimethod makes python supports function overloading@multimethoddef update(self, other: Dict) -> None:  # .yaml items 转为 dict 变量中的 key:value 对for key, value in other.items():if isinstance(value, dict):if key not in self or not isinstance(self[key], EasyConfig):  # 子条目self[key] = EasyConfig()# recursively updateself[key].update(value)else:self[key] = value@multimethoddef update(self, opts: Union[List, Tuple]) -> None:index = 0while index < len(opts):opt = opts[index]if opt.startswith('--'):opt = opt[2:]if '=' in opt:key, value = opt.split('=', 1)index += 1else:key, value = opt, opts[index + 1]index += 2current = selfsubkeys = key.split('.')try:value = literal_eval(value)except:passfor subkey in subkeys[:-1]:current = current.setdefault(subkey, EasyConfig())current[subkeys[-1]] = valuedef dict(self) -> Dict[str, Any]:configs = dict()for key, value in self.items():if isinstance(value, EasyConfig):value = value.dict()configs[key] = valuereturn configsdef hash(self) -> str:buffer = json.dumps(self.dict(), sort_keys=True)return hashlib.sha256(buffer.encode()).hexdigest()def __str__(self) -> str:texts = []for key, value in self.items():if isinstance(value, EasyConfig):seperator = '\n'else:seperator = ' 'text = key + ':' + seperator + str(value)lines = text.split('\n')for k, line in enumerate(lines[1:]):lines[k + 1] = (' ' * 2) + linetexts.extend(lines)return '\n'.join(texts)

3. 获得的参数

fpaths = ['cfgs/s3dis/pointnext-s.yaml', 'cfgs/s3dis/default.yaml', 'cfgs/default.yaml', 'default.yaml'] 所含全部 .yaml 文件 (如存在, 其中 default.yaml 不存在) 内的所有条目解析并写入 cfg 字典变量.
参数解析后得到的字典变量 cfg 如下, 其中 cfg.model 部分将被用于网络模型 (类实现) 的自动化配置与构建.

dist_url: tcp://localhost:8888
dist_backend: nccl
multiprocessing_distributed: False
ngpus_per_node: 1
world_size: 1
launcher: mp
local_rank: 0
use_gpu: True
seed: 3392
epoch: 0
epochs: 100
ignore_index: None
val_fn: validate
deterministic: False
sync_bn: False
criterion_args:NAME: CrossEntropylabel_smoothing: 0.2
use_mask: False
grad_norm_clip: 10
layer_decay: 0
step_per_update: 1
start_epoch: 1
sched_on_epoch: True
wandb:use_wandb: Falseproject: PointNeXt-S3DIStags: ['s3dis', 'train', 'pointnext-s', 'ngpus1']name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
use_amp: False
use_voting: False
val_freq: 1
resume: False
test: False
finetune: False
mode: train
logname: None
load_path: None
print_freq: 50
save_freq: -1
root_dir: log/s3dis
pretrained_path: None
datatransforms:train: ['ChromaticAutoContrast', 'PointsToTensor', 'PointCloudScaling', 'PointCloudXYZAlign', 'PointCloudJitter', 'ChromaticDropGPU', 'ChromaticNormalize']val: ['PointsToTensor', 'PointCloudXYZAlign', 'ChromaticNormalize']vote: ['ChromaticDropGPU']kwargs:color_drop: 0.2gravity_dim: 2scale: [0.9, 1.1]angle: [0, 0, 1]jitter_sigma: 0.005jitter_clip: 0.02
feature_keys: x,heights
dataset:common:NAME: S3DISdata_root: data/S3DIS/s3disfulltest_area: 5voxel_size: 0.04train:split: trainvoxel_max: 24000loop: 30presample: Falseval:split: valvoxel_max: Nonepresample: Truetest:split: testvoxel_max: Nonepresample: False
num_classes: 13
batch_size: 32
val_batch_size: 1
dataloader:num_workers: 6
cls_weighed_loss: False
optimizer:NAME: adamwweight_decay: 0.0001
sched: cosine
warmup_epochs: 0
min_lr: 1e-05
lr: 0.01
log_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
model:NAME: BaseSegencoder_args:NAME: PointNextEncoderblocks: [1, 1, 1, 1, 1]strides: [1, 4, 4, 4, 4]sa_layers: 2sa_use_res: Truewidth: 32in_channels: 4expansion: 4radius: 0.1nsample: 32aggr_args:feature_type: dp_fjreduction: maxgroup_args:NAME: ballquerynormalize_dp: Trueconv_args:order: conv-norm-actact_args:act: relunorm_args:norm: bndecoder_args:NAME: PointNextDecodercls_args:NAME: SegHeadnum_classes: 13in_channels: Nonenorm_args:norm: bnin_channels: 4
rank: 0
distributed: False
mp: False
task_name: s3dis
cfg_basename: pointnext-s
opts: mode=train
is_training: True
run_name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
run_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
exp_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
ckpt_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/checkpoint
log_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu.log
cfg_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/cfg.yaml

III. 总结

1. 结果

examples/segmentation/main.pymain() 函数中建立深度网络模型 (类实现) 的部分代码:

    if cfg.model.get('in_channels', None) is None:cfg.model.in_channels = cfg.model.encoder_args.in_channels # 4model = build_model_from_cfg(cfg.model).to(cfg.rank)model_size = cal_model_parm_nums(model)logging.info(model)logging.info('Number of params: %.4f M' % (model_size / 1e6))

通过 build_model_from_cfg(cfg.model) 调用, 进而执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg), 获得网络模型结构:

 BaseSeg((encoder): PointNextEncoder((encoder): Sequential((0): Sequential((0): SetAbstraction((convs): Sequential((0): Sequential((0): Conv1d(4, 32, kernel_size=(1,), stride=(1,))))))(1): Sequential((0): SetAbstraction((skipconv): Sequential((0): Conv1d(32, 64, kernel_size=(1,), stride=(1,)))(act): ReLU(inplace=True)(convs): Sequential((0): Sequential((0): Conv2d(35, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(grouper): QueryAndGroup()))(2): Sequential((0): SetAbstraction((skipconv): Sequential((0): Conv1d(64, 128, kernel_size=(1,), stride=(1,)))(act): ReLU(inplace=True)(convs): Sequential((0): Sequential((0): Conv2d(67, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(grouper): QueryAndGroup()))(3): Sequential((0): SetAbstraction((skipconv): Sequential((0): Conv1d(128, 256, kernel_size=(1,), stride=(1,)))(act): ReLU(inplace=True)(convs): Sequential((0): Sequential((0): Conv2d(131, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(grouper): QueryAndGroup()))(4): Sequential((0): SetAbstraction((skipconv): Sequential((0): Conv1d(256, 512, kernel_size=(1,), stride=(1,)))(act): ReLU(inplace=True)(convs): Sequential((0): Sequential((0): Conv2d(259, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(grouper): QueryAndGroup()))))(decoder): PointNextDecoder((decoder): Sequential((0): Sequential((0): FeaturePropogation((convs): Sequential((0): Sequential((0): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))))(1): Sequential((0): FeaturePropogation((convs): Sequential((0): Sequential((0): Conv1d(192, 64, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))))(2): Sequential((0): FeaturePropogation((convs): Sequential((0): Sequential((0): Conv1d(384, 128, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))))(3): Sequential((0): FeaturePropogation((convs): Sequential((0): Sequential((0): Conv1d(768, 256, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))))))(head): SegHead((head): Sequential((0): Sequential((0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Dropout(p=0.5, inplace=False)(2): Sequential((0): Conv1d(32, 13, kernel_size=(1,), stride=(1,)))))
) 

2. Todo

以上网络结构如何自动化地配置与构造? 待阅读源码学习和理解.

感谢论文和代码作者开源研究成果 !


版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com