Source code for models.virtues.virtues_model

from typing import Optional, Tuple

import torch
from einops import rearrange
from loguru import logger
from safetensors.torch import load_file
from spora_io import MultiplexImagingDataset
from spora_io.datasets._types import MultiplexTissue
from virtues.modules.multiplex_virtues import MultiplexVirtues
from virtues.utils.cell_tokens import Tuple, compute_cell_tokens
from virtues.utils.utils import (load_marker_embedding_dict,
                                 load_marker_embeddings)

from spora_bench.wrapper import SporaModelWrapper


[docs] class SporaVirTuesWrapper(SporaModelWrapper): def __init__(self, model_name: str, checkpoint_path: str, marker_embeddings_dir: str, patch_size: int, model_dim: int, feedforward_dim: int, encoder_pattern: str, num_encoder_heads: int, decoder_pattern: str, num_decoder_heads: int, num_decoder_hidden_layers: int, positional_embedding_type: str, dropout: float, group_layers: bool, norm_after_encoder_decoder: bool, tile_size: int = 128, ): super().__init__(model_name) marker_embeddings = load_marker_embeddings(marker_embeddings_dir) self.marker_embedding_dict = load_marker_embedding_dict(marker_embeddings_dir) self.patch_size = patch_size self.tile_size = tile_size self.model = MultiplexVirtues( use_default_config = False, custom_config = None, prior_bias_embeddings=marker_embeddings, prior_bias_embedding_type='esm', prior_bias_embedding_fusion_type='add', patch_size=patch_size, model_dim=model_dim, feedforward_dim=feedforward_dim, encoder_pattern=encoder_pattern, num_encoder_heads=num_encoder_heads, decoder_pattern=decoder_pattern, num_decoder_heads=num_decoder_heads, num_hidden_layers=num_decoder_hidden_layers, positional_embedding_type=positional_embedding_type, dropout=dropout, group_layers=group_layers, norm_after_encoder_decoder=norm_after_encoder_decoder, verbose=False ) checkpoint = load_file(checkpoint_path) self.model.load_state_dict(checkpoint) self.model.eval() self.model.cuda()
[docs] def compute_cell_tokens(self, dataset: MultiplexImagingDataset, tissue_id: str, ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Load data tissue = dataset.get_tissue(tissue_id, kind='uniprot_filtered', preprocess=False, image_mode="CHW") try: segmentation_mask = dataset.get_cell_instance_mask(tissue_id) segmentation_mask = segmentation_mask.mask except ValueError: logger.warning(f"Cell instance mask not found for tissue ID {tissue_id}. Skipping cell token computation.") return torch.tensor([]), torch.tensor([]) if min(tissue.image.shape[1:]) < self.tile_size: logger.warning(f"Tissue {tissue_id} is smaller than the tile size. Skipping cell token computation.") return torch.tensor([]), torch.tensor([]) if segmentation_mask.max() == 0: logger.warning(f"Segmentation mask for tissue {tissue_id} is empty. Skipping cell token computation.") return torch.tensor([]), torch.tensor([]) assert tissue.image.shape[1:] == segmentation_mask.shape, f"Image and segmentation mask shapes do not match for tissue ID {tissue_id}." # 2. Pad image and segmentation mask pre-standardization pad_size = 120 x = torch.nn.functional.pad(tissue.image, pad=(pad_size, pad_size, pad_size, pad_size), mode='constant', value=0) segmentation_mask = torch.nn.functional.pad(segmentation_mask, pad=(pad_size, pad_size, pad_size, pad_size), mode='constant', value=0) # 3. Standardize image x, refined_mask = dataset.standardizer.apply(x, tissue_id, tissue.measured_mask, tissue.image_loading_mask) image_loading_mask, channel_names, uniprot_ids = dataset._refine_channel_metadata( tissue.image_loading_mask, tissue.channel_names, tissue.uniprot_ids, refined_mask, ) # 4. Map uniprot IDs to marker embedding indices marker_indices = torch.tensor([self.marker_embedding_dict[uniprot] for uniprot in uniprot_ids], dtype=torch.long) # 5. Compute cell tokens cell_ids, cell_tokens, _, _ = compute_cell_tokens(self.model, x, marker_indices, segmentation_mask) return cell_ids, cell_tokens
[docs] @torch.inference_mode() def embed_tile(self, tissue: MultiplexTissue, ) -> torch.Tensor: # 1. Load data tissue_image = tissue.image.cuda() # 2. Map uniprot IDs to marker embedding indices uniprot_ids = tissue.uniprot_ids marker_indices = torch.tensor([self.marker_embedding_dict[uniprot] for uniprot in uniprot_ids], dtype=torch.long) # 3. Embed tile using the model with torch.amp.autocast(device_type='cuda', dtype=torch.float16): virtues_output = self.model.encoder.forward_list([tissue_image], [marker_indices]) tile_embedding = virtues_output.patch_summary_tokens[0].cpu() tile_embedding = rearrange(tile_embedding, 'h w d -> (h w) d') return tile_embedding
[docs] @torch.inference_mode() def embed_tissue(self, dataset: MultiplexImagingDataset, tissue_id: str, tissue_threshold: float = 0.3 ) -> torch.Tensor: # 1. Load data tissue = dataset.get_tissue(tissue_id, kind="uniprot_filtered", preprocess=True, image_mode="CHW") tissue_image = tissue.image.cuda() tissue_mask = dataset.get_tissue_mask(tissue_id).mask uniprot_ids = tissue.uniprot_ids marker_indices = torch.tensor([self.marker_embedding_dict[uniprot] for uniprot in uniprot_ids], dtype=torch.long) marker_indices = marker_indices.cuda() C, H, W = tissue_image.shape # 2. Get crop coordinates # make sure to cover the entire tissue image with tiles. If the tile size does not perfectly divide the image dimensions, we need to add an additional tile that overlaps with the last tile to cover the remaining area. x_crops = list(range(0, H - self.tile_size, self.tile_size)) y_crops = list(range(0, W - self.tile_size, self.tile_size)) if len(x_crops) == 0 or len(y_crops) == 0: logger.warning(f"Tissue {tissue_id} is smaller than the tile size. Returning an empty embedding.") return torch.tensor([]) if x_crops[-1] != H - self.tile_size: x_crops.append(H - self.tile_size) if y_crops[-1] != W - self.tile_size: y_crops.append(W - self.tile_size) embedding_bag = [] with torch.amp.autocast("cuda", dtype=torch.float16): for i in x_crops: for j in y_crops: tile_mask = tissue_mask[i:i+self.tile_size, j:j+self.tile_size] if tile_mask.mean() > tissue_threshold: # Only embed tiles that contain more than the specified threshold of tissue tile = tissue_image[:, i:i+self.tile_size, j:j+self.tile_size] virtues_output = self.model.encoder.forward_list([tile], [marker_indices]) patch_token_features = rearrange(virtues_output.patch_summary_tokens[0], "h w d -> (h w) d") embedding_bag.append(patch_token_features.cpu()) if len(embedding_bag) == 0: logger.warning(f"No tiles in tissue {tissue_id} passed the tissue threshold. Returning an empty embedding.") return torch.tensor([]) embedding_bag = torch.cat(embedding_bag, dim=0) return embedding_bag
[docs] @torch.inference_mode() def predict_marker(self, tissue: MultiplexTissue, target_channel_name: str, target_uniprot_id: str, ): x = tissue.image # (C, H, W) tgt = torch.zeros_like(x[0:1]) # (1, H, W) x = torch.concat([tgt, x,], dim=0) # (C+1, H, W) uniprot_ids = [target_uniprot_id,] + tissue.uniprot_ids.tolist() marker_indices = torch.tensor([self.marker_embedding_dict[uniprot] for uniprot in uniprot_ids], dtype=torch.long) C,H,W = x.shape GH, GW = H // self.patch_size, W // self.patch_size mask = torch.zeros((C, GH, GW), dtype=torch.bool) mask[0] = True x = x.cuda() marker_indices = marker_indices.cuda() mask = mask.cuda() with torch.amp.autocast(device_type='cuda', dtype=torch.float16): output = self.model.forward([x], [marker_indices], [mask]) predicted_marker = output.decoded_multiplex[0][0:1].cpu() # (1, H, W) return predicted_marker