Source code for PytorchWildlife.data.transforms

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import numpy as np
import torch
from torchvision import transforms
from yolov5.utils.augmentations import letterbox

# Making the provided classes available for import from this module
__all__ = [
    "MegaDetector_v5_Transform",
    "Classification_Inference_Transform"
]

[docs]class MegaDetector_v5_Transform: """ A transformation class to preprocess images for the MegaDetector v5 model. This includes resizing, transposing, and normalization operations. This is a required transformation for the YoloV5 model. """ def __init__(self, target_size=1280, stride=32): """ Initializes the transform. Args: target_size (int): Desired size for the image's longest side after resizing. stride (int): Stride value for resizing. """ self.target_size = target_size self.stride = stride def __call__(self, np_img): """ Applies the transformation on the provided image. Args: np_img (np.ndarray): Input image as a numpy array. Returns: torch.Tensor: Transformed image. """ # Resize and pad the image using the letterbox function img = letterbox(np_img, new_shape=self.target_size, stride=self.stride, auto=False)[0] # Transpose and convert image to PyTorch tensor img = img.transpose((2, 0, 1)) img = np.ascontiguousarray(img) img = torch.from_numpy(img).float() img /= 255.0 return img
[docs]class Classification_Inference_Transform: """ A transformation class to preprocess images for classification inference. This includes resizing, normalization, and conversion to a tensor. """ # Normalization constants mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] def __init__(self, target_size=224): """ Initializes the transform. Args: target_size (int): Desired size for the height and width after resizing. """ # Define the sequence of transformations self.trans = transforms.Compose([ transforms.Resize((target_size, target_size)), transforms.ToTensor(), transforms.Normalize(self.mean, self.std) ]) def __call__(self, img): """ Applies the transformation on the provided image. Args: img (PIL.Image.Image): Input image in PIL format. Returns: torch.Tensor: Transformed image. """ img = self.trans(img) return img