Source code for histo_kit.grand_qc.model

import torch

[docs] class GrandQC(torch.nn.Module): """ Wrapper module for a pre-trained GrandQC model. Parameters ---------- gqc_weights : str or pathlib.Path Path to the `.pt` or `.pth` file containing the serialized GrandQC model. device : str, optional Device on which to load the model "cuda" or "cpu" (default is ``"cuda"``). Attributes ---------- grandQC : torch.nn.Module The loaded GrandQC model instance. Examples -------- >>> model = GrandQC("grandqc.pth", device="cpu") >>> x = torch.randn(1, 3, 224, 224) >>> y = model(x) """ def __init__(self, gqc_weights, device="cuda"): super().__init__() self.grandQC = torch.load(gqc_weights, map_location=device)
[docs] def forward(self, x): """ Run a forward pass through the GrandQC model. Parameters ---------- x : torch.Tensor Input tensor shaped appropriately for the underlying model (e.g., ``(N, C, H, W)`` for images). Returns ------- torch.Tensor The model output tensor. """ y = self.grandQC(x) return y