from typing import Tuple
import torch
from einops import rearrange, repeat
from loguru import logger
from spora_io import MultiplexImagingDataset
from spora_io.datasets._types import Tissue
import numpy as np
from spora_bench.wrapper import SporaModelWrapper
from torchvision.models import resnet50, ResNet50_Weights
from sklearn.decomposition import MiniBatchSparsePCA
[docs]
class SporaResNetWrapper(SporaModelWrapper):
def __init__(self,
model_name: str = "resnet50",
pca_dim_channel: int = 9,
tile_size: int = 224,
):
super().__init__(model_name)
self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
# self.model = torch.nn.Sequential(*list(self.model.children())[:-1])
self.model.eval()
self.model.cuda()
self.model_dim = self.model.fc.in_features
self.tile_size = tile_size
self.pca_dim_channel = pca_dim_channel
self.dim_per_channel = self.model.fc.out_features
@torch.inference_mode()
def embed_tile(self,
tissue: Tissue,
) -> torch.Tensor:
# 1. Load data
tissue_image = tissue.image.cuda() # C x H x W
# 2. Embed tile using the model
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
embeddings = self.model(tissue_image)
return embeddings
@torch.inference_mode()
def embed_tissue(self,
dataset: MultiplexImagingDataset,
tissue_id: str,
tissue_threshold: float = 0.3
) -> torch.Tensor:
# 0. Select the cut of channels across all tissues for dimensionality reasons.
channels_per_tissue = dataset.image_channel_map
valid_channels = channels_per_tissue.columns[channels_per_tissue.all(axis=0)]
# 1. Load data
tissue = dataset.get_tissue(tissue_id, kind="complete", preprocess=True, image_mode="CHW")
tissue_image = tissue.image
channel_names = tissue.channel_names
channel_mask = np.isin(channel_names, valid_channels)
tissue_image = tissue_image[channel_mask]
C, H, W = tissue_image.shape
tissue_mask = dataset.get_tissue_mask(tissue_id).mask
# 2. Get crop coordinates
# make sure to cover the entire tissue image with tiles. If the tile size does not perfectly divide the image dimensions, we need to add an additional tile that overlaps with the last tile to cover the remaining area.
x_crops = list(range(0, H - self.tile_size, self.tile_size))
y_crops = list(range(0, W - self.tile_size, self.tile_size))
if len(x_crops) == 0 or len(y_crops) == 0:
logger.warning(f"Tissue {tissue_id} is smaller than the tile size. Returning an empty embedding.")
return torch.tensor([])
if x_crops[-1] != H - self.tile_size:
x_crops.append(H - self.tile_size)
if y_crops[-1] != W - self.tile_size:
y_crops.append(W - self.tile_size)
embedding_bag = []
for i in x_crops:
for j in y_crops:
tile_mask = tissue_mask[i:i+self.tile_size, j:j+self.tile_size]
if tile_mask.mean() > tissue_threshold:
tile = tissue_image[:, i:i+self.tile_size, j:j+self.tile_size]
tile = repeat(tile, "c h w -> c rgb h w", rgb=3) # Repeat channels to get 3 channels for resnet input (rgb)
embedding = self.model(tile.cuda())
embedding = rearrange(embedding, "c d -> (c d)")
embedding_bag.append(embedding.cpu())
if len(embedding_bag) == 0:
logger.warning(f"No tiles in tissue {tissue_id} passed the tissue threshold. Returning an empty embedding.")
return torch.tensor([])
embedding_bag = torch.vstack(embedding_bag)
return embedding_bag
[docs]
def postprocess_tile_embeddings(self,
tissue_embeddings: torch.Tensor):
sparse_pca = MiniBatchSparsePCA(n_components=self.pca_dim_channel, batch_size=500, random_state=42)
reduced_embedding = sparse_pca.fit_transform(tissue_embeddings.cpu().numpy())
return torch.from_numpy(reduced_embedding).float()