PytorchWildlife.models

This module contains the PytorchWildlife models. The models are currently divided between Classification and Detection.

Detection models

Yolov5

class PytorchWildlife.models.YOLOV5Base(weights=None, device='cpu', url=None)[source]

Bases: object

Base detector class for YOLO V5. This class provides utility methods for loading the model, generating results, and performing single and batch image detections.

CLASS_NAMES = None
IMAGE_SIZE = None
STRIDE = None
TRANSFORM = None
batch_image_detection(dataloader, conf_thres=0.2, id_strip=None)[source]

Perform detection on a batch of images.

Args:
dataloader (DataLoader):

DataLoader containing image batches.

conf_thres (float, optional):

Confidence threshold for predictions. Defaults to 0.2.

id_strip (str, optional):

Characters to strip from img_id. Defaults to None.

Returns:

list: List of detection results for all images.

results_generation(preds, img_id, id_strip=None)[source]

Generate results for detection based on model predictions.

Args:
preds (numpy.ndarray):

Model predictions.

img_id (str):

Image identifier.

id_strip (str, optional):

Strip specific characters from img_id. Defaults to None.

Returns:

dict: Dictionary containing image ID, detections, and labels.

single_image_detection(img, img_size=None, img_path=None, conf_thres=0.2, id_strip=None)[source]

Perform detection on a single image.

Args:
img (torch.Tensor):

Input image tensor.

img_size (tuple):

Original image size.

img_path (str):

Image path or identifier.

conf_thres (float, optional):

Confidence threshold for predictions. Defaults to 0.2.

id_strip (str, optional):

Characters to strip from img_id. Defaults to None.

Returns:

dict: Detection results.

Classification models

ResNet

class PytorchWildlife.models.PlainResNetInference(num_cls=36, num_layers=50, weights=None, device='cpu', url=None)[source]

Bases: Module

Inference module for the PlainResNet Classifier.

batch_image_classification(dataloader, id_strip=None)[source]

Process a batch of images for classification.

forward(img)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

results_generation(logits, img_id, id_strip=None)[source]

Process logits to produce final results.

Args:

logits (torch.Tensor): Logits from the network. img_id (str): image path. id_strip (str): stiping string for better image id saving.

Returns:

dict: Dictionary containing the results.

single_image_classification(img, img_id=None, id_strip=None)[source]
training: bool

Pretrained Classification Weights

This section provides the pretrained weights that are currently available for the classification models. Below is a table detailing the available pretrained weights:

AI4GOpossum

class PytorchWildlife.models.AI4GOpossum(weights=None, device='cpu', pretrained=True)[source]

Bases: PlainResNetInference

Opossum Classifier that inherits from PlainResNetInference. This classifier is specialized for distinguishing between Opossums and Non-opossums.

CLASS_NAMES = {0: 'Non-opossum', 1: 'Opossum'}
IMAGE_SIZE = 224
results_generation(logits, img_ids, id_strip=None)[source]

Generate results for classification.

Args:

logits (torch.Tensor): Output tensor from the model. img_id (list): List of image identifier. id_strip (str): stiping string for better image id saving.

Returns:

dict: Dictionary containing image ID, prediction, and confidence score.

training: bool

AI4GAmazonRainforest

class PytorchWildlife.models.AI4GAmazonRainforest(weights=None, device='cpu', pretrained=True)[source]

Bases: PlainResNetInference

Amazon Ranforest Animal Classifier that inherits from PlainResNetInference. This classifier is specialized for recognizing 36 different animals in the Amazon Rainforest.

CLASS_NAMES = {0: 'Dasyprocta', 1: 'Bos', 2: 'Pecari', 3: 'Mazama', 4: 'Cuniculus', 5: 'Leptotila', 6: 'Human', 7: 'Aramides', 8: 'Tinamus', 9: 'Eira', 10: 'Crax', 11: 'Procyon', 12: 'Capra', 13: 'Dasypus', 14: 'Sciurus', 15: 'Crypturellus', 16: 'Tamandua', 17: 'Proechimys', 18: 'Leopardus', 19: 'Equus', 20: 'Columbina', 21: 'Nyctidromus', 22: 'Ortalis', 23: 'Emballonura', 24: 'Odontophorus', 25: 'Geotrygon', 26: 'Metachirus', 27: 'Catharus', 28: 'Cerdocyon', 29: 'Momotus', 30: 'Tapirus', 31: 'Canis', 32: 'Furnarius', 33: 'Didelphis', 34: 'Sylvilagus', 35: 'Unknown'}
IMAGE_SIZE = 224
results_generation(logits, img_ids, id_strip=None)[source]

Generate results for classification.

Args:

logits (torch.Tensor): Output tensor from the model. img_id (str): Image identifier. id_strip (str): stiping string for better image id saving.

Returns:

dict: Dictionary containing image ID, prediction, and confidence score.

training: bool