PytorchWildlife.models¶
This module contains the PytorchWildlife models. The models are currently divided between Classification and Detection.
Detection models¶
Yolov5¶
- class PytorchWildlife.models.YOLOV5Base(weights=None, device='cpu', url=None)[source]¶
Bases:
object
Base detector class for YOLO V5. This class provides utility methods for loading the model, generating results, and performing single and batch image detections.
- CLASS_NAMES = None¶
- IMAGE_SIZE = None¶
- STRIDE = None¶
- TRANSFORM = None¶
- batch_image_detection(dataloader, conf_thres=0.2, id_strip=None)[source]¶
Perform detection on a batch of images.
- Args:
- dataloader (DataLoader):
DataLoader containing image batches.
- conf_thres (float, optional):
Confidence threshold for predictions. Defaults to 0.2.
- id_strip (str, optional):
Characters to strip from img_id. Defaults to None.
- Returns:
list: List of detection results for all images.
- results_generation(preds, img_id, id_strip=None)[source]¶
Generate results for detection based on model predictions.
- Args:
- preds (numpy.ndarray):
Model predictions.
- img_id (str):
Image identifier.
- id_strip (str, optional):
Strip specific characters from img_id. Defaults to None.
- Returns:
dict: Dictionary containing image ID, detections, and labels.
- single_image_detection(img, img_size=None, img_path=None, conf_thres=0.2, id_strip=None)[source]¶
Perform detection on a single image.
- Args:
- img (torch.Tensor):
Input image tensor.
- img_size (tuple):
Original image size.
- img_path (str):
Image path or identifier.
- conf_thres (float, optional):
Confidence threshold for predictions. Defaults to 0.2.
- id_strip (str, optional):
Characters to strip from img_id. Defaults to None.
- Returns:
dict: Detection results.
Classification models¶
ResNet¶
- class PytorchWildlife.models.PlainResNetInference(num_cls=36, num_layers=50, weights=None, device='cpu', url=None)[source]¶
Bases:
Module
Inference module for the PlainResNet Classifier.
- batch_image_classification(dataloader, id_strip=None)[source]¶
Process a batch of images for classification.
- forward(img)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- results_generation(logits, img_id, id_strip=None)[source]¶
Process logits to produce final results.
- Args:
logits (torch.Tensor): Logits from the network. img_id (str): image path. id_strip (str): stiping string for better image id saving.
- Returns:
dict: Dictionary containing the results.
- training: bool¶
Pretrained Classification Weights¶
This section provides the pretrained weights that are currently available for the classification models. Below is a table detailing the available pretrained weights:
AI4GOpossum¶
- class PytorchWildlife.models.AI4GOpossum(weights=None, device='cpu', pretrained=True)[source]¶
Bases:
PlainResNetInference
Opossum Classifier that inherits from PlainResNetInference. This classifier is specialized for distinguishing between Opossums and Non-opossums.
- CLASS_NAMES = {0: 'Non-opossum', 1: 'Opossum'}¶
- IMAGE_SIZE = 224¶
- results_generation(logits, img_ids, id_strip=None)[source]¶
Generate results for classification.
- Args:
logits (torch.Tensor): Output tensor from the model. img_id (list): List of image identifier. id_strip (str): stiping string for better image id saving.
- Returns:
dict: Dictionary containing image ID, prediction, and confidence score.
- training: bool¶
AI4GAmazonRainforest¶
- class PytorchWildlife.models.AI4GAmazonRainforest(weights=None, device='cpu', pretrained=True)[source]¶
Bases:
PlainResNetInference
Amazon Ranforest Animal Classifier that inherits from PlainResNetInference. This classifier is specialized for recognizing 36 different animals in the Amazon Rainforest.
- CLASS_NAMES = {0: 'Dasyprocta', 1: 'Bos', 2: 'Pecari', 3: 'Mazama', 4: 'Cuniculus', 5: 'Leptotila', 6: 'Human', 7: 'Aramides', 8: 'Tinamus', 9: 'Eira', 10: 'Crax', 11: 'Procyon', 12: 'Capra', 13: 'Dasypus', 14: 'Sciurus', 15: 'Crypturellus', 16: 'Tamandua', 17: 'Proechimys', 18: 'Leopardus', 19: 'Equus', 20: 'Columbina', 21: 'Nyctidromus', 22: 'Ortalis', 23: 'Emballonura', 24: 'Odontophorus', 25: 'Geotrygon', 26: 'Metachirus', 27: 'Catharus', 28: 'Cerdocyon', 29: 'Momotus', 30: 'Tapirus', 31: 'Canis', 32: 'Furnarius', 33: 'Didelphis', 34: 'Sylvilagus', 35: 'Unknown'}¶
- IMAGE_SIZE = 224¶
- results_generation(logits, img_ids, id_strip=None)[source]¶
Generate results for classification.
- Args:
logits (torch.Tensor): Output tensor from the model. img_id (str): Image identifier. id_strip (str): stiping string for better image id saving.
- Returns:
dict: Dictionary containing image ID, prediction, and confidence score.
- training: bool¶