Source code for pvcracks.utils.viz_functions

import matplotlib.pyplot as plt
import torch


[docs]def channeled_inference_and_show( data_loader, device, model, category_mapping, idx, threshold=0.5, custom_title="Model Prediction", save_path=None, rows=3, model_two=None, label_prefix_one="", label_prefix_two="", secondary_mapping=None, secondary_mapping_index=1, ): """Run inference on a single dataset element and visualize predictions per class. Args: data_loader (torch.utils.data.DataLoader): Loader that provides dataset access. device (torch.device): Computation device for running the model. model (torch.nn.Module): Trained segmentation network returning channel logits. category_mapping (Mapping[int, str]): Mapping from class index to display name. idx (int): Index of the sample to visualize inside the dataset. threshold (float, optional): Probability cutoff for converting logits to masks. """ # Get the preprocessed image and multi-hot ground truth mask img, mask = data_loader.dataset.__getitem__(idx) img = img.to(device) # Get the raw image for display (assuming __getraw__ returns a PIL image) raw_img, _ = data_loader.dataset.__getraw__(idx) # --- Run inference --- # Get raw logits from the model, then apply Sigmoid and threshold logits = model(img.unsqueeze(0)).detach().cpu() # shape: [1, n_classes, H, W] probs = torch.sigmoid(logits) # shape: [1, n_classes, H, W] pred_mask = ( (probs > threshold).float().squeeze(0).numpy() ) # shape: [n_classes, H, W] # Ground truth is assumed to be already a n_classes-channel multi-hot mask. gt_mask = mask.cpu().numpy() # shape: [n_classes, H, W] # --- Visualization --- # Create a grid with 3 rows and 4 columns: # Row 0: Raw image (displayed only once in the first column) # Row 1: Ground truth masks for each class # Row 2: Predicted masks for each class n_classes = len(category_mapping) class_names = [f"({k}) {v}" for k, v in category_mapping.items()] class_names_secondary = ( [f"({k}) {v}" for k, v in secondary_mapping.items()] if secondary_mapping else None ) fig, axs = plt.subplots(rows, n_classes, figsize=(4 * n_classes, 4 * rows)) # Row 0: Display raw image in first subplot; hide other subplots in this row. axs[0, 0].imshow(raw_img.convert("L"), cmap="viridis") axs[0, 0].set_title("Raw Image") axs[0, 0].axis("off") for j in range(1, n_classes): axs[0, j].axis("off") # Row 1: Ground truth for each class (each channel) for j in range(n_classes): axs[1, j].imshow(gt_mask[j], cmap="viridis") axs[1, j].set_title(f"GT: {class_names[j]}") axs[1, j].axis("off") # Row 2: Predictions for each class (each channel) for j in range(n_classes): axs[2, j].imshow(pred_mask[j], cmap="viridis") if secondary_mapping_index == 2 and secondary_mapping: axs[2, j].set_title(f"{label_prefix_one}Pred: {class_names_secondary[j]}") else: axs[2, j].set_title(f"{label_prefix_one}Pred: {class_names[j]}") axs[2, j].axis("off") if rows == 4 and model_two is not None: logits_two = model_two(img.unsqueeze(0)).detach().cpu() probs_two = torch.sigmoid(logits_two) pred_mask_two = (probs_two > threshold).float().squeeze(0).numpy() # Row 3: Predictions from the second model for each class (each channel) for j in range(n_classes): axs[3, j].imshow(pred_mask_two[j], cmap="viridis") if secondary_mapping_index == 3 and secondary_mapping: axs[3, j].set_title(f"{label_prefix_two}Pred: {class_names_secondary[j]}") else: axs[3, j].set_title(f"{label_prefix_two}Pred: {class_names[j]}") axs[3, j].axis("off") fig.suptitle(custom_title, fontsize=28, y=0.98) plt.tight_layout(rect=[0, 0, 1, 0.94]) if save_path is not None: plt.savefig(save_path, bbox_inches="tight") plt.show()