Source code for spora_io.datasets.compose

from __future__ import annotations

import os
from pathlib import Path
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

import pandas as pd

from spora_io._config import get_datasets_dir
from spora_io.datasets.he import HEImagingDataset
from spora_io.datasets.ihc import SingleIHCImagingDataset
from spora_io.datasets.multiplex import MultiplexImagingDataset
from spora_io.datasets._types import ModKey, ComposedTissue
from spora_io.utils.utils import print_verbose


def _norm_modality_key(mod: ModKey) -> str:
    if isinstance(mod, str):
        return mod
    elif hasattr(mod, "name"):
        return mod.name
    else:
        raise ValueError(f"Invalid modality key: {mod}. Must be a string or an object with a 'name' attribute.")


[docs] class ComposedImagingDataset: """ Compose multiple unimodal datasets (HE, Multiplex, etc.) into a single handle. - Uniform interface to fetch tissues/tiles per modality. - Ensures consistent tile strategy across modalities by construction. - Extensible via `modality_kwargs` to pass per-modality constructor arguments. Notes on behavior: - `get_tissue_ids()` returns the union of tissue IDs across all instantiated modalities. - `get_modalities_of_tissue(tissue_id)` lists which modalities contain that tissue ID. - Marker-specific helpers (indices/metadata) are forwarded to each unimodal dataset when available. """ def __init__( self, name: str, modalities: Iterable[ModKey], tile_size: int, resolution: float | str, path: Union[str, Path] | None = None, verbose: bool = True, load_cell_metadata: bool = False, tile_strategy: Optional[str] = None, split: Optional[str] = None, *, modality_kwargs: Optional[Mapping[str, Mapping[str, Any]]] = None, ) -> None: self.name = name self.path = Path(path) if path is not None else get_datasets_dir() / name self.verbose = verbose self.tile_size = tile_size self.tile_strategy = tile_strategy self.resolution = resolution self.load_cell_metadata = load_cell_metadata self.split = split self._unimodal: Dict[str, Any] = {} self._raw_modality_keys: List[str] = [] per_mod_kwargs = {k: dict(v) for k, v in (modality_kwargs or {}).items()} modality_keys = [_norm_modality_key(mod) for mod in modalities] if "ihc" in modality_keys: ihc_dir = self.path / "ihc" if not ihc_dir.exists(): raise FileNotFoundError(f"Requested all IHC markers, but IHC directory does not exist: {ihc_dir}") all_ihc_markers = os.listdir(self.path / "ihc") modality_keys = [key for key in modality_keys if key != "ihc"] + sorted(all_ihc_markers) for key in modality_keys: if self.verbose: print_verbose(f"Initializing modality '{key}'...") nominal_key = key if not key.startswith("ihc_") else "ihc" self._raw_modality_keys.append(key) kwargs_common = dict( name=name, path=str(self.path), verbose=verbose, tile_size=self.tile_size, resolution=self.resolution, load_cell_metadata=False, tile_strategy=self.tile_strategy, split=self.split, ) kwargs_extra = dict(per_mod_kwargs.get(nominal_key, {})) if key == "he": ds = HEImagingDataset(**kwargs_common, **kwargs_extra) self._unimodal[key] = ds elif key in MultiplexImagingDataset.VALID_MODALITIES: standardization = kwargs_extra.pop("standardization", "identity") if standardization == "identity": print_verbose(f"No standardization will be applied to modality '{key}'. Ensure this is intentional.", level="WARNING") ds = MultiplexImagingDataset( **kwargs_common, modality=key, standardization=standardization, **kwargs_extra, ) self._unimodal[key] = ds elif key.startswith("ihc_"): ds = SingleIHCImagingDataset( **kwargs_common, marker_name=key, **kwargs_extra, ) self._unimodal[key] = ds else: raise ValueError( f"Unsupported modality '{key}' in ComposedImagingDataset. " "Supported: he, imc, codex, cycif, mibi, ihc, and ihc_<marker>." ) self.modalities: List[str] = list(self._unimodal.keys()) self._tissue_id_to_modalities: Dict[str, List[str]] = {} for mod_key, ds in self._unimodal.items(): for tissue_id in ds.get_tissue_ids(): self._tissue_id_to_modalities.setdefault(str(tissue_id), []).append(mod_key) if self.load_cell_metadata: print_verbose(f"Loading cell metadata") self.cell_metadata = pd.read_parquet(self.path / "metadata" / "cells.parquet").set_index("tissue_id") self.all_tissue_metadata = pd.read_parquet(self.path / "metadata" / "tissues.parquet").set_index("tissue_id") if self.split is not None: if "split" not in self.all_tissue_metadata.columns: raise ValueError(f"Split column not found in tissue metadata, but split argument {self.split} was provided.") self.all_tissue_metadata = self.all_tissue_metadata[self.all_tissue_metadata["split"] == self.split] if self.all_tissue_metadata.empty: raise ValueError(f"No tissue metadata found for split {self.split}. Please check the split argument and the contents of the tissue metadata.") self.all_tissue_metadata = self.all_tissue_metadata[ self.all_tissue_metadata["modality"].isin(self._raw_modality_keys) ] self.all_tissue_ids = self.all_tissue_metadata.index.unique().tolist() self.patient_tissue_map = self.all_tissue_metadata.groupby("patient_id").apply(lambda df: df.index.tolist(), include_groups=False).to_dict() print_verbose(f"Composed dataset initialized with modalities: {self.modalities}") print_verbose(f"Total unique tissue samples across all modalities: {len(self.all_tissue_ids)}")
[docs] def get_dataset(self, modality: ModKey) -> Any: """ Get the unimodal dataset instance for a given modality key. Args: modality (ModKey): The modality key (string or object with 'name' attribute) to retrieve the dataset for. Returns: Any: The unimodal dataset instance corresponding to the modality key. Raises: KeyError: If the modality key is not part of this composed dataset. """ key = _norm_modality_key(modality) if key not in self._unimodal: raise KeyError(f"Modality '{key}' is not part of this composed dataset.") return self._unimodal[key]
[docs] def get_available_modalities(self) -> List[str]: """ Get the list of available modalities in this composed dataset. Returns: List[str]: A list of modality keys representing the available modalities. """ return list(self.modalities)
[docs] def get_tissue_ids(self, modality: Optional[ModKey] = None) -> List[str]: """ Get the list of tissue IDs available in the dataset. If a modality is specified, return only tissue IDs for that modality. Args: modality (Optional[ModKey]): The modality key to filter tissue IDs by. If None, returns tissue IDs across all modalities. Returns: List[str]: A list of tissue IDs available in the dataset (filtered by modality if specified). """ return self.all_tissue_ids if modality is None else self.get_dataset(modality).get_tissue_ids()
[docs] def get_modalities_of_tissue(self, tissue_id: str) -> List[str]: """ Get the list of modalities available for a given tissue ID. Args: tissue_id (str): The tissue ID to query modalities for. Returns: List[str]: A list of modality keys representing the modalities available for the given tissue ID. """ return self._tissue_id_to_modalities[str(tissue_id)]
[docs] def get_unimodal_tissue(self, tissue_id: str, modality: ModKey, kind: str = "uniprot_filtered", preprocess: bool = True, image_mode="CHW"): """ Get the tissue image for a given tissue ID and modality, with options for kind of image and preprocessing. Args: tissue_id (str): The tissue ID to retrieve. modality (ModKey): The modality key to specify which unimodal dataset to query. kind (str): The kind of tissue image to retrieve. Default is "uniprot_filtered". Valid options are "complete", "qc_filtered", and "uniprot_filtered". preprocess (bool): If True, preprocess the image (normalize). Default is True. image_mode (str): The desired image mode of the returned tissue image. Valid options are "CHW" and "HWC". Default is "CHW". Returns: Tissue: The tissue image as returned by the unimodal dataset's `get_tissue` method. """ ds = self.get_dataset(modality) return ds.get_tissue(tissue_id, kind=kind, preprocess=preprocess, image_mode=image_mode)
[docs] def get_unimodal_tissue_mask(self, tissue_id: str, modality: ModKey): """ Get the quality control mask for a given tissue ID and modality. Args: tissue_id (str): The tissue ID to retrieve the mask for. modality (ModKey): The modality key to specify which unimodal dataset to query. Returns: np.ndarray: The quality control mask as returned by the unimodal dataset's `get_tissue_mask` method. """ ds = self.get_dataset(modality) return ds.get_tissue_mask(tissue_id)
[docs] def get_unimodal_tissue_size(self, tissue_id: str, modality: ModKey) -> Tuple[int, int, int]: """ Get the tissue size (C,H,W) for a given tissue ID and modality. Args: tissue_id (str): The tissue ID to retrieve the size for. modality (ModKey): The modality key to specify which unimodal dataset to query. Returns: Tuple[int, int, int]: The tissue size (C,H,W) as returned by the unimodal dataset's `_get_tissue_size` method. """ ds = self.get_dataset(modality) return ds._get_tissue_size(tissue_id)
[docs] def get_unimodal_tile( self, tissue_id: str, tile_id: int, modality: ModKey, kind: str = "uniprot_filtered", preprocess: bool = True, image_mode: str = "CHW", ): """ Get a specific tile for a given tissue ID and modality. Args: tissue_id (str): The tissue ID to retrieve the tile for. tile_id (int): The tile ID to retrieve. modality (ModKey): The modality key to specify which unimodal dataset to query. kind (str): The kind of image to retrieve. Valid options depend on modality. preprocess (bool): If True, preprocess the tile before returning. image_mode (str): The returned image layout, usually "CHW" or "HWC". Returns: Tissue: The tile image as returned by the unimodal dataset's `get_tile` method. """ ds = self.get_dataset(modality) return ds.get_tile(tissue_id, tile_id, kind=kind, preprocess=preprocess, image_mode=image_mode)
[docs] def get_composed_tissue(self, tissue_id: str, kind: str = "uniprot_filtered", preprocess: bool = True, image_mode="CHW") -> ComposedTissue: """ Get a composed tissue sample for a given tissue ID, which includes all available modalities for that tissue. Args: tissue_id (str): The tissue ID to retrieve. kind (str): The kind of tissue image to retrieve. Default is "uniprot_filtered". Valid options are "complete", "qc_filtered", and "uniprot_filtered". preprocess (bool): If True, preprocess the images (normalize). Default is True. image_mode (str): The desired image mode of the returned tissue images. Valid options are "CHW" and "HWC". Default is "CHW". Returns: ComposedTissue: A ComposedTissue instance containing the tissue ID and a dictionary of modality-specific Tissue instances. """ modalities = self.get_modalities_of_tissue(tissue_id) modality_tissues = {} for mod in modalities: modality_tissues[mod] = self.get_unimodal_tissue(tissue_id, mod, kind=kind, preprocess=preprocess, image_mode=image_mode) return ComposedTissue( tissue_id=tissue_id, modalities=modality_tissues )
[docs] def get_composed_tile(self, tissue_id: str, tile_id: int, kind: str = "uniprot_filtered", preprocess: bool = True, image_mode="CHW") -> ComposedTissue: """ Get a composed tile for a given tissue ID and tile ID, which includes all available modalities for that tissue. Args: tissue_id (str): The tissue ID to retrieve. tile_id (int): The tile ID to retrieve. kind (str): The kind of tile image to retrieve. Default is "uniprot_filtered". Valid options are "complete", "qc_filtered", and "uniprot_filtered". preprocess (bool): If True, preprocess the images (normalize). Default is True. image_mode (str): The desired image mode of the returned tile images. Valid options are "CHW" and "HWC". Default is "CHW". Returns: ComposedTissue: A ComposedTissue instance containing the tissue ID and a dictionary of modality-specific tile images. """ modalities = self.get_modalities_of_tissue(tissue_id) modality_tiles = {} for mod in modalities: modality_tiles[mod] = self.get_unimodal_tile(tissue_id, tile_id, mod, kind=kind, preprocess=preprocess, image_mode=image_mode) return ComposedTissue( tissue_id=tissue_id, modalities=modality_tiles )
[docs] def get_composed_tissue_by_patient(self, patient_id: str, kind: str = "uniprot_filtered", preprocess: bool = True, image_mode="CHW") -> Sequence[ComposedTissue]: """ Get composed tissue samples for all tissues associated with a given patient ID. Args: patient_id (str): The patient ID to retrieve tissues for. kind (str): The kind of tissue image to retrieve. Default is "uniprot_filtered". Valid options are "complete", "qc_filtered", and "uniprot_filtered". preprocess (bool): If True, preprocess the images (normalize). Default is True. image_mode (str): The desired image mode of the returned tissue images. Valid options are "CHW" and "HWC". Default is "CHW". Returns: Sequence[ComposedTissue]: A list of ComposedTissue instances for each tissue associated with the patient. """ patient_tissues = self.patient_tissue_map.get(str(patient_id), []) return [self.get_composed_tissue(tid, kind=kind, preprocess=preprocess, image_mode=image_mode) for tid in patient_tissues]
def __repr__(self) -> str: n_tiles = { modality: dataset._count_tiles() for modality, dataset in self._unimodal.items() if hasattr(dataset, "_count_tiles") } return ( f"{self.__class__.__name__}(" f"name={self.name!r}, modalities={self.modalities!r}, resolution={self.resolution!r}, " f"tile_size={self.tile_size!r}, tile_strategy={self.tile_strategy!r}, " f"split={self.split!r}, n_tissues={len(self.all_tissue_ids)}, n_tiles={n_tiles})" )