Source code for PytorchWildlife.data.datasets
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
from PIL import Image
import numpy as np
import supervision as sv
from torch.utils.data import Dataset
# Making the DetectionImageFolder class available for import from this module
__all__ = [
"DetectionImageFolder",
]
[docs]class DetectionImageFolder(Dataset):
"""
A PyTorch Dataset for loading images from a specified directory.
Each item in the dataset is a tuple containing the image data,
the image's path, and the original size of the image.
"""
def __init__(self, image_dir, transform=None):
"""
Initializes the dataset.
Parameters:
image_dir (str): Path to the directory containing the images.
transform (callable, optional): Optional transform to be applied on the image.
"""
self.image_dir = image_dir
# Listing and sorting all image files in the specified directory
self.images = sorted(os.listdir(self.image_dir))
self.transform = transform
def __getitem__(self, idx):
"""
Retrieves an image from the dataset.
Parameters:
idx (int): Index of the image to retrieve.
Returns:
tuple: Contains the image data, the image's path, and its original size.
"""
# Get image filename and path
img = self.images[idx]
img_path = os.path.join(self.image_dir, img)
# Load and convert image to RGB
img = Image.open(img_path).convert("RGB")
img = np.asarray(img)
img_size_ori = img.shape
# Apply transformation if specified
if self.transform:
img = self.transform(img)
return img, img_path, np.array(img_size_ori)
def __len__(self):
"""
Returns the total number of images in the dataset.
Returns:
int: Total number of images.
"""
return len(self.images)
class DetectionCrops(Dataset):
def __init__(self, detection_results, transform=None, path_head=None, animal_cls_id=0):
self.detection_results = detection_results
self.transform = transform
self.path_head = path_head
self.animal_cls_id = animal_cls_id # This determins which detection class id represents animals.
self.img_ids = []
self.xyxys = []
self.load_detection_results()
def load_detection_results(self):
for det in self.detection_results:
for xyxy, det_id in zip(det["detections"].xyxy, det["detections"].class_id):
# Only run recognition on animal detections
if det_id == self.animal_cls_id:
self.img_ids.append(det["img_id"])
self.xyxys.append(xyxy)
def __getitem__(self, idx):
"""
Retrieves an image from the dataset.
Parameters:
idx (int): Index of the image to retrieve.
Returns:
tuple: Contains the image data and the image's path.
"""
# Get image path and corresponding bbox xyxy for cropping
img_id = self.img_ids[idx]
xyxy = self.xyxys[idx]
img_path = os.path.join(self.path_head, img_id) if self.path_head else img_id
# Load and crop image with supervision
img = sv.crop_image(np.array(Image.open(img_path).convert("RGB")),
xyxy=xyxy)
# Apply transformation if specified
if self.transform:
img = self.transform(Image.fromarray(img))
return img, img_path
def __len__(self):
return len(self.img_ids)