Source code for spora_io.datasets.base

"""Base class for all imaging datasets. 

Defines the interface for loading tissue images, tissue masks, and cell masks, as well as patient-level retrieval. Also includes functionality for loading tile coordinates if tiling is used in the dataset.
Supports lazy loading of tissue masks and tiles. 
"""

from __future__ import annotations

import os
from abc import ABC, abstractmethod
from pathlib import Path
import numpy as np
from loguru import logger
import pandas as pd
from numpy.typing import NDArray
from typing import Any, Union, Tuple, Sequence, Callable, Optional
import torch
from spora_io._config import get_datasets_dir
from spora_io.datasets._types import get_modality_from_str, is_valid_modality_instance, ModKey, Tissue, \
                                            TissueMask, CellMask
from spora_io.utils.utils import print_verbose

_TILE_COORDS_CACHE = {}




[docs] class BaseImagingDataset(ABC): """ Base class for all imaging datasets. Attributes: name (str): The name of the dataset. path (Path): The root path to the dataset. If omitted, this resolves to ``SPORA_DATASETS_DIR / name``. modality (ModKey): The modality of the dataset. resolution (str): The resolution of the dataset in mpp, formatted as a string with underscores instead of decimals (e.g. "0_5mpp"). tile_size (Optional[int]): The tile size in pixels. If None, tiling functionality will be disabled. tile_strategy (Optional[str]): The tiling strategy used for the dataset. If None, defaults to "default". This is used to determine the subdirectory under tiling/<resolution>/ where tile coordinates are stored. label (Optional[str]): The name of the label column in the tissue metadata. If None, no labels will be loaded. labels_to_keep (Optional[Sequence[str]]): The list of label values to keep if label is not None. If None, all labels will be kept. label_modifying_fn (Optional[Callable]): A function to modify the labels after loading. For example, this can be used to binarize labels or group certain labels together. If None, labels will not be modified. label_type (str): The type of the label, either "classification" or "regression". This is used to determine how to encode the labels. Default is "classification". """ def __init__(self, name: str, modality: ModKey, resolution: float | str, path: os.PathLike | str | None = None, tile_size: Optional[int] = None, load_cell_metadata: bool = False, verbose: bool = True, label: Optional[str] = None, labels_to_keep: Optional[Sequence[str]] = None, label_modifying_fn: Optional[Callable] = None, label_type: str = "classification", tile_strategy: Optional[str] = None, split: Optional[str] = None, ): self.name = name self.path = Path(path) if path is not None else get_datasets_dir() / name self.verbose = verbose self.resolution = resolution self.tile_size = tile_size self.tile_strategy = tile_strategy self.split = split self._tissue_size_cache: dict[str, Tuple[int, int, int]] = {} if self.tile_strategy is not None and self.tile_size is None: raise ValueError(f"Tile strategy {self.tile_strategy} provided without tile size. Please provide a tile size to use tiling functionality.") if self.tile_size is None: print_verbose(f"No tile size is provided, tiling functionality will break. Please provide a tile size if you intend to use tiling functionality.", level="WARNING") if self.tile_strategy is None: self.tile_strategy = "default" print_verbose(f"No tile strategy provided, using default.", level="WARNING") self.label = label self.labels_to_keep = labels_to_keep self.label_modifying_fn = label_modifying_fn self.label_type = label_type if not isinstance(self.resolution, (float, str)): try: self.resolution = float(self.resolution) except Exception as e: print_verbose(f"Failed auto-conversion of resolution argument. Expected str/float, but got {type(self.resolution)}") raise e self.resolution = f"{str(self.resolution).replace('.', '_')}mpp" if isinstance(modality, str): self.modality = get_modality_from_str(modality) else: self.modality = modality assert is_valid_modality_instance(modality), f"Invalid modality instance {type(modality)} provided." # check existence of tissue masks self.tissue_masks_dir: Any | Path = self.path / "segmentations" / self.resolution / "tissue_masks" if not self.tissue_masks_dir.exists(): print_verbose(f"Tissue masks directory {self.tissue_masks_dir} does not exist. Tissue masks will not be loaded.", level="WARNING") self.tissue_masks_dir = None self.tissue_metadata = pd.read_parquet(self.path / "metadata" / "tissues.parquet").set_index("tissue_id") if self.split is not None: if "split" not in self.tissue_metadata.columns: raise ValueError(f"Split column not found in tissue metadata, but split argument {self.split} was provided.") self.tissue_metadata = self.tissue_metadata[self.tissue_metadata["split"] == self.split] if self.tissue_metadata.empty: raise ValueError(f"No tissue metadata found for split {self.split}. Please check the split argument and the contents of the tissue metadata.") if self.label is not None: self.tissue_metadata = self.tissue_metadata[self.tissue_metadata[self.label].isin(self.labels_to_keep)] if self.label_modifying_fn is None: print_verbose(f"Did not find label_modifying_fn, using identity.") self.label_modifying_fn = lambda x: x self.tissue_metadata[self.label] = self.tissue_metadata[self.label].map(self.label_modifying_fn) if self.label_type == "classification": self.unique_labels = self.tissue_metadata[self.label].unique().to_numpy() self.label_encoder = {unique_label: i for (i, unique_label) in enumerate(self.unique_labels)} if load_cell_metadata: print_verbose(f"Loading cell metadata") self.cell_metadata = pd.read_parquet(self.path / "metadata" / "cells.parquet").set_index("tissue_id") self.tissue_modality_metadata = self.tissue_metadata[self.tissue_metadata["modality"] == self.modality.name] self.patient_tissue_map = self.tissue_modality_metadata.groupby("patient_id").apply(lambda df: df.index.tolist(), include_groups=False).to_dict()
[docs] def get_tissue_ids(self, kind="modality") -> NDArray[np.str_]: """ Get the unique tissue IDs from the tissue annotations. Args: kind (str): The kind of tissue IDs to retrieve. Default is "modality", which returns tissue ids for tissues that have the specified modality. If "all", returns tissue ids for all tissues in the dataset. Returns: numpy.ndarray: An array of unique tissue IDs. """ if kind == "modality": return self.tissue_modality_metadata.index.values elif kind == "all": return self.tissue_metadata.index.values else: raise ValueError(f"Invalid kind {kind} provided. Expected 'modality' or 'all'.")
[docs] @abstractmethod def get_tissue(self, tissue_id: str, kind="complete", preprocess: bool = True, image_mode: str = "CHW") -> Tissue: """Get a tissue image by tissue ID. Subclasses define the valid ``kind`` values and preprocessing behavior. """ pass
[docs] def get_tissue_mask(self, tissue_id: str) -> TissueMask: """ Get the tissue mask for a given tissue id. If the tissue masks directory does not exist, return a full mask. Args: tissue_id (str): The tissue ID to retrieve the mask for. Returns: TissueMask: The tissue mask as a TissueMask instance. """ if self.tissue_masks_dir is None: raise ValueError("Tissue masks directory is not set.") mask_path = self.tissue_masks_dir / f"{tissue_id}.npz" if not mask_path.exists(): print_verbose(f"Tissue mask file {mask_path} does not exist. Returning full mask.", level="WARNING") tissue_size = self._get_tissue_size(tissue_id) return TissueMask(mask=np.ones((tissue_size[1], tissue_size[2]), dtype=np.bool_), tissue_id=tissue_id) return TissueMask( mask=np.load(mask_path)["mask"], tissue_id=tissue_id )
@abstractmethod def _get_tissue_size(self, tissue_id: str) -> Tuple[int, int, int]: """ Get the tissue size (C,H,W) for a given tissue id. Args: tissue_id (str): The tissue ID to retrieve the size for. Returns: Tuple[int, int, int]: The tissue size as a tuple (C, H, W). """ pass def _get_cached_tissue_size(self, tissue_id: str) -> Tuple[int, int, int]: if tissue_id not in self._tissue_size_cache: self._tissue_size_cache[tissue_id] = self._get_tissue_size(tissue_id) return self._tissue_size_cache[tissue_id] def _channel_count(self, channel_index: Any, total_channels: int) -> int: if isinstance(channel_index, slice): return len(range(*channel_index.indices(total_channels))) channel_index = np.asarray(channel_index) if channel_index.dtype == bool: return int(channel_index.sum()) if channel_index.ndim == 0: return 1 return int(channel_index.size) def _load_padded_tile_chw(self, img: Any, row: int, col: int, channel_index: Any = slice(None)) -> torch.Tensor: if self.tile_size is None: raise ValueError("tile_size must be set to load tiles.") if self.tile_strategy == "default": return torch.from_numpy(img[channel_index, row: row + self.tile_size, col: col + self.tile_size]).float() _, height, width = img.shape if 0 <= row and 0 <= col and row + self.tile_size <= height and col + self.tile_size <= width: return torch.from_numpy(img[channel_index, row: row + self.tile_size, col: col + self.tile_size]).float() channel_count = self._channel_count(channel_index, int(img.shape[0])) tile = np.zeros((channel_count, self.tile_size, self.tile_size), dtype=img.dtype) src_row0 = max(row, 0) src_col0 = max(col, 0) src_row1 = min(row + self.tile_size, height) src_col1 = min(col + self.tile_size, width) if src_row0 >= src_row1 or src_col0 >= src_col1: return torch.from_numpy(tile).float() dst_row0 = src_row0 - row dst_col0 = src_col0 - col loaded = img[channel_index, src_row0:src_row1, src_col0:src_col1] tile[:, dst_row0: dst_row0 + loaded.shape[1], dst_col0: dst_col0 + loaded.shape[2]] = loaded return torch.from_numpy(tile).float()
[docs] @abstractmethod def get_tile(self, tissue_id: str, tile_id: int, kind: str = "complete", image_mode: str = "CHW", preprocess: bool = True) -> Tissue: """ Get a specific tile based on the tissue id and tile id. Tile IDs are precomputed during tiling. If tile coordinates are not available, this method should return a random tile from the tissue image. Subclasses that implement tiling should override this method to load the tile based on the precomputed tile coordinates. Args: tissue_id (str): The tissue ID to retrieve the tile for. tile_id (int): The tile ID to retrieve. kind (str): The kind of tile image to retrieve. Default is "complete". image_mode (str): The image mode of the tile image. Valid options are "CHW" and "HWC". Default is "CHW". preprocess (bool): Whether to preprocess the tile image (e.g. normalize) before returning it. Default is True. Returns: Tissue: The tile image as a Tissue instance. """
[docs] @abstractmethod def get_tile_by_coordinates(self, tissue_id: str, row: int, col: int, kind: str = "complete", image_mode: str = "CHW", preprocess: bool = True) -> Tissue: """ Get a specific tile based on the tissue id and tile coordinates. This method is used when tile coordinates are not precomputed and get_tile should return a random tile. Args: tissue_id (str): The tissue ID to retrieve the tile for. row (int): The row coordinate of the tile. col (int): The column coordinate of the tile. kind (str): The kind of tile image to retrieve. Default is "complete". image_mode (str): The image mode of the tile image. Valid options are "CHW" and "HWC". Default is "CHW". preprocess (bool): Whether to preprocess the tile image (e.g. normalize) before returning it. Default is True. Returns: Tissue: The tile image as a Tissue instance. """
[docs] def get_cell_instance_mask(self, tissue_id: str, use_instances_from_virtues: bool = False) -> CellMask: """ Get the cell instance mask for a given tissue id. Args: tissue_id (str): The tissue ID to retrieve the cell instance mask for. use_instances_from_virtues (bool): Whether to use instances generated using the VirTues foundation model. Returns: CellMask: The cell instance mask as a CellMask instance. """ instance_folder = "instances_virtues" if use_instances_from_virtues else "instances" ci_mask_path = self.path / "segmentations" / self.resolution / "cell_masks" / instance_folder / f"{tissue_id}.npz" if not ci_mask_path.exists(): raise ValueError(f"Cell instance mask file {ci_mask_path} does not exist for tissue_id {tissue_id}.") mask = torch.from_numpy(np.load(ci_mask_path)["mask"].astype(np.int32)) return CellMask( mask=mask, tissue_id=tissue_id )
[docs] def get_cell_task_mask(self, tissue_id: str, mask_type: str) -> CellMask: """ Get the cell task mask for a given tissue id and mask type. Args: tissue_id (str): The tissue ID to retrieve the cell task mask for. mask_type (str): The type of cell task mask to retrieve. Valid options can be retrieved from `get_cell_task_mask_types` method. Returns: CellMask: The cell task mask as a CellMask instance. """ ct_mask_path = self.path / "segmentations" / self.resolution / "cell_masks" / mask_type / f"{tissue_id}.npz" if not ct_mask_path.exists(): raise ValueError(f"Cell task mask file {ct_mask_path} does not exist for tissue_id {tissue_id} and mask_type {mask_type}.") if hasattr(self, f"{mask_type}_label_encoder"): label_encoder = getattr(self, f"{mask_type}_label_encoder") mapping = getattr(self, f"{mask_type}_mapping") else: label_encoder = pd.read_parquet(self.path / "segmentations" / self.resolution / "cell_masks" / mask_type / "label_encoder.parquet") mapping = {row["id"]: row["name"] for _, row in label_encoder.iterrows()} setattr(self, f"{mask_type}_label_encoder", label_encoder) setattr(self, f"{mask_type}_mapping", mapping) # label_encoder is df with columns name and id mask = torch.from_numpy(np.load(ct_mask_path)["mask"]) return CellMask( mask=mask, tissue_id=tissue_id, mapping=mapping )
[docs] def get_cell_task_mask_types(self) -> Sequence[str]: """ Get the available cell task mask types for the dataset. Returns: Sequence[str]: A list of available cell task mask types. """ categories_dir = self.path / "segmentations" / self.resolution / "cell_masks" if not categories_dir.exists(): raise ValueError(f"Categories directory {categories_dir} does not exist.") return [d.name for d in categories_dir.iterdir() if d.is_dir() and d.name != "instances"]
[docs] def get_tissue_by_patient(self, patient_id: str, kind="complete", preprocess: bool = True, image_mode: str = "CHW") -> Sequence[Tissue]: """Get all tissue images associated with a patient ID.""" tissue_ids = self.patient_tissue_map.get(str(patient_id), []) return [self.get_tissue(tissue_id, kind=kind, preprocess=preprocess, image_mode=image_mode) for tissue_id in tissue_ids]
def _try_to_load_tile_coords(self): if self.tile_size is None: self.tile_coordinates = None if self.verbose: print_verbose(f"No tile size provided, skipping loading of tile coordinates.", level="WARNING") return tile_coords_path = self.path / "tiling" / self.resolution / self.tile_strategy / f"{self.tile_size}_tile_coordinates.parquet" if tile_coords_path in _TILE_COORDS_CACHE: self.tile_coordinates = _TILE_COORDS_CACHE[tile_coords_path] return if tile_coords_path.exists(): coords_df = pd.read_parquet(tile_coords_path) required_columns = {"tissue_id", "tile_id", "row", "col"} if not required_columns.issubset(coords_df.columns): raise ValueError( f"Tile coordinate parquet {tile_coords_path} is missing required columns " f"{sorted(required_columns)}." ) coords_df = coords_df.sort_values(["tissue_id", "tile_id"], kind="stable") self.tile_coordinates = { tissue_id: list(zip(group["row"].astype(int), group["col"].astype(int), strict=False)) for tissue_id, group in coords_df.groupby("tissue_id", sort=False) } _TILE_COORDS_CACHE[tile_coords_path] = self.tile_coordinates if self.verbose: print_verbose(f"Loaded tile coordinates from {tile_coords_path}") else: self.tile_coordinates = None if self.verbose: print_verbose( f"No tile coordinates found at {tile_coords_path}. get_tile will return random tiles.", level="WARNING", ) tile_count = self._count_tiles() if self.verbose: print_verbose(f"Dataset {self.name} has {tile_count} tiles of size {self.tile_size} at resolution {self.resolution}.", level="DEBUG" if tile_count > 0 else "WARNING") def _count_tiles(self) -> int: """ Count the number of tiles in the dataset. """ if self.tile_coordinates is not None: return sum(len(coords) for coords in self.tile_coordinates.values()) else: return 0 def __repr__(self) -> str: modality = getattr(getattr(self, "modality", None), "name", None) n_tissues = len(self.tissue_modality_metadata) if hasattr(self, "tissue_modality_metadata") else "?" n_tiles = self._count_tiles() if hasattr(self, "tile_coordinates") else 0 return ( f"{self.__class__.__name__}(" f"name={self.name!r}, modality={modality!r}, resolution={self.resolution!r}, " f"tile_size={self.tile_size!r}, tile_strategy={self.tile_strategy!r}, " f"split={self.split!r}, n_tissues={n_tissues}, n_tiles={n_tiles})" )