SporaKronosWrapper

class models.kronos.kronos_model.SporaKronosWrapper(model_name, checkpoint_path, marker_metadata, uniprot_mapping, tile_size=224, pca_dimensions=256, image_mpp=1)[source]

Bases: SporaModelWrapper

Parameters:
  • model_name (str)

  • checkpoint_path (str)

  • marker_metadata (str)

  • uniprot_mapping (str)

  • tile_size (int)

  • pca_dimensions (int)

  • image_mpp (float)

embed_tile(tissue)

Embed a tile from the input image. :param tissue: The input tissue containing the image and protein IDs. :type tissue: Tissue

Returns:

The embedded tile tensor. Shape: (D,)

Return type:

Tensor

Parameters:

tissue (spora_io.datasets._types.Tissue)

embed_tissue(dataset, tissue_id, tissue_threshold=0.3)

Embed the tissue from the input image. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset: MultiplexImagingDataset :param tissue_id: The ID of the tissue to embed. :type tissue_id: str :param tissue_threshold: The threshold for considering a tile as containing tissue. :type tissue_threshold: float

Returns:

A sequence-shaped embedding of the tissue. Shape: (N,D)

Return type:

Tensor

Parameters:
  • dataset (spora_io.datasets.multiplex.MultiplexImagingDataset)

  • tissue_id (str)

  • tissue_threshold (float)

compute_cell_tokens(dataset, tissue_id, batch_size=32, cell_tile_size=64)

Compute cell tokens for the given dataset and tissue ID. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset: MultiplexImagingDataset :param tissue_id: The ID of the tissue to compute cell tokens for. :type tissue_id: str

Return type:

Tuple[Tensor, Tensor]

Parameters:
  • dataset (spora_io.datasets.multiplex.MultiplexImagingDataset)

  • tissue_id (str)

  • batch_size (int)

  • cell_tile_size (int)

postprocess_tile_embeddings(tissue_embeddings)[source]

Postprocess the tissue embedding if necessary (e.g., for dimensionality reduction for kronos patient level tasks). :param tissue_embedding: The raw tissue embedding tensor. Shape: (N, D) :type tissue_embedding: torch.Tensor

Returns:

The postprocessed tissue token tensor. Shape: (N, D’)

Return type:

torch.Tensor

Parameters:

tissue_embeddings (torch.Tensor)

Kronos module

class models.kronos.kronos_model.SporaKronosWrapper(model_name, checkpoint_path, marker_metadata, uniprot_mapping, tile_size=224, pca_dimensions=256, image_mpp=1)[source]
Parameters:
  • model_name (str)

  • checkpoint_path (str)

  • marker_metadata (str)

  • uniprot_mapping (str)

  • tile_size (int)

  • pca_dimensions (int)

  • image_mpp (float)

embed_tile(tissue)

Embed a tile from the input image. :param tissue: The input tissue containing the image and protein IDs. :type tissue: Tissue

Returns:

The embedded tile tensor. Shape: (D,)

Return type:

Tensor

Parameters:

tissue (spora_io.datasets._types.Tissue)

embed_tissue(dataset, tissue_id, tissue_threshold=0.3)

Embed the tissue from the input image. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset: MultiplexImagingDataset :param tissue_id: The ID of the tissue to embed. :type tissue_id: str :param tissue_threshold: The threshold for considering a tile as containing tissue. :type tissue_threshold: float

Returns:

A sequence-shaped embedding of the tissue. Shape: (N,D)

Return type:

Tensor

Parameters:
  • dataset (spora_io.datasets.multiplex.MultiplexImagingDataset)

  • tissue_id (str)

  • tissue_threshold (float)

compute_cell_tokens(dataset, tissue_id, batch_size=32, cell_tile_size=64)

Compute cell tokens for the given dataset and tissue ID. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset: MultiplexImagingDataset :param tissue_id: The ID of the tissue to compute cell tokens for. :type tissue_id: str

Return type:

Tuple[Tensor, Tensor]

Parameters:
  • dataset (spora_io.datasets.multiplex.MultiplexImagingDataset)

  • tissue_id (str)

  • batch_size (int)

  • cell_tile_size (int)

postprocess_tile_embeddings(tissue_embeddings)[source]

Postprocess the tissue embedding if necessary (e.g., for dimensionality reduction for kronos patient level tasks). :param tissue_embedding: The raw tissue embedding tensor. Shape: (N, D) :type tissue_embedding: torch.Tensor

Returns:

The postprocessed tissue token tensor. Shape: (N, D’)

Return type:

torch.Tensor

Parameters:

tissue_embeddings (torch.Tensor)

class models.kronos.kronos_model.KronosCellDataset(tissue_image, cell_instance_mask, background_values, tile_size=64, image_mpp=1)[source]
Parameters:
  • tissue_image (torch.Tensor)

  • cell_instance_mask (torch.Tensor)

  • background_values (torch.Tensor)

  • tile_size (int)

  • image_mpp (float)