Source code for PytorchWildlife.models.classification.resnet.base_classifier

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import numpy as np
import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
from torch.hub import load_state_dict_from_url
from tqdm import tqdm
from collections import OrderedDict

# Making the PlainResNetInference class available for import from this module
__all__ = ["PlainResNetInference"]


class ResNetBackbone(ResNet):
    """
    Custom ResNet Backbone that extracts features from input images.
    """
    def _forward_impl(self, x):
        # Following the ResNet structure to extract features
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x


class PlainResNetClassifier(nn.Module):
    """
    Basic ResNet Classifier that uses a custom ResNet backbone.
    """
    name = "PlainResNetClassifier"

    def __init__(self, num_cls=1, num_layers=50):
        super(PlainResNetClassifier, self).__init__()
        self.num_cls = num_cls
        self.num_layers = num_layers
        self.feature = None
        self.classifier = None
        self.criterion_cls = None
        # Initialize the network and weights
        self.setup_net()

    def setup_net(self):
        """
        Set up the ResNet classifier according to the specified number of layers.
        """
        kwargs = {}

        if self.num_layers == 18:
            block = BasicBlock
            layers = [2, 2, 2, 2]
            # ... [Missing weight URL definition for ResNet18]
        elif self.num_layers == 50:
            block = Bottleneck
            layers = [3, 4, 6, 3]
            # ... [Missing weight URL definition for ResNet50]
        else:
            raise Exception("ResNet Type not supported.")

        self.feature = ResNetBackbone(block, layers, **kwargs)
        self.classifier = nn.Linear(512 * block.expansion, self.num_cls)

    def setup_criteria(self):
        """
        Setup the criterion for classification.
        """
        self.criterion_cls = nn.CrossEntropyLoss()

    def feat_init(self):
        """
        Initialize the features using pretrained weights.
        """
        init_weights = self.pretrained_weights.get_state_dict(progress=True)
        init_weights = OrderedDict({k.replace("module.", "").replace("feature.", ""): init_weights[k]
                                    for k in init_weights})
        self.feature.load_state_dict(init_weights, strict=False)
        # Print missing and unused keys for debugging purposes
        load_keys = set(init_weights.keys())
        self_keys = set(self.feature.state_dict().keys())
        missing_keys = self_keys - load_keys
        unused_keys = load_keys - self_keys
        print("missing keys:", sorted(list(missing_keys)))
        print("unused_keys:", sorted(list(unused_keys)))


[docs]class PlainResNetInference(nn.Module): """ Inference module for the PlainResNet Classifier. """ def __init__(self, num_cls=36, num_layers=50, weights=None, device="cpu", url=None): super(PlainResNetInference, self).__init__() self.device = device self.net = PlainResNetClassifier(num_cls=num_cls, num_layers=num_layers) if weights: clf_weights = torch.load(weights, map_location=torch.device(self.device)) elif url: clf_weights = load_state_dict_from_url(url, map_location=torch.device(self.device)) else: raise Exception("Need weights for inference.") self.load_state_dict(clf_weights["state_dict"], strict=True) self.eval() self.net.to(self.device)
[docs] def results_generation(self, logits, img_id, id_strip=None): """ Process logits to produce final results. Args: logits (torch.Tensor): Logits from the network. img_id (str): image path. id_strip (str): stiping string for better image id saving. Returns: dict: Dictionary containing the results. """ pass
[docs] def forward(self, img): feats = self.net.feature(img) logits = self.net.classifier(feats) return logits
[docs] def single_image_classification(self, img, img_id=None, id_strip=None): logits = self.forward(img.unsqueeze(0).to(self.device)) return self.results_generation(logits.cpu(), [img_id], id_strip=id_strip)[0]
[docs] def batch_image_classification(self, dataloader, id_strip=None): """ Process a batch of images for classification. """ total_logits = [] total_paths = [] with tqdm(total=len(dataloader)) as pbar: for batch in dataloader: imgs, paths = batch imgs = imgs.to(self.device) total_logits.append(self.forward(imgs)) total_paths.append(paths) pbar.update(1) total_logits = torch.cat(total_logits, dim=0).cpu() total_paths = np.concatenate(total_paths, axis=0) return self.results_generation(total_logits, total_paths, id_strip=id_strip)