mmocr.datasets.transforms.wrappers 源代码
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Optional, Tuple, Union
import imgaug
import imgaug.augmenters as iaa
import numpy as np
import torchvision.transforms as torchvision_transforms
from mmcv.transforms.base import BaseTransform
from PIL import Image
from mmocr.registry import TRANSFORMS
from mmocr.utils import poly2bbox
[文档]@TRANSFORMS.register_module()
class ImgAugWrapper(BaseTransform):
"""A wrapper around imgaug https://github.com/aleju/imgaug.
Find available augmenters at
https://imgaug.readthedocs.io/en/latest/source/overview_of_augmenters.html.
Required Keys:
- img
- gt_polygons (optional for text recognition)
- gt_bboxes (optional for text recognition)
- gt_bboxes_labels (optional for text recognition)
- gt_ignored (optional for text recognition)
- gt_texts (optional)
Modified Keys:
- img
- gt_polygons (optional for text recognition)
- gt_bboxes (optional for text recognition)
- gt_bboxes_labels (optional for text recognition)
- gt_ignored (optional for text recognition)
- img_shape (optional)
- gt_texts (optional)
Args:
args (list[list or dict]], optional): The argumentation list. For
details, please refer to imgaug document. Take
args=[['Fliplr', 0.5], dict(cls='Affine', rotate=[-10, 10]),
['Resize', [0.5, 3.0]]] as an example. The args horizontally flip
images with probability 0.5, followed by random rotation with
angles in range [-10, 10], and resize with an independent scale in
range [0.5, 3.0] for each side of images. Defaults to None.
"""
def __init__(self, args: Optional[List[Union[List, Dict]]] = None) -> None:
assert args is None or isinstance(args, list) and len(args) > 0
if args is not None:
for arg in args:
assert isinstance(arg, (list, dict)), \
'args should be a list of list or dict'
self.args = args
self.augmenter = self._build_augmentation(args)
[文档] def transform(self, results: Dict) -> Dict:
"""Transform the image and annotation data.
Args:
results (dict): Result dict containing the data to transform.
Returns:
dict: The transformed data.
"""
# img is bgr
image = results['img']
aug = None
ori_shape = image.shape
if self.augmenter:
aug = self.augmenter.to_deterministic()
if not self._augment_annotations(aug, ori_shape, results):
return None
results['img'] = aug.augment_image(image)
results['img_shape'] = (results['img'].shape[0],
results['img'].shape[1])
return results
def _augment_annotations(self, aug: imgaug.augmenters.meta.Augmenter,
ori_shape: Tuple[int,
int], results: Dict) -> Dict:
"""Augment annotations following the pre-defined augmentation sequence.
Args:
aug (imgaug.augmenters.meta.Augmenter): The imgaug augmenter.
ori_shape (tuple[int, int]): The ori_shape of the original image.
results (dict): Result dict containing annotations to transform.
Returns:
bool: Whether the transformation has been successfully applied. If
the transform results in empty polygon/bbox annotations, return
False.
"""
# Assume co-existence of `gt_polygons`, `gt_bboxes` and `gt_ignored`
# for text detection
if 'gt_polygons' in results:
# augment polygons
transformed_polygons, removed_poly_inds = self._augment_polygons(
aug, ori_shape, results['gt_polygons'])
if len(transformed_polygons) == 0:
return False
results['gt_polygons'] = transformed_polygons
# remove instances that are no longer inside the augmented image
results['gt_bboxes_labels'] = np.delete(
results['gt_bboxes_labels'], removed_poly_inds, axis=0)
results['gt_ignored'] = np.delete(
results['gt_ignored'], removed_poly_inds, axis=0)
# TODO: deal with gt_texts corresponding to clipped polygons
if 'gt_texts' in results:
results['gt_texts'] = [
text for i, text in enumerate(results['gt_texts'])
if i not in removed_poly_inds
]
# Generate new bboxes
bboxes = [poly2bbox(poly) for poly in transformed_polygons]
results['gt_bboxes'] = np.zeros((0, 4), dtype=np.float32)
if len(bboxes) > 0:
results['gt_bboxes'] = np.stack(bboxes)
return True
def _augment_polygons(self, aug: imgaug.augmenters.meta.Augmenter,
ori_shape: Tuple[int, int], polys: List[np.ndarray]
) -> Tuple[List[np.ndarray], List[int]]:
"""Augment polygons.
Args:
aug (imgaug.augmenters.meta.Augmenter): The imgaug augmenter.
ori_shape (tuple[int, int]): The shape of the original image.
polys (list[np.ndarray]): The polygons to be augmented.
Returns:
tuple(list[np.ndarray], list[int]): The augmented polygons, and the
indices of polygons removed as they are out of the augmented image.
"""
imgaug_polys = []
for poly in polys:
poly = poly.reshape(-1, 2)
imgaug_polys.append(imgaug.Polygon(poly))
imgaug_polys = aug.augment_polygons(
[imgaug.PolygonsOnImage(imgaug_polys, shape=ori_shape)])[0]
new_polys = []
removed_poly_inds = []
for i, poly in enumerate(imgaug_polys.polygons):
if poly.is_out_of_image(imgaug_polys.shape):
removed_poly_inds.append(i)
continue
new_poly = []
for point in poly.clip_out_of_image(imgaug_polys.shape)[0]:
new_poly.append(np.array(point, dtype=np.float32))
new_poly = np.array(new_poly, dtype=np.float32).flatten()
# Under some conditions, imgaug can generate "polygon" with only
# two points, which is not a valid polygon.
if len(new_poly) <= 4:
removed_poly_inds.append(i)
continue
new_polys.append(new_poly)
return new_polys, removed_poly_inds
def _build_augmentation(self, args, root=True):
"""Build ImgAugWrapper augmentations.
Args:
args (dict): Arguments to be passed to imgaug.
root (bool): Whether it's building the root augmenter.
Returns:
imgaug.augmenters.meta.Augmenter: The built augmenter.
"""
if args is None:
return None
if isinstance(args, (int, float, str)):
return args
if isinstance(args, list):
if root:
sequence = [
self._build_augmentation(value, root=False)
for value in args
]
return iaa.Sequential(sequence)
arg_list = [self._to_tuple_if_list(a) for a in args[1:]]
return getattr(iaa, args[0])(*arg_list)
if isinstance(args, dict):
if 'cls' in args:
cls = getattr(iaa, args['cls'])
return cls(
**{
k: self._to_tuple_if_list(v)
for k, v in args.items() if not k == 'cls'
})
else:
return {
key: self._build_augmentation(value, root=False)
for key, value in args.items()
}
raise RuntimeError('unknown augmenter arg: ' + str(args))
def _to_tuple_if_list(self, obj: Any) -> Any:
"""Convert an object into a tuple if it is a list."""
if isinstance(obj, list):
return tuple(obj)
return obj
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(args = {self.args})'
return repr_str
[文档]@TRANSFORMS.register_module()
class TorchVisionWrapper(BaseTransform):
"""A wrapper around torchvision trasnforms. It applies specific transform
to ``img`` and updates ``height`` and ``width`` accordingly.
Required Keys:
- img (ndarray): The input image.
Modified Keys:
- img (ndarray): The modified image.
- img_shape (tuple(int, int)): The shape of the image in (height, width).
Warning:
This transform only affects the image but not its associated
annotations, such as word bounding boxes and polygons. Therefore,
it may only be applicable to text recognition tasks.
Args:
op (str): The name of any transform class in
:func:`torchvision.transforms`.
**kwargs: Arguments that will be passed to initializer of torchvision
transform.
"""
def __init__(self, op: str, **kwargs) -> None:
assert isinstance(op, str)
obj_cls = getattr(torchvision_transforms, op)
self.torchvision = obj_cls(**kwargs)
self.op = op
self.kwargs = kwargs
[文档] def transform(self, results):
"""Transform the image.
Args:
results (dict): Result dict from the data loader.
Returns:
dict: Transformed results.
"""
assert 'img' in results
# BGR -> RGB
img = results['img'][..., ::-1]
img = Image.fromarray(img)
img = self.torchvision(img)
img = np.asarray(img)
img = img[..., ::-1]
results['img'] = img
results['img_shape'] = img.shape[:2]
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(op = {self.op}'
for k, v in self.kwargs.items():
repr_str += f', {k} = {v}'
repr_str += ')'
return repr_str