
mmocr.models.textdet.dense_heads.textsnake_head 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.runner import BaseModule

from mmocr.models.builder import HEADS, build_loss
from . import HeadMixin

[文档]@HEADS.register_module() class TextSnakeHead(HeadMixin, BaseModule): """The class for TextSnake head. TextSnake: `A Flexible Representation for Detecting Text of Arbitrary Shapes <>`_. Args: in_channels (int): Number of input channels. decoding_type (str): Decoding type. It usually should not be changed. text_repr_type (str): Use polygon or quad to represent. Available options are "poly" or "quad". loss (dict): Configuration dictionary for loss type. train_cfg, test_cfg: Depreciated. init_cfg (dict or list[dict], optional): Initialization configs. """ def __init__(self, in_channels, decoding_type='textsnake', text_repr_type='poly', loss=dict(type='TextSnakeLoss'), train_cfg=None, test_cfg=None, init_cfg=dict( type='Normal', override=dict(name='out_conv'), mean=0, std=0.01)): super().__init__(init_cfg=init_cfg) assert isinstance(in_channels, int) self.in_channels = in_channels self.out_channels = 5 self.downsample_ratio = 1.0 self.decoding_type = decoding_type self.text_repr_type = text_repr_type self.loss_module = build_loss(loss) self.train_cfg = train_cfg self.test_cfg = test_cfg self.out_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0)
[文档] def forward(self, inputs): """ Args: inputs (Tensor): Shape :math:`(N, C_{in}, H, W)`, where :math:`C_{in}` is ``in_channels``. :math:`H` and :math:`W` should be the same as the input of backbone. Returns: Tensor: A tensor of shape :math:`(N, 5, H, W)`. """ outputs = self.out_conv(inputs) return outputs
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.