SporaVirTuesWrapper

class models.virtues.virtues_model.SporaVirTuesWrapper(model_name, checkpoint_path, marker_embeddings_dir, patch_size, model_dim, feedforward_dim, encoder_pattern, num_encoder_heads, decoder_pattern, num_decoder_heads, num_decoder_hidden_layers, positional_embedding_type, dropout, group_layers, norm_after_encoder_decoder, tile_size=128)[source]

Bases: SporaModelWrapper

Parameters:
  • model_name (str)

  • checkpoint_path (str)

  • marker_embeddings_dir (str)

  • patch_size (int)

  • model_dim (int)

  • feedforward_dim (int)

  • encoder_pattern (str)

  • num_encoder_heads (int)

  • decoder_pattern (str)

  • num_decoder_heads (int)

  • num_decoder_hidden_layers (int)

  • positional_embedding_type (str)

  • dropout (float)

  • group_layers (bool)

  • norm_after_encoder_decoder (bool)

  • tile_size (int)

compute_cell_tokens(dataset, tissue_id)[source]

Compute cell tokens for the given dataset and tissue ID. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset: MultiplexImagingDataset :param tissue_id: The ID of the tissue to compute cell tokens for. :type tissue_id: str

Return type:

Tensor'>)

Parameters:
  • dataset (spora_io.MultiplexImagingDataset)

  • tissue_id (str)

embed_tile(tissue)[source]

Embed a tile from the input image. :param tissue: The input tissue containing the image and protein IDs. :type tissue: MultiplexTissue

Returns:

The embedded tile tensor. Shape: (D,)

Return type:

Tensor

Parameters:

tissue (spora_io.datasets._types.MultiplexTissue)

embed_tissue(dataset, tissue_id, tissue_threshold=0.3)[source]

Embed the tissue from the input image. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset: MultiplexImagingDataset :param tissue_id: The ID of the tissue to embed. :type tissue_id: str :param tissue_threshold: The threshold for considering a tile as containing tissue. :type tissue_threshold: float

Returns:

A sequence-shaped embedding of the tissue. Shape: (N,D)

Return type:

Tensor

Parameters:
  • dataset (spora_io.MultiplexImagingDataset)

  • tissue_id (str)

  • tissue_threshold (float)

predict_marker(tissue, target_channel_name, target_uniprot_id)[source]

Inpaint the marker channels in the given image based on the segmentation mask. :param tissue: The input multiplexed image tensor. Shape: (C, H, W) :type tissue: MultiplexTissue :param target_channel_name: The name of the target channel to be inpainted. :type target_channel_name: str :param target_uniprot_id: The uniprot ID of the target channel to be inpainted. :type target_uniprot_id: str

Returns:

The inpainted image marker. Shape: (1, H, W)

Return type:

torch.Tensor

Parameters:
  • tissue (spora_io.datasets._types.MultiplexTissue)

  • target_channel_name (str)

  • target_uniprot_id (str)

Virtues module

class models.virtues.virtues_model.SporaVirTuesWrapper(model_name, checkpoint_path, marker_embeddings_dir, patch_size, model_dim, feedforward_dim, encoder_pattern, num_encoder_heads, decoder_pattern, num_decoder_heads, num_decoder_hidden_layers, positional_embedding_type, dropout, group_layers, norm_after_encoder_decoder, tile_size=128)[source]
Parameters:
  • model_name (str)

  • checkpoint_path (str)

  • marker_embeddings_dir (str)

  • patch_size (int)

  • model_dim (int)

  • feedforward_dim (int)

  • encoder_pattern (str)

  • num_encoder_heads (int)

  • decoder_pattern (str)

  • num_decoder_heads (int)

  • num_decoder_hidden_layers (int)

  • positional_embedding_type (str)

  • dropout (float)

  • group_layers (bool)

  • norm_after_encoder_decoder (bool)

  • tile_size (int)

compute_cell_tokens(dataset, tissue_id)[source]

Compute cell tokens for the given dataset and tissue ID. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset: MultiplexImagingDataset :param tissue_id: The ID of the tissue to compute cell tokens for. :type tissue_id: str

Return type:

Tensor'>)

Parameters:
  • dataset (spora_io.MultiplexImagingDataset)

  • tissue_id (str)

embed_tile(tissue)[source]

Embed a tile from the input image. :param tissue: The input tissue containing the image and protein IDs. :type tissue: MultiplexTissue

Returns:

The embedded tile tensor. Shape: (D,)

Return type:

Tensor

Parameters:

tissue (spora_io.datasets._types.MultiplexTissue)

embed_tissue(dataset, tissue_id, tissue_threshold=0.3)[source]

Embed the tissue from the input image. :param dataset: The input dataset containing the multiplexed images and segmentation masks. :type dataset: MultiplexImagingDataset :param tissue_id: The ID of the tissue to embed. :type tissue_id: str :param tissue_threshold: The threshold for considering a tile as containing tissue. :type tissue_threshold: float

Returns:

A sequence-shaped embedding of the tissue. Shape: (N,D)

Return type:

Tensor

Parameters:
  • dataset (spora_io.MultiplexImagingDataset)

  • tissue_id (str)

  • tissue_threshold (float)

predict_marker(tissue, target_channel_name, target_uniprot_id)[source]

Inpaint the marker channels in the given image based on the segmentation mask. :param tissue: The input multiplexed image tensor. Shape: (C, H, W) :type tissue: MultiplexTissue :param target_channel_name: The name of the target channel to be inpainted. :type target_channel_name: str :param target_uniprot_id: The uniprot ID of the target channel to be inpainted. :type target_uniprot_id: str

Returns:

The inpainted image marker. Shape: (1, H, W)

Return type:

torch.Tensor

Parameters:
  • tissue (spora_io.datasets._types.MultiplexTissue)

  • target_channel_name (str)

  • target_uniprot_id (str)