Source code for scparadise.scadam

import warnings
warnings.filterwarnings("ignore")

import os 
import json 
import pandas as pd
import numpy as np
import anndata
import optuna
import copy
import fsspec
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
from typing import Dict, List, Optional, Tuple, Literal, Union
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from skimage.filters import threshold_otsu
from sklearn.mixture import GaussianMixture
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score
from sklearn.model_selection import StratifiedKFold
import optuna
from optuna.pruners import MedianPruner
# load scParadise dust module
import scparadise.dust as dust 

# Dataset Class
class scRNAseqDataset(Dataset):
    def __init__(self, adata, celltype_keys, layer):
        """
        Function for converting AnnData objects with hierarchical annotations to a scAdam model compatible format.

        Parameters
        ----------
        adata: AnnData
            Dataset with cell type annotations in adata.obs 
        celltype_keys: list
            List of cell type annotations in adata.obs.
            Example: ['lineage', 'cell type', 'cell state']
        layer: str (default: None)
            If specified, use new_adata.layers[layer] for expression values instead of new_adata.X.

        Internal abbreviations:
            gn: number of genes in adata.var_names
            labels: unique cell types in annotation level
            celltype_encoders: encoder for each annotation level
            var_names: feature (gene) names
            obs_names: barcode (cell) names
        """
        self.adata = adata
        self.celltype_keys = celltype_keys
        self.layer = layer
        self.obs_names = adata.obs_names.tolist()
        self.var_names = adata.var_names.tolist()
        self.gn = len(adata.var_names.tolist())
        self.labels = {}
        self.celltype_encoders = {}

        # Check cell type annotations in adata.obs
        for key in celltype_keys:
            if key not in adata.obs.columns:
                raise ValueError(f"'{key}' not found in adata.obs")
        # Get gene expression data from adata from layer or X
        if layer is not None and layer in adata.layers:
            X = adata.layers[layer]
        else:
            X = adata.X
        X = X.toarray() if hasattr(X, 'toarray') else X
        if X.min(0).any():
            X = np.maximum(X, 0)
            X_norm = (X - X.min(0)) / (np.ptp(X, axis=0) + 1e-10)
        else:
            X_norm = (X - X.min(0)) / (np.ptp(X, axis=0) + 1e-10)
        self.X = torch.FloatTensor(X_norm)
        assert not torch.isnan(self.X).any(), "NaN in adata data!"
        assert not torch.isinf(self.X).any(), "Inf in adata data!"
        # Hierarchical labels
        for level_idx, key in enumerate(celltype_keys):
            labels = adata.obs[key].astype(str).values
            encoder = LabelEncoder()
            encoded_celltypes = encoder.fit_transform(labels)
            self.celltype_encoders[key] = {
                'celltype_encoder': encoder,
                'n_classes': len(encoder.classes_),
                'level': level_idx
            }
            self.labels[key] = torch.LongTensor(encoded_celltypes)
        
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        x = self.X[idx]
        y = {key: self.labels[key][idx] for key in self.celltype_keys}
        return x, y


# Hierarchical Classification scAdam model
class scAdamClassifier(nn.Module):
    def __init__(self, input_dim, ncl, hd = 256, dropout = 0.2):
        """
        Hierarchical scAdam classifier.
        
        Parameters
        ----------
        input_dim: int 
            Number nodes in first layer.
        ncl: int
            Number of cell types in single annotation level.
        hd: int (default: 256)
            Number of nodes in each layer. hd = hidden dim
        dropout: float (default: 0.2)
            Portion of neurons that temporarily ignored during training (prevents overfitting).
        """
        super().__init__()
        self.num_levels = len(ncl)
        self.ncl = ncl
        
        # Classifiers for each level
        self.classifiers = nn.ModuleList()
        for level_idx, n_classes in enumerate(ncl):
            level_input_dim = input_dim
            if level_idx > 0:
                level_input_dim += sum(ncl[:level_idx])
            classifier = nn.Sequential(
                nn.Linear(level_input_dim, hd),
                nn.LayerNorm(hd),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hd, hd // 2),
                nn.LayerNorm(hd // 2),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hd // 2, n_classes)
            )
            self.classifiers.append(classifier)
    
    def forward(self, x):
        outputs = {}
        prev_probs = []
        for level_idx, classifier in enumerate(self.classifiers):
            if level_idx > 0:
                level_input = torch.cat([x] + prev_probs, dim=1)
            else:
                level_input = x
            logits = classifier(level_input)
            probs = F.softmax(logits, dim=-1)
            outputs[f'level_{level_idx}'] = {
                'logits': logits,
                'probs': probs
            }
            prev_probs.append(probs)
        
        return outputs


# scAdam transformer model
class scAdamTransformer(nn.Module):    
    def __init__(
        self,
        gn,
        ncl,
        ed=256,
        nc=16,
        nb=5,
        nh=8,
        ff_hd=512,
        classifier_hd=256,
        dropout=0.3
    ):
        """
        Hierarchical transformer scAdam model for cell type annotation.
        Includes Transformer model and scAdam classifier.
        
        Parameters
        ----------
        gn: int
            Number of genes in adata.var_names.
        ncl: int 
            Number of cell types in single annotation level.
        ed: int (default: 256)
            Embedding dimensionality.
        nb: int (default 5)
            Number of blocks in scAdam model.
        nc: int (default: 16)
            Number of chunks for genes from adata.
        nh: int (default: 5)
            Number of heads in scAdam model attention mechanism.
        ff_hd: int (default: 512)
            Number of nodes in each scAdam model layer in feed forward network.
        classifier_hd: int (default: 256)
            Number of nodes in each scAdam classifier.
        dropout: float (default: 0.3)
            Portion of neurons that temporarily ignored during training (prevents overfitting).
            
        """
        super().__init__()
        self.gn = gn
        self.ncl = ncl
        self.num_levels = len(ncl)
        self.ed = ed
        self.nc = nc
        self.nb = nb
        self.nh = nh
        self.ff_hd = ff_hd
        self.classifier_hd = classifier_hd
        self.dropout = dropout
        
        # Gene embedding with chunking
        self.gene_embedding = dust.Embedding(gn, ed, nc)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            dust.RibBlock(ed, nh, ff_hd, dropout)
            for block in range(nb)
        ])
        
        # Hierarchical classifier
        self.classifier = scAdamClassifier(ed, ncl, classifier_hd, dropout)
        
        # Layer norm
        self.norm = nn.LayerNorm(ed)
    
    def forward(self, x, return_attention=False):
        # Gene embedding
        x = self.gene_embedding(x)
        
        # Transformer blocks
        attention_weights = []
        for block in self.blocks:
            x, attn = block(x)
            if return_attention:
                attention_weights.append(attn)
        
        # Global average pooling over chunks
        x = self.norm(x)
        x = x.mean(dim=1)
        
        # Hierarchical classification
        outputs = self.classifier(x)
        
        if return_attention:
            return outputs, attention_weights
        return outputs
        

# Adaptive Hierarchical Focal Loss
class HierarchicalLoss(nn.Module):
    def __init__(
        self,
        num_levels,
        celltype_keys,
        alpha=0.25,
        gamma=2.0,
        adaptive=True
    ):
        """
        Adaptive hierarchical loss for scAdam model training.
        
        """
        super().__init__()
        self.num_levels = num_levels
        self.celltype_keys = celltype_keys
        self.alpha = alpha
        self.gamma = gamma
        self.adaptive = adaptive

        # Adaptive weights for each level
        self.register_buffer('level_weights', torch.ones(num_levels))
        self.register_buffer('level_losses', torch.zeros(num_levels))
        self.update_counter = 0

    def focal_loss(self, logits, targets, weight=None):
        ce_loss = F.cross_entropy(logits, targets, reduction='none', weight=weight)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        return focal_loss.mean()

    def forward(self, predictions, targets, class_weights=None):
        total_loss = 0.0
        level_losses = {}

        # Classification losses
        for level_idx in range(self.num_levels):
            level_key = f'level_{level_idx}'
            original_key = self.celltype_keys[level_idx]

            logits = predictions[level_key]['logits']
            target = targets[original_key]

            # Class weights
            weight = class_weights.get(level_key) if class_weights else None

            # Focal loss for level
            loss = self.focal_loss(logits, target, weight)
            level_losses[level_key] = loss

            # Weighted sum
            if self.adaptive:
                total_loss += self.level_weights[level_idx] * loss
            else:
                total_loss += loss

        # Update adaptive weights
        if self.adaptive and self.training:
            self.update_adaptive_weights(level_losses)

        return total_loss, level_losses

    def update_adaptive_weights(self, level_losses):
        """Update adaptive weights"""
        self.update_counter += 1

        # Exponential moving average of losses
        for level_idx in range(self.num_levels):
            level_key = f'level_{level_idx}'
            self.level_losses[level_idx] = 0.9 * self.level_losses[level_idx] + \
                                           0.1 * level_losses[level_key].detach()

        # Update weights every 10 epochs
        if self.update_counter % 10 == 0:
            normalized_losses = self.level_losses / (self.level_losses.sum() + 1e-8)
            self.level_weights = 1.0 / (normalized_losses + 1e-8)


# Unsupervised pretraining, analogous to masked language modeling in NLP
class Pretraining(nn.Module):
    def __init__(self, encoder, gn, ed, prob = 0.15):
        """
        Masked gene expression modeling for unsupervised pretraining.
        Analogous to masked language modeling in NLP.

        Parameters
        ----------
        encoder: model
            scAdam model.
        gn: int
            Number of genes in adata.var_names.
        ed: int (default: 256)
            embedding dimensionality.
        prob: float (default: 0.15)
            Probability of gene masking.
            
        """
        super().__init__()
        self.encoder = encoder
        self.gn = gn
        self.ed = ed
        self.prob = prob
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(ed, ed * 2),
            nn.GELU(),
            nn.LayerNorm(ed * 2),
            nn.Dropout(0.1),
            nn.Linear(ed * 2, ed),
            nn.GELU(),
            nn.LayerNorm(ed),
            nn.Dropout(0.1)
        )
        
        # Projection head
        self.gene_projection = nn.Linear(ed, gn)
    
    def create_mask(self, x):
        B, N = x.shape
        mask = torch.rand(B, N, device=x.device) < self.prob
        x_masked = x.clone()
        x_masked[mask] = 0.0
        
        return x_masked, mask
    
    def forward(self, x):
        # Create mask
        x_masked, mask = self.create_mask(x)
        
        # Encode data
        encoded = self.encoder.gene_embedding(x_masked)
        for block in self.encoder.blocks:
            encoded, _ = block(encoded)
        encoded = self.encoder.norm(encoded)
        
        # Global pooling
        encoded = encoded.mean(dim=1)
        
        # Decode
        decoded = self.decoder(encoded)
        
        # Reconstruct gene expression
        reconstructed = self.gene_projection(decoded)
        
        return reconstructed, mask, x


# Function for unsupervised pretraining
def pretrain_unsupervised(
    adata,
    model,
    layer = None,
    batch_size = 128,
    epochs = 50,
    lr = 1e-4,
    
    prob = 0.15,
    device = 'auto',
    random_state = 42,
    verbose = True
):
    """
    Unsupervised pretraining usnig masked gene expression modeling.
    
    Parameters
    ----------
    adata: AnnData object
        Could be AnnData without cell type annotations.
    model: scAdam model
    layer: str (default: None)
        If specified, use adata.layers[layer] for expression values instead of adata.X.
    batch_size: int (default: 128)
        Number of samples per batch during training.
    epochs: int (default: 500)
        Maximum number of training epochs.
    lr: float (default: 1e-4)
        Controls how much a model's parameters are updated during training.
    prob: float (default: 0.15)
        Probability of gene masking.
    device: str (default: 'auto')
        Type of device to use for model training ('cpu' or 'cuda'). Set to 'auto' for automatic selection.
    random_state: int (default: 1)
        Seed for random number generators to ensure reproducibility.
    verbose: bool (default: True)
        Display progress.
        
    Returns
    -------
    Pretrained scAdam model.
    
    """
    
    # Fix random state for reproducibility
    np.random.seed(random_state)
    torch.manual_seed(random_state)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(random_state)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device

    # Move model to device
    model = model.to(device)
    
    # Create pretraining wrapper
    pretraining_model = Pretraining(
        encoder=model,
        gn=model.gn,
        ed=model.gene_embedding.norm.normalized_shape[0],
        prob=prob
    ).to(device)
    
    # Get gene expression data from adata from layer or X
    if layer is not None and layer in adata.layers:
        X = adata.layers[layer]
    else:
        X = adata.X
    X = X.toarray() if hasattr(X, 'toarray') else X
    if X.min(0).any():
        X = np.maximum(X, 0)
        X_norm = (X - X.min(0)) / (np.ptp(X, axis=0) + 1e-10)
    else:
        X_norm = (X - X.min(0)) / (np.ptp(X, axis=0) + 1e-10)
    X_tensor = torch.FloatTensor(X_norm)

    # Create DataLoader
    dataset = torch.utils.data.TensorDataset(X_tensor)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Optimizer
    optimizer = torch.optim.AdamW(
        pretraining_model.parameters(),
        lr=lr,
        weight_decay=1e-5
    )
    
    # MSE loss function
    criterion = nn.MSELoss()
    
    # Training loop
    pretraining_model.train()
    for epoch in tqdm(range(epochs), desc='Unsupervised pretraining', colour='blue', disable = not verbose):
        total_loss = 0.0
        
        for (batch_x, ) in loader:
            batch_x = batch_x.to(device)
            
            # Forward
            reconstructed, mask, original = pretraining_model(batch_x)
            
            # Compute loss only using masked genes
            loss = criterion(reconstructed[mask], original[mask])
            
            # Backward
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(pretraining_model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(loader)
    
    return model


# Function to get weighted metric across all annotation levels
def weighted_metric(
    history,
    celltype_keys,
    metric,
    strategy = 'linear_offset'
):
    """
    Calculates a weighted metric for the quality of model training across all levels of the hierarchy.
    
    Parameters
    ----------
    history: dict
        Learning history with metrics by level.
    celltype_keys: list
        List of cell type annotations in adata.obs.
        Example: ['lineage', 'cell type', 'cell state']
    metric: str (default: 'balanced_accuracy')
        Classification metric: 'accuracy', 'balanced_accuracy', 'f1_score'
    strategy: str (default: 'linear_offset')
        Weighting strategy for different cell type annotation levels.
        The following weighting strategies are available: linear, exponential, linear_offset, equal.
        linear: linear increase in weight from level to level.
        exponential: exponential increase in weight from level to level
        linear_offset: linear increase in weight from level to level with offset.
        equal: equal weight for all cell type annotation levels.
    
    Returns
    -------
        Weighted score from 0 to 1.
        
    """
    num_levels = len(celltype_keys)
    scores = []
    
    # Get scores for each annotation level
    for key in celltype_keys:
        score = history['val_metrics'][key][metric][-1]
        scores.append(score)
    
    # Calculate weights
    if strategy == 'linear':
        weights = list(range(1, num_levels + 1))
    elif strategy == 'exponential':
        weights = [2**i for i in range(num_levels)]
    elif strategy == 'linear_offset':
        weights = [1.0 + (i / 2) for i in range(num_levels)]
    elif strategy == 'equal':
        weights = [1.0 for i in range(num_levels)]
    else:
        raise ValueError(f"Unknown weighting strategy: {strategy}")
    
    # Normalize weights
    total_weight = sum(weights)
    normalized_weights = [w / total_weight for w in weights]
    
    # Weighted mean
    weighted_score = sum(s * w for s, w in zip(scores, normalized_weights))
    
    return weighted_score


def celltype_level_weights(num_levels, strategy = 'exponential'):
    """
    Uses different weighting strategies for multiple cell type annotation levels.

    Parameters
    ----------
    num_levels: int
        Number of cell type annotation levels
    strategy: str (default: 'linear_offset')
        Weighting strategy for different cell type annotation levels.
        The following weighting strategies are available: linear, exponential, linear_offset, equal.
        linear: linear increase in weight from level to level.
        exponential: exponential increase in weight from level to level
        linear_offset: linear increase in weight from level to level with offset.
        equal: equal weight for all cell type annotation levels.
    
    Examples
    --------
        celltype_level_weights(3, 'linear')
        [0.167, 0.333, 0.500]  # 1/6, 2/6, 3/6
        
        celltype_level_weights(3, 'exponential')
        [0.143, 0.286, 0.571]  # 1/7, 2/7, 4/7
        
        celltype_level_weights(3, 'linear_offset')
        [0.222, 0.333, 0.444] # (1 + 0/2)/4.5, (1 + 1/2)/4.5, (1 + 2/2)/4.5 
        
        celltype_level_weights(3, 'equal')
        [0.333, 0.333, 0.333] # 1/3, 1/3, 1/3
    """
    if strategy == 'linear':
        weights = list(range(1, num_levels + 1))
    elif strategy == 'exponential':
        weights = [2**i for i in range(num_levels)]
    elif strategy == 'linear_offset':
        weights = [1.0 + (i / 2) for i in range(num_levels)]
    elif strategy == 'equal':
        weights = [1.0 for i in range(num_levels)]
    else:
        raise ValueError(f"Unknown weighting strategy: {strategy}")
    
    total = sum(weights)
    return [w / total for w in weights]


# Save scAdam model function 
def save_model(model, path, model_name, verbose=True):
    """
    Function for scAdam model saving.
    Saves model in a folder 'model_name'.
    """
    save_dict = {
        'state_dict': model.state_dict(),
        'gn': model.gn,
        'ncl': model.ncl,
        'ed': model.ed,
        'nc': model.nc,
        'nb': model.nb,
        'nh': model.nh,
        'ff_hd': model.ff_hd,
        'classifier_hd': model.classifier_hd,
        'dropout': model.dropout,
        'celltype_keys': getattr(model, 'celltype_keys', None),
        'celltype_encoders': getattr(model, 'celltype_encoders', None),
        'var_names': getattr(model, 'var_names', None),
        'history': getattr(model, 'history', None)
    }
    # Create folder to save model
    os.makedirs(os.path.join(path, model_name).replace("\\","/"), exist_ok = True)
    # Save scAdam model
    torch.save(save_dict, os.path.join(path, model_name, 'model_v2.pth'))

    # Save unknown cell type detector 
    if hasattr(model, 'unknown_detector') and model.unknown_detector is not None:
        detector_state = {
            #'entropy_threshold': model.unknown_detector.entropy_threshold,
            #'gradient_threshold': model.unknown_detector.gradient_threshold,
            #'distance_threshold': model.unknown_detector.distance_threshold,
            'class_centroids': {k: v.cpu().numpy().tolist() 
                               for k, v in model.unknown_detector.class_centroids.items()} if model.unknown_detector.class_centroids else None,
            'class_stds': {k: v if isinstance(v, (int, float)) else v 
                          for k, v in model.unknown_detector.class_stds.items()} if model.unknown_detector.class_stds else None
        }
        
        if model.unknown_detector.class_centroids is not None:
            detector_state['class_centroids'] = {
                str(k): v.cpu().numpy().tolist() if torch.is_tensor(v) else v.tolist() if hasattr(v, 'tolist') else v
                for k, v in model.unknown_detector.class_centroids.items()
            }
        if model.unknown_detector.class_stds is not None:
            detector_state['class_stds'] = {
                str(k): float(v.item()) if torch.is_tensor(v) else float(v) if hasattr(v, '__float__') else float(v)
                for k, v in model.unknown_detector.class_stds.items()
            }
        # Save unknown cell type detector state 
        with open(os.path.join(path, model_name, 'unknown_detector.json'), 'w') as f:
            json.dump(detector_state, f, indent=2)
            
    if verbose:
        print(f"Model saved to {os.path.join(path, model_name)}")


# Load scAdam model function
def load_model(path, device = 'auto', verbose=True):
    """
    Function for scAdam model loading.
    Loads model from a folder 'path'.
    """
     # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device

    # Load scAdam models checkpoint
    checkpoint = torch.load(os.path.join(path, 'model_v2.pth'), map_location=device, weights_only = False)
    
    # Create scAdam model
    model = scAdamTransformer(
        gn=checkpoint['gn'],
        ncl=checkpoint['ncl'],
        ed=checkpoint['ed'],
        nc=checkpoint['nc'],
        nb=checkpoint['nb'],
        nh=checkpoint['nh'],
        ff_hd=checkpoint['ff_hd'],
        classifier_hd=checkpoint['classifier_hd'],
        dropout=checkpoint['dropout']
    )

    # Load state to scAdam model 
    model.load_state_dict(checkpoint['state_dict'])
    model = model.to(device)
    model.eval()
    
    # Restore metadata
    for key in ['celltype_keys', 'celltype_encoders', 'var_names', 'history']:
        if key in checkpoint:
            setattr(model, key, checkpoint[key])

    # Load unknown cell type detector 
    detector_path = os.path.join(path, 'unknown_detector.json')
    if os.path.exists(detector_path):
        with open(detector_path, 'r') as f:
            detector_state = json.load(f) # load detector state
            
        # Create unknown cell type detector using loaded detector state
        detector = UnknownCellDetector()
        
        # Restore centroids 
        if detector_state.get('class_centroids') is not None:
            detector.class_centroids = {
                int(k): torch.tensor(v, dtype=torch.float32)
                for k, v in detector_state['class_centroids'].items()
            }
        
        # Restore stds
        if detector_state.get('class_stds') is not None:
            detector.class_stds = {
                int(k): v
                for k, v in detector_state['class_stds'].items()
            }
        
        model.unknown_detector = detector
        if verbose:
            print(f"scAdam model with unknown detector loaded from {path}")
    else:
        if verbose:
            print(f"scAdam model without unknown detector loaded from {path}")
        
    return model


# Function for calculation of otsu threshold 
def otsu_threshold(data, threshold = 2.0, nbins = 256, random_state = 1):
    """
    Analyzes the distribution and returns the Otsu threshold only if the data is bimodal.
    
    Parameters
    ----------
    data : array-like
        An array of values (eg. gene expression).
    threshold: float (default = 2.0)
        Ashman coefficient threshold. Threshold > 2 indicates good peak separation.
        For a more subtle separation, threshold > 1.5 can be used.
        
    Returns
    -------
    float or None
        Threshold (float) if the distribution is bimodal.
        None if the distribution is considered unimodal.
    """
    data = np.array(data).reshape(-1, 1)
    
    # Calculate Gaussian models with 1 and 2 components
    gmm1 = GaussianMixture(n_components=1, random_state=random_state).fit(data)
    gmm2 = GaussianMixture(n_components=2, random_state=random_state).fit(data)
    
    # Compare models with 1 and 2 components using BIC
    if gmm1.bic(data) < gmm2.bic(data):
        return None

    # Checking the separability of peaks (Ashman's D)
    # Usefull in case of skewed distribution
    means = gmm2.means_.flatten()
    covs = gmm2.covariances_.flatten()
    
    # Sort the components by average
    idx = np.argsort(means)
    mu1, mu2 = means[idx]
    sigma1, sigma2 = np.sqrt(covs[idx])
    
    # Ashman's D statistic
    # D = sqrt(2) * |mu1 - mu2| / sqrt(sigma1^2 + sigma2^2)
    # Check if peaks are too close, most likely it is one "wide" distribution
    d_score = np.sqrt(2) * abs(mu1 - mu2) / np.sqrt(sigma1**2 + sigma2**2)
    if d_score < threshold:
        return None 

    # If the checks are passed, we calculate the Otsu threshold
    thresh = threshold_otsu(data.flatten(), nbins=nbins)
    return thresh
    
class UnknownCellDetector:
    def __init__(
        self
    ):
        """
        UnknownCellDetector for detection cells with unknown for scAdam model cell type in data.
        
        """

        self.class_centroids = None
        self.class_stds = None

    def fit(
        self, 
        model, 
        dataloader, 
        device='cuda', 
        verbose=True
    ):
        """
        Fitting of UnknownCellDetector.

        Parameters
        ----------
        model: scAdam model.
        dataloader: full training dataset dataloader.
        device: str (default: 'auto')
            Type of device to use in training model ('cpu', 'cuda'). Set 'auto' for automatic selection.
        verbose: bool (default: True)
            Show UnknownCellDetector progress or not. 
        
        """
        # Create class_centroids and class_stds dictionaries for distance calculation.
        self.class_centroids = {}
        self.class_stds = {}
        
        model.eval()
        model = model.to(device)

        if verbose:
            print("Fitting unknown cells detector")

        all_embeddings = []
        all_labels = {key: [] for key in model.celltype_keys}

        last_level = f'level_{model.num_levels - 1}'

        with torch.enable_grad():
            for batch_x, batch_y in dataloader:
                batch_x = batch_x.to(device)

                # Embeddings for each batch
                with torch.no_grad():
                    x_emb = model.gene_embedding(batch_x)
                    for block in model.blocks:
                        x_emb, _ = block(x_emb)
                    x_emb = model.norm(x_emb).mean(dim=1)
                    all_embeddings.append(x_emb.cpu())

                # Labels
                for key in model.celltype_keys:
                    all_labels[key].append(batch_y[key])

        # Concatenate
        all_embeddings = torch.cat(all_embeddings, dim=0)

        # Thresholds
        last_key = model.celltype_keys[-1]
        labels = torch.cat(all_labels[last_key], dim=0)
        unique_labels = labels.unique()

        for label in unique_labels:
            label_idx = (labels == label)
            class_emb = all_embeddings[label_idx]
            centroid = class_emb.mean(dim=0)
            std = class_emb.std(dim=0).mean()
            self.class_centroids[label.item()] = centroid
            self.class_stds[label.item()] = std

    def detect(
        self, 
        model, 
        dataloader, 
        threshold = 2.0, 
        temperature = 2.0, 
        mc_passes = 5, 
        method = 'voting', 
        device = 'cuda', 
        verbose = True, 
        random_state = 1
    ):
        """
        Detect cells with unknown cell type.

        temperature: float (default: 2.0)
            A parameter that controls the confidence of the model in its predictions.
            If temperature > 1 - model appears less confident.
            If temperature < 1 - model appears more confident.
        
        """
        self.temperature = temperature
        self.mc_passes = mc_passes
        
        model.eval()
        model = model.to(device)

        all_gradient_norms = []
        all_entropy_scores = []
        all_dist_scores = []
        all_predictions = []

        last_level = f'level_{model.num_levels - 1}'

        with torch.enable_grad():
            for batch in dataloader:
                if isinstance(batch, (list, tuple)):
                    batch_x = batch[0]
                else:
                    batch_x = batch
                batch_x = batch_x.to(device)

                # Embeddings
                with torch.no_grad():
                    x_emb = model.gene_embedding(batch_x)
                    for block in model.blocks:
                        x_emb, _ = block(x_emb)
                    x_emb = model.norm(x_emb).mean(dim=1)

                # Classifier forward
                with torch.no_grad():
                    outputs_eval = model.classifier(x_emb)
                    logits_eval = outputs_eval[last_level]['logits']
                    probs_eval = F.softmax(logits_eval / self.temperature, dim=-1)
                    preds = logits_eval.argmax(dim=1)
                    all_predictions.append(preds.detach().cpu())

                # Gradient norm (with MC-dropout)
                if method == 'gradient' or method == 'voting' or method == 'combined':
                    x_emb_grad = x_emb.clone().detach().requires_grad_(True)

                    model_train_state = model.training
                    if self.mc_passes is not None and self.mc_passes > 1:
                        model.train()
                        loss_total = 0.0
                        for _ in range(self.mc_passes):
                            outputs = model.classifier(x_emb_grad)
                            logits = outputs[last_level]['logits']
                            probs = F.softmax(logits / self.temperature, dim=-1)
                            max_prob, _ = probs.max(dim=1)
                            loss = -torch.log(max_prob + 1e-10)
                            loss_total = loss_total + loss.sum()
                        loss_total.backward()
                    else:
                        model.eval()
                        outputs = model.classifier(x_emb_grad)
                        logits = outputs[last_level]['logits']
                        probs = F.softmax(logits / self.temperature, dim=-1)
                        max_prob, _ = probs.max(dim=1)
                        loss = -torch.log(max_prob + 1e-10)
                        loss.sum().backward()

                    if model_train_state:
                        model.train()
                    else:
                        model.eval()

                    grad_norms = torch.norm(x_emb_grad.grad, dim=1)
                    all_gradient_norms.append(grad_norms.detach().cpu())
                    x_emb_grad.grad.zero_()

                # Entropy (temperature + MC-dropout)
                if method == 'entropy' or method == 'voting' or method == 'combined':
                    with torch.no_grad():
                        probs_mc = []
                        if self.mc_passes is not None and self.mc_passes > 1:
                            model_train_state = model.training
                            model.train()
                            for _ in range(self.mc_passes):
                                outputs = model(batch_x)
                                logits = outputs[last_level]['logits']
                                probs = F.softmax(logits / self.temperature, dim=-1)
                                probs_mc.append(probs)
                            if model_train_state:
                                model.train()
                            else:
                                model.eval()
                            probs_mean = torch.stack(probs_mc, dim=0).mean(dim=0)
                        else:
                            logits = logits_eval
                            probs_mean = probs_eval

                        entropy = -(probs_mean * torch.log(probs_mean + 1e-10)).sum(dim=1)
                        all_entropy_scores.append(entropy.cpu())

                # Distance
                if method == 'distance' or method == 'voting' or method == 'combined':
                    if self.class_centroids:
                        min_dist = []
                        for i in range(x_emb.size(0)):
                            pred_label = preds[i].item()
                            if pred_label in self.class_centroids:
                                centroid = self.class_centroids[pred_label].to(device)
                                dist = torch.norm(x_emb[i] - centroid)
                                min_dist.append(dist.detach().cpu())
                            else:
                                min_dist.append(torch.tensor(float('inf')))
                        all_dist_scores.append(torch.stack(min_dist))

        # Concatenate results + convert to numpy
        all_predictions = torch.cat(all_predictions).numpy()
        all_gradient_norms = torch.cat(all_gradient_norms).numpy() if all_gradient_norms else None
        all_entropy_scores = torch.cat(all_entropy_scores).numpy() if all_entropy_scores else None
        all_dist_scores = torch.cat(all_dist_scores).numpy() if all_dist_scores else None

        # Calculating thresholds using Otsu method
        if all_gradient_norms is not None:
            self.gradient_threshold = otsu_threshold(all_gradient_norms, threshold = threshold, nbins=len(all_gradient_norms), random_state=random_state)
            print(f'Gradient threshold calculated using Otsu method: {self.gradient_threshold}')
        if all_entropy_scores is not None:
            self.entropy_threshold = otsu_threshold(all_entropy_scores, threshold = threshold, nbins=len(all_entropy_scores), random_state=random_state)
            print(f'Entropy threshold calculated using Otsu method: {self.entropy_threshold}')
        if all_dist_scores is not None:
            self.distance_threshold = otsu_threshold(all_dist_scores, threshold = threshold, nbins=len(all_dist_scores), random_state=random_state)
            print(f'Distance threshold calculated using Otsu method: {self.distance_threshold}')

        # Check thresholds
        if (method == 'entropy' or method == 'voting' or method == 'combined') and self.entropy_threshold is None:
            unknown_mask = None
            return unknown_mask, {
                'gradient': all_gradient_norms,
                'entropy': all_entropy_scores,
                'distance': all_dist_scores,
                'predictions': all_predictions
            }
        elif (method == 'gradient' or method == 'voting' or method == 'combined') and self.gradient_threshold is None:
            unknown_mask = None
            return unknown_mask, {
                'gradient': all_gradient_norms,
                'entropy': all_entropy_scores,
                'distance': all_dist_scores,
                'predictions': all_predictions
            }
        elif (method == 'distance' or method == 'voting' or method == 'combined') and self.distance_threshold is None:
            unknown_mask = None
            return unknown_mask, {
                'gradient': all_gradient_norms,
                'entropy': all_entropy_scores,
                'distance': all_dist_scores,
                'predictions': all_predictions
            }

        
        # Method selection
        if method == 'gradient':
            unknown_mask = all_gradient_norms > self.gradient_threshold

        elif method == 'entropy':
            unknown_mask = all_entropy_scores > self.entropy_threshold

        elif method == 'distance':
            unknown_mask = all_dist_scores > self.distance_threshold

        elif (method == 'voting') or (method == 'combined'):
            votes = np.zeros(len(all_predictions), dtype=int)
            num_methods = 0

            if all_gradient_norms is not None:
                votes += (all_gradient_norms > self.gradient_threshold).astype(int)
                num_methods += 1

            if all_entropy_scores is not None:
                votes += (all_entropy_scores > self.entropy_threshold).astype(int)
                num_methods += 1

            if all_dist_scores is not None:
                votes += (all_dist_scores > self.distance_threshold).astype(int)
                num_methods += 1

            if method == 'voting':
                if num_methods >= 2:
                    threshold_votes = 2
                else:
                    threshold_votes = 1
                unknown_mask = votes >= threshold_votes

            else:  # method == 'combined'
                g_denom = (np.max(all_gradient_norms) + 1e-12) if all_gradient_norms is not None else 1.0
                e_denom = (np.max(all_entropy_scores) + 1e-12) if all_entropy_scores is not None else 1.0
                d_denom = (np.max(all_dist_scores) + 1e-12) if all_dist_scores is not None else 1.0

                all_gradient_norms_scaled = all_gradient_norms / g_denom
                all_entropy_scores_scaled = all_entropy_scores / e_denom
                all_dist_scores_scaled = all_dist_scores / d_denom

                gradient_threshold = self.gradient_threshold / g_denom
                entropy_threshold = self.entropy_threshold / e_denom
                distance_threshold = self.distance_threshold / d_denom

                sum_of_thresholds = distance_threshold + entropy_threshold + gradient_threshold
                sum_of_scores = all_gradient_norms_scaled + all_entropy_scores_scaled + all_dist_scores_scaled
                unknown_mask = sum_of_scores >= sum_of_thresholds

        if verbose and (method == 'voting'):
            print(f"Unknown cell type detection using {num_methods} methods. Threshold: {threshold_votes} votes.")
        elif verbose and (method == 'combined'):
            print(f"Unknown cell type detection using combination of methods.")
        elif verbose:
            print(f"Unknown cell type detection using {method}.")

        if verbose:
            print(f"Detected {unknown_mask.sum()} unknown cells ({100*unknown_mask.mean():.1f}%)")

        return unknown_mask, {
            'gradient': all_gradient_norms,
            'entropy': all_entropy_scores,
            'distance': all_dist_scores,
            'predictions': all_predictions
        }


# Function for training scAdam model
[docs] def train( adata, celltype_keys, layer = None, path = '', model_name = 'scAdam_model', test_size = 0.2, eval_metric = ['accuracy', 'balanced_accuracy'], strategy = 'linear_offset', batch_size = 128, epochs = 200, patience = 10, nc = 16, nb = 5, nh = 8, ff_hd = 512, ed_nh_ratio = 32, classifier_hd = 256, dropout = 0.3, lr = 1e-4, weight_decay = 1e-4, use_augmentation = True, aug_probability = 0.5, prob = 0.15, noise_std = 0.1, dropout_aug = 0.1, alpha = 0.2, from_unsupervised = True, pretrain_epochs = 50, pretrain_data = None, adaptive_loss = True, device = 'auto', random_state = 0, return_model = False, unknown_detection = True, verbose = True ): """ scAdam model training with unsupervised pretraining. scAdam is a model for automatic cell type annotation. This function creates a model that can be used for cell type annotation. Parameters ---------- adata: AnnData Dataset with cell type annotations in adata.obs path: str, path object Path to create a model folder containing the training history, cell annotation dictionary, and genes used for training. celltype_keys: list List of cell type annotations in adata.obs. Example: ['lineage', 'cell type', 'cell state'] layer: str (default: None) If specified, use adata.layers[layer] for expression values instead of adata.X. model_name: str (default: 'scAdam_model') Name of a folder to save model. test_size: float or int, (default: 0.2) If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test cells. batch_size: int, (default: 128) Number of examples per batch. epochs: int (default: 150) Maximum number of epochs for scAdam model training patience: int (default: 10) Number of consecutive epochs without improvement before performing early stopping. If patience is set to 0, then no early stopping will be performed. Note that if patience is enabled, then best weights from best epoch will automatically be loaded at the end of the training. eval_metric: str or list (default: ['accuracy', 'balanced_accuracy']) Available evaluation metrics:'accuracy', 'balanced_accuracy', 'f1_score'. The last metric is used as the target and for early stopping. strategy: str (default: 'linear_offset') Weighting strategy for different cell type annotation levels. The following weighting strategies are available: linear, exponential, linear_offset, equal, last. linear: linear increase in weight from level to level. exponential: exponential increase in weight from level to level linear_offset: linear increase in weight from level to level with offset. equal: equal weight for all cell type annotation levels. last: uses only last cell type annotation for model evaluation. nc: int (default: 16) Number of chunks for genes from adata. nb: int (default: 5) Number of blocks in scAdam model. nh: int (default: 8) Number of heads in scAdam model attention mechanism. ed_nh_ratio: int (default: 32) Used for calculating embedding dimensionality ('ed') from 'nh'. Default ed = nh * ed_nh_ratio = 8 * 32 = 256. ff_hd: int (default: 512) Number of nodes in each scAdam model layer in feed forward network. classifier_hd: int (default: 256) Number of nodes in each scAdam classifier. dropout: float (default: 0.3) Portion of neurons that temporarily ignored during training (prevents overfitting). lr: float (default: 1e-4) Determines the step size at each iteration while moving toward a minimum of a loss function. weight_decay: float (default: 1e-4) Weight decay coefficient. adaptive_loss: bool (default: True) If True, enables adaptive weighting of hierarchical loss levels: each level’s loss is tracked over training and its contribution to the total loss is increased if this level remains hard (high loss) and decreased if it becomes easy (low loss). This helps the model focus more on poorly performing levels of the hierarchy instead of weighting all levels equally. If False, all levels are summed with a fixed weight of 1.0. use_augmentation: bool (default: True) Use data augmentation or not. aug_probability: float (default: 0.5) The probability of applying augmentation to a batch prob: float (default: 0.15) Gene masking probability. noise_std: float (default: 0.1) Gaussian noise standard deviation. dropout_aug: float (default: 0.1) Dropout probability for simulating technical noise. alpha: float (default: 0.2) Alpha parameter for mixup augmentation. from_unsupervised: bool (default: True) Use a previously self supervised model as starting weights. Supervised model training included in function. pretrain_epochs: int (default: 50) Number of pretraining epochs. pretrain_data: AnnData Anndata with the same expression matrix for unsupervised pretraining. Additional Anndata may not contain cell type annotations. device: str (default: 'auto') Type of device to use in training model ('cpu', 'cuda'). Set 'auto' for automatic selection. random_state: int (default: 0) Controls the data shuffling, splitting to folds and model training. Pass an int for reproducible output across multiple function calls. verbose: bool (default: True) Show progress bar for each epoch during training. return_model: bool (default: False) Return model after training or not. unknown_detection: bool (default: True) Train unknown cell detector - identifies unknown cells when model used for prediction on a new data. Returns ------- Saves scAdam model for cell type annotation. """ # Check 'ed' - 'nh' compatibility ed = nh * ed_nh_ratio if ed % nh != 0: raise ValueError(f"Incompatible parameters: 'ed' must be divisible by 'nh' without a remainder.") # Set random state (for reproducibility) np.random.seed(random_state) torch.manual_seed(random_state) if torch.cuda.is_available(): torch.cuda.manual_seed_all(random_state) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Device selection device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device # Get evaluation metric if isinstance(eval_metric, str): eval_metric = [eval_metric] if verbose: print(f"Device: {device}") if verbose and strategy != 'last': level_weights = celltype_level_weights(len(celltype_keys), strategy) # Create dataset dataset = scRNAseqDataset( adata, celltype_keys, layer=layer ) ncl = [dataset.celltype_encoders[key]['n_classes'] for key in celltype_keys] if verbose: print(f"Number of features: {dataset.gn}") print(f"Label hierarchy: {' → '.join(celltype_keys)}") print(f"Annotation levels weights using strategy '{strategy}':") for key, n_celltypes, weight in zip(celltype_keys, ncl, level_weights): print(f" {key}: {n_celltypes} cell types, {round(weight, 3)} relative weight") # Train/val split indices = np.arange(len(dataset)) train_idx, val_idx = train_test_split( indices, test_size=test_size, stratify=adata.obs[celltype_keys[-1]], # Using last (most detailed) annotation level random_state=random_state ) if verbose: print(f"\nDataset split:") print(f'Train dataset contains: {len(train_idx)} cells, it is {round(100*(len(train_idx)/(len(train_idx) + len(val_idx))), ndigits=2)} % of input dataset') print(f'Validation dataset contains: {len(val_idx)} cells, it is {round(100*(len(val_idx)/(len(train_idx) + len(val_idx))), ndigits=2)} % of input dataset') # Create dataloaders train_loader = DataLoader( Subset(dataset, train_idx), batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False ) val_loader = DataLoader( Subset(dataset, val_idx), batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False ) # Initialize model model = scAdamTransformer( gn=dataset.gn, ncl=ncl, ed=ed, nc=nc, nb=nb, nh=nh, ff_hd=ff_hd, classifier_hd=classifier_hd, dropout=dropout ).to(device) # Unsupervised pretraining if from_unsupervised: if pretrain_data is not None: pretrain_adata = pretrain_data else: pretrain_adata = adata model = pretrain_unsupervised( pretrain_adata, model, layer=layer, batch_size=batch_size, epochs=pretrain_epochs, lr=lr, prob=prob, device=device, random_state=random_state, verbose=verbose ) # Augmentation augmenter = None if use_augmentation: augmenter = dust.Augmentation( prob = prob, noise_std = noise_std, dropout_prob = dropout_aug, alpha = alpha ) # Loss function criterion = HierarchicalLoss( num_levels=len(celltype_keys), celltype_keys=celltype_keys, alpha=0.25, gamma=2.0, adaptive=adaptive_loss ).to(device) # Optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=weight_decay ) # Learning rate scheduler def lr_lambda(epoch): return 0.5 * (1.0 + np.cos(np.pi * (epoch / epochs))) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # Early stopping early_stopping = dust.EarlyStopping( patience=patience, delta=1e-4, mode='max', verbose=verbose ) # Available metric functions available_metric_functions = { 'accuracy': accuracy_score, 'balanced_accuracy': balanced_accuracy_score, 'f1_score': lambda y_true, y_pred: f1_score(y_true, y_pred, average='weighted', zero_division=0) } # Create dictionary of used for training metric functions metric_functions = {} for metric in eval_metric: if metric not in available_metric_functions: raise ValueError(f"Unknown metric: {metric}") else: metric_functions[metric] = available_metric_functions[metric] # Create training history dictionary history = { 'train_loss': [], 'val_loss': [], 'train_metrics': {key: {metric: [] for metric in eval_metric} for key in celltype_keys}, 'val_metrics': {key: {metric: [] for metric in eval_metric} for key in celltype_keys}, 'lr': [] } # Training Loop for epoch in tqdm(range(epochs), desc='Training scAdam model', colour='blue', disable = not verbose): # Model training model.train() train_loss = 0.0 train_recon_loss = 0.0 train_preds = {key: [] for key in celltype_keys} train_targets = {key: [] for key in celltype_keys} for batch_x, batch_y in train_loader: batch_x = batch_x.to(device) batch_y = {key: val.to(device) for key, val in batch_y.items()} # Augmentation only if random < aug_probability if use_augmentation and np.random.random() < aug_probability: batch_x = augmenter(batch_x) # Get predictions using model predictions = model(batch_x) # Calculate model loss loss, level_losses = criterion(predictions, batch_y) # Backward optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() train_loss += loss.item() # Predictions and targets with torch.no_grad(): for level_idx, key in enumerate(celltype_keys): level_key = f'level_{level_idx}' preds = predictions[level_key]['logits'].argmax(dim=1) train_preds[key].append(preds.cpu()) train_targets[key].append(batch_y[key].cpu()) train_loss /= len(train_loader) history['train_loss'].append(train_loss) # Compute training metrics for key in celltype_keys: train_preds[key] = torch.cat(train_preds[key]).numpy() train_targets[key] = torch.cat(train_targets[key]).numpy() for metric_name, metric_func in metric_functions.items(): score = metric_func(train_targets[key], train_preds[key]) history['train_metrics'][key][metric_name].append(score) # Model validation model.eval() val_loss = 0.0 val_preds = {key: [] for key in celltype_keys} val_targets = {key: [] for key in celltype_keys} with torch.no_grad(): for batch_x, batch_y in val_loader: batch_x = batch_x.to(device) batch_y = {key: val.to(device) for key, val in batch_y.items()} # Get predictions using model predictions = model(batch_x) # Calculate model loss loss, level_losses = criterion(predictions, batch_y) val_loss += loss.item() # Store predictions and targets for level_idx, key in enumerate(celltype_keys): level_key = f'level_{level_idx}' preds = predictions[level_key]['logits'].argmax(dim=1) val_preds[key].append(preds.cpu()) val_targets[key].append(batch_y[key].cpu()) val_loss /= len(val_loader) history['val_loss'].append(val_loss) # Validation metrics for key in celltype_keys: val_preds[key] = torch.cat(val_preds[key]).numpy() val_targets[key] = torch.cat(val_targets[key]).numpy() for metric_name, metric_func in metric_functions.items(): score = metric_func(val_targets[key], val_preds[key]) history['val_metrics'][key][metric_name].append(score) # Learning rate history["lr"].append(optimizer.param_groups[0]["lr"]) # Early Stopping if strategy == 'last': es_score = history['val_metrics'][celltype_keys[-1]][eval_metric][-1] else: es_score = weighted_metric( history=history, celltype_keys=celltype_keys, metric=eval_metric[-1], strategy=strategy ) if early_stopping(es_score, model): break scheduler.step() # Load best model early_stopping.load_bm(model) model.eval() # Store metadata model.celltype_keys = celltype_keys model.celltype_encoders = dataset.celltype_encoders model.var_names = dataset.var_names model.history = history if verbose: print("Training completed!") # Remove loaders del train_loader, val_loader # Train unknown cell type detector if unknown_detection and verbose: # Create detector detector = UnknownCellDetector() # Create full dataset loader for detector training full_loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False ) # Fit detector on full dataset detector.fit(model, full_loader, device=device) # Save detector model.unknown_detector = detector if verbose: print("UnknownCellDetector fitted successfully!") save_model(model, path, model_name, verbose) if return_model: return model
# Predict cell type annotation using trained scAdam model
[docs] def predict( adata, path_model, layer = None, batch_size = 256, device = 'auto', prefix = 'pred_', detect_unknown = False, method = 'voting', threshold = 2.0, temperature = 2.0, mc_passes = 5, verbose = True ): """ Predict cell types and cell type probilities using pretrained scAdam model. It is also possible to check if a cell has an unknown cell type for the scAdam model (detect_unknown = True). Parameters ---------- adata: AnnData Annotated data matrix. path_model: str, path object Path to the folder containing the trained scAdam model. layer: str (default: None) If specified, use adata.layers[layer] for expression values instead of adata.X. batch_size: int, (default: 256) Number of examples per batch. device: str (default: 'auto') Type of device to use in prediction ('cpu', 'cuda'). Set 'auto' for automatic selection. prefix: str (default: 'pred_') Prefix for new columns in adata.obs detect_unknown: bool (default: False) Detection of unknown cell types. Parameters 'method', 'temperature', 'mc_passes', 'threshold' are not used if 'detect_unknown' = False. method: str (default: 'voting') Method of unknown cell type detection. gradient - uses the norm of the gradient with respect to the last annotation level embedding as an uncertainty score. entropy - uses the entropy of the softmax probabilities on the last annotation level. distance - uses the distance in embedding space to the centroid of the predicted class. voting - combines the three previous criteria with majority voting. combined - builds one continuous combined score from all three individual methods with combined threshold. threshold: float (default = 2.0) Ashman coefficient threshold. Threshold > 2 indicates good peak separation. For a more subtle separation, threshold > 1.5 can be used. temperature: float (default: 2.0) A parameter that controls the confidence of the model in its predictions. If temperature > 1 - model appears less confident. If temperature < 1 - model appears more confident. mc_passes: int (default: 5) It is the number of forward passes you run with dropout left on for the same data (Monte Carlo dropout). Necessary for calculating the gradient score. verbose: bool (default: True) Show progress bar for each batch during prediction. Returns ------- adata: AnnData Annotated adata object with predicted cell types in adata.obs. if detect_unknown = True, adds methods scores in adata.obs. """ # Check method of unknown cell type detection if method not in ['voting', 'gradient', 'entropy', 'distance', 'combined']: raise ValueError(f"Unknown method: {method}") # Load scAdam model if os.path.exists(os.path.join(path_model, 'model_v2.pth')): model = load_model(path_model, device = 'auto', verbose=verbose) # Device selection device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device model = model.to(device) model.eval() if not hasattr(model, 'var_names') or model.var_names is None: raise ValueError("Model does not have var_names attribute. Train the model first.") if not hasattr(model, 'celltype_keys') or model.celltype_keys is None: raise ValueError("Model does not have celltype_keys attribute. Train the model first.") # Get gene expression data if layer is not None and layer in adata.layers: X = adata.layers[layer] else: X = adata.X # Convert sparse to dense if needed if hasattr(X, 'toarray'): X = X.toarray() # Get feature names from new data new_var_names = adata.var_names.tolist() model_var_names = model.var_names # Align features: create matrix with genes in the same order as training n_cells = X.shape[0] n_model_genes = len(model_var_names) X_aligned = np.zeros((n_cells, n_model_genes), dtype=np.float32) # Find intersection and assign values matched_features = 0 for i, gene in enumerate(model_var_names): if gene in new_var_names: gene_idx = new_var_names.index(gene) X_aligned[:, i] = X[:, gene_idx] matched_features += 1 if verbose: print(f"Gene alignment:") print(f" Model features: {n_model_genes}") print(f" Matched features: {matched_features} ({100*matched_features/n_model_genes:.1f}%)") # Raise warning if matched features is lower than 80% from used for training if matched_features < 0.80 * n_model_genes: warnings.warn( f"Only {matched_features}/{n_model_genes} genes matched! " "This may lead to poor predictions. " "Make sure the gene names are in the same format (e.g., gene symbols)." ) # Normalize the same way as during training X_aligned = np.maximum(X_aligned, 0) # Remove negative values X_norm = (X_aligned - X_aligned.min(0)) / (np.ptp(X_aligned, axis=0) + 1e-10) # Check for NaN/Inf if np.isnan(X_norm).any() or np.isinf(X_norm).any(): warnings.warn("NaN or Inf found in normalized data. Replacing with zeros.") X_norm = np.nan_to_num(X_norm, nan=0.0, posinf=0.0, neginf=0.0) X_tensor = torch.FloatTensor(X_norm) # Create dataloader dataset = torch.utils.data.TensorDataset(X_tensor) loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) # Predict predictions = {key: [] for key in model.celltype_keys} probabilities = {key: [] for key in model.celltype_keys} with torch.no_grad(): for (batch_x,) in tqdm(loader, desc='Predicting', colour='blue'): batch_x = batch_x.to(device) outputs = model(batch_x) for level_idx, key in enumerate(model.celltype_keys): level_key = f'level_{level_idx}' preds = outputs[level_key]['logits'].argmax(dim=1).cpu().numpy() predictions[key].append(preds) probs = outputs[level_key]['probs'].cpu().numpy() probabilities[key].append(probs) # Concatenate predictions for key in model.celltype_keys: predictions[key] = np.concatenate(predictions[key]) # Decode labels back to original names encoder = model.celltype_encoders[key]['celltype_encoder'] predictions[key] = encoder.inverse_transform(predictions[key]) probabilities[key] = np.concatenate(probabilities[key]) # Unknown cell type detection unknown_mask = None if detect_unknown: unknown_mask, scores = model.unknown_detector.detect( model, loader, threshold = threshold, temperature = temperature, mc_passes = mc_passes, device = device, method = method, verbose = verbose ) # Add mask if unknown_mask is not None: adata.obs[f'{prefix}unknown'] = unknown_mask else: print('No unknown cells found!') # Add predictions to adata.obs for i, key in enumerate(model.celltype_keys, start=1): col_name = f"{prefix}celltype_l{i}" # Use mask to set 'Unknown' cell types predictions_with_unknown = predictions[key].copy() if unknown_mask is not None: predictions_with_unknown[unknown_mask] = "Unknown" # Add to adata.obs new column adata.obs[col_name] = predictions_with_unknown.astype(str) if verbose: print(f"Added cell type column: {col_name}") # Add probabilities max_probs = probabilities[key].max(axis=1) prob_col_name = f"{prefix}celltype_l{i}_probability" adata.obs[prob_col_name] = max_probs if verbose: print(f"Added probabilities column: {prob_col_name}") # Add scores to adata.obs if scores['gradient'] is not None: adata.obs['gradient_score'] = scores['gradient'] if scores['entropy'] is not None: adata.obs['entropy_score'] = scores['entropy'] if scores['distance'] is not None: adata.obs['distance_score'] = scores['distance'] else: # Without unknown cell detection for i, key in enumerate(model.celltype_keys, start=1): col_name = f"{prefix}celltype_l{i}" adata.obs[col_name] = predictions[key] if verbose: print(f"Added cell type column: {col_name}") # Add probabilities max_probs = probabilities[key].max(axis=1) prob_col_name = f"{prefix}celltype_l{i}_probability" adata.obs[prob_col_name] = max_probs if verbose: print(f"Added probabilities column: {prob_col_name}") return adata # Tabnet-based v1 model else: # load genes of trained model features = pd.read_csv(os.path.join(path_model, 'genes.csv').replace("\\","/")) features = list(features['feature_name']) if verbose: print('Successfully loaded list of genes used for training model') print() # Get gene expression data if layer is not None and layer in adata.layers: X = adata.layers[layer] else: X = adata.X # Convert sparse to dense if needed if hasattr(X, 'toarray'): X = X.toarray() # Get gene names from new data new_var_names = adata.var_names.tolist() model_var_names = pd.read_csv(os.path.join(path_model, 'genes.csv').replace("\\","/")) model_var_names = list(model_var_names['feature_name']) if verbose: print('Successfully loaded list of genes used for training model') print() # Align genes: create matrix with genes in the same order as training n_cells = X.shape[0] n_model_genes = len(model_var_names) X_aligned = np.zeros((n_cells, n_model_genes), dtype=np.float32) # Find intersection and assign values matched_features = 0 for i, gene in enumerate(model_var_names): if gene in new_var_names: gene_idx = new_var_names.index(gene) X_aligned[:, i] = X[:, gene_idx] matched_features += 1 if verbose: print(f"Gene alignment:") print(f" Model genes: {n_model_genes}") print(f" New data genes: {len(new_var_names)}") print(f" Matched genes: {matched_features} ({100*matched_features/n_model_genes:.1f}%)") if matched_features < 0.5 * n_model_genes: warnings.warn( f"Only {matched_features}/{n_model_genes} genes matched! " "This may lead to poor predictions. " "Make sure the gene names are in the same format (e.g., gene symbols)." ) # Normalize the same way as during training X_aligned = np.maximum(X_aligned, 0) # Remove negative values # Load dictionary of trained cell types with open(os.path.join(path_model, 'dict.txt')) as dict: dict = dict.read() dict_multi = json.loads(dict) if verbose: print('Successfully loaded dictionary of dataset annotations') print() # Load pretrained model loaded_model = TabNetMultiTaskClassifier() for file in os.listdir(path_model): if file.endswith('.zip'): loaded_model.load_model(os.path.join(path_model, file).replace("\\","/")) if verbose: print('Successfully loaded model') print() # Predict cell types predictions = loaded_model.predict(X_aligned) # Get prediction probabilities probabilities = loaded_model.predict_proba(X_aligned) # Define get_key function for dictionaries def get_key(d, value): for k, v in d.items(): if v == value: return k # Add predictions and probabilities to adata for i in range(len(dict_multi)): prediction_i = [get_key(dict_multi[i], prediction) for prediction in predictions[i].astype(dtype=int)] adata.obs[prefix + 'celltype_l' + f'{i+1}'] = prediction_i probabilities_i = probabilities[i] probabilities__i = [] for j in range(len(probabilities_i)): probabilities__i.append(max(probabilities_i[j])) adata.obs['prob_celltype_l' + f'{i+1}'] = probabilities__i if verbose: print(f'Successfully added predicted celltype_l{i+1} and cell type probabilities') return adata
def get_default_tune_params(): """ Get default ranges of tuned hyperparameters. For integer parameters ('nc', 'nb', 'nh', 'ed_nh_ratio', 'ff_hd', 'classifier_hd', 'batch_size', 'patience', 'epochs', 'pretrain_epochs') a list is used where the first value is the minimum, the second is the maximum, and the third is the step. For float parameters ('dropout', 'lr', 'weight_decay', 'aug_probability', 'prob', 'noise_std', 'dropout_aug', 'alpha') a list is used where the first value is the minimum, the second is the maximum. For categorical parameters ('use_augmentation', 'from_unsupervised') a list [True, False] is used. ed_nh_ratio - Used for "ed" calculation: ed = nh * ed_nh_ratio. """ return { "nc": [2, 16, 2], "nb": [1, 8, 1], "nh": [2, 16, 2], "ed_nh_ratio": [8, 32, 4], "ff_hd": [128, 1024, 128], "classifier_hd": [128, 1024, 128], "dropout": [0.0, 0.5], "lr": [1e-5, 1e-2], "weight_decay": [1e-6, 1e-2], "batch_size": [64, 2048, 64], "patience": [5, 30, 5], "epochs": [50, 200, 5], "use_augmentation": [True, False], "aug_probability": [0.1, 1.0], "prob": [0.05, 0.4], "noise_std": [0.0, 0.4], "dropout_aug": [0.0, 0.4], "alpha": [0.0, 0.4], 'from_unsupervised': [True, False], 'pretrain_epochs': [10, 75, 5], } # Function for hyperparameters tuning
[docs] def hyperparameter_tuning( adata, celltype_keys, path = '', layer = None, model_name = "scAdam_model_tuning", storage = "scadam_model_tuning.db", study_name = "study", load_if_exists = True, eval_metric = ['balanced_accuracy'], strategy = 'linear_offset', device = "auto", tune_params = "auto", adaptive_loss = True, num_trials = 100, n_splits = 5, epochs = None, patience = None, batch_size = None, use_augmentation = None, aug_probability = None, prob = None, noise_std = None, dropout_aug = None, alpha = None, nc = None, nb = None, nh = None, ed_nh_ratio = None, ff_hd = None, classifier_hd = None, dropout = None, lr = None, weight_decay = None, from_unsupervised = None, pretrain_epochs = None, pretrain_data = None, random_state = 0, verbose = True ): """ Hyperparameter tuning for scAdam model with k-fold cross validation using Optuna. Parameters ---------- adata: AnnData Dataset with cell type annotations in adata.obs path: str, path object Path to create a folder with best hyperparameters, dictionary of cell annotations and genes used for hyperparameters optimization. celltype_keys: list List of cell type annotations in adata.obs. Example: ['lineage', 'cell type', 'cell state'] layer: str (default: None) If specified, use adata.layers[layer] for expression values instead of adata.X. model_name: str (default: 'scAdam_model_tuning') Name of a folder to save tuned hyperparameters. storage: str (default: 'scadam_model_tuning.db') Database URL. If this argument is set to None, in-memory (RAM) storage is used, and the study will not be persistent. We don't recommend to use in-memory (RAM) storage to save optimization progress. study_name: str (default: 'study') Study’s name. If this argument is set to None, a unique name is generated automatically. load_if_exists: bool (default: True) Flag to control the behavior to handle a conflict of study names. In the case where a study named study_name already exists in the storage, a DuplicatedStudyError is raised if load_if_exists is set to False. Otherwise, the creation of the study is skipped, and the existing one is returned. If the value is True, allows hyperparameter tuning to continue if interrupted (keyboard interrupt, or OS update). eval_metric: str or list (default: ['balanced_accuracy']) Available evaluation metrics:'accuracy', 'balanced_accuracy', 'f1_score'. The last metric is used as the target and for early stopping. num_trials: int (default: 100) The number of trials to get optimized hyperparameters for model training. n_splits: int (default: 5) The number of data splits (folds) per trial. The data is divided into n_splits parts, where each part in turn is validation data, and the rest is training data. The number of folds determines the test_size. If n_splits = 5, then test_size = 0.2. If n_splits = 4, then test_size = 0.25. adaptive_loss: bool (default: True) If True, enables adaptive weighting of hierarchical loss levels: each level’s loss is tracked over training and its contribution to the total loss is increased if this level remains hard (high loss) and decreased if it becomes easy (low loss). This helps the model focus more on poorly performing levels of the hierarchy instead of weighting all levels equally. If False, all levels are summed with a fixed weight of 1.0. strategy: str (default: 'linear_offset') Weighting strategy for different cell type annotation levels. The following weighting strategies are available: linear, exponential, linear_offset, equal, last. linear: linear increase in weight from level to level. exponential: exponential increase in weight from level to level linear_offset: linear increase in weight from level to level with offset. equal: equal weight for all cell type annotation levels. last: uses only last cell type annotation for model evaluation. device: str (default: 'auto') Type of device to use in training model ('cpu', 'cuda'). Set 'auto' for automatic selection. tune_params: dict or 'auto' (default: 'auto') Dict specifying search spaces or "auto" to use built‑in defaults. Ranges and step for scAdam model and training parameters. Default tuning parameters are available using 'scparadise.scadam.get_default_tune_params'. The differences between setting parameters are available in '?scparadise.scadam.get_default_tune_params'. For a description of the parameters, see the 'scparadise.scadam.train' function. batch_size: int or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. epochs: int or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. patience: int or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. nc: int or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. nb: int or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. nh: int or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. ed_nh_ratio: int or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. ff_hd: int or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. classifier_hd: int or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. dropout: float or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. lr: float or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. weight_decay: float or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. use_augmentation: bool or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. aug_probability: float or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. prob: float or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. noise_std: float or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. dropout_aug: float or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. alpha: float or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. from_unsupervised: bool or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. Supervised model training included in function. pretrain_epochs: int or None (default: None) If a value is specified, then the tuning of this parameter will not be performed. pretrain_data: AnnData Anndata with the same expression matrix for unsupervised pretraining. Additional Anndata may not contain cell type annotations. random_state: int (default: 0) Controls the data shuffling, splitting to folds and model training. Pass an int for reproducible output across multiple function calls. verbose: bool (default: True) Show progress bar for each trail during hyperparameter tuning. """ # Set random state (for reproducibility) np.random.seed(random_state) torch.manual_seed(random_state) if torch.cuda.is_available(): torch.cuda.manual_seed_all(random_state) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Device selection device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device # Set default best_score if os.path.isfile(os.path.join(path, model_name, 'best_score.txt').replace("\\","/")): with open(os.path.join(path, model_name, 'best_score.txt').replace("\\","/")) as best_score: best_score = best_score.read() best_score = json.loads(best_score) else: best_score = 0 # Create folder to save tuning results of scAdam model os.makedirs(os.path.join(path, model_name).replace("\\","/"), exist_ok = True) # Get evaluation metric if isinstance(eval_metric, str): eval_metric = [eval_metric] if verbose: print(f"Device: {device}") if verbose and strategy != 'last': level_weights = celltype_level_weights(len(celltype_keys), strategy) # Create dataset dataset = scRNAseqDataset(adata, celltype_keys, layer=layer) gn = dataset.gn ncl = [dataset.celltype_encoders[key]['n_classes'] for key in celltype_keys] if verbose: print(f"Number of features: {gn}") print(f"Number of cells: {len(dataset)}") print(f"Label hierarchy: {' → '.join(celltype_keys)}") print(f"Annotation levels weights using strategy '{strategy}':") for key, n_celltypes, weight in zip(celltype_keys, ncl, level_weights): print(f" {key}: {n_celltypes} cell types, {round(weight, 3)} relative weight") print(f"Start model optimization using optuna...") # Train/val split indices = np.arange(len(dataset)) # Default tuning hyperparameters default_tune_params = { "nc": [2, 16, 2], "nb": [1, 8, 1], "nh": [2, 16, 2], "ed_nh_ratio": [8, 32, 4], "ff_hd": [128, 1024, 128], "classifier_hd": [128, 1024, 128], "dropout": [0.0, 0.5], "lr": [1e-5, 1e-2], "weight_decay": [1e-6, 1e-2], "batch_size": [64, 2048, 64], "patience": [5, 30, 5], "epochs": [50, 200, 5], "use_augmentation": [True, False], "aug_probability": [0.1, 1.0], "prob": [0.05, 0.4], "noise_std": [0.0, 0.4], "dropout_aug": [0.0, 0.4], "alpha": [0.0, 0.4], 'from_unsupervised': [True, False], 'pretrain_epochs': [10, 75, 5], } if tune_params == "auto": tune_params = default_tune_params # Function for training a single fold def train_fold( params, train_idx, val_idx, epochs, device, fold_id ): # Create dataloaders train_loader = DataLoader( Subset(dataset, train_idx), batch_size=params["batch_size"], shuffle=True, num_workers=0, drop_last=False, ) val_loader = DataLoader( Subset(dataset, val_idx), batch_size=params["batch_size"], shuffle=False, num_workers=0, drop_last=False, ) # Initialize model model = scAdamTransformer( gn=gn, ncl=ncl, ed=params["ed_nh_ratio"] * params["nh"], nc=params["nc"], nb=params["nb"], nh=params["nh"], ff_hd=params["ff_hd"], classifier_hd=params["classifier_hd"], dropout=params["dropout"], ).to(device) if params['from_unsupervised']: if pretrain_data is not None: pretrain_adata = pretrain_data else: pretrain_adata = adata model = pretrain_unsupervised( pretrain_adata, model, layer=layer, batch_size=params["batch_size"], epochs=params["pretrain_epochs"], lr=params["lr"], prob=params["prob"], device=device, random_state=random_state, verbose = False ) # Augmentation augmenter = None if params["use_augmentation"]: augmenter = dust.Augmentation( prob=params["prob"], noise_std=params["noise_std"], dropout_prob=params["dropout_aug"], alpha=params["alpha"], ) # Loss function criterion = HierarchicalLoss( num_levels=len(celltype_keys), celltype_keys=celltype_keys, alpha=0.25, gamma=2.0, adaptive=adaptive_loss ).to(device) # Optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=params["lr"], betas=(0.9, 0.95), weight_decay=params["weight_decay"], ) # Learning rate scheduler def lr_lambda(epoch): return 0.5 * (1.0 + np.cos(np.pi * epoch / max(params["epochs"], 1))) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # Early stopping early_stopping = dust.EarlyStopping( patience=params["patience"], delta=1e-4, mode="max", verbose=False ) # Metric functions available_metric_functions = { 'accuracy': accuracy_score, 'balanced_accuracy': balanced_accuracy_score, 'f1_score': lambda y_true, y_pred: f1_score(y_true, y_pred, average='weighted', zero_division=0) } # Create dictionary of used for training metric functions metric_functions = {} for metric in eval_metric: if metric not in available_metric_functions: raise ValueError(f"Unknown metric: {metric}") else: metric_functions[metric] = available_metric_functions[metric] # Create training history dictionary history = { 'train_loss': [], 'val_loss': [], 'train_metrics': {key: {metric: [] for metric in eval_metric} for key in celltype_keys}, 'val_metrics': {key: {metric: [] for metric in eval_metric} for key in celltype_keys}, 'lr': [] } best_val = None # Training loop for epoch in range(params["epochs"]): model.train() train_loss = 0.0 train_recon_loss = 0.0 train_preds = {key: [] for key in celltype_keys} train_targets = {key: [] for key in celltype_keys} for batch_x, batch_y in train_loader: batch_x = batch_x.to(device) batch_y = {key: val.to(device) for key, val in batch_y.items()} # Augmantation block if use_augmentation and np.random.random() < aug_probability: batch_x = augmenter(batch_x) # Get predictions using model predictions = model(batch_x) # Calculate model loss loss, level_losses = criterion(predictions, batch_y) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() train_loss += loss.item() # Predictions and targets with torch.no_grad(): for level_idx, key in enumerate(celltype_keys): level_key = f'level_{level_idx}' preds = predictions[level_key]['logits'].argmax(dim=1) train_preds[key].append(preds.cpu()) train_targets[key].append(batch_y[key].cpu()) train_loss /= len(train_loader) history['train_loss'].append(train_loss) # Compute training metrics for key in celltype_keys: train_preds[key] = torch.cat(train_preds[key]).numpy() train_targets[key] = torch.cat(train_targets[key]).numpy() for metric_name, metric_func in metric_functions.items(): score = metric_func(train_targets[key], train_preds[key]) history['train_metrics'][key][metric_name].append(score) # Validation model.eval() val_loss = 0.0 val_preds = {key: [] for key in celltype_keys} val_targets = {key: [] for key in celltype_keys} with torch.no_grad(): for batch_x, batch_y in val_loader: batch_x = batch_x.to(device) batch_y = {key: val.to(device) for key, val in batch_y.items()} # Forward pass predictions = model(batch_x) # Compute loss loss, level_losses = criterion(predictions, batch_y) val_loss += loss.item() # Store predictions and targets for level_idx, key in enumerate(celltype_keys): level_key = f'level_{level_idx}' preds = predictions[level_key]['logits'].argmax(dim=1) val_preds[key].append(preds.cpu()) val_targets[key].append(batch_y[key].cpu()) val_loss /= len(val_loader) history['val_loss'].append(val_loss) # Validation metrics for key in celltype_keys: val_preds[key] = torch.cat(val_preds[key]).numpy() val_targets[key] = torch.cat(val_targets[key]).numpy() for metric_name, metric_func in metric_functions.items(): score = metric_func(val_targets[key], val_preds[key]) history['val_metrics'][key][metric_name].append(score) # Learning rate current_lr = optimizer.param_groups[0]['lr'] history['lr'].append(current_lr) # Calculate validation metrics if strategy == 'last': fold_score = history['val_metrics'][celltype_keys[-1]][eval_metric][-1] else: fold_score = weighted_metric( history=history, celltype_keys=celltype_keys, metric=eval_metric[-1], strategy=strategy ) if best_val is None or fold_score > best_val: best_val = fold_score # early stop on chosen metric if early_stopping(fold_score, model): break scheduler.step() # Restore best weights after early stopping early_stopping.load_bm(model) if verbose: print(f"Fold {fold_id} finished with {eval_metric[-1]} value = {best_val:.6f}", flush=True) return float(best_val) # Function for processing integer hyperparameters def int_param(name, fixed_value, trial): if fixed_value is not None: return trial.suggest_int(name, fixed_value, fixed_value) if name in tune_params: lo, hi, step = tune_params[name][0], tune_params[name][1], tune_params[name][2] return trial.suggest_int(name, lo, hi, step=step) # Function for processing float hyperparameters def float_param(name, fixed_value, trial, log=False): if fixed_value is not None: return trial.suggest_float(name, fixed_value, fixed_value, log=log) if name in tune_params: lo, hi = tune_params[name][0], tune_params[name][1] return trial.suggest_float(name, lo, hi, log=log) # Function for processing categorial hyperparameters def categorical_param(name, fixed_value, trial): if fixed_value is not None: return trial.suggest_categorical(name, [fixed_value, fixed_value]) if name in tune_params: return trial.suggest_categorical(name, tune_params[name]) # Function to get hyperparameters in a trial def suggest_params(trial): # helper to suggest values either from tune_params or from fixed user-specified ones params = {} # Model initialization parameters params["nc"] = int_param("nc", nc, trial) params["nb"] = int_param("nb", nb, trial) params["nh"] = int_param("nh", nh, trial) params["ed_nh_ratio"] = int_param("ed_nh_ratio", ed_nh_ratio, trial) params["ff_hd"] = int_param("ff_hd", ff_hd, trial) params["classifier_hd"] = int_param("classifier_hd", classifier_hd, trial) params["dropout"] = float_param("dropout", dropout, trial, log=False) # Optimizer parameters params["lr"] = float_param("lr", lr, trial, log=True) params["weight_decay"] = float_param("weight_decay", weight_decay, trial, log=True) # Training params["batch_size"] = int_param("batch_size", batch_size, trial) params["patience"] = int_param("patience", patience, trial) params["epochs"] = int_param("epochs", epochs, trial) # Augmentation parameters params["use_augmentation"] = categorical_param("use_augmentation", use_augmentation, trial) params["aug_probability"] = float_param("aug_probability", aug_probability, trial, log=False) params["prob"] = float_param("prob", prob, trial, log=False) params["noise_std"] = float_param("noise_std", noise_std, trial, log=False) params["dropout_aug"] = float_param("dropout_aug", dropout_aug, trial, log=False) params["alpha"] = float_param("alpha", alpha, trial, log=False) # Pretraining params["from_unsupervised"] = categorical_param("from_unsupervised", from_unsupervised, trial) params["pretrain_epochs"] = int_param("pretrain_epochs", pretrain_epochs, trial) return params # Function for define objective and params def objective(trial, best_score = best_score): # Get trial params params = suggest_params(trial) # N splits of data for K fold cross validation skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state) split_iter = skf.split(indices, adata.obs[celltype_keys[-1]]) # Using last (most detailed) annotation level fold_scores = [] for fold_id, (train_idx, val_idx) in enumerate(split_iter, start=1): score = train_fold(params, train_idx, val_idx, epochs, device, fold_id) fold_scores.append(score) # pruning hook trial.report(score, step=fold_id) if trial.should_prune(): raise optuna.TrialPruned() # Get average score between folds score = float(np.mean(fold_scores)) if score < best_score: best_score = score # Write best_params to model folder with open(os.path.join(path, model_name, 'best_params.txt').replace("\\","/"), 'w') as f: f.write(json.dumps(params)) with open(os.path.join(path, model_name, 'best_score.txt').replace("\\","/"), 'w') as f: f.write(json.dumps(best_score)) return score # Study storage storage_url = None if storage is not None: # optuna expects URL-like string or sqlite if storage.startswith("sqlite:///") or storage.startswith("postgresql://") or storage.startswith("mysql://"): storage_url = storage else: storage_url = "sqlite:///" + os.path.join(path, model_name, storage).replace("\\","/") # Create optuna study study = optuna.create_study( direction = "maximize", study_name = study_name, storage = storage_url, load_if_exists = load_if_exists, pruner = optuna.pruners.HyperbandPruner() ) # Set default parameters params_default = { "nc": 16, "nb": 5, "nh": 8, "ed_nh_ratio": 32, "ff_hd": 512, "classifier_hd": 256, "dropout": 0.3, "lr": 1e-4, "weight_decay": 1e-4, "batch_size": 128, "patience": 10, "epochs": 200, "use_augmentation": True, "aug_probability": 0.5, "prob": 0.15, "noise_std": 0.1, "dropout_aug": 0.1, "alpha": 0.2, "from_unsupervised": True, "pretrain_epochs": 50 } # Enqueue a trial which uses the default parameters if not study.trials: study.enqueue_trial(params_default) # Restart optimization trials = study.get_trials(deepcopy=False) if len(trials) > 0: last = trials[-1] if last.state in (optuna.trial.TrialState.FAIL, optuna.trial.TrialState.RUNNING): if len(last.params) > 0: if verbose: print(f"Re-enqueue last interrupted trial: Trial number {last.number} (will run as new trial).", flush=True) study.enqueue_trial(last.params) # Study optimization study.optimize(objective, n_trials=num_trials, n_jobs=1) best_params = dict(study.best_params) best_score = float(study.best_value) # Save best parameters and score with open(os.path.join(path, model_name, 'best_params.txt').replace("\\","/"), "w") as f: f.write(json.dumps(best_params)) with open(os.path.join(path, model_name, 'best_score.txt').replace("\\","/"), "w") as f: f.write(json.dumps({"best_value": best_score, "eval_metric": eval_metric[-1]})) if verbose: print(f"Best value ({eval_metric[-1]}) = {best_score}") print(f"Best hyperparameters saved to: {os.path.join(path, model_name, 'best_params.txt')}") return best_params
# Function for training model using parameters tuned by scparadise.scadam.hyperparameter_tuning
[docs] def train_tuned( adata, celltype_keys, layer = None, path = '', path_tuned = '', model_name = 'scAdam_model_tuned', test_size = 0.2, eval_metric = ['accuracy', 'balanced_accuracy'], strategy = 'linear_offset', batch_size = None, epochs = None, patience = None, nc = None, nb = None, nh = None, ed_nh_ratio = None, ff_hd = None, classifier_hd = None, dropout = None, lr = None, weight_decay = None, adaptive_loss = True, use_augmentation = None, aug_probability = None, prob = None, noise_std = None, dropout_aug = None, alpha = None, from_unsupervised = None, pretrain_epochs = None, pretrain_data = None, device = 'auto', random_state = 0, return_model = False, unknown_detection = True, verbose = True ): """ Train custom scAdam model with tuned hyperparameters. The function automatically uses the configured hyperparameters. However, you can change any hyperparameter by passing it via the corresponding parameter. Parameters ---------- adata : AnnData Dataset with cell type annotations in adata.obs path: str, path object Path to create a model folder containing the training history, cell annotation dictionary, and genes used for training. path_tuned: str, path object Path to folder with tuned parameters by 'scparadise.scadam.hyperparameter_tuning' function. model_name: str (default: 'scAdam_model_tuned') Name of a folder to save model. celltype_keys: list List of cell type annotations in adata.obs. Example: ['lineage', 'cell type', 'cell state'] layer: str (default: None) If specified, use adata.layers[layer] for expression values instead of adata.X. test_size: float or int (default: 0.2) If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test cells. epochs: int (default: None) Maximum number of epochs for scAdam model training. If specified, the specified value is used. eval_metric: str or list (default: ['accuracy', 'balanced_accuracy']) Available evaluation metrics:'accuracy', 'balanced_accuracy', 'f1_score'. The last metric is used as the target and for early stopping. batch_size: int, (default: None) Number of examples per batch. If specified, the specified value is used. patience: int (default: None) Number of consecutive epochs without improvement before performing early stopping. If patience is set to 0, then no early stopping will be performed. Note that if patience is enabled, then best weights from best epoch will automatically be loaded at the end of the training. If specified, the specified value is used. use_augmentation: bool (default: None) Use data augmentation or not. If specified, the specified value is used. aug_probability: float (default: None) The probability of applying augmentation to a batch. If specified, the specified value is used. prob: float (default: None) Gene masking probability. If specified, the specified value is used. noise_std: float (default: None) Gaussian noise standard deviation. If specified, the specified value is used. dropout_aug: float (default: None) Dropout probability for simulating technical noise. If specified, the specified value is used. alpha: float (default: None) Alpha parameter for mixup augmentation. If specified, the specified value is used. nc: int (default: None) Number of chunks for genes from adata. If specified, the specified value is used. nb: int (default None) Number of blocks in scAdam model. If specified, the specified value is used. nh: int (default: None) Number of heads in scAdam model attention mechanism. If specified, the specified value is used. ed_nh_ratio: int (default: None) Used for calculating embedding dimensionality ('ed') from 'nh'. If specified, the specified value is used. ff_hd: int (default: None) Number of nodes in each scAdam model layer in feed forward network. If specified, the specified value is used. classifier_hd: int (default: None) Number of nodes in scAdam classifier. If specified, the specified value is used. dropout: float (default: None) Portion of neurons that temporarily ignored during training (prevents overfitting). If specified, the specified value is used. lr: float (default: None) Determines the step size at each iteration while moving toward a minimum of a loss function. If specified, the specified value is used. weight_decay: float (default: None) Weight decay coefficient. If specified, the specified value is used. from_unsupervised: bool (default: None) Use a previously self supervised model as starting weights. Supervised model training included in function. pretrain_epochs: int (default: None) Number of pretraining epochs. pretrain_data: AnnData (default: None) Anndata with the same expression matrix for unsupervised pretraining. Additional Anndata may not contain cell type annotations. device: str (default: 'auto') Type of device to use in training model ('cpu', 'cuda'). Set 'auto' for automatic selection. random_state: int (default: 0) Controls the data shuffling, splitting to folds and model training. Pass an int for reproducible output across multiple function calls. verbose: bool (default: True) Show progress bar for each epoch during training. return_model: bool (default: False) Return model after training or not. """ # Create new directory with model and list of genes if not os.path.exists(os.path.join(path, model_name).replace("\\","/")): os.makedirs(os.path.join(path, model_name).replace("\\","/")) # load parameters for Adam model training with open(os.path.join(path_tuned, 'best_params.txt')) as params: params = params.read() params = json.loads(params) print('Successfully loaded tuned hyperparameters!') # Dictionary of given parameters for a function params_given = { 'epochs': epochs, 'batch_size': batch_size, 'patience': patience, 'use_augmentation': use_augmentation, 'aug_probability': aug_probability, 'prob': prob, 'noise_std': noise_std, 'dropout_aug': dropout_aug, 'alpha': alpha, 'nc': nc, 'nb': nb, 'nh': nh, 'ed_nh_ratio': ed_nh_ratio, 'ff_hd': ff_hd, 'classifier_hd': classifier_hd, 'dropout': dropout, 'lr': lr, 'weight_decay': weight_decay, 'from_unsupervised': from_unsupervised, 'pretrain_epochs': pretrain_epochs } # Dictionary of default parameters params_default = { 'epochs': 200, 'batch_size': 128, 'patience': 10, 'use_augmentation': True, 'aug_probability': 0.5, 'prob': 0.15, 'noise_std': 0.1, 'dropout_aug': 0.1, 'alpha': 0.2, 'nc': 4, 'nb': 4, 'nh': 8, 'ed_nh_ratio': 32, 'ff_hd': 512, 'classifier_hd': 512, 'dropout': 0.3, 'lr': 1e-4, 'weight_decay': 1e-4, 'from_unsupervised': True, 'pretrain_epochs': 50 } # Replace param in loaded parameters with a given value for i in params_given.keys(): if params_given[i] is not None: params[i] = params_given[i] # Check params for None for i in params.keys(): if params[i] is None: params[i] = params_default[i] # Train model with loaded parameters (corrected if given) model = train( adata = adata, celltype_keys = celltype_keys, layer = layer, path = path, model_name = model_name, test_size = test_size, strategy = strategy, epochs = params['epochs'], eval_metric = eval_metric, batch_size = params['batch_size'], patience = params['patience'], use_augmentation = params['use_augmentation'], aug_probability = params['aug_probability'], prob = params['prob'], noise_std = params['noise_std'], dropout_aug = params['dropout_aug'], alpha = params['alpha'], nc = params['nc'], nb = params['nb'], nh = params['nh'], ed_nh_ratio = params['ed_nh_ratio'], ff_hd = params['ff_hd'], classifier_hd = params['classifier_hd'], dropout = params['dropout'], lr = params['lr'], weight_decay = params['weight_decay'], from_unsupervised = params['from_unsupervised'], pretrain_epochs = params['pretrain_epochs'], pretrain_data = pretrain_data, return_model = True, unknown_detection = unknown_detection, device = device, random_state = random_state, verbose = verbose ) if return_model: return model
# Function for creation of dataset for warm-start training class scRNAseqDataset_warm_start(Dataset): """ Warm-start dataset: - Aligns new_adata genes to model.var_names preserving order, missing genes -> zeros) - Encodes labels using pretrained encoders (no refit), so classifier output dims stay valid """ def __init__( self, adata, celltype_keys, model_var_names, celltype_encoders, layer=None, allow_unseen_labels=False, ): self.adata = adata self.celltype_keys = list(celltype_keys) self.layer = layer self.obs_names = adata.obs_names.tolist() self.var_names = list(model_var_names) self.gn = len(self.var_names) # align expression matrix to model genes if layer is not None and layer in adata.layers: X = adata.layers[layer] else: X = adata.X X = X.toarray() if hasattr(X, 'toarray') else X new_var_names = adata.var_names.tolist() idx_map = {g: i for i, g in enumerate(new_var_names)} n_cells = adata.n_obs n_model_genes = self.gn X_aligned = np.zeros((n_cells, n_model_genes), dtype=np.float32) matched = 0 for j, g in enumerate(self.var_names): i = idx_map.get(g, None) if i is None: continue matched += 1 X_aligned[:, j] = X[:, i] self.matched_genes = matched self.match_fraction = matched / max(n_model_genes, 1) # min-max per gene as in scRNAseqDataset if X_aligned.min(0).any(): X_aligned = np.maximum(X_aligned, 0) X_norm = (X_aligned - X_aligned.min(0)) / (np.ptp(X_aligned, axis=0) + 1e-10) self.X = torch.FloatTensor(X_norm) assert not torch.isnan(self.X).any(), "NaN in adata data!" assert not torch.isinf(self.X).any(), "Inf in adata data!" # Labels encoding using pretrained models encoders self.labels = {} self.celltype_encoders = celltype_encoders for key in self.celltype_keys: if key not in adata.obs.columns: raise ValueError(f"'{key}' not found in adata.obs") enc = celltype_encoders[key]["celltype_encoder"] y_str = adata.obs[key].astype(str).values known = set(enc.classes_.tolist()) unseen = sorted(set(y_str.tolist()) - known) if len(unseen) > 0 and not allow_unseen_labels: raise ValueError( f"Unseen labels in '{key}': {unseen[:10]}{'...' if len(unseen) > 10 else ''}. " f"Warm-start requires the same label space as the pretrained model." ) if len(unseen) > 0 and allow_unseen_labels: fallback = enc.classes_[0] y_str = np.array([v if v in known else fallback for v in y_str], dtype=object) y_enc = enc.transform(y_str) self.labels[key] = torch.LongTensor(y_enc) def __len__(self): return self.X.shape[0] def __getitem__(self, idx): x = self.X[idx] y = {key: self.labels[key][idx] for key in self.celltype_keys} return x, y # Function for warm-start scAdam model fine-tuning
[docs] def warm_start( adata, path_model, celltype_keys=None, layer=None, path='', model_name="scAdam_model_warm_start", test_size=0.2, eval_metric=["accuracy", "balanced_accuracy"], strategy="linear_offset", batch_size=128, epochs=100, patience=10, lr=5e-5, weight_decay=1e-4, use_augmentation=True, aug_probability=0.5, prob=0.15, noise_std=0.1, dropout_aug=0.1, alpha=0.2, adaptive_loss=True, freeze_transformer=False, allow_unseen_labels=False, unknown_detection=True, device="auto", random_state=0, return_model=False, verbose=True ): """ Warm-start fine-tuning of an existing scAdam model on new data. Warm-start training is a technique in machine learning that involves initializing a model with parameters or states learned from a previously trained model. adata: AnnData New dataset with cell type annotations in adata.obs path_model: str, path object Path to a model folder containing pretrained scAdam model. path: str, path object Path to create a model folder containing the training history, cell annotation dictionary, and genes used for scAdam model warm start training. celltype_keys: list List of cell type annotations in adata.obs. Example: ['lineage', 'cell type', 'cell state'] layer: str (default: None) If specified, use adata.layers[layer] for expression values instead of adata.X. model_name: str (default: 'scAdam_model_warm_start') Name of a folder to save model. test_size: float or int (default: 0.2) If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test cells. batch_size: int, (default: 128) Number of examples per batch. epochs: int (default: 150) Maximum number of epochs for scAdam model training patience: int (default: 10) Number of consecutive epochs without improvement before performing early stopping. If patience is set to 0, then no early stopping will be performed. Note that if patience is enabled, then best weights from best epoch will automatically be loaded at the end of the training. eval_metric: str or list (default: ['accuracy', 'balanced_accuracy']) Available evaluation metrics:'accuracy', 'balanced_accuracy', 'f1_score'. The last metric is used as the target and for early stopping. freeze_transformer: bool (default = False) If True, freezes the transformer backbone (gene embedding + transformer blocks) and trains only the classifier head. If False, fine-tunes the full model. allow_unseen_labels: bool (default = False) If False, raise an error if new data contains labels not present in the pretrained `LabelEncoder` for any level. If True, unseen labels are mapped to a fallback known label to keep the label space unchanged. unknown_detection: bool (default: True) Train unknown cell detector - identifies unknown cells when model used for prediction on a new data. verbose: bool (default: True) Show progress bar for each epoch during training. return_model: bool (default: False) Return model after training or not. Returns ------- Saves fine-tuned scAdam model for cell type annotation. """ # Set random state (for reproducibility) np.random.seed(random_state) torch.manual_seed(random_state) if torch.cuda.is_available(): torch.cuda.manual_seed_all(random_state) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Device selection device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device # load model model = load_model(path_model, device=device, verbose=verbose) model = model.to(device) if celltype_keys is None: if not hasattr(model, "celltype_keys") or model.celltype_keys is None: raise ValueError("celltype_keys is None and model has no celltype_keys metadata.") celltype_keys = list(model.celltype_keys) if verbose: print(f"Device: {device}") if verbose and strategy != 'last': level_weights = celltype_level_weights(len(celltype_keys), strategy) # Align genes and create datasest dataset = scRNAseqDataset_warm_start( adata, celltype_keys=celltype_keys, model_var_names=model.var_names, celltype_encoders=model.celltype_encoders, layer=layer, allow_unseen_labels=allow_unseen_labels, ) ncl = [dataset.celltype_encoders[key]['n_classes'] for key in celltype_keys] if verbose: print("Gene alignment:") print(f" Number of features: {dataset.gn}") print(f" Matched features: {dataset.matched_genes} ({100*dataset.match_fraction:.1f}%)") if dataset.matched_genes < 0.80 * dataset.gn: warnings.warn( f"Only {100*dataset.match_fraction:.1f} genes matched. Fine-tuning may be unstable; " "Ensure consistent feature naming and data preprocessing (filtering features, reference genome, etc.)" ) if verbose: print(f"Label hierarchy: {' → '.join(celltype_keys)}") print(f"Annotation levels weights using strategy '{strategy}':") for key, n_celltypes, weight in zip(celltype_keys, ncl, level_weights): print(f" {key}: {n_celltypes} cell types, {round(weight, 3)} relative weight") # Sanity check: output dims must match checkpoint ncl ncl_expected = list(model.ncl) ncl_now = [len(model.celltype_encoders[k]["celltype_encoder"].classes_) for k in celltype_keys] if ncl_now != ncl_expected: raise ValueError( f"Mismatch cell type numbers: model cell types number = {ncl_expected}, given cell types number = {ncl_now}. " "This should not happen unless metadata is inconsistent." ) # Train/val split indices = np.arange(len(dataset)) train_idx, val_idx = train_test_split( indices, test_size=test_size, stratify=adata.obs[celltype_keys[-1]], # Using last (most detailed) annotation level random_state=random_state ) if verbose: print(f"\nDataset split:") print(f'Train dataset contains: {len(train_idx)} cells, it is {round(100*(len(train_idx)/(len(train_idx) + len(val_idx))), ndigits=2)} % of input dataset') print(f'Validation dataset contains: {len(val_idx)} cells, it is {round(100*(len(val_idx)/(len(train_idx) + len(val_idx))), ndigits=2)} % of input dataset') # Create dataloaders train_loader = DataLoader( Subset(dataset, train_idx), batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False ) val_loader = DataLoader( Subset(dataset, val_idx), batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False ) # freeze backbone (optional) if freeze_transformer: for p in model.gene_embedding.parameters(): p.requires_grad = False for blk in model.blocks: for p in blk.parameters(): p.requires_grad = False # keep classifier trainable for p in model.classifier.parameters(): p.requires_grad = True # Augmentation augmenter = None if use_augmentation: augmenter = dust.Augmentation( prob = prob, noise_std = noise_std, dropout_prob = dropout_aug, alpha = alpha ) # loss function criterion = HierarchicalLoss( num_levels=len(celltype_keys), celltype_keys=celltype_keys, alpha=0.25, gamma=2.0, adaptive=adaptive_loss ).to(device) # Optimizer over trainable params only params_to_train = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.AdamW( params_to_train, lr=lr, betas=(0.9, 0.95), weight_decay=weight_decay ) # Learning rate scheduler def lr_lambda(epoch): return 0.5 * (1.0 + np.cos(np.pi * (epoch / epochs))) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # Early stopping early_stopping = dust.EarlyStopping( patience=patience, delta=1e-4, mode='max', verbose=verbose ) # Available metric functions available_metric_functions = { 'accuracy': accuracy_score, 'balanced_accuracy': balanced_accuracy_score, 'f1_score': lambda y_true, y_pred: f1_score(y_true, y_pred, average='weighted', zero_division=0) } # Create dictionary of used for training metric functions metric_functions = {} for metric in eval_metric: if metric not in available_metric_functions: raise ValueError(f"Unknown metric: {metric}") else: metric_functions[metric] = available_metric_functions[metric] # Create training history dictionary history = { 'train_loss': [], 'val_loss': [], 'train_metrics': {key: {metric: [] for metric in eval_metric} for key in celltype_keys}, 'val_metrics': {key: {metric: [] for metric in eval_metric} for key in celltype_keys}, 'lr': [] } # Training Loop for epoch in tqdm(range(epochs), desc="Warm-start model fine-tuning", colour="blue", disable=not verbose): # Model training model.train() train_loss = 0.0 train_preds = {key: [] for key in celltype_keys} train_targets = {key: [] for key in celltype_keys} for batch_x, batch_y in train_loader: batch_x = batch_x.to(device) batch_y = {k: v.to(device) for k, v in batch_y.items()} # Augmentation only if random < aug_probability if use_augmentation and np.random.random() < aug_probability: batch_x = augmenter(batch_x) # Get predictions using model predictions = model(batch_x) # Calculate model loss loss, _ = criterion(predictions, batch_y) # Backward optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(params_to_train, max_norm=1.0) optimizer.step() train_loss += loss.item() # Predictions and targets with torch.no_grad(): for level_idx, key in enumerate(celltype_keys): level_key = f'level_{level_idx}' preds = predictions[level_key]['logits'].argmax(dim=1) train_preds[key].append(preds.cpu()) train_targets[key].append(batch_y[key].cpu()) train_loss /= len(train_loader) history['train_loss'].append(train_loss) # Compute training metrics for key in celltype_keys: train_preds[key] = torch.cat(train_preds[key]).numpy() train_targets[key] = torch.cat(train_targets[key]).numpy() for metric_name, metric_func in metric_functions.items(): score = metric_func(train_targets[key], train_preds[key]) history['train_metrics'][key][metric_name].append(score) # Model validation model.eval() val_loss = 0.0 val_preds = {key: [] for key in celltype_keys} val_targets = {key: [] for key in celltype_keys} with torch.no_grad(): for batch_x, batch_y in val_loader: batch_x = batch_x.to(device) batch_y = {k: v.to(device) for k, v in batch_y.items()} # Get predictions using model predictions = model(batch_x) # Calculate model loss loss, _ = criterion(predictions, batch_y) val_loss += loss.item() # Store predictions and targets for level_idx, key in enumerate(celltype_keys): level_key = f"level_{level_idx}" preds = predictions[level_key]["logits"].argmax(dim=1) val_preds[key].append(preds.cpu()) val_targets[key].append(batch_y[key].cpu()) val_loss /= len(val_loader) history['val_loss'].append(val_loss) # Validation metrics for key in celltype_keys: val_preds[key] = torch.cat(val_preds[key]).numpy() val_targets[key] = torch.cat(val_targets[key]).numpy() for metric_name, metric_func in metric_functions.items(): score = metric_func(val_targets[key], val_preds[key]) history['val_metrics'][key][metric_name].append(score) # Learning rate history["lr"].append(optimizer.param_groups[0]["lr"]) # Early Stopping if strategy == "last": es_score = history["val_metrics"][celltype_keys[-1]][eval_metric[-1]][-1] else: es_score = weighted_metric( history=history, celltype_keys=celltype_keys, metric=eval_metric[-1], strategy=strategy ) if early_stopping(es_score, model): break scheduler.step() # Load best model early_stopping.load_bm(model) model.eval() # Update metadata and training history model.celltype_keys = celltype_keys model.history = history # Remove loaders del train_loader, val_loader # Refit unknown detector on new data if unknown_detection: # Create detector detector = UnknownCellDetector() # Create full dataset loader for detector training full_loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False ) # Fit detector on full dataset detector.fit(model, full_loader, device=device) # Save detector model.unknown_detector = detector if verbose: print("UnknownCellDetector fitted successfully!") # Save warm start fine tuned model save_model(model, path, model_name, verbose) if return_model: return model
# Function to display available models in github
[docs] def available_models( ): ''' Download dataframe with available pretrained scAdam models. ''' models = pd.read_csv('https://raw.githubusercontent.com/Chechekhins/scParadise/main/scadam_available_models.csv', sep=',') return models
# Function for downloading tuned pretrained models from github
[docs] def download_model( model_name='', save_path='', github_username=None, github_token=None ): """ Download pretrained tuned model for highly accurate cell type annotation. Parameters ---------- model_name: str Name of the model from column 'model' from scparadise.scadam.available_models(). save_path: str, path object Path to save trained scAdam model. github_username: str Your GitHub username. github_token: str Token for GitHub API. """ # Create new directory with model save = os.path.join(save_path, model_name + '_scAdam').replace("\\", "/") os.makedirs(save, exist_ok=True) # read creds from args or env github_username = github_username or os.getenv("GITHUB_USERNAME") github_token = github_token or os.getenv("GITHUB_TOKEN") or os.getenv("GH_TOKEN") fs_kwargs = dict(org="Chechekhins", repo="scParadise") if github_username and github_token: fs_kwargs.update(username=github_username, token=github_token) fs = fsspec.filesystem("github", **fs_kwargs) # Download content of model remote_dir = os.path.join("models_scadam", model_name + "_scAdam").replace("\\", "/") fs.get(fs.ls(remote_dir), save)