文章 12
评论 4
浏览 20910
目标检测框架——mmdetection的使用总结

目标检测框架——mmdetection的使用总结

简介

  mdetection 是商汤和香港中文大学联合开源的一个基于 PyTorch 的目标检测工具包,属于香港中文大学多媒体实验室 open-mmlab 项目的一部分,项目地址为 https://github.com/open-mmlab/mmdetection

  mmdetection 使用模块化设计,将一般的目标检测算法分成了几个不同的模块,使用时只需在配置文件中声明各个模块使用的组件名称和相应的参数,就可以像搭积木一样搭建出一个完整的目标检测模型。mmdetection 有着很好的灵活性和扩展性,如果希望在其中添加新的目标检测算法,只需按照 mmdetection 的接口设计增加关键模块的代码并编写相应的配置文件即可,这样实现了不同算法间相同代码的复用,大大减少了实现一个目标检测算法所需编写的代码量。除此之外,mmdetection 大多数运算都在 GPU 上进行,这使得它有着不俗的性能。目前,mmdetection 已经集成了很多经典的目标检测算法如 Fast RCNN,Faster RCNN,Mask RCNN,retinanet,FCOS 等,未来将会添加更多的算法。

  这篇博客主要从我个人使用的感受出发,简要介绍下我对 mmdetection 的代码设计和使用的总结和思考。接下来,我将分模块对 mmdetection 的代码进行一个宏观上的分析。

mmdetection 的基本结构

  为了便于描述,我们基于一个完整的目标检测的模型的构建和训练过程,简要分析一下 mmdetection 对应部分的代码功能。

模型构建

  在进行训练之前,我们需要定义模型的基本结构和运作方式,mmdetection 中定义了一个 Registry 类,这个类类似于 python 中的字典,用来存放定义好的类名和类的引用的键值对。Registry 类提供了一个 register_module 方法,代码如下

def _register_module(self, module_class):
    """Register a module.

    Args:
        module (:obj:`nn.Module`): Module to be registered.
    """
    if not inspect.isclass(module_class):
        raise TypeError('module must be a class, but got {}'.format(
            type(module_class)))
    module_name = module_class.__name__
    if module_name in self._module_dict:
        raise KeyError('{} is already registered in {}'.format(
            module_name, self.name))
    self._module_dict[module_name] = module_class
def register_module(self, cls):
    self._register_module(cls)
    return cls

观察代码不难发现,这个方法本质上是一个类的包装器,通过调用 register_module 方法可以将这个类名的字符串和类的引用作为键值对存入 Registry 对象中,这样我们只需要在配置文件中指定类名,python 就可以根据类名在集合中寻找到类的引用。实际代码实现中,可以使用 python 的装饰器语法将类添加进指定的集合中,例如

@DETECTORS.register_module
class FasterRCNN(TwoStageDetector)

上面的代码表示将 FasterRCNN 类添加进 DETECTORS 对象中,由于装饰器在模块加载时才会调用,因此我们需要在对应模块的__init__.py 文件中将类导入。

  事实上,mmdetection 中的组件绝大多数都以类的形式定义。官方代码中预先定义了几个 Registry 类的对象用来分别存放这些类,列举如下:

  • BACKBONES,对应目标检测模型的主干网络,如常用的 ResNet,ResNext,HRNet 等。
  • NECKS,对主干网络产生的特征图做一些特定的处理,最常见的就是 FPN 多尺度模型。
  • ROI_EXTRACTORS,这个集合中的类一般负责两阶段模型中 ROI pooling(align)阶段提取特征。
  • HEADS,目标检测模型的头部,这些类事实上包含了目标检测过程的主要算法逻辑,包括 bbox(proposal)的产生、回归的 target 的计算、损失函数的计算等。
  • LOSSES,损失函数类,常见的如 smooth L1 loss,focal loss 等。
  • DETECTORS,这个集合中的类可以看作是前面所介绍的组件搭建而成的一个整体的目标检测类,程序运行过程中也是直接通过加载这个类来执行算法的。
  • DATASETS,由于不同数据集的数据标注类型不同,因此当我们需要训练某个新的数据集时(官方代码提供了 MS COCO 和 PASCAL VOC 两个数据集的工具类),需要为这个数据集定义一个单独的类,在这个类中编写代码,并将该类添加进 DATASETS 集合中。
  • PIPELINES,数据增强类的集合,下文会详细介绍。

下面以 faster rcnn 为例介绍代码中的一些细节。

  官方代码提供了一个 TwoStageDetector 类(包含在 DETECTORS 集合中) ,这个类事实上定义了一个完整的一般化的两阶段目标检测算法,里面提供了一般的两阶段算法需要使用的方法,因此这个类也是 mmdetection 中所有两阶段算法的基类(相应的,单阶段算法的基类为 SingleStageDetector)。如果我们需要新增加一个两阶段的算法,可以使新的算法类继承该类,然后可以通过编写新的方法或重写已有的方法来定义新的算法中不同于一般两阶段算法的行为,这样某些繁琐的代码就无需反复编写。对于 TwoStageDetector 类,我们首先看它的__init__方法,如下

def __init__(self,
             backbone,
             neck=None,
             shared_head=None,
             rpn_head=None,
             bbox_roi_extractor=None,
             bbox_head=None,
             mask_roi_extractor=None,
             mask_head=None,
             train_cfg=None,
             test_cfg=None,
             pretrained=None):
    super(TwoStageDetector, self).__init__()
    self.backbone = builder.build_backbone(backbone)

    if neck is not None:
        self.neck = builder.build_neck(neck)

    if shared_head is not None:
        self.shared_head = builder.build_shared_head(shared_head)

    if rpn_head is not None:
        self.rpn_head = builder.build_head(rpn_head)

    if bbox_head is not None:
        self.bbox_roi_extractor = builder.build_roi_extractor(
            bbox_roi_extractor)
        self.bbox_head = builder.build_head(bbox_head)

    if mask_head is not None:
        if mask_roi_extractor is not None:
            self.mask_roi_extractor = builder.build_roi_extractor(
                mask_roi_extractor)
            self.share_roi_extractor = False
        else:
            self.share_roi_extractor = True
            self.mask_roi_extractor = self.bbox_roi_extractor
        self.mask_head = builder.build_head(mask_head)

    self.train_cfg = train_cfg
    self.test_cfg = test_cfg

    self.init_weights(pretrained=pretrained)

构造方法的行为正如前文所述,我们在配置文件中指定参数为各个模块选择相应的组件,如 rpn_head,bbox_head 等,这些参数以函数参数的形式传入构造方法。从代码可以看出,这些组件都是可选的,这意味着在我们不指定组件名称和参数时,程序默认不添加相应模块(例如不指定 mask_head 参数的值,这就表示模型中不添加 mask 分支)。

  除了构造方法,我们还关注它的 forward_train 方法和 simple_test 方法,其部分代码片段如下

def forward_train(self,
                  img,
                  img_meta,
                  gt_bboxes,
                  gt_labels,
                  gt_bboxes_ignore=None,
                  gt_masks=None,
                  proposals=None):
    x = self.extract_feat(img)

    losses = dict()

    # RPN forward and loss
    if self.with_rpn:
        rpn_outs = self.rpn_head(x)
        rpn_loss_inputs = rpn_outs + (gt_bboxes, img_meta,
                                      self.train_cfg.rpn)
        rpn_losses = self.rpn_head.loss(
            *rpn_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
        losses.update(rpn_losses)

        proposal_cfg = self.train_cfg.get('rpn_proposal',
                                          self.test_cfg.rpn)
        proposal_inputs = rpn_outs + (img_meta, proposal_cfg)
        proposal_list = self.rpn_head.get_bboxes(*proposal_inputs)
    else:
        proposal_list = proposals

    # assign gts and sample proposals
    if self.with_bbox or self.with_mask:
        bbox_assigner = build_assigner(self.train_cfg.rcnn.assigner)
        bbox_sampler = build_sampler(
            self.train_cfg.rcnn.sampler, context=self)
        num_imgs = img.size(0)
        if gt_bboxes_ignore is None:
            gt_bboxes_ignore = [None for _ in range(num_imgs)]
        sampling_results = []
        for i in range(num_imgs):
            assign_result = bbox_assigner.assign(proposal_list[i],
                                                 gt_bboxes[i],
                                                 gt_bboxes_ignore[i],
                                                 gt_labels[i])
            sampling_result = bbox_sampler.sample(
                assign_result,
                proposal_list[i],
                gt_bboxes[i],
                gt_labels[i],
                feats=[lvl_feat[i][None] for lvl_feat in x])
            sampling_results.append(sampling_result)
     # bbox head forward and loss
     if self.with_bbox:
        ...
     # mask head forward and loss
     if self.with_mask:
        ...

forward_train 方法的主要过程与一般的两阶段目标检测算法一致,**即将 FPN 输出的 feature map 输入 RPN 生成 proposal,然后为这些 proposal 重新指定 target(包括位置大小和类别,代码中的 bbox_assigner 和 bbox_sampler 负责),最后随机 sample 一定数量 proposal(配置文件中指定)送入第二阶段的 ROI pooling 产生最终结果(如果是 Mask RCNN 则还有一个 mask 分支,第二阶段的具体代码这里没有给出)。**前面提到过,HEADS 集合中的组件负责产生 bbox 或 proposal 以及计算损失函数,例如 rpn_head 就通过 get_bboxes 产生 proposal,loss 方法计算损失函数,在添加新的 head 组件时需要注意接口与之保持一致。forward_train 方法主要提供训练时的功能,而测试时则会调用 simple_testaug_test 方法,如果测试时使用了一些数据增强的方法(如 multi-scale test)则需要调用 aug_test 方法,否则使用 simple_test,下面我们主要看一下 simple_test 的代码

def simple_test(self, img, img_meta, proposals=None, rescale=False):
    """Test without augmentation."""
    assert self.with_bbox, "Bbox head must be implemented."

    x = self.extract_feat(img)

    if proposals is None:
        proposal_list = self.simple_test_rpn(x, img_meta,
                                                 self.test_cfg.rpn)
    else:
        proposal_list = proposals

    det_bboxes, det_labels = self.simple_test_bboxes(
        x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale)
    bbox_results = bbox2result(det_bboxes, det_labels,
                                   self.bbox_head.num_classes)

    if not self.with_mask:
        return bbox_results
    else:
        segm_results = self.simple_test_mask(
            x, img_meta, det_bboxes, det_labels, rescale=rescale)
        return bbox_results, segm_results

可以看到 simple_test 方法的行为与训练时基本一致,只是没有计算损失函数的步骤。不过这里值得一提的是, TwoStageDetector 这个类继承了 BaseDetector, RPNTestMixin, BBoxTestMixin,MaskTestMixin 四个类,其中 BaseDetector 类定义了一个最基本的目标检测框架,余下的三个类则分别为测试时 rpn 产生、bbox 和 mask 的产生提供了接口,这些类中只定义了相应的方法,没有构造函数和成员变量(python 中的这种类一般称为 Mixin 类)。换句话说,这些类的功能只能通过继承的方式使用,其目的是为了将一些相对独立的接口分离出来,便于代码的维护。

  事实上,TwoStageDetector 类的行为与 Faster RCNN 和 Mask RCNN 基本一致,因此在定义这两个类时只需继承 TwoStageDetector 类即可(但需要注意构造函数的参数的区别)。当然,前文只是介绍了这几个类中部分关键的代码,实际代码量远远不止于此,篇幅所限这里不再详述。

数据加载

  数据的读取是训练和测试时都必不可少的一环,PyTorch 中已经提供了一些数据加载的工具类,这些类中封装了一些加速的算法(如多进程),使得数据的读取过程更加清晰和高效。mmdetection 则在这个基础上进行了进一步的封装,大大增加了框架的灵活性。mdetection 提供了一个 CustomDataset 类,其中定义了一组基本的数据加载接口,如下

@DATASETS.register_module
class CustomDataset(Dataset):
    def __init__(self,
                 ann_file,
                 pipeline,
                 data_root=None,
                 img_prefix='',
                 seg_prefix=None,
                 proposal_file=None,
                 test_mode=False,
                 filter_empty_gt=True):
        ...
        self.pipeline = Compose(pipeline)
  
    def load_annotations(self, ann_file):
        return mmcv.load(ann_file)

    def load_proposals(self, proposal_file):
        return mmcv.load(proposal_file)

    def get_ann_info(self, idx):
        return self.img_infos[idx]['ann']

    def pre_pipeline(self, results):
        results['img_prefix'] = self.img_prefix
        results['seg_prefix'] = self.seg_prefix
        results['proposal_file'] = self.proposal_file
        results['bbox_fields'] = []
        results['mask_fields'] = []
        results['seg_fields'] = []

    def _filter_imgs(self, min_size=32):
        # 略

    def _set_group_flag(self):
        # 略

    def _rand_another(self, idx):
        # 略

    def __getitem__(self, idx):
        # 略

    def prepare_train_img(self, idx):
        img_info = self.img_infos[idx]
        ann_info = self.get_ann_info(idx)
        results = dict(img_info=img_info, ann_info=ann_info)
        if self.proposals is not None:
            results['proposals'] = self.proposals[idx]
        self.pre_pipeline(results)
        return self.pipeline(results)

    def prepare_test_img(self, idx):
        img_info = self.img_infos[idx]
        results = dict(img_info=img_info)
        if self.proposals is not None:
            results['proposals'] = self.proposals[idx]
        self.pre_pipeline(results)
        return self.pipeline(results)

CustomDataset 类提供了最基本的数据加载和数据增强功能,当我们为一个特定的数据集定义数据读取类时只需继承这个基类,然后按照这个接口的规范添加或重写方法即可。在数据读取时,除了读取图片和标注信息以外,常常还会会用到很多数据增强的算法(如 random crop、random flip 等),为了灵活性和可扩展性,mmdetection 以 pipeline 的形式将这些操作组织起来。首先我们看 Compose 类构造函数的代码

@PIPELINES.register_module
class Compose(object):
    def __init__(self, transforms):
        assert isinstance(transforms, collections.abc.Sequence)
        self.transforms = []
        for transform in transforms:
            if isinstance(transform, dict):
                transform = build_from_cfg(transform, PIPELINES)
                self.transforms.append(transform)
            elif callable(transform):
                self.transforms.append(transform)
            else:
                raise TypeError('transform must be callable or a dict')

Compose 类的构造函数的参数 transforms 是配置文件中的一个字典对象的列表,列表中的每一个元素都是一个字典对象,由它指定一种操作和它的参数,这些对数据的操作的代码都以类的形式存在于 mmdet/datasets/pipelines 文件夹中,并且按照操作的类型分别存在于四个文件中,列举如下

  • loading.py,这个文件中类封装了读取图片和标注信息的代码。
  • transforms.py,存放数据增强类的代码(训练时的数据增强)。
  • formating.py,这个文件中的类主要功能为对处理好的数据进行格式转换和筛选。
  • test_aug.py,存放测试时用的数据增强类的代码。

这些操作类全部都重载了括号运算符,可以像调用函数一样使用。Compose 类的构造函数依据传入的参数从 PIPELINES 集合中选择对应的操作并构造对象,然后将这些对象存放在 self.transforms 列表中。Compose 类也重载了括号运算符

def __call__(self, data):
    for t in self.transforms:
        data = t(data)
        if data is None:
            return None
    return data

在这个方法里依次调用了列表中的对象对 data 进行处理。我们再回过头看 CustomDataset 类的 prepare_train_img 方法,这个方法中构建了一个字典对象 results 来存放各种信息,self.pipeline 即为前面介绍过的 Compose 类的对象,将 results 作为参数传入就可以依次对数据执行配置文件中指定的操作。这整个过程就好比让数据通过一根管道,并在这个管道中对数据进行特定的操作,故取名为 pipeline。这样做的好处在于,如果我们希望修改对数据的处理方式或者增加更多的数据增强算法,只要在配置文件中做相应的修改而无需改变代码,大大增加了灵活性。看下面一个例子

train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]

这是 faster rcnn 配置文件中定义地数据读取操作,表示首先从图片文件和标注文件中读取信息,然后将图片尺寸 resize 至 1333*800,然后依次进行 RandomFlip,Normalize 和 Pad 操作,并且以默认的方式组织这些数据,最后只选择 results 字典中'img', 'gt_bboxes', 'gt_labels'这几个信息,忽略其他信息。

训练和测试

  mmdetection 中的训练和测试过程都有分布式和非分布式两种类型,并分别提供了 dist_train.shdist_test.sh 两个脚本来启动分布式训练或测试,具体使用方法可以在官方代码的 readme 中找到。本节内容并不关注训练或测试的细节,只是重点介绍一下 mmdetection 的 hook 机制。

  mmdetection 的很多算法都依赖于 mmcv 库,mmcv 是 open-mmlab 所开源的一个计算机视觉工具库,它主要包含了两部分的内容,一部分是与 deep learning framework 无关的一些工具函数,另一部分则是为 PyTorch 所编写的一套训练工具,可以大大减少用户需要写的代码量,同时让整个流程的定制变得容易。mmcv 提供了一个 Runner 类对整个训练过程进行了一层封装,借助 Runner 类我们可以很方便地定制整个训练过程,首先看 Runner 类构造函数代码

def __init__(self,
             model,
             batch_processor,
             optimizer=None,
             work_dir=None,
             log_level=logging.INFO,
             logger=None):
    assert callable(batch_processor)
    self.model = model
    if optimizer is not None:
        self.optimizer = self.init_optimizer(optimizer)
    else:
        self.optimizer = None
    self.batch_processor = batch_processor

    # create work_dir
    if mmcv.is_str(work_dir):
        self.work_dir = osp.abspath(work_dir)
        mmcv.mkdir_or_exist(self.work_dir)
    elif work_dir is None:
        self.work_dir = None
    else:
        raise TypeError('"work_dir" must be a str or None')

    # get model name from the model class
    if hasattr(self.model, 'module'):
        self._model_name = self.model.module.__class__.__name__
    else:
        self._model_name = self.model.__class__.__name__

    self._rank, self._world_size = get_dist_info()
    self.timestamp = get_time_str()
    if logger is None:
        self.logger = self.init_logger(work_dir, log_level)
    else:
        self.logger = logger
    self.log_buffer = LogBuffer()

    self.mode = None
    self._hooks = []
    self._epoch = 0
    self._iter = 0
    self._inner_iter = 0
    self._max_epochs = 0
    self._max_iters = 0

其中参数 model 表示我们构建的目标检测模型对象,optimizer 表示优化器(定义学习率和学习率衰减方式),work_dir 为本次训练的日志文件和模型文件存储的路径,log_level 为日志等级,logger 表示日志对象,batch_processor 则是一个函数对象的引用,函数中定义了我们希望对每个 batch 数据的训练方式,mmdetection 中的定义如下

def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                '{} is not a tensor or list of tensors'.format(loss_name))

    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

    log_vars['loss'] = loss
    for loss_name, loss_value in log_vars.items():
        # reduce loss when distributed training
        if dist.is_initialized():
            loss_value = loss_value.data.clone()
            dist.all_reduce(loss_value.div_(dist.get_world_size()))
        log_vars[loss_name] = loss_value.item()

    return loss, log_vars

def batch_processor(model, data, train_mode):
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

    outputs = dict(
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

    return outputs

上面代码的行为包括将一个 batch 的数据对象(字典)data 传入 model 中,然后获得本次迭代训练的损失函数,最后将这些损失函数整理成一个字典返回,Runner 类通过调用这个函数来获得每次迭代的损失函数。依照构造函数构建了一个 Runner 对象后,与本次训练有关的所有信息便都被封装在了这个对象中,如果要启动训练过程只需调用该对象的 run 方法,这个方法的具体行为与即将介绍的 hook 机制密不可分。

  mmcv 中定义了一个 Hook 类,将整个训练和验证的过程划分成了几个阶段,看下面的代码

class Hook(object):

    def before_run(self, runner):
        pass

    def after_run(self, runner):
        pass

    def before_epoch(self, runner):
        pass

    def after_epoch(self, runner):
        pass

    def before_iter(self, runner):
        pass

    def after_iter(self, runner):
        pass

    def before_train_epoch(self, runner):
        self.before_epoch(runner)

    def before_val_epoch(self, runner):
        self.before_epoch(runner)

    def after_train_epoch(self, runner):
        self.after_epoch(runner)

    def after_val_epoch(self, runner):
        self.after_epoch(runner)

    def before_train_iter(self, runner):
        self.before_iter(runner)

    def before_val_iter(self, runner):
        self.before_iter(runner)

    def after_train_iter(self, runner):
        self.after_iter(runner)

    def after_val_iter(self, runner):
        self.after_iter(runner)

    def every_n_epochs(self, runner, n):
        return (runner.epoch + 1) % n == 0 if n > 0 else False

    def every_n_inner_iters(self, runner, n):
        return (runner.inner_iter + 1) % n == 0 if n > 0 else False

    def every_n_iters(self, runner, n):
        return (runner.iter + 1) % n == 0 if n > 0 else False

    def end_of_epoch(self, runner):
        return runner.inner_iter + 1 == len(runner.data_loader)

观察 Hook 类中的方法不难发现,Hook 类将训练过程中我们可能采取额外操作(如调整学习率,存储模型和日志文件,打印训练信息等)的时间点分为开始训练前、一个 iteration 前、一个 iteration 后、一个 epoch 前、一个 epoch 后、每 n 个 iteration 后、每 n 个 epoch 后,这些时间点又分为 train 和 validate 过程(在基类中默认两个过程采取相同的操作)。Hook 类的定义类似于一个抽象类,仅仅定义了一组接口而没有具体实现,这意味着我们必须通过继承的方式来使用。如果希望在某几个时间点采取一些特定的操作,需要定义一个新的类并继承 Hook 类,然后重写各个时间点对应的方法,最后调用 Runner 对象的 register_hook 方法在对象中注册这个 hook。Runner 类中维护了一个存放 hook 对象的列表 self._hooks,在每个时间点会通过 call_hook 方法依次调用列表中所有 hook 对象对应的接口以执行相关操作,call_hook 方法定义为

def call_hook(self, fn_name):
    for hook in self._hooks:
        getattr(hook, fn_name)(self)

其中 fn_name 是一个字符串对象,表示希望执行的方法名,这里利用了 python 的内建函数 getattr 来获得 hook 对象中同名方法的引用。为了便于理解这个过程,我们以 mmcv 中的 LrUpdaterHook 类为例简要分析一下 hook 对象的行为。LrUpdaterHook 类主要封装了一些对学习率的修改操作,看下面的代码

class LrUpdaterHook(Hook):

    def __init__(self,
                 by_epoch=True,
                 warmup=None,
                 warmup_iters=0,
                 warmup_ratio=0.1,
                 **kwargs):
        # validate the "warmup" argument
        if warmup is not None:
            if warmup not in ['constant', 'linear', 'exp']:
                raise ValueError(
                    '"{}" is not a supported type for warming up, valid types'
                    ' are "constant" and "linear"'.format(warmup))
        if warmup is not None:
            assert warmup_iters > 0, \
                '"warmup_iters" must be a positive integer'
            assert 0 < warmup_ratio <= 1.0, \
                '"warmup_ratio" must be in range (0,1]'

        self.by_epoch = by_epoch
        self.warmup = warmup
        self.warmup_iters = warmup_iters
        self.warmup_ratio = warmup_ratio

        self.base_lr = []  # initial lr for all param groups
        self.regular_lr = []  # expected lr if no warming up is performed

    def _set_lr(self, runner, lr_groups):
        for param_group, lr in zip(runner.optimizer.param_groups, lr_groups):
            param_group['lr'] = lr

    def get_lr(self, runner, base_lr):
        raise NotImplementedError

    def get_regular_lr(self, runner):
        return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]

    def get_warmup_lr(self, cur_iters):
        if self.warmup == 'constant':
            warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr]
        elif self.warmup == 'linear':
            k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
            warmup_lr = [_lr * (1 - k) for _lr in self.regular_lr]
        elif self.warmup == 'exp':
            k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
            warmup_lr = [_lr * k for _lr in self.regular_lr]
        return warmup_lr

    def before_run(self, runner):
        # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
        # it will be set according to the optimizer params
        for group in runner.optimizer.param_groups:
            group.setdefault('initial_lr', group['lr'])
        self.base_lr = [
            group['initial_lr'] for group in runner.optimizer.param_groups
        ]

    def before_train_epoch(self, runner):
        if not self.by_epoch:
            return
        self.regular_lr = self.get_regular_lr(runner)
        self._set_lr(runner, self.regular_lr)

    def before_train_iter(self, runner):
        cur_iter = runner.iter
        if not self.by_epoch:
            self.regular_lr = self.get_regular_lr(runner)
            if self.warmup is None or cur_iter >= self.warmup_iters:
                self._set_lr(runner, self.regular_lr)
            else:
                warmup_lr = self.get_warmup_lr(cur_iter)
                self._set_lr(runner, warmup_lr)
        elif self.by_epoch:
            if self.warmup is None or cur_iter > self.warmup_iters:
                return
            elif cur_iter == self.warmup_iters:
                self._set_lr(runner, self.regular_lr)
            else:
                warmup_lr = self.get_warmup_lr(cur_iter)
                self._set_lr(runner, warmup_lr)

这个类重写了 before_run、before_train_epoch、before_train_iter 方法,其构造函数的参数 by_epoch 如果为 True 则表明我们以 epoch 为单位计量训练进程,否则以 iteration 为单位。warmup 参数为字符串,指定了 warmup 算法中学习率的变化方式,warmup_iters 和 warmup_ratio 分别指定了 warmup 的 iteration 数和增长比例。从代码中可以看出,在训练开始前,LrUpdaterHook 对象首先会设置 Runner 对象中所维护的优化器的各项参数,然后在每个 iteration 和 epoch 开始前检查学习率和 iteration(epoch)的值,然后计算下一次迭代过程的学习率的值并修改 Runner 中的学习率。需要注意的是,LrUpdaterHook 类并未实现 get_lr 方法,要使用 LrUpdaterHook 类必须通过继承的方式并给出 get_lr 方法的实现。换句话说,LrUpdaterHook 类仅提供了在相应时间修改学习率的代码,至于学习率的衰减方式则应该根据需要自行设置。Hook 机制的好处在于,当我们需要在某些时间点添加一组特定的操作时,只需要编写相应的 hook 类将操作封装并调用 Runner 对象的 register_hook 方法注册即可,这使得整个训练的过程变得更容易定制。

  现在我们回过头来看 Runner 类的 run 方法,看下面的代码

def run(self, data_loaders, workflow, max_epochs, **kwargs):
    assert isinstance(data_loaders, list)
    assert mmcv.is_list_of(workflow, tuple)
    assert len(data_loaders) == len(workflow)

    self._max_epochs = max_epochs
    work_dir = self.work_dir if self.work_dir is not None else 'NONE'
    self.logger.info('Start running, host: %s, work_dir: %s',
                     get_host_info(), work_dir)
    self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
    self.call_hook('before_run')

    while self.epoch < max_epochs:
        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if isinstance(mode, str):  # self.train()
                if not hasattr(self, mode):
                    raise ValueError(
                       'runner has no method named "{}" to run an epoch'.
                        format(mode))
                epoch_runner = getattr(self, mode)
            elif callable(mode):  # custom train()
                epoch_runner = mode
            else:
                raise TypeError('mode in workflow must be a str or '
                                'callable function, not {}'.format(
                                    type(mode)))
            for _ in range(epochs):
                if mode == 'train' and self.epoch >= max_epochs:
                    return
                epoch_runner(data_loaders[i], **kwargs)

    time.sleep(1)  # wait for some hooks like loggers to finish
    self.call_hook('after_run')

其中 data_loaders 表示数据加载的对象,max_epochs 表示训练的 epoch 数,workflow 是一个列表对象,需要我们在配置文件中指定,表示在每一个 epoch 中需要采取的行为,例如

workflow = [('train', 1)]

表示在一个 epoch 中调用 Runner 的 train 方法训练一个 epoch,train 方法的定义如下

def train(self, data_loader, **kwargs):
    self.model.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(data_loader)
    self.call_hook('before_train_epoch')
    for i, data_batch in enumerate(data_loader):
        self._inner_iter = i
        self.call_hook('before_train_iter')
        outputs = self.batch_processor(
            self.model, data_batch, train_mode=True, **kwargs)
        if not isinstance(outputs, dict):
            raise TypeError('batch_processor() must return a dict')
        if 'log_vars' in outputs:
            self.log_buffer.update(outputs['log_vars'],
                                   outputs['num_samples'])
        self.outputs = outputs
        self.call_hook('after_train_iter')
        self._iter += 1

    self.call_hook('after_train_epoch')
    self._epoch += 1

不难看出,train 方法定义的就是训练的过程。run 方法中的 while 循环表示的就是一个完整的训练过程,故而在这个循环的前后分别执行了 self.call_hook('before_run')和 self.call_hook('after_run'),而 train 方法中的 for 循环定义了一个 epoch 训练的过程,并且每次循环都表示一次 iteration,因此在整个循环前后分别执行了 self.call_hook('before_train_epoch')和 self.call_hook('after_train_epoch'),在每次迭代前后又分别执行 self.call_hook('before_train_iter')和 self.call_hook('after_train_iter')。

总结

  作为一个全面的目标检测框架,mmdetection 有数万行的代码量级,它所提供的功能也远远不止前文所讲述的那些,本文也只是集中于它的几个关键模块做了一些简单的介绍,即便如此很多代码中的细节限于篇幅也只能一笔带过。总的来说,mmdetection 主要包含了以下几方面的内容:

  • 目标检测中的一些基础组件,例如常用的网络模型的组件(如 FPN),常用的损失函数(如 focal loss),以及一些其他的必要算法(如 anchor 的计算,使用 c++ 和 cuda 实现的 nms、soft nms 等算法)。
  • 在整个架构的基础上集成了许多经典的目标检测算法,开箱即用。
  • 提供了一套完整的训练和测试流程,包括目标检测中的常用指标 AP、AR 的计算工具。
  • 一些额外的功能(如格式转换、flops 的计算等)。

关于这个框架更多的细节,可以通过官方的 READEME 和具体的代码获得,阅读本文也许能从宏观上对这个框架的整体架构有一个较为清晰的认识,但具体的使用和修改还需要亲自动手实践。

写在最后

这篇博客主要基于我个人使用的经验撰写,很多地方难免管窥蠡测,若有错误和遗漏之处还请指出。最后,由衷地感谢每一位阅读到这里的读者,共勉!


标题:目标检测框架——mmdetection的使用总结
作者:coollwd
地址:http://coollwd.top/articles/2019/12/27/1577427352768.html

Everything that kills me makes me feel alive

取消