from typing import List, Tuple, Union
import os
from loguru import logger
import random
import numpy as np
import torch
from safetensors import safe_open
from pathlib import Path
from spora_io._config import get_datasets_dir
from spora_io.datasets._types import MULTIPLEX_MODALITIES
def is_rank0() -> bool:
"""
Check if the current process runs in RANK 0
"""
return os.environ.get('RANK', '0') == '0'
def print_verbose(msg, level="INFO") -> None:
"""
Print a message if the current process is RANK 0
Args:
msg: The message to print
"""
if is_rank0():
if level == "INFO":
logger.info(msg)
elif level == "WARNING":
logger.warning(msg)
elif level == "ERROR":
logger.error(msg)
elif level == "DEBUG":
logger.debug(msg)
else:
logger.info(msg)
def set_seed(seed: int | None, omit_random: bool=False, omit_numpy: bool=False, omit_torch: bool=False, set_cudnn_deterministic: bool=True,
set_cudnn_benchmark: bool = False, use_deterministic_algorithms: bool = False) -> None:
"""
Configure the seed settings for reproducibility.
Args:
seed (int | None): The seed value to set. If None, no seed is set.
omit_random (bool): If True, do not set the seed for the random module.
omit_numpy (bool): If True, do not set the seed for numpy.
omit_torch (bool): If True, do not set the seed for torch.
set_cudnn_deterministic (bool): If True, sets torch.backends.cudnn.deterministic to True.
set_cudnn_benchmark (bool): If True, sets torch.backends.cudnn.benchmark to True.
use_deterministic_algorithms (bool): If True, enables the use of deterministic algorithms in PyTorch.
"""
if seed is None:
if is_rank0():
logger.warning("seed is None, no seed is set.")
return
if not omit_random:
random.seed(seed)
if not omit_numpy:
np.random.seed(seed)
if not omit_torch:
torch.manual_seed(seed) # type: ignore
torch.cuda.manual_seed_all(seed)
if set_cudnn_deterministic:
torch.backends.cudnn.deterministic = True
if set_cudnn_benchmark:
torch.backends.cudnn.benchmark = True
if use_deterministic_algorithms:
torch.use_deterministic_algorithms(True, warn_only=True)
def load_checkpoint_safetensors(path):
tensors = {}
with safe_open(path, framework="pt", device='cpu') as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
return tensors
[docs]
def get_modalities_of_dataset(dataset_name, base_path):
dataset_path = os.path.join(base_path, dataset_name)
if not os.path.exists(dataset_path):
raise ValueError(f"Dataset {dataset_name} does not exist in {base_path}")
possible_modalities = ['he', 'imc', 'cycif', 'mibi', 'mif', 'ihc', 'codex']
modalities = []
for modality in possible_modalities:
modality_path = os.path.join(dataset_path, modality)
if os.path.exists(modality_path):
modalities.append(modality)
return modalities
def get_all_datasets_of_modality(modality, dataset_dir: Path | str | None = None) -> List[str]:
if dataset_dir is None:
dataset_dir = get_datasets_dir()
else:
dataset_dir = Path(dataset_dir)
if modality == 'multiplex':
modalities_to_check = MULTIPLEX_MODALITIES
else:
modalities_to_check = [modality]
datasets = os.listdir(dataset_dir)
datasets_with_modality = []
for dataset in datasets:
ds_modalities = get_modalities_of_dataset(dataset, dataset_dir)
for modality in modalities_to_check:
if modality in ds_modalities:
datasets_with_modality.append(dataset)
return list(set(datasets_with_modality))