Source code for models.kronos.kronos_model

from typing import Tuple

import numpy as np
import pandas as pd
import torch
from einops import rearrange
from kronos import create_model_from_pretrained
from loguru import logger
from omegaconf import OmegaConf
from sklearn.decomposition import PCA
from tqdm import tqdm
from spora_io.datasets._types import Tissue
from spora_io.datasets.multiplex import MultiplexImagingDataset
from torchvision.transforms.functional import pad, resize
from skimage.transform import rescale

from spora_bench.wrapper import SporaModelWrapper


[docs] class SporaKronosWrapper(SporaModelWrapper): def __init__(self, model_name: str, checkpoint_path: str, marker_metadata: str, uniprot_mapping: str, tile_size: int = 224, pca_dimensions: int = 256, image_mpp: float = 1, ): super().__init__(model_name) self.model, self.precision, self.feat_dim = create_model_from_pretrained( checkpoint_path=checkpoint_path, ) self.model.eval() self.model.cuda() self.marker_metadata = pd.read_csv(marker_metadata) self.uniprot_mapping = pd.read_csv(uniprot_mapping) self.marker_metadata.set_index("marker_id", inplace=True) self.marker_metadata[["marker_mean", "marker_std"]] = self.marker_metadata[["marker_mean", "marker_std"]].astype(np.float32) self.tile_size = tile_size self.pca_dimensions = pca_dimensions self.image_mpp = image_mpp def _preprocess(self, tissue: Tissue ): """ Preprocesses the tissue image and maps protein IDs to marker IDs for Kronos. Arguments: tissue (Tissue): The input tissue containing the image and protein IDs. Returns: tissue_image (torch.Tensor): The normalized tissue image. marker_ids_kronos (torch.Tensor): The corresponding marker IDs for the tissue image. """ protein_ids = tissue.uniprot_ids marker_ids_kronos = [] for protein_id in protein_ids: marker_ids = self.uniprot_mapping.query(f"uniprot_id == '{protein_id}'")["marker_id"].values if len(marker_ids) > 0: marker_ids_kronos.append(marker_ids[0]) else: marker_ids_kronos.append(np.nan) marker_ids_kronos = np.array(marker_ids_kronos) valid_indices = ~pd.isna(marker_ids_kronos) tissue_image = tissue.image[valid_indices] marker_ids_kronos = marker_ids_kronos[valid_indices].astype(int) # means are needed for background imputation for cell token computation when removing pixel values outside of the cell instance mask means, std = torch.tensor(self.marker_metadata.loc[marker_ids_kronos][["marker_mean", "marker_std"]].values.T) return tissue_image.float(), torch.tensor(marker_ids_kronos).long(), means @torch.inference_mode() def embed_tile(self, tissue: Tissue ) -> torch.Tensor: tissue_image, marker_ids_kronos = self._preprocess(tissue) _, _, patch_token_features = self.model(x=tissue_image.unsqueeze(0), marker_ids=marker_ids_kronos.unsqueeze(0)) patch_token_features = rearrange(patch_token_features, "b c h w d -> b (h w) (c d)") tile_embedding = patch_token_features.mean(axis=1)[0] return tile_embedding @torch.inference_mode() def embed_tissue(self, dataset: MultiplexImagingDataset, tissue_id: str, tissue_threshold: float = 0.3 ) -> torch.Tensor: # 0. Select the cut of channels across all tissues for dimensionality reasons. channels_per_tissue = dataset.image_channel_map valid_channels = channels_per_tissue.columns[channels_per_tissue.all(axis=0)] tissue = dataset.get_tissue(tissue_id, kind="qc_filtered", preprocess=True, image_mode="CHW") channel_names = tissue.channel_names channel_mask = np.isin(channel_names, valid_channels) tissue.image = tissue.image[channel_mask] tissue.channel_names = tissue.channel_names[channel_mask] tissue.uniprot_ids = tissue.uniprot_ids[channel_mask] tissue_image, marker_ids_kronos, _ = self._preprocess(tissue) # Rescale the tissue to match the 0.5mpp resolution tissue_mask = dataset.get_tissue_mask(tissue_id).mask tissue_mask = rescale(tissue_mask, self.image_mpp / 0.5, order=0, preserve_range=True).astype(bool) tissue_image = rescale(tissue_image.numpy(), (1, self.image_mpp / 0.5, self.image_mpp / 0.5), preserve_range=True) tissue_image = torch.from_numpy(tissue_image) C, H, W = tissue_image.shape # make sure to cover the entire tissue image with tiles. If the tile size does not perfectly divide the image dimensions, we need to add an additional tile that overlaps with the last tile to cover the remaining area. x_crops = list(range(0, H - self.tile_size, self.tile_size)) y_crops = list(range(0, W - self.tile_size, self.tile_size)) if len(x_crops) == 0 or len(y_crops) == 0: logger.warning(f"Tissue {tissue_id} is smaller than the tile size. Returning an empty embedding.") return torch.tensor([]) if x_crops[-1] != H - self.tile_size: x_crops.append(H - self.tile_size) if y_crops[-1] != W - self.tile_size: y_crops.append(W - self.tile_size) tissue_image = tissue_image.cuda() marker_ids_kronos = marker_ids_kronos.cuda() # embed all tiles and concatenate the resulting token features. Apply padding if the tile size does not perfectly divide the image dimensions. # we iterate over the image in steps of tile_size embedding_bag = [] with torch.amp.autocast("cuda", dtype=torch.float16): for i in x_crops: for j in y_crops: tile = tissue_image[:, i:i+self.tile_size, j:j+self.tile_size] tile_mask = tissue_mask[i:i+self.tile_size, j:j+self.tile_size] if tile_mask.mean() > tissue_threshold: # Only embed tiles that contain more than the specified threshold of tissue _, _, patch_token_features = self.model(x=tile.unsqueeze(0), marker_ids=marker_ids_kronos.unsqueeze(0)) patch_token_features = rearrange(patch_token_features, "b c h w d -> b (h w) (c d)") mean_token = patch_token_features.mean(axis=1) embedding_bag.append(mean_token[0].cpu()) if len(embedding_bag) == 0: # If no tiles were embedded, return a zero tensor with the appropriate shape return torch.tensor([]) return torch.vstack(embedding_bag) @torch.inference_mode() def compute_cell_tokens(self, dataset: MultiplexImagingDataset, tissue_id: str, batch_size: int = 32, cell_tile_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor]: # 0. Select the cut of channels across all tissues for dimensionality reasons. channels_per_tissue = dataset.image_channel_map valid_channels = channels_per_tissue.columns[channels_per_tissue.all(axis=0)] tissue = dataset.get_tissue(tissue_id, kind="qc_filtered", preprocess=True, image_mode="CHW") channel_names = tissue.channel_names channel_mask = np.isin(channel_names, valid_channels) tissue.image = tissue.image[channel_mask] tissue.channel_names = tissue.channel_names[channel_mask] tissue.uniprot_ids = tissue.uniprot_ids[channel_mask] # pad the tissue image and cell instance mask by half the tile size tissue.image = pad(tissue.image, (cell_tile_size//2, cell_tile_size//2)) tissue_image, marker_ids_kronos, channel_means = self._preprocess(tissue) background_values = -channel_means try: segmentation_mask = dataset.get_cell_instance_mask(tissue_id) segmentation_mask = segmentation_mask.mask except ValueError: logger.warning(f"Cell instance mask not found for tissue ID {tissue_id}. Skipping cell token computation.") return torch.tensor([]), torch.tensor([]) if min(tissue_image.shape[1:]) < self.tile_size: logger.warning(f"Tissue {tissue_id} is smaller than the tile size. Skipping cell token computation.") return torch.tensor([]), torch.tensor([]) if segmentation_mask.max() == 0: logger.warning(f"Segmentation mask for tissue {tissue_id} is empty. Skipping cell token computation.") return torch.tensor([]), torch.tensor([]) cell_instance_mask = pad(segmentation_mask, (cell_tile_size//2, cell_tile_size//2)) marker_ids_kronos = marker_ids_kronos.to("cuda", non_blocking=True).unsqueeze(0) cell_datasets = KronosCellDataset(tissue_image, cell_instance_mask, background_values, tile_size=cell_tile_size, image_mpp=self.image_mpp) cell_loader = torch.utils.data.DataLoader(cell_datasets, batch_size=batch_size, shuffle=False, num_workers=24, pin_memory=True) cell_embeddings = [] cell_ids = [] with torch.amp.autocast("cuda", dtype=torch.float16): for cell_tiles, ids in cell_loader: cell_tiles = cell_tiles.to("cuda") _, patch_marker_features, _ = self.model(x=cell_tiles, marker_ids=marker_ids_kronos.expand(cell_tiles.shape[0], -1)) cell_embedding = rearrange(patch_marker_features, "b c d -> b (c d)") cell_embeddings += list(cell_embedding.to("cpu", non_blocking=False)) cell_ids += list(ids) return torch.tensor(cell_ids), torch.vstack(cell_embeddings)
[docs] def postprocess_tile_embeddings(self, tissue_embeddings: torch.Tensor): # Apply PCA to reduce dimensionality of tile embeddings pca = PCA(n_components=self.pca_dimensions) reduced_embeddings = pca.fit_transform(np.array(tissue_embeddings)) return torch.tensor(reduced_embeddings)
[docs] class KronosCellDataset(torch.utils.data.Dataset): def __init__(self, tissue_image: torch.Tensor, cell_instance_mask: torch.Tensor, background_values: torch.Tensor, tile_size: int = 64, image_mpp: float = 1): self.tile_size = tile_size cell_instances = torch.unique(cell_instance_mask) self.cell_instances = cell_instances[cell_instances != 0] # Exclude background self.tissue_image = tissue_image self.cell_instance_mask = cell_instance_mask self.background_values = background_values self.image_mpp = image_mpp def __len__(self): return len(self.cell_instances) def __getitem__(self, idx): """ Creates a {tile_size}x{tile_size} tile centered around the cell instance with the given index. Args: idx (int): The index of the cell instance to create a tile for. Returns: torch.Tensor: The tile centered around the cell instance. Shape: (C, tile_size, tile_size) int: The cell instance ID corresponding to the tile. """ cell_id = self.cell_instances[idx] # Cell instance IDs start from 1 cell_mask = (self.cell_instance_mask == cell_id).float() if cell_mask.sum() == 0: raise ValueError(f"Cell instance with ID {cell_id} has no pixels in the cell instance mask.") # Get the bounding box of the cell instance coords = torch.nonzero(cell_mask) y_min, x_min = coords.min(dim=0)[0] y_max, x_max = coords.max(dim=0)[0] # Calculate the center of the bounding box y_center = (y_min + y_max) // 2 x_center = (x_min + x_max) // 2 # Create a tile centered around the cell instance tile_size = self.tile_size // int(self.image_mpp / 0.5) half_tile_size = tile_size // 2 y_start = max(0, y_center - half_tile_size) y_end = min(self.tissue_image.shape[1], y_center + half_tile_size) x_start = max(0, x_center - half_tile_size) x_end = min(self.tissue_image.shape[2], x_center + half_tile_size) tile = self.tissue_image[:, y_start:y_end, x_start:x_end] tile_cell_mask = cell_mask[y_start:y_end, x_start:x_end] tile[:, ~tile_cell_mask.bool()] = self.background_values[:, None] # Rescale the image to the desired tile size (effectively, this does the rescaling to 0.5mpp) tile = resize(tile, (self.tile_size, self.tile_size)) return tile, cell_id