Source code for PytorchWildlife.models.classification.resnet.opossum

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import torch
from .base_classifier import PlainResNetInference

__all__ = [
    "AI4GOpossum"
]


[docs]class AI4GOpossum(PlainResNetInference): """ Opossum Classifier that inherits from PlainResNetInference. This classifier is specialized for distinguishing between Opossums and Non-opossums. """ # Image size for the Opossum classifier IMAGE_SIZE = 224 # Class names for prediction CLASS_NAMES = { 0: "Non-opossum", 1: "Opossum" } def __init__(self, weights=None, device="cpu", pretrained=True): """ Initialize the Opossum Classifier. Args: weights (str, optional): Path to the model weights. Defaults to None. device (str, optional): Device for model inference. Defaults to "cpu". pretrained (bool, optional): Whether to use pretrained weights. Defaults to True. """ # If pretrained, use the provided URL to fetch the weights if pretrained: url = "https://zenodo.org/records/10023414/files/OpossumClassification_v0.0.0.ckpt?download=1" else: url = None super(AI4GOpossum, self).__init__(weights=weights, device=device, num_cls=1, num_layers=50, url=url)
[docs] def results_generation(self, logits, img_ids, id_strip=None): """ Generate results for classification. Args: logits (torch.Tensor): Output tensor from the model. img_id (list): List of image identifier. id_strip (str): stiping string for better image id saving. Returns: dict: Dictionary containing image ID, prediction, and confidence score. """ probs = torch.sigmoid(logits) preds = (probs > 0.5).squeeze(1).numpy().astype(int) results = [] for pred, img_id, prob in zip(preds, img_ids, probs): r = {"img_id": str(img_id).strip(id_strip)} r["prediction"] = self.CLASS_NAMES[pred] r["class_id"] = pred r["confidence"] = prob.item() if pred == 1 else (1 - prob.item()) results.append(r) return results