Wrapper

class spora_bench.wrapper.SporaModelWrapper(model_name)[source]

Bases: ABC

Parameters:

model_name (str)

compute_cell_tokens(dataset, tissue_id)[source]

Compute cell tokens for the given dataset and tissue ID. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset: MultiplexImagingDataset :param tissue_id: The ID of the tissue to compute cell tokens for. :type tissue_id: str

Parameters:
  • dataset (spora_io.datasets.MultiplexImagingDataset)

  • tissue_id (str)

embed_tile(tissue)[source]

Embed a tile from the input image. :param tissue: The input tissue containing the image and protein IDs. :type tissue: MultiplexTissue

Returns:

The embedded tile tensor. Shape: (D,)

Return type:

torch.Tensor

Parameters:

tissue (spora_io.datasets.MultiplexTissue)

embed_tissue(dataset, tissue_id, tissue_threshold=0.3)[source]

Embed the tissue from the input image. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset: MultiplexImagingDataset :param tissue_id: The ID of the tissue to embed. :type tissue_id: str :param tissue_threshold: The threshold for considering a tile as containing tissue. :type tissue_threshold: float

Returns:

A sequence-shaped embedding of the tissue. Shape: (N,D)

Return type:

Tensor

Parameters:
  • dataset (spora_io.datasets.MultiplexImagingDataset)

  • tissue_id (str)

  • tissue_threshold (float)

postprocess_tile_embeddings(tissue_embedding)[source]

Postprocess the tissue embedding if necessary (e.g., for dimensionality reduction for kronos patient level tasks). :param tissue_embedding: The raw tissue embedding tensor. Shape: (N, D) :type tissue_embedding: Tensor

Returns:

The postprocessed tissue token tensor. Shape: (N, D’)

Return type:

torch.Tensor

Parameters:

tissue_embedding (torch.Tensor)

predict_marker(tissue, target_channel_name, target_uniprot_id=None)[source]

Inpaint the marker channels in the given image based on the segmentation mask. :param tissue: The input multiplexed image tensor. Shape: (C, H, W) :type tissue: MultiplexTissue :param target_channel_name: The name of the target channel to be inpainted. :type target_channel_name: str :param target_uniprot_id: The uniprot ID of the target channel to be inpainted. :type target_uniprot_id: Optional[str]

Returns:

The inpainted image marker. Shape: (1, H, W)

Return type:

torch.Tensor

Parameters:
  • tissue (spora_io.datasets.MultiplexTissue)

  • target_channel_name (str)

  • target_uniprot_id (str | None)