Title: PointNeXt 源码阅读 (I) —— 注册机制与参数解析
文章目录
- 前言
- I. 注册机制
- 1. 注册类 Registry
- 2. 类的注册
- 3. 注册应用
- II. 参数解析
- 1. 命令行解析
- 2. 参数加载更新
- 3. 获得的参数
- III. 总结
- 1. 结果
- 2. Todo
前言
学习了部分 PointNeXt 源码, 先记录一下, 以备忘.
本篇博文分为两部分, 注册机制和参数解析, 理解的重点是注册机制.
相关注释和调试信息都是基于下面测试 session.
CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py \--cfg cfgs/s3dis/pointnext-s.yaml mode=train
I. 注册机制
所谓注册机制是指 PointNeXt 中模块/类的注册机制, 可以实现字符串到模块/类的映射. 换而言之, 这种注册机制就可以实现读入配置文件中的参数字符串, 进而直接映射获得对应的模块/类的实例. 这部分的实现 PointNeXt 源作者参考了 mmcv 中的注册机制.
1. 注册类 Registry
注册机制本身是通过注册类 class Registry
实现的, 其中关键方法有:
方法 | 解释 |
---|---|
__init__() | 类初始化, 其中也初始化了注册模块字典 self._module_dict = dict() |
get(self, key) | 实现从字符串到类的映射, 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key] |
register_module(self, name=None, force=False, module=None) | 注册模块, 实现对模块/类的注册, 也用作为对模块/类进行装饰的装饰器 |
_register(cls) | 装饰器 register_module 内部的包装函数 wrapper. 适用于装饰情况下的调用, 参数 cls 就是传递进来的需要被装饰的类. 这个包装函数在 cls 类定义的基础上, 先调用_register_module(self, module_class, module_name=None, force=False) 实现了类的注册 self._module_dict[name] = module_class , 然后没有任何其他处理而直接 return 了类定义 cls |
build_from_cfg(cfg, registry, default_args=None) | 从配置字典构建模块/类实例, 实现由字符串生成模块/类实例 |
openpoints/utils/registry.py
中定义了 Registry 类, 添加注释如下.
# Acknowledgement: built upon mmcv
import inspect
import warnings
from functools import partial
import copy class Registry:"""A registry to map strings to classes.Registered object could be built from registry.Example:>>> MODELS = Registry('models')>>> @MODELS.register_module()>>> class ResNet:>>> pass>>> resnet = MODELS.build(dict(NAME='ResNet'))Please refer to https://mmcv.readthedocs.io/en/latest/registry.html foradvanced useage.Args:name (str): Registry name.build_func(func, optional): Build function to construct instance fromRegistry, func:`build_from_cfg` is used if neither ``parent`` or``build_func`` is specified. If ``parent`` is specified and``build_func`` is not given, ``build_func`` will be inheritedfrom ``parent``. Default: None.parent (Registry, optional): Parent registry. The class registered inchildren registry could be built from parent. Default: None.scope (str, optional): The scope of registry. It is the key to searchfor children registry. If not specified, scope will be the name ofthe package where class is defined, e.g. mmdet, mmcls, mmseg.Default: None."""def __init__(self, name, build_func=None, parent=None, scope=None):self._name = nameself._module_dict = dict()self._children = dict()self._scope = self.infer_scope() if scope is None else scope# self._scope = 'openpoints'# self.build_func will be set with the following priority:# 1. build_func# 2. parent.build_func# 3. build_from_cfgif build_func is None:if parent is not None:self.build_func = parent.build_funcelse:self.build_func = build_from_cfgelse:self.build_func = build_funcif parent is not None:assert isinstance(parent, Registry)parent._add_children(self)self.parent = parentelse:self.parent = Nonedef __len__(self):return len(self._module_dict)def __contains__(self, key):return self.get(key) is not Nonedef __repr__(self):format_str = self.__class__.__name__ + \f'(name={self._name}, ' \f'items={self._module_dict})'return format_str@staticmethoddef infer_scope():"""Infer the scope of registry.The name of the package where registry is defined will be returned.Example:# in mmdet/models/backbone/resnet.py>>> MODELS = Registry('models')>>> @MODELS.register_module()>>> class ResNet:>>> passThe scope of ``ResNet`` will be ``mmdet``.Returns:scope (str): The inferred scope name."""# inspect.stack() trace where this function is called, the index-2# indicates the frame where `infer_scope()` is calledfilename = inspect.getmodule(inspect.stack()[2][0]).__name__# filename = 'openpoints.models.build'split_filename = filename.split('.') # ['openpoints', 'models', 'build']return split_filename[0] # 'openpoints'@staticmethod # 返回函数的静态方法、声明一个静态方法def split_scope_key(key):"""Split scope and key.The first scope will be split from key.Examples:>>> Registry.split_scope_key('mmdet.ResNet')'mmdet', 'ResNet'>>> Registry.split_scope_key('ResNet')None, 'ResNet'Return:scope (str, None): The first scope.key (str): The remaining key."""split_index = key.find('.') # 如果没有检测到 key 中包含字符, 则返回 -1; 如果检测到了该字符, 则返回开始时的索引值if split_index != -1:return key[:split_index], key[split_index + 1:]else:return None, key@propertydef name(self):return self._name@propertydef scope(self):return self._scope@propertydef module_dict(self):return self._module_dict@propertydef children(self):return self._childrendef get(self, key):# 实现从字符串到类的映射# 以字符串 key 映射到 self._module_dict 中注册的类 self._module_dict[real_key]"""Get the registry record.Args:key (str): The class name in string format.Returns:class: The corresponding class."""scope, real_key = self.split_scope_key(key) # key = BaseSeg; scope = None; real_key = BaseSegif scope is None or scope == self._scope:# get from selfif real_key in self._module_dict:return self._module_dict[real_key]else:# get from self._childrenif scope in self._children:return self._children[scope].get(real_key)else:# goto rootparent = self.parentwhile parent.parent is not None:parent = parent.parentreturn parent.get(key)def build(self, *args, **kwargs):return self.build_func(*args, **kwargs, registry=self)def _add_children(self, registry):"""Add children for a registry.The ``registry`` will be added as children based on its scope.The parent registry could build objects from children registry.Example:>>> models = Registry('models')>>> mmdet_models = Registry('models', parent=models)>>> @mmdet_models.register_module()>>> class ResNet:>>> pass>>> resnet = models.build(dict(NAME='mmdet.ResNet'))"""assert isinstance(registry, Registry)assert registry.scope is not Noneassert registry.scope not in self.children, \f'scope {registry.scope} exists in {self.name} registry'self.children[registry.scope] = registrydef _register_module(self, module_class, module_name=None, force=False):if not inspect.isclass(module_class):raise TypeError('module must be a class, 'f'but got {type(module_class)}')if module_name is None:module_name = module_class.__name__if isinstance(module_name, str):module_name = [module_name]for name in module_name:if not force and name in self._module_dict:raise KeyError(f'{name} is already registered 'f'in {self.name}')self._module_dict[name] = module_classdef deprecated_register_module(self, cls=None, force=False):warnings.warn('The old API of register_module(module, force=False) ''is deprecated and will be removed, please use the new API ''register_module(name=None, force=False, module=None) instead.')if cls is None:return partial(self.deprecated_register_module, force=force)self._register_module(cls, force=force)return clsdef register_module(self, name=None, force=False, module=None): # 装饰器"""Register a module.A record will be added to `self._module_dict`, whose key is the classname or the specified name, and value is the class itself.It can be used as a decorator or a normal function.Example:>>> backbones = Registry('backbone')>>> @backbones.register_module()>>> class ResNet:>>> pass>>> backbones = Registry('backbone')>>> @backbones.register_module(name='mnet')>>> class MobileNet:>>> pass>>> backbones = Registry('backbone')>>> class ResNet:>>> pass>>> backbones.register_module(ResNet)Args:name (str | None): The module name to be registered. If notspecified, the class name will be used.force (bool, optional): Whether to override an existing class withthe same name. Default: False.module (type): Module class to be registered."""if not isinstance(force, bool):raise TypeError(f'force must be a boolean, but got {type(force)}')# NOTE: This is a walkaround to be compatible with the old api,# while it may introduce unexpected bugs.if isinstance(name, type):return self.deprecated_register_module(name, force=force)# raise the error ahead of timeif not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):raise TypeError('name must be either of None, an instance of str or a sequence'f' of str, but got {type(name)}')# use it as a normal method: x.register_module(module=SomeClass)# 正常调用 reister_module, 不是装饰情况if module is not None:self._register_module(module_class=module, module_name=name, force=force)return module# use it as a decorator: @x.register_module()# 这是装饰器的包装函数 wrapper# 装饰情况下的调用, cls 就是传递进来的需要被装饰的类 def _register(cls):self._register_module(module_class=cls, module_name=name, force=force)return clsreturn _register # 装饰器返回这个包装函数def build_from_cfg(cfg, registry, default_args=None):"""Build a module from config dict.Args:cfg (edict): Config dict. It should at least contain the key "NAME".registry (:obj:`Registry`): The registry to search the type from.Returns:object: The constructed object."""if not isinstance(cfg, dict):raise TypeError(f'cfg must be a dict, but got {type(cfg)}')if 'NAME' not in cfg:if default_args is None or 'NAME' not in default_args:raise KeyError('`cfg` or `default_args` must contain the key "NAME", 'f'but got {cfg}\n{default_args}')if not isinstance(registry, Registry):raise TypeError('registry must be an mmcv.Registry object, 'f'but got {type(registry)}')if not (isinstance(default_args, dict) or default_args is None):raise TypeError('default_args must be a dict or None, 'f'but got {type(default_args)}')# if default_args is not None:# cfg = config.merge_new_config(cfg, default_args)obj_type = cfg.get('NAME') # 'BaseSeg'if isinstance(obj_type, str):obj_cls = registry.get(obj_type) # <class 'openpoints.models.segmentation.base_seg.BaseSeg'># 按照名字字符串 从 self._module_dict 找出对应的 类/模块# 实现从字符串到类的映射if obj_cls is None:raise KeyError(f'{obj_type} is not in the {registry.name} registry')elif inspect.isclass(obj_type):obj_cls = obj_typeelse:raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')try:obj_cfg = copy.deepcopy(cfg)if default_args is not None:obj_cfg.update(default_args) obj_cfg.pop('NAME') # 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项# 'NAME' 已完成对类 obj_cls 的映射return obj_cls(**obj_cfg)# 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)# 又由于 BaseSeg 加了装饰器 @MODELS.register_module() # 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg))# 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了# 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)except Exception as e:# Normal TypeError does not print class name.raise type(e)(f'{obj_cls.__name__}: {e}')
2. 类的注册
首先在 openpoints/models/build.py
中声明和定义了全局的注册类对象 MODELS, 称为注册器.
通过 Python 的导入机制 import 命令, 注册器 MODELS 会在程序初始运行时 (先于 __main__
/main()
) 就建立.
from openpoints.utils import registry
MODELS = registry.Registry('models')
# 创建 register.Registry 对象 (MODELS 也称为注册器), 作为全局变量
# 程序初始运行, 先于 __main__/main() 的执行, 所以程序一开始就建立了注册器 MODELSdef build_model_from_cfg(cfg, **kwargs):"""Build a model, defined by `NAME`.Args:cfg (eDICT): Returns:Model: a constructed model specified by NAME."""return MODELS.build(cfg, **kwargs)
也是因为 Python 导入机制, 在注册器 MODELS 创立后, openpoints/models
下面在类定义前装饰了 @MODELS.register_module()
的类, 一旦被 import 扫描执行到, 都将被注册到 MODELS 注册器中.
例如下面的 BaseSeg
类也会先注册到 MODELS._module_dict
中.
"""
Author: PointNeXt
"""
import copy
from typing import List
import torch
import torch.nn as nn
import logging
from ...utils import get_missing_parameters_message, get_unexpected_parameters_message
from ..build import MODELS, build_model_from_cfg
from ..layers import create_linearblock, create_convblock1d# 为类 BaseSeg 加了装饰器 MODELS.register_module
# 调用 BaseSeg() 创建对象时, 效果相当于调用 MODELS.register_module(module=BaseSeg())
# 程序初始运行对类的装饰, 先于 __main__/main(), 但晚于注册器 MODELS 的建立.
# 所以在程序初始部分, 就以完成类的注册了, 待调用 main() 时, 就能顺利利用注册器将字符串转换为类
@MODELS.register_module()
class BaseSeg(nn.Module):def __init__(self,encoder_args=None,decoder_args=None,cls_args=None,**kwargs):super().__init__()
调试过程中, 跟踪查看 MODELS._module_dict
可以发现已经注册了好多类.
MODELS._module_dict = {
'PointNetEncoder': <class 'openpoints.models.backbone.pointnet.PointNetEncoder'>,
'PointPatchEmbed': <class 'openpoints.models.layers.group_embed.PointPatchEmbed'>,
'P3Embed': <class 'openpoints.models.layers.group_embed.P3Embed'>,
'PointNet2Encoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Encoder'>,
'PointNet2Decoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2Decoder'>,
'PointNet2PartDecoder': <class 'openpoints.models.backbone.pointnetv2.PointNet2PartDecoder'>,
'PointNextEncoder': <class 'openpoints.models.backbone.pointnext.PointNextEncoder'>,
'PointNextDecoder': <class 'openpoints.models.backbone.pointnext.PointNextDecoder'>,
'PointNextPartDecoder': <class 'openpoints.models.backbone.pointnext.PointNextPartDecoder'>,
'DGCNN': <class 'openpoints.models.backbone.dgcnn.DGCNN'>,
'DeepGCN': <class 'openpoints.models.backbone.deepgcn.DeepGCN'>,
'PointMLPEncoder': <class 'openpoints.models.backbone.pointmlp.PointMLPEncoder'>,
'PointMLP': <class 'openpoints.models.backbone.pointmlp.PointMLP'>,
'PointViT': <class 'openpoints.models.backbone.pointvit.PointViT'>,
'PointViTDecoder': <class 'openpoints.models.backbone.pointvit.PointViTDecoder'>,
'PointViTPartDecoder': <class 'openpoints.models.backbone.pointvit.PointViTPartDecoder'>,
'InvPointViT': <class 'openpoints.models.backbone.pointvit_inv.InvPointViT'>,
'InvPointViTDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTDecoder'>,
'InvPointViTPartDecoder': <class 'openpoints.models.backbone.pointvit_inv.InvPointViTPartDecoder'>,
'CurveNet': <class 'openpoints.models.backbone.curvenet.CurveNet'>,
'MVFC': <class 'openpoints.models.backbone.simpleview.MVFC'>,
'MVModel': <class 'openpoints.models.backbone.simpleview.MVModel'>,
'BaseSeg': <class 'openpoints.models.segmentation.base_seg.BaseSeg'>,
'BasePartSeg': <class 'openpoints.models.segmentation.base_seg.BasePartSeg'>,
'VariableSeg': <class 'openpoints.models.segmentation.base_seg.VariableSeg'>,
'SegHead': <class 'openpoints.models.segmentation.base_seg.SegHead'>,
'VariableSegHead': <class 'openpoints.models.segmentation.base_seg.VariableSegHead'>,
'MultiSegHead': <class 'openpoints.models.segmentation.base_seg.MultiSegHead'>,
'BaseCls': <class 'openpoints.models.classification.cls_base.BaseCls'>,
'DistillCls': <class 'openpoints.models.classification.cls_base.DistillCls'>,
'ClsHead': <class 'openpoints.models.classification.cls_base.ClsHead'>,
'MaskedTransformerDecoder': <class 'openpoints.models.reconstruction.base_recontruct.MaskedTransformerDecoder'>,
'FoldingNet': <class 'openpoints.models.reconstruction.base_recontruct.FoldingNet'>,
'NodeShuffle': <class 'openpoints.models.reconstruction.base_recontruct.NodeShuffle'>,
'MaskedPointViT': <class 'openpoints.models.reconstruction.maskedpointvit.MaskedPointViT'>,
'MaskedPoint': <class 'openpoints.models.reconstruction.maskedpoint.MaskedPoint'>,
'MaskedPointGroup': <class 'openpoints.models.reconstruction.maskedpointgroup.MaskedPointGroup'>
}
3. 注册应用
有了注册器 MODELS, 并向其注册了各个类, 那么就可以应用其由字符串映射为类的功能, 方便地从 .yaml
文件配置实现类实例的创建.
初略时序如下图所示:
其中由 .yaml
文件读取获得的配置字典变量 cfg
中存在 NAME 条目, 通过 registry.get(*)
就能获得 NAME 字符串对应的已经注册了的类. 获得了对应的类后, NAME 条目完成使命, 剩下的其他配置条目将被用于 PointNeXT 中具体的深度神将网络模块/类的自动化配置构造 (这篇博文不涉及).
细节注释参看类 Registry 的方法 build_from_cfg, 重复如下:
def build_from_cfg(cfg, registry, default_args=None):"""Build a module from config dict.Args:cfg (edict): Config dict. It should at least contain the key "NAME".registry (:obj:`Registry`): The registry to search the type from.Returns:object: The constructed object."""if not isinstance(cfg, dict):raise TypeError(f'cfg must be a dict, but got {type(cfg)}')if 'NAME' not in cfg:if default_args is None or 'NAME' not in default_args:raise KeyError('`cfg` or `default_args` must contain the key "NAME", 'f'but got {cfg}\n{default_args}')if not isinstance(registry, Registry):raise TypeError('registry must be an mmcv.Registry object, 'f'but got {type(registry)}')if not (isinstance(default_args, dict) or default_args is None):raise TypeError('default_args must be a dict or None, 'f'but got {type(default_args)}')# if default_args is not None:# cfg = config.merge_new_config(cfg, default_args)obj_type = cfg.get('NAME') # 'BaseSeg'if isinstance(obj_type, str):obj_cls = registry.get(obj_type) # <class 'openpoints.models.segmentation.base_seg.BaseSeg'># 按照名字字符串 从 self._module_dict 找出对应的 类/模块# 实现从字符串到类的映射if obj_cls is None:raise KeyError(f'{obj_type} is not in the {registry.name} registry')elif inspect.isclass(obj_type):obj_cls = obj_typeelse:raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')try:obj_cfg = copy.deepcopy(cfg)if default_args is not None:obj_cfg.update(default_args) obj_cfg.pop('NAME') # 删除 "NAME" 项, obj_cfg 中留下除了 "NAME" 项的其他项# 'NAME' 已完成对类 obj_cls 的映射return obj_cls(**obj_cfg)# 把变量都展开, 为 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)# 又由于 BaseSeg 加了装饰器 @MODELS.register_module() # 相当于调用 MODELS.register_module(module=BaseSeg(**obj_cfg))# 其实已经在程序开头注册过了, 所以注册部分在此就没什么作用了# 现在开始执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
II. 参数解析
注册机制需要字符串参数的传入以构建类实例. 而参数的获得需要借助于解析过程将 .yaml
文件中的配置读入程序中.
1. 命令行解析
主程序部分先要将相关的 .yaml
文件读入并更新到 cfg 字典变量中, 注释如下.
if __name__ == "__main__":parser = argparse.ArgumentParser('Scene segmentation training/testing')# 创建解析器parser.add_argument('--cfg', type=str, required=True, help='config file')parser.add_argument('--profile', action='store_true', default=False, help='set to True to profile speed')# 添加参数args, opts = parser.parse_known_args()# CUDA_VISIBLE_DEVICES=0,1 python examples/segmentation/main.py --cfg cfgs/s3dis/pointnext-s.yaml mode=train# 其中 CUDA_VISIBLE_DEVICES=0,1 为环境变量, 不由 parser 解析# args = Namespace(cfg='cfgs/s3dis/pointnext-s.yaml', profile=False)# opts = ['mode=train']cfg = EasyConfig()cfg.load(args.cfg, recursive=True) # args.cfg = cfs/s3dis/pointnext-s.yamlcfg.update(opts) # overwrite the default arguments in yml# mode = train 更新入 cfg 字典if cfg.seed is None:cfg.seed = np.random.randint(1, 10000)# init distributed env first, since logger depends on the dist info.cfg.rank, cfg.world_size, cfg.distributed, cfg.mp = dist_utils.get_dist_info(cfg)cfg.sync_bn = cfg.world_size > 1 # debug 时, 只能单块 GPU; 正常运行时, 可以多块并行# init log dircfg.task_name = args.cfg.split('.')[-2].split('/')[-2] # task/dataset name, \eg s3dis, modelnet40_cls# args.cfg = 'cfgs/s3dis/pointnext-s.yaml'# args.cfg.split('.')[-2] = 'cfgs/s3dis/pointnext-s'# args.cfg.split('.')[-2].split('/')[-2] = 's3dis'cfg.cfg_basename = args.cfg.split('.')[-2].split('/')[-1] # cfg_basename, \eg pointnext-xl\# args.cfg.split('.')[-2].split('/')[-1] = 'pointnext-s'tags = [cfg.task_name, # task name (the folder of name under ./cfgscfg.mode,cfg.cfg_basename, # cfg file namef'ngpus{cfg.world_size}',]# tags = ['s3dis', 'train', 'pointnext-s', 'ngpus1']opt_list = [] # for checking experiment configs from logging filefor i, opt in enumerate(opts):if 'rank' not in opt and 'dir' not in opt and 'root' not in opt and 'pretrain' not in opt and 'path' not in opt and 'wandb' not in opt and '/' not in opt:opt_list.append(opt)cfg.root_dir = os.path.join(cfg.root_dir, cfg.task_name)cfg.opts = '-'.join(opt_list) # 使用'-'作分隔符来进行joincfg.is_training = cfg.mode not in ['test', 'testing', 'val', 'eval', 'evaluation']if cfg.mode in ['resume', 'val', 'test']:resume_exp_directory(cfg, pretrained_path=cfg.pretrained_path) # 需要命令行 加 pretrained_path=XXXcfg.wandb.tags = [cfg.mode]else:generate_exp_directory(cfg, tags, additional_id=os.environ.get('MASTER_PORT', None))cfg.wandb.tags = tagsos.environ["JOB_LOG_DIR"] = cfg.log_dircfg_path = os.path.join(cfg.run_dir, "cfg.yaml")# cfg_path = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC/cfg.yaml'with open(cfg_path, 'w') as f:yaml.dump(cfg, f, indent=2) # cfg 写入 f 文件os.system('cp %s %s' % (args.cfg, cfg.run_dir))# args.cfg = 'cfgs/s3dis/pointnext-s.yaml'# cfg.run_dir = 'log/s3dis/s3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC'cfg.cfg_path = cfg_path# wandb configcfg.wandb.name = cfg.run_name# cfg.run_name = 's3dis-train-pointnext-s-ngpus1-20240730-092203-hQtDgCBNbQaYpAMwLVn9TC'# multi processing.if cfg.mp:port = find_free_port()cfg.dist_url = f"tcp://localhost:{port}"print('using mp spawn for distributed training')mp.spawn(main, nprocs=cfg.world_size, args=(cfg,))else:main(0, cfg)
2. 参数加载更新
配置条目的读入和更新在类 EasyConfig 中实现, 部分注释如下.
class EasyConfig(dict):def __getattr__(self, key: str) -> Any:if key not in self:raise AttributeError(key)return self[key]def __setattr__(self, key: str, value: Any) -> None:self[key] = valuedef __delattr__(self, key: str) -> None:del self[key]def load(self, fpath: str, *, recursive: bool = False) -> None:"""load cfg from yamlArgs:fpath (str): path to the yaml filerecursive (bool, optional): recursily load its parent defaul yaml files. Defaults to False."""if not os.path.exists(fpath):raise FileNotFoundError(fpath)fpaths = [fpath]# 'cfgs/s3dis/pointnext-s.yaml'if recursive: # Trueextension = os.path.splitext(fpath)[1] # .yamlwhile os.path.dirname(fpath) != fpath: # 如果 fpath 是文件路径fpath = os.path.dirname(fpath) # 去掉文件名, 返回目录, 每次脱去一级fpaths.append(os.path.join(fpath, 'default' + extension)) # fpaths =['cfgs/s3dis/pointnext-s.yaml', # 'cfgs/s3dis/default.yaml', # 'cfgs/default.yaml', # 'default.yaml']for fpath in reversed(fpaths): # 反转迭代器if os.path.exists(fpath):with open(fpath) as f:self.update(yaml.safe_load(f)) # 把 fpaths 中的所有 .yaml 文件中的配置条目写在一个 dict 变量中def reload(self, fpath: str, *, recursive: bool = False) -> None:self.clear()self.load(fpath, recursive=recursive)# mutimethod makes python supports function overloading@multimethoddef update(self, other: Dict) -> None: # .yaml items 转为 dict 变量中的 key:value 对for key, value in other.items():if isinstance(value, dict):if key not in self or not isinstance(self[key], EasyConfig): # 子条目self[key] = EasyConfig()# recursively updateself[key].update(value)else:self[key] = value@multimethoddef update(self, opts: Union[List, Tuple]) -> None:index = 0while index < len(opts):opt = opts[index]if opt.startswith('--'):opt = opt[2:]if '=' in opt:key, value = opt.split('=', 1)index += 1else:key, value = opt, opts[index + 1]index += 2current = selfsubkeys = key.split('.')try:value = literal_eval(value)except:passfor subkey in subkeys[:-1]:current = current.setdefault(subkey, EasyConfig())current[subkeys[-1]] = valuedef dict(self) -> Dict[str, Any]:configs = dict()for key, value in self.items():if isinstance(value, EasyConfig):value = value.dict()configs[key] = valuereturn configsdef hash(self) -> str:buffer = json.dumps(self.dict(), sort_keys=True)return hashlib.sha256(buffer.encode()).hexdigest()def __str__(self) -> str:texts = []for key, value in self.items():if isinstance(value, EasyConfig):seperator = '\n'else:seperator = ' 'text = key + ':' + seperator + str(value)lines = text.split('\n')for k, line in enumerate(lines[1:]):lines[k + 1] = (' ' * 2) + linetexts.extend(lines)return '\n'.join(texts)
3. 获得的参数
将 fpaths = ['cfgs/s3dis/pointnext-s.yaml', 'cfgs/s3dis/default.yaml', 'cfgs/default.yaml', 'default.yaml']
所含全部 .yaml 文件 (如存在, 其中 default.yaml 不存在) 内的所有条目解析并写入 cfg 字典变量.
参数解析后得到的字典变量 cfg 如下, 其中 cfg.model 部分将被用于网络模型 (类实现) 的自动化配置与构建.
dist_url: tcp://localhost:8888
dist_backend: nccl
multiprocessing_distributed: False
ngpus_per_node: 1
world_size: 1
launcher: mp
local_rank: 0
use_gpu: True
seed: 3392
epoch: 0
epochs: 100
ignore_index: None
val_fn: validate
deterministic: False
sync_bn: False
criterion_args:NAME: CrossEntropylabel_smoothing: 0.2
use_mask: False
grad_norm_clip: 10
layer_decay: 0
step_per_update: 1
start_epoch: 1
sched_on_epoch: True
wandb:use_wandb: Falseproject: PointNeXt-S3DIStags: ['s3dis', 'train', 'pointnext-s', 'ngpus1']name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
use_amp: False
use_voting: False
val_freq: 1
resume: False
test: False
finetune: False
mode: train
logname: None
load_path: None
print_freq: 50
save_freq: -1
root_dir: log/s3dis
pretrained_path: None
datatransforms:train: ['ChromaticAutoContrast', 'PointsToTensor', 'PointCloudScaling', 'PointCloudXYZAlign', 'PointCloudJitter', 'ChromaticDropGPU', 'ChromaticNormalize']val: ['PointsToTensor', 'PointCloudXYZAlign', 'ChromaticNormalize']vote: ['ChromaticDropGPU']kwargs:color_drop: 0.2gravity_dim: 2scale: [0.9, 1.1]angle: [0, 0, 1]jitter_sigma: 0.005jitter_clip: 0.02
feature_keys: x,heights
dataset:common:NAME: S3DISdata_root: data/S3DIS/s3disfulltest_area: 5voxel_size: 0.04train:split: trainvoxel_max: 24000loop: 30presample: Falseval:split: valvoxel_max: Nonepresample: Truetest:split: testvoxel_max: Nonepresample: False
num_classes: 13
batch_size: 32
val_batch_size: 1
dataloader:num_workers: 6
cls_weighed_loss: False
optimizer:NAME: adamwweight_decay: 0.0001
sched: cosine
warmup_epochs: 0
min_lr: 1e-05
lr: 0.01
log_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
model:NAME: BaseSegencoder_args:NAME: PointNextEncoderblocks: [1, 1, 1, 1, 1]strides: [1, 4, 4, 4, 4]sa_layers: 2sa_use_res: Truewidth: 32in_channels: 4expansion: 4radius: 0.1nsample: 32aggr_args:feature_type: dp_fjreduction: maxgroup_args:NAME: ballquerynormalize_dp: Trueconv_args:order: conv-norm-actact_args:act: relunorm_args:norm: bndecoder_args:NAME: PointNextDecodercls_args:NAME: SegHeadnum_classes: 13in_channels: Nonenorm_args:norm: bnin_channels: 4
rank: 0
distributed: False
mp: False
task_name: s3dis
cfg_basename: pointnext-s
opts: mode=train
is_training: True
run_name: s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
run_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
exp_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu
ckpt_dir: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/checkpoint
log_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu.log
cfg_path: log/s3dis/s3dis-train-pointnext-s-ngpus1-20240802-170230-MzNANkH27yLSBYpmpmF5tu/cfg.yaml
III. 总结
1. 结果
在 examples/segmentation/main.py
的 main()
函数中建立深度网络模型 (类实现) 的部分代码:
if cfg.model.get('in_channels', None) is None:cfg.model.in_channels = cfg.model.encoder_args.in_channels # 4model = build_model_from_cfg(cfg.model).to(cfg.rank)model_size = cal_model_parm_nums(model)logging.info(model)logging.info('Number of params: %.4f M' % (model_size / 1e6))
通过 build_model_from_cfg(cfg.model)
调用, 进而执行 openpoints.models.segmentation.base_seg.BaseSeg(**obj_cfg)
, 获得网络模型结构:
BaseSeg((encoder): PointNextEncoder((encoder): Sequential((0): Sequential((0): SetAbstraction((convs): Sequential((0): Sequential((0): Conv1d(4, 32, kernel_size=(1,), stride=(1,))))))(1): Sequential((0): SetAbstraction((skipconv): Sequential((0): Conv1d(32, 64, kernel_size=(1,), stride=(1,)))(act): ReLU(inplace=True)(convs): Sequential((0): Sequential((0): Conv2d(35, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(grouper): QueryAndGroup()))(2): Sequential((0): SetAbstraction((skipconv): Sequential((0): Conv1d(64, 128, kernel_size=(1,), stride=(1,)))(act): ReLU(inplace=True)(convs): Sequential((0): Sequential((0): Conv2d(67, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(grouper): QueryAndGroup()))(3): Sequential((0): SetAbstraction((skipconv): Sequential((0): Conv1d(128, 256, kernel_size=(1,), stride=(1,)))(act): ReLU(inplace=True)(convs): Sequential((0): Sequential((0): Conv2d(131, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(grouper): QueryAndGroup()))(4): Sequential((0): SetAbstraction((skipconv): Sequential((0): Conv1d(256, 512, kernel_size=(1,), stride=(1,)))(act): ReLU(inplace=True)(convs): Sequential((0): Sequential((0): Conv2d(259, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(grouper): QueryAndGroup()))))(decoder): PointNextDecoder((decoder): Sequential((0): Sequential((0): FeaturePropogation((convs): Sequential((0): Sequential((0): Conv1d(96, 32, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))))(1): Sequential((0): FeaturePropogation((convs): Sequential((0): Sequential((0): Conv1d(192, 64, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))))(2): Sequential((0): FeaturePropogation((convs): Sequential((0): Sequential((0): Conv1d(384, 128, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))))(3): Sequential((0): FeaturePropogation((convs): Sequential((0): Sequential((0): Conv1d(768, 256, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Sequential((0): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))))))(head): SegHead((head): Sequential((0): Sequential((0): Conv1d(32, 32, kernel_size=(1,), stride=(1,), bias=False)(1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(1): Dropout(p=0.5, inplace=False)(2): Sequential((0): Conv1d(32, 13, kernel_size=(1,), stride=(1,)))))
)
2. Todo
以上网络结构如何自动化地配置与构造? 待阅读源码学习和理解.
感谢论文和代码作者开源研究成果 !