Shortcuts

mmocr.models.textdet.heads.base 源代码

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple, Union

import torch
from mmengine.model import BaseModule
from torch import Tensor

from mmocr.registry import MODELS
from mmocr.utils.typing import DetSampleList


[文档]@MODELS.register_module() class BaseTextDetHead(BaseModule): """Base head for text detection, build the loss and postprocessor. 1. The ``init_weights`` method is used to initialize head's model parameters. After detector initialization, ``init_weights`` is triggered when ``detector.init_weights()`` is called externally. 2. The ``loss`` method is used to calculate the loss of head, which includes two steps: (1) the head model performs forward propagation to obtain the feature maps (2) The ``module_loss`` method is called based on the feature maps to calculate the loss. .. code:: text loss(): forward() -> module_loss() 3. The ``predict`` method is used to predict detection results, which includes two steps: (1) the head model performs forward propagation to obtain the feature maps (2) The ``postprocessor`` method is called based on the feature maps to predict detection results including post-processing. .. code:: text predict(): forward() -> postprocessor() 4. The ``loss_and_predict`` method is used to return loss and detection results at the same time. It will call head's ``forward``, ``module_loss`` and ``postprocessor`` methods in order. .. code:: text loss_and_predict(): forward() -> module_loss() -> postprocessor() Args: loss (dict): Config to build loss. postprocessor (dict): Config to build postprocessor. init_cfg (dict or list[dict], optional): Initialization configs. Defaults to None. """ def __init__(self, module_loss: Dict, postprocessor: Dict, init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None: super().__init__(init_cfg=init_cfg) assert isinstance(module_loss, dict) assert isinstance(postprocessor, dict) self.module_loss = MODELS.build(module_loss) self.postprocessor = MODELS.build(postprocessor)
[文档] def loss(self, x: Tuple[Tensor], data_samples: DetSampleList) -> dict: """Perform forward propagation and loss calculation of the detection head on the features of the upstream network. Args: x (tuple[Tensor]): Features from the upstream network, each is a 4D-tensor. data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: dict: A dictionary of loss components. """ outs = self(x, data_samples) losses = self.module_loss(outs, data_samples) return losses
[文档] def loss_and_predict(self, x: Tuple[Tensor], data_samples: DetSampleList ) -> Tuple[dict, DetSampleList]: """Perform forward propagation of the head, then calculate loss and predictions from the features and data samples. Args: x (tuple[Tensor]): Features from FPN. data_samples (list[:obj:`DetDataSample`]): Each item contains the meta information of each image and corresponding annotations. Returns: tuple: the return value is a tuple contains: - losses: (dict[str, Tensor]): A dictionary of loss components. - predictions (list[:obj:`InstanceData`]): Detection results of each image after the post process. """ outs = self(x, data_samples) losses = self.module_loss(outs, data_samples) predictions = self.postprocessor(outs, data_samples) return losses, predictions
[文档] def predict(self, x: torch.Tensor, data_samples: DetSampleList) -> DetSampleList: """Perform forward propagation of the detection head and predict detection results on the features of the upstream network. Args: x (tuple[Tensor]): Multi-level features from the upstream network, each is a 4D-tensor. data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. Returns: SampleList: Detection results of each image after the post process. """ outs = self(x, data_samples) predictions = self.postprocessor(outs, data_samples) return predictions