Wrapper
- class spora_bench.wrapper.SporaModelWrapper(model_name)[source]
Bases:
ABC- Parameters:
model_name (str)
- compute_cell_tokens(dataset, tissue_id)[source]
Compute cell tokens for the given dataset and tissue ID. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset:
MultiplexImagingDataset:param tissue_id: The ID of the tissue to compute cell tokens for. :type tissue_id:str- Parameters:
dataset (spora_io.datasets.MultiplexImagingDataset)
tissue_id (str)
- embed_tile(tissue)[source]
Embed a tile from the input image. :param tissue: The input tissue containing the image and protein IDs. :type tissue:
MultiplexTissue- Returns:
The embedded tile tensor. Shape: (D,)
- Return type:
torch.Tensor
- Parameters:
tissue (spora_io.datasets.MultiplexTissue)
- embed_tissue(dataset, tissue_id, tissue_threshold=0.3)[source]
Embed the tissue from the input image. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset:
MultiplexImagingDataset:param tissue_id: The ID of the tissue to embed. :type tissue_id:str:param tissue_threshold: The threshold for considering a tile as containing tissue. :type tissue_threshold:float- Returns:
A sequence-shaped embedding of the tissue. Shape: (N,D)
- Return type:
Tensor- Parameters:
dataset (spora_io.datasets.MultiplexImagingDataset)
tissue_id (str)
tissue_threshold (float)
- postprocess_tile_embeddings(tissue_embedding)[source]
Postprocess the tissue embedding if necessary (e.g., for dimensionality reduction for kronos patient level tasks). :param tissue_embedding: The raw tissue embedding tensor. Shape: (N, D) :type tissue_embedding:
Tensor- Returns:
The postprocessed tissue token tensor. Shape: (N, D’)
- Return type:
torch.Tensor
- Parameters:
tissue_embedding (torch.Tensor)
- predict_marker(tissue, target_channel_name, target_uniprot_id=None)[source]
Inpaint the marker channels in the given image based on the segmentation mask. :param tissue: The input multiplexed image tensor. Shape: (C, H, W) :type tissue:
MultiplexTissue:param target_channel_name: The name of the target channel to be inpainted. :type target_channel_name:str:param target_uniprot_id: The uniprot ID of the target channel to be inpainted. :type target_uniprot_id:Optional[str]- Returns:
The inpainted image marker. Shape: (1, H, W)
- Return type:
torch.Tensor
- Parameters:
tissue (spora_io.datasets.MultiplexTissue)
target_channel_name (str)
target_uniprot_id (str | None)