Source code for spora_io.utils.helpers.tile

import sys

import numpy as np
from dataclasses import dataclass
from typing import Optional
from tqdm import tqdm


[docs] @dataclass class Tile: y: int x: int h: int w: int valid_ratio: float gain: int # newly covered valid pixels when this tile was selected
def _integral_image(arr: np.ndarray) -> np.ndarray: """Integral image with one zero-padded row/col at the top-left.""" arr = arr.astype(np.int64, copy=False) return np.pad(arr, ((1, 0), (1, 0)), mode="constant").cumsum(0).cumsum(1) def _rect_sums_vec(ii: np.ndarray, ys: np.ndarray, xs: np.ndarray, h: int, w: int) -> np.ndarray: """Vectorised rect-sum for arrays of (y, x) positions. Returns int64 array of shape (N,).""" y2 = ys + h x2 = xs + w return ii[y2, x2] - ii[ys, x2] - ii[y2, xs] + ii[ys, xs] def _candidate_starts(n: int, tile: int, stride: int) -> np.ndarray: """Candidate start positions along one dimension.""" if n <= tile: return np.array([0], dtype=np.int32) starts = np.arange(0, n - tile + 1, stride, dtype=np.int32) if starts[-1] != n - tile: starts = np.append(starts, n - tile) return starts def _pad_to_tile(mask: np.ndarray, tile: int) -> np.ndarray: h, w = mask.shape H, W = max(h, tile), max(w, tile) if H == h and W == w: return mask out = np.zeros((H, W), dtype=mask.dtype) out[:h, :w] = mask return out
[docs] def get_grid_tile( mask: np.ndarray, tile_size: int, stride: int = None, tolerance: float = 0.85, ): """Return fixed-grid tiles, padding image edges with background.""" if stride is None: stride = tile_size if stride <= 0: raise ValueError("stride must be positive.") if not (0.0 <= tolerance < 1.0): raise ValueError("tolerance must be in [0, 1).") original_mask = (np.asarray(mask) > 0).astype(np.uint8) h, w = original_mask.shape H = tile_size if h <= tile_size else int(np.ceil((h - tile_size) / stride) * stride + tile_size) W = tile_size if w <= tile_size else int(np.ceil((w - tile_size) / stride) * stride + tile_size) padded_mask = np.zeros((H, W), dtype=np.uint8) padded_mask[:h, :w] = original_mask ys = np.arange(0, H - tile_size + 1, stride, dtype=np.int32) xs = np.arange(0, W - tile_size + 1, stride, dtype=np.int32) ys_grid, xs_grid = np.meshgrid(ys, xs, indexing="ij") all_ys = ys_grid.ravel().astype(np.int32) all_xs = xs_grid.ravel().astype(np.int32) total_valid = int(original_mask.sum()) tile_area = tile_size * tile_size covered_mask = np.zeros_like(padded_mask, dtype=np.uint8) if total_valid == 0: return [], { "num_tiles": 0, "candidate_count": 0, "grid_candidate_count": int(len(all_ys)), "covered_valid_pixels": 0, "total_valid_pixels": 0, "coverage_ratio": 1.0, "total_tile_area": 0, "overlap_pixels": 0, "overlap_ratio": 0.0, "padded_height": int(H), "padded_width": int(W), "stop_reason": "empty_mask", }, covered_mask ii = _integral_image(padded_mask) valids = _rect_sums_vec(ii, all_ys, all_xs, tile_size, tile_size) ratios = valids / float(tile_area) keep = (valids > 0) & (ratios >= (1.0 - tolerance)) selected_tiles = [ Tile( y=int(y), x=int(x), h=int(tile_size), w=int(tile_size), valid_ratio=float(ratio), gain=int(valid), ) for y, x, ratio, valid in zip(all_ys[keep], all_xs[keep], ratios[keep], valids[keep]) ] for tile in selected_tiles: patch = padded_mask[tile.y: tile.y + tile_size, tile.x: tile.x + tile_size] covered_mask[tile.y: tile.y + tile_size, tile.x: tile.x + tile_size] |= patch covered_valid = int((covered_mask[:h, :w] & original_mask).sum()) total_tile_area = len(selected_tiles) * tile_area overlap_pixels = max(0, total_tile_area - covered_valid) stats = { "num_tiles": len(selected_tiles), "candidate_count": int(keep.sum()), "grid_candidate_count": int(len(all_ys)), "accepted_candidates": len(selected_tiles), "covered_valid_pixels": covered_valid, "total_valid_pixels": total_valid, "coverage_ratio": covered_valid / float(total_valid), "total_tile_area": total_tile_area, "overlap_pixels": overlap_pixels, "overlap_ratio": overlap_pixels / total_tile_area if total_tile_area > 0 else 0.0, "padded_height": int(H), "padded_width": int(W), "stop_reason": "fixed_grid_complete", } return selected_tiles, stats, covered_mask
[docs] def best_mask_tiling_try_to_stop( mask: np.ndarray, tile_size: int, stride: int = None, tolerance: float = 0.2, coverage_goal: float = 0.99, min_gain_ratio: float = 0.05, max_tiles: int = None, allow_overlap: bool = True, progress: bool = False, progress_desc: str = "Tiling", ): """ Find a good tiling of the unmasked region with adaptive stopping. Stopping is controlled by TWO criteria that must *both* be true to stop: 1. covered_valid / total_valid >= coverage_goal 2. best_gain / tile_area < min_gain_ratio This makes the two parameters complementary: - ``coverage_goal=0.98, min_gain_ratio=0.05`` Runs past 0.98 as long as tiles still contribute ≥5 % new pixels, potentially reaching near-full coverage for free. - ``coverage_goal=1.0, min_gain_ratio=0.05`` Aims for full coverage but bails early once tiles become mostly redundant (< 5 % new pixels), avoiding useless overlap. Set ``min_gain_ratio=0.0`` to recover the original hard-cutoff behaviour (stops exactly at coverage_goal). Parameters ---------- mask : np.ndarray Binary mask of shape (H, W), with 1 = valid/unmasked, 0 = masked. tile_size : int Tile size C, so each tile is C x C. stride : int Sliding stride. Defaults to tile_size (non-overlapping grid). tolerance : float Maximum fraction of invalid pixels allowed inside a tile (0 = strict). coverage_goal : float Soft lower bound on coverage — the loop will not stop *below* this unless gains have already hit zero. min_gain_ratio : float Soft upper bound on marginal efficiency — once the best remaining tile covers less than this fraction of its area in new pixels, AND coverage_goal has been reached, the loop stops. Range [0, 1). Default 0.05. max_tiles : int or None Hard cap on number of selected tiles. allow_overlap : bool If False, selected tiles cannot overlap each other. progress : bool Show a tqdm progress bar on stderr. progress_desc : str Label prefix on the progress bar. Returns ------- tiles : list[Tile] stats : dict covered_mask : np.ndarray """ if stride is None: stride = tile_size if not (0.0 <= tolerance < 1.0): raise ValueError("tolerance must be in [0, 1).") if not (0.0 < coverage_goal <= 1.0): raise ValueError("coverage_goal must be in (0, 1].") if not (0.0 <= min_gain_ratio < 1.0): raise ValueError("min_gain_ratio must be in [0, 1).") mask = (np.asarray(mask) > 0).astype(np.uint8) mask = _pad_to_tile(mask, tile_size) H, W = mask.shape total_valid = int(mask.sum()) tile_area = tile_size * tile_size # Absolute pixel threshold derived from min_gain_ratio. min_gain_px = int(np.ceil(min_gain_ratio * tile_area)) if total_valid == 0: return [], { "num_tiles": 0, "candidate_count": 0, "covered_valid_pixels": 0, "total_valid_pixels": 0, "coverage_ratio": 1.0, "stop_reason": "empty_mask", }, np.zeros_like(mask, dtype=np.uint8) min_valid_ratio = 1.0 - tolerance use_progress = progress # ------------------------------------------------------------------ # # Phase 1 – vectorised candidate filtering # # ------------------------------------------------------------------ # ys_starts = _candidate_starts(H, tile_size, stride) xs_starts = _candidate_starts(W, tile_size, stride) ys_grid, xs_grid = np.meshgrid(ys_starts, xs_starts, indexing="ij") all_ys = ys_grid.ravel().astype(np.int32) all_xs = xs_grid.ravel().astype(np.int32) ii_full = _integral_image(mask) all_valids = _rect_sums_vec(ii_full, all_ys, all_xs, tile_size, tile_size) all_ratios = all_valids / float(tile_area) keep = (all_valids > 0) & (all_ratios >= min_valid_ratio) cand_y = all_ys[keep].copy() cand_x = all_xs[keep].copy() cand_vr = all_ratios[keep].astype(np.float32) total_candidates = int(keep.sum()) if use_progress: tqdm.write( f"{progress_desc} – Phase 1 done: " f"{total_candidates:,} / {len(all_ys):,} candidates pass tolerance filter", file=sys.stderr, ) # ------------------------------------------------------------------ # # Phase 2 – greedy selection with incremental gain updates # # ------------------------------------------------------------------ # uncovered = mask.copy() selected_tiles = [] covered_valid = 0 stop_reason = "candidates_exhausted" # Compute all initial gains once, upfront. ii_unc = _integral_image(uncovered) gains = _rect_sums_vec(ii_unc, cand_y, cand_x, tile_size, tile_size).astype(np.int64) coverage_bar: Optional["tqdm"] = None if use_progress: coverage_bar = tqdm( total=total_valid, desc=f"{progress_desc}", unit="px", dynamic_ncols=True, file=sys.stderr, miniters=1, mininterval=0.1, ) try: while len(cand_y): best_idx = int(np.argmax(gains)) best_gain = int(gains[best_idx]) if best_gain <= 0: stop_reason = "no_gain" break # ---- adaptive dual stopping criterion ---- # coverage_goal=1.0 means "aim for full coverage" — since exactly # 1.0 is nearly unreachable, treat it as no floor and let # min_gain_ratio alone govern stopping. Any other value acts as a # minimum floor: we won't stop until coverage has reached it. coverage_reached = True if coverage_goal >= 1.0 else (covered_valid / total_valid) >= coverage_goal gain_too_low = best_gain < min_gain_px if coverage_reached and gain_too_low: stop_reason = "adaptive_stop" break sy, sx = int(cand_y[best_idx]), int(cand_x[best_idx]) selected_tiles.append( Tile( y=sy, x=sx, h=tile_size, w=tile_size, valid_ratio=float(cand_vr[best_idx]), gain=best_gain, ) ) # Identify spatially overlapping candidates. remove = np.zeros(len(cand_y), dtype=bool) remove[best_idx] = True affected = ( ~remove & (cand_y < sy + tile_size) & (cand_y + tile_size > sy) & (cand_x < sx + tile_size) & (cand_x + tile_size > sx) ) if not allow_overlap: remove |= affected elif affected.any(): # Local integral image over the selected patch — O(tile²). local_patch = uncovered[sy: sy + tile_size, sx: sx + tile_size] local_ii = _integral_image(local_patch) ay = cand_y[affected] ax = cand_x[affected] iy1 = np.maximum(ay, sy) - sy iy2 = np.minimum(ay + tile_size, sy + tile_size) - sy ix1 = np.maximum(ax, sx) - sx ix2 = np.minimum(ax + tile_size, sx + tile_size) - sx decrements = ( local_ii[iy2, ix2] - local_ii[iy1, ix2] - local_ii[iy2, ix1] + local_ii[iy1, ix1] ) gains[affected] -= decrements # Zero out uncovered AFTER computing decrements. uncovered[sy: sy + tile_size, sx: sx + tile_size] = 0 keep_mask = ~remove cand_y = cand_y[keep_mask] cand_x = cand_x[keep_mask] cand_vr = cand_vr[keep_mask] gains = gains[keep_mask] # Incremental coverage (no uncovered.sum() scan needed). covered_valid += best_gain if coverage_bar is not None: bar_delta = covered_valid - coverage_bar.n if bar_delta > 0: coverage_bar.update(bar_delta) coverage_bar.set_postfix( tiles=len(selected_tiles), cov=f"{covered_valid / total_valid:.1%}", gain=f"{best_gain / tile_area:.1%}", cands=len(cand_y), refresh=True, ) if max_tiles is not None and len(selected_tiles) >= max_tiles: stop_reason = "max_tiles" break finally: if coverage_bar is not None: coverage_bar.close() covered_mask = (mask > uncovered).astype(np.uint8) covered_valid_final = int(covered_mask.sum()) total_tile_area = len(selected_tiles) * tile_area overlap_pixels = max(0, total_tile_area - covered_valid_final) overlap_ratio = overlap_pixels / total_tile_area if total_tile_area > 0 else 0.0 stats = { "num_tiles": len(selected_tiles), "candidate_count": int(keep.sum()), "accepted_candidates": len(selected_tiles), "covered_valid_pixels": covered_valid_final, "total_valid_pixels": total_valid, "coverage_ratio": covered_valid_final / float(total_valid), # Intersection: fraction of total tile area that overlaps with already-covered pixels. # overlap_ratio=0.0 means every tile was 100% new; 0.5 means half the placed # tile area was redundant overlap. "total_tile_area": total_tile_area, "overlap_pixels": overlap_pixels, "overlap_ratio": overlap_ratio, "stop_reason": stop_reason, } return selected_tiles, stats, covered_mask