Shortcuts

mmocr.models.textdet.dense_heads.pan_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner import BaseModule

from mmocr.models.builder import HEADS, build_loss
from mmocr.utils import check_argument
from . import HeadMixin


[文档]@HEADS.register_module() class PANHead(HeadMixin, BaseModule): """The class for PANet head. Args: in_channels (list[int]): A list of 4 numbers of input channels. out_channels (int): Number of output channels. text_repr_type (str): Use polygon or quad to represent. Available options are "poly" or "quad". downsample_ratio (float): Downsample ratio. loss (dict): Configuration dictionary for loss type. Supported loss types are "PANLoss" and "PSELoss". train_cfg, test_cfg (dict): Depreciated. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__( self, in_channels, out_channels, text_repr_type='poly', # 'poly' or 'quad' downsample_ratio=0.25, loss=dict(type='PANLoss'), train_cfg=None, test_cfg=None, init_cfg=dict( type='Normal', mean=0, std=0.01, override=dict(name='out_conv'))): super().__init__(init_cfg=init_cfg) assert check_argument.is_type_list(in_channels, int) assert isinstance(out_channels, int) assert text_repr_type in ['poly', 'quad'] assert 0 <= downsample_ratio <= 1 self.loss_module = build_loss(loss) self.in_channels = in_channels self.out_channels = out_channels self.text_repr_type = text_repr_type self.train_cfg = train_cfg self.test_cfg = test_cfg self.downsample_ratio = downsample_ratio if loss['type'] == 'PANLoss': self.decoding_type = 'pan' elif loss['type'] == 'PSELoss': self.decoding_type = 'pse' else: type = loss['type'] raise NotImplementedError(f'unsupported loss type {type}.') self.out_conv = nn.Conv2d( in_channels=np.sum(np.array(in_channels)), out_channels=out_channels, kernel_size=1)
[文档] def forward(self, inputs): r""" Args: inputs (list[Tensor] | Tensor): Each tensor has the shape of :math:`(N, C_i, W, H)`, where :math:`\sum_iC_i=C_{in}` and :math:`C_{in}` is ``input_channels``. Returns: Tensor: A tensor of shape :math:`(N, C_{out}, W, H)` where :math:`C_{out}` is ``output_channels``. """ if isinstance(inputs, tuple): outputs = torch.cat(inputs, dim=1) else: outputs = inputs outputs = self.out_conv(outputs) return outputs
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.