import os
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional
import torch
from spora_io.datasets import MultiplexImagingDataset, MultiplexTissue
[docs]
class SporaModelWrapper(ABC):
def __init__(self, model_name: str):
self.model_name = model_name
[docs]
def compute_cell_tokens(self,
dataset: MultiplexImagingDataset,
tissue_id: str,
):
"""
Compute cell tokens for the given dataset and tissue ID.
Args:
dataset (MultiplexImagingDataset): The input dataset containing the multiplexed images and segmentation masks.
tissue_id (str): The ID of the tissue to compute cell tokens for.
"""
raise NotImplementedError("Cell token computation is not implemented for this model.")
[docs]
def embed_tile(self,
tissue: MultiplexTissue
):
"""
Embed a tile from the input image.
Args:
tissue (MultiplexTissue): The input tissue containing the image and protein IDs.
Returns:
torch.Tensor: The embedded tile tensor. Shape: (D,)
"""
raise NotImplementedError("Tile embedding is not implemented for this model.")
[docs]
def embed_tissue(self,
dataset: MultiplexImagingDataset,
tissue_id: str,
tissue_threshold: float = 0.3
) -> torch.Tensor:
"""
Embed the tissue from the input image.
Args:
dataset (MultiplexImagingDataset): The input dataset containing the multiplexed images and segmentation masks.
tissue_id (str): The ID of the tissue to embed.
tissue_threshold (float): The threshold for considering a tile as containing tissue.
Returns:
torch.Tensor: A sequence-shaped embedding of the tissue. Shape: (N,D)
"""
raise NotImplementedError("Tissue embedding is not implemented for this model.")
[docs]
def postprocess_tile_embeddings(self, tissue_embedding: torch.Tensor):
"""
Postprocess the tissue embedding if necessary (e.g., for dimensionality reduction for kronos patient level tasks).
Args:
tissue_embedding (torch.Tensor): The raw tissue embedding tensor. Shape: (N, D)
Returns:
torch.Tensor: The postprocessed tissue token tensor. Shape: (N, D')
"""
return tissue_embedding
[docs]
def predict_marker(self,
tissue: MultiplexTissue,
target_channel_name: str,
target_uniprot_id: Optional[str] = None,
):
"""
Inpaint the marker channels in the given image based on the segmentation mask.
Args:
tissue (MultiplexTissue): The input multiplexed image tensor. Shape: (C, H, W)
target_channel_name (str): The name of the target channel to be inpainted.
target_uniprot_id (str): The uniprot ID of the target channel to be inpainted.
Returns:
torch.Tensor: The inpainted image marker. Shape: (1, H, W)
"""
raise NotImplementedError("Inpainting is not implemented for this model.")