import os
import numpy as np
import torch
from PIL import Image
from openslide import OpenSlide
from torch.utils.data import DataLoader
from .artifacts import Artifact
from .dataset import GrandQCDataset
from .visualisation import make_artifacts_color_map, make_overlay
from ..utils.file_utils import get_basename, save_rescaled
from ..utils.image import gaussian_window
from ..utils.matlab2python import list2cell
from ..utils.wsi import load_wsi_mag, get_regions_location
import scipy.io as sio
[docs]
def detect_artifacts_slide(slide_file, res_dict_path, batch_size, num_workers,
device, model, paths_dict, vis_mag, overlap=0.7, mag_model=10,
patch_size=512, mode="gaussian", classes = 8, sigma=None, save_mag = 2.5,
save_confidence_maps=True):
"""
Run optimized GrandQC inference on a whole-slide image (WSI), aggregate patch-wise
predictions into full-resolution confidence maps, produce visualizations and save
results to desired output folders.
This function:
- loads a WSI and a precomputed tissue/background detection map,
- extracts patches from tissue bounding boxes using a sliding grid,
- runs the provided model on batches of patches,
- merges patch predictions into full-image class confidence maps using a weighting
(Gaussian or average) window,
- normalizes the aggregated maps by the sum of weights,
- produces a hard prediction mask and color visualizations,
- optionally rescales results to a different magnification and saves both the
hard masks and per-class confidence maps as MATLAB `.mat` files, and
- saves an overlay visualization.
Parameters
----------
slide_file : str or pathlib.Path
Path to the WSI file (e.g. `.svs`) to process.
res_dict_path : str or pathlib.Path
Path to `.mat` file containing precomputed tissue/background detection and
related metadata used to compute bounding boxes (expected keys used in code:
``"mask_bg"``, ``"ind_WSI"``, ``"ratio"``, ``"thr"``).
batch_size : int
Batch size for the PyTorch DataLoader used during inference.
num_workers : int
Number of worker processes for the DataLoader.
device : torch.device or str
Device used for inference (e.g. ``"cuda"`` or ``cpu``).
model : torch.nn.Module
Pre-trained segmentation/classification model that accepts a batch of image
tensors and returns per-pixel class confidence maps (shape ``(N, C, H, W)``).
paths_dict : dict
Dictionary with output directory paths. Expected keys (used in function):
- ``"grandqc_overlay_vis"`` : directory to save overlay PNGs
- ``"masks_grandqc"`` : directory to save `.mat` with hard masks
- ``"masks_grandqc_confidence_maps"`` : directory to save per-class `.mat`
vis_mag : float
Magnification used for creating visualization overlays (e.g. 2.5, 5, 10).
overlap : float, optional
Fractional overlap between sliding patches (default ``0.7``).
mag_model : float or int, optional
Working magnification at which the model expects inputs (default ``10``).
patch_size : int, optional
Side length of extracted square patches in pixels (default ``512``).
mode : {'gaussian', 'average'}, optional
Weighting mode used when aggregating overlapping patch predictions. If
``'gaussian'``, a Gaussian window is applied; if ``'average'``, equal weights
are used (default ``"gaussian"``).
classes : int, optional
Number of classes (including background) predicted by the network (default ``8``).
sigma : float or None, optional
Standard deviation parameter for the Gaussian window (only relevant when
``mode=='gaussian'``). If ``None``, a reasonable default (from gaussian_window)
is assumed by that helper (default ``None``).
save_mag : float or int, optional
Magnification at which final masks/confidence maps are saved (default ``2.5``).
save_confidence_maps : bool, optional
If True, save per-class confidence maps as separate `.mat` files in the
directory pointed by ``paths_dict["masks_grandqc_confidence_maps"]`` (default True).
Returns
-------
A dictionary summarizing saved results and metadata. Keys include:
- ``'basename'`` : str, base filename of the processed WSI (without extension)
- ``'mask_art'`` : list or MATLAB cell (converted inside function) of hard masks
for tissue regions
- ``'ind_WSI'`` : value loaded from ``res_dict_path`` (indexes for WSI layers)
- ``'ratio'`` : value loaded from ``res_dict_path`` (ratio per layer)
- ``'scale_val'`` : float, final scale factor applied to masks
- ``'thr'`` : thresholds loaded from ``res_dict_path`` for color channels
- ``'bbox'`` : bounding boxes used for patch extraction (possibly rescaled)
(The function also writes multiple `.mat` files and PNGs to disk as a side effect.)
Notes
-----
- The function may produce coordinates outside original image bounds; these are
handled by cropping and padding logic inside the dataset / merge loop.
- The caller must ensure that ``paths_dict`` directories exist and are writable.
- Large WSIs can consume substantial memory; the implementation aggregates patch
predictions into full-resolution arrays sized (H, W, classes) — ensure sufficient
RAM is available for the target magnification and image size.
Examples
--------
>>> paths = {
... "grandqc_overlay_vis": "/out/overlays",
... "masks_grandqc": "/out/masks",
... "masks_grandqc_confidence_maps": "/out/conf_maps"
... }
>>> save_info = detect_artifacts_slide(
... "slide.svs",
... "slide_res.mat",
... batch_size=8,
... num_workers=4,
... device="cuda",
... model=my_model,
... paths_dict=paths,
... vis_mag=2.5,
... overlap=0.7,
... mag_model=10,
... patch_size=512
... )
>>> print(save_info["basename"])
References
----------
This function is based on \ :footcite:p:`Weng2024`
"""
# slide basename
basename = get_basename(slide_file)
# load slide
slide = OpenSlide(slide_file)
# rescale region
region, scale_val, info, mpp_slide, ratio = load_wsi_mag(slide, mag_model, allow_upscaling=True)
W, H = region.size
region = np.array(region)
# size for visualisations
w_l0, h_l0 = slide.level_dimensions[0]
mag_l0 = float(slide.properties["openslide.objective-power"])
scale_vis = vis_mag / mag_l0
vis_size = (int(w_l0 * scale_vis), int(h_l0 * scale_vis))
data = sio.loadmat(res_dict_path)
tis_det = data["mask_bg"]
tis_det = np.array(Image.fromarray(tis_det).resize((W, H), Image.Resampling.NEAREST))
bbox, images_list = get_regions_location(tis_det)
dataset = GrandQCDataset(region, tis_det, bbox, patch_size, overlap)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
model.eval()
raw_mask = np.zeros((H, W, classes))
weights = np.zeros((H, W))
if mode == "gaussian":
weight_patch = gaussian_window(patch_size, patch_size, sigma=sigma)
elif mode == "average":
weight_patch = np.ones((patch_size, patch_size))
for batch in dataloader:
with torch.no_grad():
images = batch["patch"].to(device)
pred = model(images).to("cpu").numpy()
for i, pred in enumerate(pred):
pred_hwc = pred.transpose(1, 2, 0)
orig_x0 = int(batch["x_start"][i])
orig_y0 = int(batch["y_start"][i])
orig_x1 = int(batch["x_end"][i])
orig_y1 = int(batch["y_end"][i])
dst_x0 = max(0, orig_x0)
dst_y0 = max(0, orig_y0)
dst_x1 = min(W, orig_x1)
dst_y1 = min(H, orig_y1)
h = dst_y1 - dst_y0
w = dst_x1 - dst_x0
if h <= 0 or w <= 0:
continue
src_x0 = dst_x0 - orig_x0
src_y0 = dst_y0 - orig_y0
pred_patch = pred_hwc[src_y0:src_y0 + h, src_x0:src_x0 + w, :] # (h, w, C)
gauss_patch_crop = weight_patch[src_y0:src_y0 + h, src_x0:src_x0 + w] # (h, w)
assert pred_patch.shape[0] == gauss_patch_crop.shape[0] and pred_patch.shape[1] == gauss_patch_crop.shape[
1], \
f"Shape mismatch: pred_patch {pred_patch.shape}, gauss {gauss_patch_crop.shape}"
raw_mask[dst_y0:dst_y1, dst_x0:dst_x1, :] += pred_patch * gauss_patch_crop[..., None]
weights[dst_y0:dst_y1, dst_x0:dst_x1] += gauss_patch_crop
for i in range(classes):
raw_mask[:, :, i] = np.divide(
raw_mask[:, :, i],
weights,
out=np.zeros_like(raw_mask[:, :, i]),
where=weights != 0
)
del dataset
pred_mask = np.argmax(raw_mask, axis=2).astype('int8')
pred_mask = pred_mask[:region.shape[0], :region.shape[1]]
# add information about the background detected by model, tissue is black, background is white (class probability)
raw_mask[:, :, 0] = 1 - tis_det
# remove the rest of bg pixels
tis_det_bool = tis_det.astype(bool)
pred_mask[~tis_det_bool] = 0
# make color visualisation
artifacts_color_map = Image.fromarray(make_artifacts_color_map(pred_mask))
# overlay heatmap on the image
overlay = make_overlay(region, artifacts_color_map, tis_det, vis_size)
overlay = Image.fromarray(overlay)
save_rescaled(overlay, vis_size, os.path.join(paths_dict["grandqc_overlay_vis"], f'{basename}.png'))
del overlay
del artifacts_color_map
del weights
# rescale result to desired magnification
if mag_model != save_mag:
scale = float(mag_model) / float(save_mag)
new_H = int(round(H / scale))
new_W = int(round(W / scale))
resized_size = (new_W, new_H)
pred_mask = np.array(Image.fromarray(pred_mask).resize(resized_size, Image.Resampling.NEAREST))
raw_mask_rescaled = np.zeros((new_H, new_W, raw_mask.shape[2]), dtype=raw_mask.dtype)
for c in range(raw_mask.shape[2]):
img = Image.fromarray(raw_mask[:, :, c])
img_resized = img.resize(resized_size, Image.Resampling.NEAREST)
raw_mask_rescaled[:, :, c] = np.array(img_resized)
raw_mask = raw_mask_rescaled
bbox = np.array(bbox, dtype=float)
bbox = np.round(bbox / scale).astype(int)
images_list_scaled = []
for n, image_bbox in enumerate(images_list):
image_bbox = Image.fromarray(image_bbox)
ymin, xmin, ymax, xmax = bbox[n]
width = xmax - xmin
height = ymax - ymin
image_bbox = image_bbox.resize((width, height), Image.Resampling.NEAREST)
images_list_scaled.append(np.array(image_bbox, dtype=np.uint8))
images_list = images_list_scaled
h_res, w_res = pred_mask.shape
scale_val = save_mag/mag_l0
save_dict = {
'basename': basename, # tissue file basename (without .svs extension)
'mask_art': [], # mask with artifacts detected by grandQC for given region
'ind_WSI': data['ind_WSI'].astype(np.uint8), # indexes for WSI image layers (idx from 1)
'ratio': data['ratio'], # ratio for each layer
'scale_val': scale_val, # scale factor of masks
'thr': data['thr'], # thresholds calculated for R, G, B color channels
'bbox': bbox.astype(np.uint64), # bboxes for tissue regions (indexing from 0)
'mask_mag': save_mag, # magnification of the final mask
'mpp': mpp_slide, # mpp of the slide
'mag_l0': mag_l0 # magnification of the largest WSI layer
}
save_dict_raw = {
'basename': basename, # tissue file basename (without .svs extension)
'mask_conf': [], # mask with confidence scores generated by grandQC for given region, thresholding mask is also added.
# Each mask is scaled to range 0-255, white pixels have high confidence scores
'ind_WSI': data['ind_WSI'].astype(np.uint8), # indexes for WSI image layers (idx from 1)
'ratio': data['ratio'], # ratio for each layer
'scale_val': scale_val, # scale factor of masks
'thr': data['thr'], # thresholds calculated for R, G, B color channels
'bbox': bbox.astype(np.uint64), # bboxes for tissue regions (indexing from 0)
'mask_mag': save_mag, # magnification of the final mask
'mpp': mpp_slide, # mpp of the slide
'mag_l0': mag_l0 # magnification of the largest WSI layer
}
for n, (region_bbox, image_bbox) in enumerate(zip(bbox, images_list)):
y0, x0, y1, x1 = map(int, region_bbox)
y0 = max(0, min(y0, h_res))
y1 = max(0, min(y1, h_res))
x0 = max(0, min(x0, w_res))
x1 = max(0, min(x1, w_res))
pred_mask_region = pred_mask[y0:y1, x0:x1].astype(np.uint8)*image_bbox
raw_mask_region = (raw_mask[y0:y1, x0:x1, :]*255).astype(np.uint8)*image_bbox[...,None]
save_dict['mask_art'].append(pred_mask_region)
if save_confidence_maps:
save_dict_raw['mask_conf'].append(raw_mask_region)
# convert to cells for matlab
del pred_mask
del raw_mask
save_dict['mask_art'] = list2cell(save_dict['mask_art'])
sio.savemat(os.path.join(paths_dict["masks_grandqc"], f'{basename}.mat'), save_dict, do_compression=True)
if save_confidence_maps:
save_dict_raw['mask_conf'] = list2cell(save_dict_raw['mask_conf'])
sio.savemat(os.path.join(paths_dict["grandqc_confidence_maps"], f'{basename}.mat'), save_dict_raw, do_compression=True)
return save_dict