import warnings
warnings.filterwarnings("ignore")
import os
import json
import pandas as pd
import numpy as np
import anndata
import optuna
import copy
import fsspec
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
from typing import Dict, List, Optional, Tuple, Literal, Union
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from skimage.filters import threshold_otsu
from sklearn.mixture import GaussianMixture
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score
from sklearn.model_selection import StratifiedKFold
import optuna
from optuna.pruners import MedianPruner
# load scParadise dust module
import scparadise.dust as dust
# Dataset Class
class scRNAseqDataset(Dataset):
def __init__(self, adata, celltype_keys, layer):
"""
Function for converting AnnData objects with hierarchical annotations to a scAdam model compatible format.
Parameters
----------
adata: AnnData
Dataset with cell type annotations in adata.obs
celltype_keys: list
List of cell type annotations in adata.obs.
Example: ['lineage', 'cell type', 'cell state']
layer: str (default: None)
If specified, use new_adata.layers[layer] for expression values instead of new_adata.X.
Internal abbreviations:
gn: number of genes in adata.var_names
labels: unique cell types in annotation level
celltype_encoders: encoder for each annotation level
var_names: feature (gene) names
obs_names: barcode (cell) names
"""
self.adata = adata
self.celltype_keys = celltype_keys
self.layer = layer
self.obs_names = adata.obs_names.tolist()
self.var_names = adata.var_names.tolist()
self.gn = len(adata.var_names.tolist())
self.labels = {}
self.celltype_encoders = {}
# Check cell type annotations in adata.obs
for key in celltype_keys:
if key not in adata.obs.columns:
raise ValueError(f"'{key}' not found in adata.obs")
# Get gene expression data from adata from layer or X
if layer is not None and layer in adata.layers:
X = adata.layers[layer]
else:
X = adata.X
X = X.toarray() if hasattr(X, 'toarray') else X
if X.min(0).any():
X = np.maximum(X, 0)
X_norm = (X - X.min(0)) / (np.ptp(X, axis=0) + 1e-10)
else:
X_norm = (X - X.min(0)) / (np.ptp(X, axis=0) + 1e-10)
self.X = torch.FloatTensor(X_norm)
assert not torch.isnan(self.X).any(), "NaN in adata data!"
assert not torch.isinf(self.X).any(), "Inf in adata data!"
# Hierarchical labels
for level_idx, key in enumerate(celltype_keys):
labels = adata.obs[key].astype(str).values
encoder = LabelEncoder()
encoded_celltypes = encoder.fit_transform(labels)
self.celltype_encoders[key] = {
'celltype_encoder': encoder,
'n_classes': len(encoder.classes_),
'level': level_idx
}
self.labels[key] = torch.LongTensor(encoded_celltypes)
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
x = self.X[idx]
y = {key: self.labels[key][idx] for key in self.celltype_keys}
return x, y
# Hierarchical Classification scAdam model
class scAdamClassifier(nn.Module):
def __init__(self, input_dim, ncl, hd = 256, dropout = 0.2):
"""
Hierarchical scAdam classifier.
Parameters
----------
input_dim: int
Number nodes in first layer.
ncl: int
Number of cell types in single annotation level.
hd: int (default: 256)
Number of nodes in each layer. hd = hidden dim
dropout: float (default: 0.2)
Portion of neurons that temporarily ignored during training (prevents overfitting).
"""
super().__init__()
self.num_levels = len(ncl)
self.ncl = ncl
# Classifiers for each level
self.classifiers = nn.ModuleList()
for level_idx, n_classes in enumerate(ncl):
level_input_dim = input_dim
if level_idx > 0:
level_input_dim += sum(ncl[:level_idx])
classifier = nn.Sequential(
nn.Linear(level_input_dim, hd),
nn.LayerNorm(hd),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hd, hd // 2),
nn.LayerNorm(hd // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hd // 2, n_classes)
)
self.classifiers.append(classifier)
def forward(self, x):
outputs = {}
prev_probs = []
for level_idx, classifier in enumerate(self.classifiers):
if level_idx > 0:
level_input = torch.cat([x] + prev_probs, dim=1)
else:
level_input = x
logits = classifier(level_input)
probs = F.softmax(logits, dim=-1)
outputs[f'level_{level_idx}'] = {
'logits': logits,
'probs': probs
}
prev_probs.append(probs)
return outputs
# scAdam transformer model
class scAdamTransformer(nn.Module):
def __init__(
self,
gn,
ncl,
ed=256,
nc=16,
nb=5,
nh=8,
ff_hd=512,
classifier_hd=256,
dropout=0.3
):
"""
Hierarchical transformer scAdam model for cell type annotation.
Includes Transformer model and scAdam classifier.
Parameters
----------
gn: int
Number of genes in adata.var_names.
ncl: int
Number of cell types in single annotation level.
ed: int (default: 256)
Embedding dimensionality.
nb: int (default 5)
Number of blocks in scAdam model.
nc: int (default: 16)
Number of chunks for genes from adata.
nh: int (default: 5)
Number of heads in scAdam model attention mechanism.
ff_hd: int (default: 512)
Number of nodes in each scAdam model layer in feed forward network.
classifier_hd: int (default: 256)
Number of nodes in each scAdam classifier.
dropout: float (default: 0.3)
Portion of neurons that temporarily ignored during training (prevents overfitting).
"""
super().__init__()
self.gn = gn
self.ncl = ncl
self.num_levels = len(ncl)
self.ed = ed
self.nc = nc
self.nb = nb
self.nh = nh
self.ff_hd = ff_hd
self.classifier_hd = classifier_hd
self.dropout = dropout
# Gene embedding with chunking
self.gene_embedding = dust.Embedding(gn, ed, nc)
# Transformer blocks
self.blocks = nn.ModuleList([
dust.RibBlock(ed, nh, ff_hd, dropout)
for block in range(nb)
])
# Hierarchical classifier
self.classifier = scAdamClassifier(ed, ncl, classifier_hd, dropout)
# Layer norm
self.norm = nn.LayerNorm(ed)
def forward(self, x, return_attention=False):
# Gene embedding
x = self.gene_embedding(x)
# Transformer blocks
attention_weights = []
for block in self.blocks:
x, attn = block(x)
if return_attention:
attention_weights.append(attn)
# Global average pooling over chunks
x = self.norm(x)
x = x.mean(dim=1)
# Hierarchical classification
outputs = self.classifier(x)
if return_attention:
return outputs, attention_weights
return outputs
# Adaptive Hierarchical Focal Loss
class HierarchicalLoss(nn.Module):
def __init__(
self,
num_levels,
celltype_keys,
alpha=0.25,
gamma=2.0,
adaptive=True
):
"""
Adaptive hierarchical loss for scAdam model training.
"""
super().__init__()
self.num_levels = num_levels
self.celltype_keys = celltype_keys
self.alpha = alpha
self.gamma = gamma
self.adaptive = adaptive
# Adaptive weights for each level
self.register_buffer('level_weights', torch.ones(num_levels))
self.register_buffer('level_losses', torch.zeros(num_levels))
self.update_counter = 0
def focal_loss(self, logits, targets, weight=None):
ce_loss = F.cross_entropy(logits, targets, reduction='none', weight=weight)
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()
def forward(self, predictions, targets, class_weights=None):
total_loss = 0.0
level_losses = {}
# Classification losses
for level_idx in range(self.num_levels):
level_key = f'level_{level_idx}'
original_key = self.celltype_keys[level_idx]
logits = predictions[level_key]['logits']
target = targets[original_key]
# Class weights
weight = class_weights.get(level_key) if class_weights else None
# Focal loss for level
loss = self.focal_loss(logits, target, weight)
level_losses[level_key] = loss
# Weighted sum
if self.adaptive:
total_loss += self.level_weights[level_idx] * loss
else:
total_loss += loss
# Update adaptive weights
if self.adaptive and self.training:
self.update_adaptive_weights(level_losses)
return total_loss, level_losses
def update_adaptive_weights(self, level_losses):
"""Update adaptive weights"""
self.update_counter += 1
# Exponential moving average of losses
for level_idx in range(self.num_levels):
level_key = f'level_{level_idx}'
self.level_losses[level_idx] = 0.9 * self.level_losses[level_idx] + \
0.1 * level_losses[level_key].detach()
# Update weights every 10 epochs
if self.update_counter % 10 == 0:
normalized_losses = self.level_losses / (self.level_losses.sum() + 1e-8)
self.level_weights = 1.0 / (normalized_losses + 1e-8)
# Unsupervised pretraining, analogous to masked language modeling in NLP
class Pretraining(nn.Module):
def __init__(self, encoder, gn, ed, prob = 0.15):
"""
Masked gene expression modeling for unsupervised pretraining.
Analogous to masked language modeling in NLP.
Parameters
----------
encoder: model
scAdam model.
gn: int
Number of genes in adata.var_names.
ed: int (default: 256)
embedding dimensionality.
prob: float (default: 0.15)
Probability of gene masking.
"""
super().__init__()
self.encoder = encoder
self.gn = gn
self.ed = ed
self.prob = prob
# Decoder
self.decoder = nn.Sequential(
nn.Linear(ed, ed * 2),
nn.GELU(),
nn.LayerNorm(ed * 2),
nn.Dropout(0.1),
nn.Linear(ed * 2, ed),
nn.GELU(),
nn.LayerNorm(ed),
nn.Dropout(0.1)
)
# Projection head
self.gene_projection = nn.Linear(ed, gn)
def create_mask(self, x):
B, N = x.shape
mask = torch.rand(B, N, device=x.device) < self.prob
x_masked = x.clone()
x_masked[mask] = 0.0
return x_masked, mask
def forward(self, x):
# Create mask
x_masked, mask = self.create_mask(x)
# Encode data
encoded = self.encoder.gene_embedding(x_masked)
for block in self.encoder.blocks:
encoded, _ = block(encoded)
encoded = self.encoder.norm(encoded)
# Global pooling
encoded = encoded.mean(dim=1)
# Decode
decoded = self.decoder(encoded)
# Reconstruct gene expression
reconstructed = self.gene_projection(decoded)
return reconstructed, mask, x
# Function for unsupervised pretraining
def pretrain_unsupervised(
adata,
model,
layer = None,
batch_size = 128,
epochs = 50,
lr = 1e-4,
prob = 0.15,
device = 'auto',
random_state = 42,
verbose = True
):
"""
Unsupervised pretraining usnig masked gene expression modeling.
Parameters
----------
adata: AnnData object
Could be AnnData without cell type annotations.
model: scAdam model
layer: str (default: None)
If specified, use adata.layers[layer] for expression values instead of adata.X.
batch_size: int (default: 128)
Number of samples per batch during training.
epochs: int (default: 500)
Maximum number of training epochs.
lr: float (default: 1e-4)
Controls how much a model's parameters are updated during training.
prob: float (default: 0.15)
Probability of gene masking.
device: str (default: 'auto')
Type of device to use for model training ('cpu' or 'cuda'). Set to 'auto' for automatic selection.
random_state: int (default: 1)
Seed for random number generators to ensure reproducibility.
verbose: bool (default: True)
Display progress.
Returns
-------
Pretrained scAdam model.
"""
# Fix random state for reproducibility
np.random.seed(random_state)
torch.manual_seed(random_state)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(random_state)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device
# Move model to device
model = model.to(device)
# Create pretraining wrapper
pretraining_model = Pretraining(
encoder=model,
gn=model.gn,
ed=model.gene_embedding.norm.normalized_shape[0],
prob=prob
).to(device)
# Get gene expression data from adata from layer or X
if layer is not None and layer in adata.layers:
X = adata.layers[layer]
else:
X = adata.X
X = X.toarray() if hasattr(X, 'toarray') else X
if X.min(0).any():
X = np.maximum(X, 0)
X_norm = (X - X.min(0)) / (np.ptp(X, axis=0) + 1e-10)
else:
X_norm = (X - X.min(0)) / (np.ptp(X, axis=0) + 1e-10)
X_tensor = torch.FloatTensor(X_norm)
# Create DataLoader
dataset = torch.utils.data.TensorDataset(X_tensor)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Optimizer
optimizer = torch.optim.AdamW(
pretraining_model.parameters(),
lr=lr,
weight_decay=1e-5
)
# MSE loss function
criterion = nn.MSELoss()
# Training loop
pretraining_model.train()
for epoch in tqdm(range(epochs), desc='Unsupervised pretraining', colour='blue', disable = not verbose):
total_loss = 0.0
for (batch_x, ) in loader:
batch_x = batch_x.to(device)
# Forward
reconstructed, mask, original = pretraining_model(batch_x)
# Compute loss only using masked genes
loss = criterion(reconstructed[mask], original[mask])
# Backward
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(pretraining_model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(loader)
return model
# Function to get weighted metric across all annotation levels
def weighted_metric(
history,
celltype_keys,
metric,
strategy = 'linear_offset'
):
"""
Calculates a weighted metric for the quality of model training across all levels of the hierarchy.
Parameters
----------
history: dict
Learning history with metrics by level.
celltype_keys: list
List of cell type annotations in adata.obs.
Example: ['lineage', 'cell type', 'cell state']
metric: str (default: 'balanced_accuracy')
Classification metric: 'accuracy', 'balanced_accuracy', 'f1_score'
strategy: str (default: 'linear_offset')
Weighting strategy for different cell type annotation levels.
The following weighting strategies are available: linear, exponential, linear_offset, equal.
linear: linear increase in weight from level to level.
exponential: exponential increase in weight from level to level
linear_offset: linear increase in weight from level to level with offset.
equal: equal weight for all cell type annotation levels.
Returns
-------
Weighted score from 0 to 1.
"""
num_levels = len(celltype_keys)
scores = []
# Get scores for each annotation level
for key in celltype_keys:
score = history['val_metrics'][key][metric][-1]
scores.append(score)
# Calculate weights
if strategy == 'linear':
weights = list(range(1, num_levels + 1))
elif strategy == 'exponential':
weights = [2**i for i in range(num_levels)]
elif strategy == 'linear_offset':
weights = [1.0 + (i / 2) for i in range(num_levels)]
elif strategy == 'equal':
weights = [1.0 for i in range(num_levels)]
else:
raise ValueError(f"Unknown weighting strategy: {strategy}")
# Normalize weights
total_weight = sum(weights)
normalized_weights = [w / total_weight for w in weights]
# Weighted mean
weighted_score = sum(s * w for s, w in zip(scores, normalized_weights))
return weighted_score
def celltype_level_weights(num_levels, strategy = 'exponential'):
"""
Uses different weighting strategies for multiple cell type annotation levels.
Parameters
----------
num_levels: int
Number of cell type annotation levels
strategy: str (default: 'linear_offset')
Weighting strategy for different cell type annotation levels.
The following weighting strategies are available: linear, exponential, linear_offset, equal.
linear: linear increase in weight from level to level.
exponential: exponential increase in weight from level to level
linear_offset: linear increase in weight from level to level with offset.
equal: equal weight for all cell type annotation levels.
Examples
--------
celltype_level_weights(3, 'linear')
[0.167, 0.333, 0.500] # 1/6, 2/6, 3/6
celltype_level_weights(3, 'exponential')
[0.143, 0.286, 0.571] # 1/7, 2/7, 4/7
celltype_level_weights(3, 'linear_offset')
[0.222, 0.333, 0.444] # (1 + 0/2)/4.5, (1 + 1/2)/4.5, (1 + 2/2)/4.5
celltype_level_weights(3, 'equal')
[0.333, 0.333, 0.333] # 1/3, 1/3, 1/3
"""
if strategy == 'linear':
weights = list(range(1, num_levels + 1))
elif strategy == 'exponential':
weights = [2**i for i in range(num_levels)]
elif strategy == 'linear_offset':
weights = [1.0 + (i / 2) for i in range(num_levels)]
elif strategy == 'equal':
weights = [1.0 for i in range(num_levels)]
else:
raise ValueError(f"Unknown weighting strategy: {strategy}")
total = sum(weights)
return [w / total for w in weights]
# Save scAdam model function
def save_model(model, path, model_name, verbose=True):
"""
Function for scAdam model saving.
Saves model in a folder 'model_name'.
"""
save_dict = {
'state_dict': model.state_dict(),
'gn': model.gn,
'ncl': model.ncl,
'ed': model.ed,
'nc': model.nc,
'nb': model.nb,
'nh': model.nh,
'ff_hd': model.ff_hd,
'classifier_hd': model.classifier_hd,
'dropout': model.dropout,
'celltype_keys': getattr(model, 'celltype_keys', None),
'celltype_encoders': getattr(model, 'celltype_encoders', None),
'var_names': getattr(model, 'var_names', None),
'history': getattr(model, 'history', None)
}
# Create folder to save model
os.makedirs(os.path.join(path, model_name).replace("\\","/"), exist_ok = True)
# Save scAdam model
torch.save(save_dict, os.path.join(path, model_name, 'model_v2.pth'))
# Save unknown cell type detector
if hasattr(model, 'unknown_detector') and model.unknown_detector is not None:
detector_state = {
#'entropy_threshold': model.unknown_detector.entropy_threshold,
#'gradient_threshold': model.unknown_detector.gradient_threshold,
#'distance_threshold': model.unknown_detector.distance_threshold,
'class_centroids': {k: v.cpu().numpy().tolist()
for k, v in model.unknown_detector.class_centroids.items()} if model.unknown_detector.class_centroids else None,
'class_stds': {k: v if isinstance(v, (int, float)) else v
for k, v in model.unknown_detector.class_stds.items()} if model.unknown_detector.class_stds else None
}
if model.unknown_detector.class_centroids is not None:
detector_state['class_centroids'] = {
str(k): v.cpu().numpy().tolist() if torch.is_tensor(v) else v.tolist() if hasattr(v, 'tolist') else v
for k, v in model.unknown_detector.class_centroids.items()
}
if model.unknown_detector.class_stds is not None:
detector_state['class_stds'] = {
str(k): float(v.item()) if torch.is_tensor(v) else float(v) if hasattr(v, '__float__') else float(v)
for k, v in model.unknown_detector.class_stds.items()
}
# Save unknown cell type detector state
with open(os.path.join(path, model_name, 'unknown_detector.json'), 'w') as f:
json.dump(detector_state, f, indent=2)
if verbose:
print(f"Model saved to {os.path.join(path, model_name)}")
# Load scAdam model function
def load_model(path, device = 'auto', verbose=True):
"""
Function for scAdam model loading.
Loads model from a folder 'path'.
"""
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device
# Load scAdam models checkpoint
checkpoint = torch.load(os.path.join(path, 'model_v2.pth'), map_location=device, weights_only = False)
# Create scAdam model
model = scAdamTransformer(
gn=checkpoint['gn'],
ncl=checkpoint['ncl'],
ed=checkpoint['ed'],
nc=checkpoint['nc'],
nb=checkpoint['nb'],
nh=checkpoint['nh'],
ff_hd=checkpoint['ff_hd'],
classifier_hd=checkpoint['classifier_hd'],
dropout=checkpoint['dropout']
)
# Load state to scAdam model
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)
model.eval()
# Restore metadata
for key in ['celltype_keys', 'celltype_encoders', 'var_names', 'history']:
if key in checkpoint:
setattr(model, key, checkpoint[key])
# Load unknown cell type detector
detector_path = os.path.join(path, 'unknown_detector.json')
if os.path.exists(detector_path):
with open(detector_path, 'r') as f:
detector_state = json.load(f) # load detector state
# Create unknown cell type detector using loaded detector state
detector = UnknownCellDetector()
# Restore centroids
if detector_state.get('class_centroids') is not None:
detector.class_centroids = {
int(k): torch.tensor(v, dtype=torch.float32)
for k, v in detector_state['class_centroids'].items()
}
# Restore stds
if detector_state.get('class_stds') is not None:
detector.class_stds = {
int(k): v
for k, v in detector_state['class_stds'].items()
}
model.unknown_detector = detector
if verbose:
print(f"scAdam model with unknown detector loaded from {path}")
else:
if verbose:
print(f"scAdam model without unknown detector loaded from {path}")
return model
# Function for calculation of otsu threshold
def otsu_threshold(data, threshold = 2.0, nbins = 256, random_state = 1):
"""
Analyzes the distribution and returns the Otsu threshold only if the data is bimodal.
Parameters
----------
data : array-like
An array of values (eg. gene expression).
threshold: float (default = 2.0)
Ashman coefficient threshold. Threshold > 2 indicates good peak separation.
For a more subtle separation, threshold > 1.5 can be used.
Returns
-------
float or None
Threshold (float) if the distribution is bimodal.
None if the distribution is considered unimodal.
"""
data = np.array(data).reshape(-1, 1)
# Calculate Gaussian models with 1 and 2 components
gmm1 = GaussianMixture(n_components=1, random_state=random_state).fit(data)
gmm2 = GaussianMixture(n_components=2, random_state=random_state).fit(data)
# Compare models with 1 and 2 components using BIC
if gmm1.bic(data) < gmm2.bic(data):
return None
# Checking the separability of peaks (Ashman's D)
# Usefull in case of skewed distribution
means = gmm2.means_.flatten()
covs = gmm2.covariances_.flatten()
# Sort the components by average
idx = np.argsort(means)
mu1, mu2 = means[idx]
sigma1, sigma2 = np.sqrt(covs[idx])
# Ashman's D statistic
# D = sqrt(2) * |mu1 - mu2| / sqrt(sigma1^2 + sigma2^2)
# Check if peaks are too close, most likely it is one "wide" distribution
d_score = np.sqrt(2) * abs(mu1 - mu2) / np.sqrt(sigma1**2 + sigma2**2)
if d_score < threshold:
return None
# If the checks are passed, we calculate the Otsu threshold
thresh = threshold_otsu(data.flatten(), nbins=nbins)
return thresh
class UnknownCellDetector:
def __init__(
self
):
"""
UnknownCellDetector for detection cells with unknown for scAdam model cell type in data.
"""
self.class_centroids = None
self.class_stds = None
def fit(
self,
model,
dataloader,
device='cuda',
verbose=True
):
"""
Fitting of UnknownCellDetector.
Parameters
----------
model: scAdam model.
dataloader: full training dataset dataloader.
device: str (default: 'auto')
Type of device to use in training model ('cpu', 'cuda'). Set 'auto' for automatic selection.
verbose: bool (default: True)
Show UnknownCellDetector progress or not.
"""
# Create class_centroids and class_stds dictionaries for distance calculation.
self.class_centroids = {}
self.class_stds = {}
model.eval()
model = model.to(device)
if verbose:
print("Fitting unknown cells detector")
all_embeddings = []
all_labels = {key: [] for key in model.celltype_keys}
last_level = f'level_{model.num_levels - 1}'
with torch.enable_grad():
for batch_x, batch_y in dataloader:
batch_x = batch_x.to(device)
# Embeddings for each batch
with torch.no_grad():
x_emb = model.gene_embedding(batch_x)
for block in model.blocks:
x_emb, _ = block(x_emb)
x_emb = model.norm(x_emb).mean(dim=1)
all_embeddings.append(x_emb.cpu())
# Labels
for key in model.celltype_keys:
all_labels[key].append(batch_y[key])
# Concatenate
all_embeddings = torch.cat(all_embeddings, dim=0)
# Thresholds
last_key = model.celltype_keys[-1]
labels = torch.cat(all_labels[last_key], dim=0)
unique_labels = labels.unique()
for label in unique_labels:
label_idx = (labels == label)
class_emb = all_embeddings[label_idx]
centroid = class_emb.mean(dim=0)
std = class_emb.std(dim=0).mean()
self.class_centroids[label.item()] = centroid
self.class_stds[label.item()] = std
def detect(
self,
model,
dataloader,
threshold = 2.0,
temperature = 2.0,
mc_passes = 5,
method = 'voting',
device = 'cuda',
verbose = True,
random_state = 1
):
"""
Detect cells with unknown cell type.
temperature: float (default: 2.0)
A parameter that controls the confidence of the model in its predictions.
If temperature > 1 - model appears less confident.
If temperature < 1 - model appears more confident.
"""
self.temperature = temperature
self.mc_passes = mc_passes
model.eval()
model = model.to(device)
all_gradient_norms = []
all_entropy_scores = []
all_dist_scores = []
all_predictions = []
last_level = f'level_{model.num_levels - 1}'
with torch.enable_grad():
for batch in dataloader:
if isinstance(batch, (list, tuple)):
batch_x = batch[0]
else:
batch_x = batch
batch_x = batch_x.to(device)
# Embeddings
with torch.no_grad():
x_emb = model.gene_embedding(batch_x)
for block in model.blocks:
x_emb, _ = block(x_emb)
x_emb = model.norm(x_emb).mean(dim=1)
# Classifier forward
with torch.no_grad():
outputs_eval = model.classifier(x_emb)
logits_eval = outputs_eval[last_level]['logits']
probs_eval = F.softmax(logits_eval / self.temperature, dim=-1)
preds = logits_eval.argmax(dim=1)
all_predictions.append(preds.detach().cpu())
# Gradient norm (with MC-dropout)
if method == 'gradient' or method == 'voting' or method == 'combined':
x_emb_grad = x_emb.clone().detach().requires_grad_(True)
model_train_state = model.training
if self.mc_passes is not None and self.mc_passes > 1:
model.train()
loss_total = 0.0
for _ in range(self.mc_passes):
outputs = model.classifier(x_emb_grad)
logits = outputs[last_level]['logits']
probs = F.softmax(logits / self.temperature, dim=-1)
max_prob, _ = probs.max(dim=1)
loss = -torch.log(max_prob + 1e-10)
loss_total = loss_total + loss.sum()
loss_total.backward()
else:
model.eval()
outputs = model.classifier(x_emb_grad)
logits = outputs[last_level]['logits']
probs = F.softmax(logits / self.temperature, dim=-1)
max_prob, _ = probs.max(dim=1)
loss = -torch.log(max_prob + 1e-10)
loss.sum().backward()
if model_train_state:
model.train()
else:
model.eval()
grad_norms = torch.norm(x_emb_grad.grad, dim=1)
all_gradient_norms.append(grad_norms.detach().cpu())
x_emb_grad.grad.zero_()
# Entropy (temperature + MC-dropout)
if method == 'entropy' or method == 'voting' or method == 'combined':
with torch.no_grad():
probs_mc = []
if self.mc_passes is not None and self.mc_passes > 1:
model_train_state = model.training
model.train()
for _ in range(self.mc_passes):
outputs = model(batch_x)
logits = outputs[last_level]['logits']
probs = F.softmax(logits / self.temperature, dim=-1)
probs_mc.append(probs)
if model_train_state:
model.train()
else:
model.eval()
probs_mean = torch.stack(probs_mc, dim=0).mean(dim=0)
else:
logits = logits_eval
probs_mean = probs_eval
entropy = -(probs_mean * torch.log(probs_mean + 1e-10)).sum(dim=1)
all_entropy_scores.append(entropy.cpu())
# Distance
if method == 'distance' or method == 'voting' or method == 'combined':
if self.class_centroids:
min_dist = []
for i in range(x_emb.size(0)):
pred_label = preds[i].item()
if pred_label in self.class_centroids:
centroid = self.class_centroids[pred_label].to(device)
dist = torch.norm(x_emb[i] - centroid)
min_dist.append(dist.detach().cpu())
else:
min_dist.append(torch.tensor(float('inf')))
all_dist_scores.append(torch.stack(min_dist))
# Concatenate results + convert to numpy
all_predictions = torch.cat(all_predictions).numpy()
all_gradient_norms = torch.cat(all_gradient_norms).numpy() if all_gradient_norms else None
all_entropy_scores = torch.cat(all_entropy_scores).numpy() if all_entropy_scores else None
all_dist_scores = torch.cat(all_dist_scores).numpy() if all_dist_scores else None
# Calculating thresholds using Otsu method
if all_gradient_norms is not None:
self.gradient_threshold = otsu_threshold(all_gradient_norms, threshold = threshold, nbins=len(all_gradient_norms), random_state=random_state)
print(f'Gradient threshold calculated using Otsu method: {self.gradient_threshold}')
if all_entropy_scores is not None:
self.entropy_threshold = otsu_threshold(all_entropy_scores, threshold = threshold, nbins=len(all_entropy_scores), random_state=random_state)
print(f'Entropy threshold calculated using Otsu method: {self.entropy_threshold}')
if all_dist_scores is not None:
self.distance_threshold = otsu_threshold(all_dist_scores, threshold = threshold, nbins=len(all_dist_scores), random_state=random_state)
print(f'Distance threshold calculated using Otsu method: {self.distance_threshold}')
# Check thresholds
if (method == 'entropy' or method == 'voting' or method == 'combined') and self.entropy_threshold is None:
unknown_mask = None
return unknown_mask, {
'gradient': all_gradient_norms,
'entropy': all_entropy_scores,
'distance': all_dist_scores,
'predictions': all_predictions
}
elif (method == 'gradient' or method == 'voting' or method == 'combined') and self.gradient_threshold is None:
unknown_mask = None
return unknown_mask, {
'gradient': all_gradient_norms,
'entropy': all_entropy_scores,
'distance': all_dist_scores,
'predictions': all_predictions
}
elif (method == 'distance' or method == 'voting' or method == 'combined') and self.distance_threshold is None:
unknown_mask = None
return unknown_mask, {
'gradient': all_gradient_norms,
'entropy': all_entropy_scores,
'distance': all_dist_scores,
'predictions': all_predictions
}
# Method selection
if method == 'gradient':
unknown_mask = all_gradient_norms > self.gradient_threshold
elif method == 'entropy':
unknown_mask = all_entropy_scores > self.entropy_threshold
elif method == 'distance':
unknown_mask = all_dist_scores > self.distance_threshold
elif (method == 'voting') or (method == 'combined'):
votes = np.zeros(len(all_predictions), dtype=int)
num_methods = 0
if all_gradient_norms is not None:
votes += (all_gradient_norms > self.gradient_threshold).astype(int)
num_methods += 1
if all_entropy_scores is not None:
votes += (all_entropy_scores > self.entropy_threshold).astype(int)
num_methods += 1
if all_dist_scores is not None:
votes += (all_dist_scores > self.distance_threshold).astype(int)
num_methods += 1
if method == 'voting':
if num_methods >= 2:
threshold_votes = 2
else:
threshold_votes = 1
unknown_mask = votes >= threshold_votes
else: # method == 'combined'
g_denom = (np.max(all_gradient_norms) + 1e-12) if all_gradient_norms is not None else 1.0
e_denom = (np.max(all_entropy_scores) + 1e-12) if all_entropy_scores is not None else 1.0
d_denom = (np.max(all_dist_scores) + 1e-12) if all_dist_scores is not None else 1.0
all_gradient_norms_scaled = all_gradient_norms / g_denom
all_entropy_scores_scaled = all_entropy_scores / e_denom
all_dist_scores_scaled = all_dist_scores / d_denom
gradient_threshold = self.gradient_threshold / g_denom
entropy_threshold = self.entropy_threshold / e_denom
distance_threshold = self.distance_threshold / d_denom
sum_of_thresholds = distance_threshold + entropy_threshold + gradient_threshold
sum_of_scores = all_gradient_norms_scaled + all_entropy_scores_scaled + all_dist_scores_scaled
unknown_mask = sum_of_scores >= sum_of_thresholds
if verbose and (method == 'voting'):
print(f"Unknown cell type detection using {num_methods} methods. Threshold: {threshold_votes} votes.")
elif verbose and (method == 'combined'):
print(f"Unknown cell type detection using combination of methods.")
elif verbose:
print(f"Unknown cell type detection using {method}.")
if verbose:
print(f"Detected {unknown_mask.sum()} unknown cells ({100*unknown_mask.mean():.1f}%)")
return unknown_mask, {
'gradient': all_gradient_norms,
'entropy': all_entropy_scores,
'distance': all_dist_scores,
'predictions': all_predictions
}
# Function for training scAdam model
[docs]
def train(
adata,
celltype_keys,
layer = None,
path = '',
model_name = 'scAdam_model',
test_size = 0.2,
eval_metric = ['accuracy', 'balanced_accuracy'],
strategy = 'linear_offset',
batch_size = 128,
epochs = 200,
patience = 10,
nc = 16,
nb = 5,
nh = 8,
ff_hd = 512,
ed_nh_ratio = 32,
classifier_hd = 256,
dropout = 0.3,
lr = 1e-4,
weight_decay = 1e-4,
use_augmentation = True,
aug_probability = 0.5,
prob = 0.15,
noise_std = 0.1,
dropout_aug = 0.1,
alpha = 0.2,
from_unsupervised = True,
pretrain_epochs = 50,
pretrain_data = None,
adaptive_loss = True,
device = 'auto',
random_state = 0,
return_model = False,
unknown_detection = True,
verbose = True
):
"""
scAdam model training with unsupervised pretraining.
scAdam is a model for automatic cell type annotation.
This function creates a model that can be used for cell type annotation.
Parameters
----------
adata: AnnData
Dataset with cell type annotations in adata.obs
path: str, path object
Path to create a model folder containing the training history, cell annotation dictionary, and genes used for training.
celltype_keys: list
List of cell type annotations in adata.obs.
Example: ['lineage', 'cell type', 'cell state']
layer: str (default: None)
If specified, use adata.layers[layer] for expression values instead of adata.X.
model_name: str (default: 'scAdam_model')
Name of a folder to save model.
test_size: float or int, (default: 0.2)
If float, should be between 0.0 and 1.0 and represent the proportion
of the dataset to include in the test split. If int, represents the
absolute number of test cells.
batch_size: int, (default: 128)
Number of examples per batch.
epochs: int (default: 150)
Maximum number of epochs for scAdam model training
patience: int (default: 10)
Number of consecutive epochs without improvement before performing early stopping.
If patience is set to 0, then no early stopping will be performed.
Note that if patience is enabled, then best weights from best epoch will automatically be loaded at the end of the training.
eval_metric: str or list (default: ['accuracy', 'balanced_accuracy'])
Available evaluation metrics:'accuracy', 'balanced_accuracy', 'f1_score'.
The last metric is used as the target and for early stopping.
strategy: str (default: 'linear_offset')
Weighting strategy for different cell type annotation levels.
The following weighting strategies are available: linear, exponential, linear_offset, equal, last.
linear: linear increase in weight from level to level.
exponential: exponential increase in weight from level to level
linear_offset: linear increase in weight from level to level with offset.
equal: equal weight for all cell type annotation levels.
last: uses only last cell type annotation for model evaluation.
nc: int (default: 16)
Number of chunks for genes from adata.
nb: int (default: 5)
Number of blocks in scAdam model.
nh: int (default: 8)
Number of heads in scAdam model attention mechanism.
ed_nh_ratio: int (default: 32)
Used for calculating embedding dimensionality ('ed') from 'nh'.
Default ed = nh * ed_nh_ratio = 8 * 32 = 256.
ff_hd: int (default: 512)
Number of nodes in each scAdam model layer in feed forward network.
classifier_hd: int (default: 256)
Number of nodes in each scAdam classifier.
dropout: float (default: 0.3)
Portion of neurons that temporarily ignored during training (prevents overfitting).
lr: float (default: 1e-4)
Determines the step size at each iteration while moving toward a minimum of a loss function.
weight_decay: float (default: 1e-4)
Weight decay coefficient.
adaptive_loss: bool (default: True)
If True, enables adaptive weighting of hierarchical loss levels:
each level’s loss is tracked over training and its contribution to
the total loss is increased if this level remains hard (high loss)
and decreased if it becomes easy (low loss). This helps the model
focus more on poorly performing levels of the hierarchy instead of
weighting all levels equally. If False, all levels are summed with
a fixed weight of 1.0.
use_augmentation: bool (default: True)
Use data augmentation or not.
aug_probability: float (default: 0.5)
The probability of applying augmentation to a batch
prob: float (default: 0.15)
Gene masking probability.
noise_std: float (default: 0.1)
Gaussian noise standard deviation.
dropout_aug: float (default: 0.1)
Dropout probability for simulating technical noise.
alpha: float (default: 0.2)
Alpha parameter for mixup augmentation.
from_unsupervised: bool (default: True)
Use a previously self supervised model as starting weights.
Supervised model training included in function.
pretrain_epochs: int (default: 50)
Number of pretraining epochs.
pretrain_data: AnnData
Anndata with the same expression matrix for unsupervised pretraining.
Additional Anndata may not contain cell type annotations.
device: str (default: 'auto')
Type of device to use in training model ('cpu', 'cuda'). Set 'auto' for automatic selection.
random_state: int (default: 0)
Controls the data shuffling, splitting to folds and model training.
Pass an int for reproducible output across multiple function calls.
verbose: bool (default: True)
Show progress bar for each epoch during training.
return_model: bool (default: False)
Return model after training or not.
unknown_detection: bool (default: True)
Train unknown cell detector - identifies unknown cells when model used for prediction on a new data.
Returns
-------
Saves scAdam model for cell type annotation.
"""
# Check 'ed' - 'nh' compatibility
ed = nh * ed_nh_ratio
if ed % nh != 0:
raise ValueError(f"Incompatible parameters: 'ed' must be divisible by 'nh' without a remainder.")
# Set random state (for reproducibility)
np.random.seed(random_state)
torch.manual_seed(random_state)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(random_state)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Device selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device
# Get evaluation metric
if isinstance(eval_metric, str):
eval_metric = [eval_metric]
if verbose:
print(f"Device: {device}")
if verbose and strategy != 'last':
level_weights = celltype_level_weights(len(celltype_keys), strategy)
# Create dataset
dataset = scRNAseqDataset(
adata,
celltype_keys,
layer=layer
)
ncl = [dataset.celltype_encoders[key]['n_classes'] for key in celltype_keys]
if verbose:
print(f"Number of features: {dataset.gn}")
print(f"Label hierarchy: {' → '.join(celltype_keys)}")
print(f"Annotation levels weights using strategy '{strategy}':")
for key, n_celltypes, weight in zip(celltype_keys, ncl, level_weights):
print(f" {key}: {n_celltypes} cell types, {round(weight, 3)} relative weight")
# Train/val split
indices = np.arange(len(dataset))
train_idx, val_idx = train_test_split(
indices,
test_size=test_size,
stratify=adata.obs[celltype_keys[-1]], # Using last (most detailed) annotation level
random_state=random_state
)
if verbose:
print(f"\nDataset split:")
print(f'Train dataset contains: {len(train_idx)} cells, it is {round(100*(len(train_idx)/(len(train_idx) + len(val_idx))), ndigits=2)} % of input dataset')
print(f'Validation dataset contains: {len(val_idx)} cells, it is {round(100*(len(val_idx)/(len(train_idx) + len(val_idx))), ndigits=2)} % of input dataset')
# Create dataloaders
train_loader = DataLoader(
Subset(dataset, train_idx),
batch_size=batch_size,
shuffle=True,
num_workers=0,
drop_last=False
)
val_loader = DataLoader(
Subset(dataset, val_idx),
batch_size=batch_size,
shuffle=False,
num_workers=0,
drop_last=False
)
# Initialize model
model = scAdamTransformer(
gn=dataset.gn,
ncl=ncl,
ed=ed,
nc=nc,
nb=nb,
nh=nh,
ff_hd=ff_hd,
classifier_hd=classifier_hd,
dropout=dropout
).to(device)
# Unsupervised pretraining
if from_unsupervised:
if pretrain_data is not None:
pretrain_adata = pretrain_data
else:
pretrain_adata = adata
model = pretrain_unsupervised(
pretrain_adata,
model,
layer=layer,
batch_size=batch_size,
epochs=pretrain_epochs,
lr=lr,
prob=prob,
device=device,
random_state=random_state,
verbose=verbose
)
# Augmentation
augmenter = None
if use_augmentation:
augmenter = dust.Augmentation(
prob = prob,
noise_std = noise_std,
dropout_prob = dropout_aug,
alpha = alpha
)
# Loss function
criterion = HierarchicalLoss(
num_levels=len(celltype_keys),
celltype_keys=celltype_keys,
alpha=0.25,
gamma=2.0,
adaptive=adaptive_loss
).to(device)
# Optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
betas=(0.9, 0.95),
weight_decay=weight_decay
)
# Learning rate scheduler
def lr_lambda(epoch):
return 0.5 * (1.0 + np.cos(np.pi * (epoch / epochs)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# Early stopping
early_stopping = dust.EarlyStopping(
patience=patience,
delta=1e-4,
mode='max',
verbose=verbose
)
# Available metric functions
available_metric_functions = {
'accuracy': accuracy_score,
'balanced_accuracy': balanced_accuracy_score,
'f1_score': lambda y_true, y_pred: f1_score(y_true, y_pred, average='weighted', zero_division=0)
}
# Create dictionary of used for training metric functions
metric_functions = {}
for metric in eval_metric:
if metric not in available_metric_functions:
raise ValueError(f"Unknown metric: {metric}")
else:
metric_functions[metric] = available_metric_functions[metric]
# Create training history dictionary
history = {
'train_loss': [],
'val_loss': [],
'train_metrics': {key: {metric: [] for metric in eval_metric}
for key in celltype_keys},
'val_metrics': {key: {metric: [] for metric in eval_metric}
for key in celltype_keys},
'lr': []
}
# Training Loop
for epoch in tqdm(range(epochs), desc='Training scAdam model', colour='blue', disable = not verbose):
# Model training model.train()
train_loss = 0.0
train_recon_loss = 0.0
train_preds = {key: [] for key in celltype_keys}
train_targets = {key: [] for key in celltype_keys}
for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device)
batch_y = {key: val.to(device) for key, val in batch_y.items()}
# Augmentation only if random < aug_probability
if use_augmentation and np.random.random() < aug_probability:
batch_x = augmenter(batch_x)
# Get predictions using model
predictions = model(batch_x)
# Calculate model loss
loss, level_losses = criterion(predictions, batch_y)
# Backward
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
# Predictions and targets
with torch.no_grad():
for level_idx, key in enumerate(celltype_keys):
level_key = f'level_{level_idx}'
preds = predictions[level_key]['logits'].argmax(dim=1)
train_preds[key].append(preds.cpu())
train_targets[key].append(batch_y[key].cpu())
train_loss /= len(train_loader)
history['train_loss'].append(train_loss)
# Compute training metrics
for key in celltype_keys:
train_preds[key] = torch.cat(train_preds[key]).numpy()
train_targets[key] = torch.cat(train_targets[key]).numpy()
for metric_name, metric_func in metric_functions.items():
score = metric_func(train_targets[key], train_preds[key])
history['train_metrics'][key][metric_name].append(score)
# Model validation
model.eval()
val_loss = 0.0
val_preds = {key: [] for key in celltype_keys}
val_targets = {key: [] for key in celltype_keys}
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x = batch_x.to(device)
batch_y = {key: val.to(device) for key, val in batch_y.items()}
# Get predictions using model
predictions = model(batch_x)
# Calculate model loss
loss, level_losses = criterion(predictions, batch_y)
val_loss += loss.item()
# Store predictions and targets
for level_idx, key in enumerate(celltype_keys):
level_key = f'level_{level_idx}'
preds = predictions[level_key]['logits'].argmax(dim=1)
val_preds[key].append(preds.cpu())
val_targets[key].append(batch_y[key].cpu())
val_loss /= len(val_loader)
history['val_loss'].append(val_loss)
# Validation metrics
for key in celltype_keys:
val_preds[key] = torch.cat(val_preds[key]).numpy()
val_targets[key] = torch.cat(val_targets[key]).numpy()
for metric_name, metric_func in metric_functions.items():
score = metric_func(val_targets[key], val_preds[key])
history['val_metrics'][key][metric_name].append(score)
# Learning rate
history["lr"].append(optimizer.param_groups[0]["lr"])
# Early Stopping
if strategy == 'last':
es_score = history['val_metrics'][celltype_keys[-1]][eval_metric][-1]
else:
es_score = weighted_metric(
history=history,
celltype_keys=celltype_keys,
metric=eval_metric[-1],
strategy=strategy
)
if early_stopping(es_score, model):
break
scheduler.step()
# Load best model
early_stopping.load_bm(model)
model.eval()
# Store metadata
model.celltype_keys = celltype_keys
model.celltype_encoders = dataset.celltype_encoders
model.var_names = dataset.var_names
model.history = history
if verbose:
print("Training completed!")
# Remove loaders
del train_loader, val_loader
# Train unknown cell type detector
if unknown_detection and verbose:
# Create detector
detector = UnknownCellDetector()
# Create full dataset loader for detector training
full_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
drop_last=False
)
# Fit detector on full dataset
detector.fit(model, full_loader, device=device)
# Save detector
model.unknown_detector = detector
if verbose:
print("UnknownCellDetector fitted successfully!")
save_model(model, path, model_name, verbose)
if return_model:
return model
# Predict cell type annotation using trained scAdam model
[docs]
def predict(
adata,
path_model,
layer = None,
batch_size = 256,
device = 'auto',
prefix = 'pred_',
detect_unknown = False,
method = 'voting',
threshold = 2.0,
temperature = 2.0,
mc_passes = 5,
verbose = True
):
"""
Predict cell types and cell type probilities using pretrained scAdam model.
It is also possible to check if a cell has an unknown cell type for the scAdam model (detect_unknown = True).
Parameters
----------
adata: AnnData
Annotated data matrix.
path_model: str, path object
Path to the folder containing the trained scAdam model.
layer: str (default: None)
If specified, use adata.layers[layer] for expression values instead of adata.X.
batch_size: int, (default: 256)
Number of examples per batch.
device: str (default: 'auto')
Type of device to use in prediction ('cpu', 'cuda'). Set 'auto' for automatic selection.
prefix: str (default: 'pred_')
Prefix for new columns in adata.obs
detect_unknown: bool (default: False)
Detection of unknown cell types.
Parameters 'method', 'temperature', 'mc_passes', 'threshold' are not used if 'detect_unknown' = False.
method: str (default: 'voting')
Method of unknown cell type detection.
gradient - uses the norm of the gradient with respect to the last annotation level embedding as an uncertainty score.
entropy - uses the entropy of the softmax probabilities on the last annotation level.
distance - uses the distance in embedding space to the centroid of the predicted class.
voting - combines the three previous criteria with majority voting.
combined - builds one continuous combined score from all three individual methods with combined threshold.
threshold: float (default = 2.0)
Ashman coefficient threshold. Threshold > 2 indicates good peak separation.
For a more subtle separation, threshold > 1.5 can be used.
temperature: float (default: 2.0)
A parameter that controls the confidence of the model in its predictions.
If temperature > 1 - model appears less confident.
If temperature < 1 - model appears more confident.
mc_passes: int (default: 5)
It is the number of forward passes you run with dropout left on for the same data (Monte Carlo dropout).
Necessary for calculating the gradient score.
verbose: bool (default: True)
Show progress bar for each batch during prediction.
Returns
-------
adata: AnnData
Annotated adata object with predicted cell types in adata.obs.
if detect_unknown = True, adds methods scores in adata.obs.
"""
# Check method of unknown cell type detection
if method not in ['voting', 'gradient', 'entropy', 'distance', 'combined']:
raise ValueError(f"Unknown method: {method}")
# Load scAdam model
if os.path.exists(os.path.join(path_model, 'model_v2.pth')):
model = load_model(path_model, device = 'auto', verbose=verbose)
# Device selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device
model = model.to(device)
model.eval()
if not hasattr(model, 'var_names') or model.var_names is None:
raise ValueError("Model does not have var_names attribute. Train the model first.")
if not hasattr(model, 'celltype_keys') or model.celltype_keys is None:
raise ValueError("Model does not have celltype_keys attribute. Train the model first.")
# Get gene expression data
if layer is not None and layer in adata.layers:
X = adata.layers[layer]
else:
X = adata.X
# Convert sparse to dense if needed
if hasattr(X, 'toarray'):
X = X.toarray()
# Get feature names from new data
new_var_names = adata.var_names.tolist()
model_var_names = model.var_names
# Align features: create matrix with genes in the same order as training
n_cells = X.shape[0]
n_model_genes = len(model_var_names)
X_aligned = np.zeros((n_cells, n_model_genes), dtype=np.float32)
# Find intersection and assign values
matched_features = 0
for i, gene in enumerate(model_var_names):
if gene in new_var_names:
gene_idx = new_var_names.index(gene)
X_aligned[:, i] = X[:, gene_idx]
matched_features += 1
if verbose:
print(f"Gene alignment:")
print(f" Model features: {n_model_genes}")
print(f" Matched features: {matched_features} ({100*matched_features/n_model_genes:.1f}%)")
# Raise warning if matched features is lower than 80% from used for training
if matched_features < 0.80 * n_model_genes:
warnings.warn(
f"Only {matched_features}/{n_model_genes} genes matched! "
"This may lead to poor predictions. "
"Make sure the gene names are in the same format (e.g., gene symbols)."
)
# Normalize the same way as during training
X_aligned = np.maximum(X_aligned, 0) # Remove negative values
X_norm = (X_aligned - X_aligned.min(0)) / (np.ptp(X_aligned, axis=0) + 1e-10)
# Check for NaN/Inf
if np.isnan(X_norm).any() or np.isinf(X_norm).any():
warnings.warn("NaN or Inf found in normalized data. Replacing with zeros.")
X_norm = np.nan_to_num(X_norm, nan=0.0, posinf=0.0, neginf=0.0)
X_tensor = torch.FloatTensor(X_norm)
# Create dataloader
dataset = torch.utils.data.TensorDataset(X_tensor)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# Predict
predictions = {key: [] for key in model.celltype_keys}
probabilities = {key: [] for key in model.celltype_keys}
with torch.no_grad():
for (batch_x,) in tqdm(loader, desc='Predicting', colour='blue'):
batch_x = batch_x.to(device)
outputs = model(batch_x)
for level_idx, key in enumerate(model.celltype_keys):
level_key = f'level_{level_idx}'
preds = outputs[level_key]['logits'].argmax(dim=1).cpu().numpy()
predictions[key].append(preds)
probs = outputs[level_key]['probs'].cpu().numpy()
probabilities[key].append(probs)
# Concatenate predictions
for key in model.celltype_keys:
predictions[key] = np.concatenate(predictions[key])
# Decode labels back to original names
encoder = model.celltype_encoders[key]['celltype_encoder']
predictions[key] = encoder.inverse_transform(predictions[key])
probabilities[key] = np.concatenate(probabilities[key])
# Unknown cell type detection
unknown_mask = None
if detect_unknown:
unknown_mask, scores = model.unknown_detector.detect(
model,
loader,
threshold = threshold,
temperature = temperature,
mc_passes = mc_passes,
device = device,
method = method,
verbose = verbose
)
# Add mask
if unknown_mask is not None:
adata.obs[f'{prefix}unknown'] = unknown_mask
else:
print('No unknown cells found!')
# Add predictions to adata.obs
for i, key in enumerate(model.celltype_keys, start=1):
col_name = f"{prefix}celltype_l{i}"
# Use mask to set 'Unknown' cell types
predictions_with_unknown = predictions[key].copy()
if unknown_mask is not None:
predictions_with_unknown[unknown_mask] = "Unknown"
# Add to adata.obs new column
adata.obs[col_name] = predictions_with_unknown.astype(str)
if verbose:
print(f"Added cell type column: {col_name}")
# Add probabilities
max_probs = probabilities[key].max(axis=1)
prob_col_name = f"{prefix}celltype_l{i}_probability"
adata.obs[prob_col_name] = max_probs
if verbose:
print(f"Added probabilities column: {prob_col_name}")
# Add scores to adata.obs
if scores['gradient'] is not None:
adata.obs['gradient_score'] = scores['gradient']
if scores['entropy'] is not None:
adata.obs['entropy_score'] = scores['entropy']
if scores['distance'] is not None:
adata.obs['distance_score'] = scores['distance']
else:
# Without unknown cell detection
for i, key in enumerate(model.celltype_keys, start=1):
col_name = f"{prefix}celltype_l{i}"
adata.obs[col_name] = predictions[key]
if verbose:
print(f"Added cell type column: {col_name}")
# Add probabilities
max_probs = probabilities[key].max(axis=1)
prob_col_name = f"{prefix}celltype_l{i}_probability"
adata.obs[prob_col_name] = max_probs
if verbose:
print(f"Added probabilities column: {prob_col_name}")
return adata
# Tabnet-based v1 model
else:
# load genes of trained model
features = pd.read_csv(os.path.join(path_model, 'genes.csv').replace("\\","/"))
features = list(features['feature_name'])
if verbose:
print('Successfully loaded list of genes used for training model')
print()
# Get gene expression data
if layer is not None and layer in adata.layers:
X = adata.layers[layer]
else:
X = adata.X
# Convert sparse to dense if needed
if hasattr(X, 'toarray'):
X = X.toarray()
# Get gene names from new data
new_var_names = adata.var_names.tolist()
model_var_names = pd.read_csv(os.path.join(path_model, 'genes.csv').replace("\\","/"))
model_var_names = list(model_var_names['feature_name'])
if verbose:
print('Successfully loaded list of genes used for training model')
print()
# Align genes: create matrix with genes in the same order as training
n_cells = X.shape[0]
n_model_genes = len(model_var_names)
X_aligned = np.zeros((n_cells, n_model_genes), dtype=np.float32)
# Find intersection and assign values
matched_features = 0
for i, gene in enumerate(model_var_names):
if gene in new_var_names:
gene_idx = new_var_names.index(gene)
X_aligned[:, i] = X[:, gene_idx]
matched_features += 1
if verbose:
print(f"Gene alignment:")
print(f" Model genes: {n_model_genes}")
print(f" New data genes: {len(new_var_names)}")
print(f" Matched genes: {matched_features} ({100*matched_features/n_model_genes:.1f}%)")
if matched_features < 0.5 * n_model_genes:
warnings.warn(
f"Only {matched_features}/{n_model_genes} genes matched! "
"This may lead to poor predictions. "
"Make sure the gene names are in the same format (e.g., gene symbols)."
)
# Normalize the same way as during training
X_aligned = np.maximum(X_aligned, 0) # Remove negative values
# Load dictionary of trained cell types
with open(os.path.join(path_model, 'dict.txt')) as dict:
dict = dict.read()
dict_multi = json.loads(dict)
if verbose:
print('Successfully loaded dictionary of dataset annotations')
print()
# Load pretrained model
loaded_model = TabNetMultiTaskClassifier()
for file in os.listdir(path_model):
if file.endswith('.zip'):
loaded_model.load_model(os.path.join(path_model, file).replace("\\","/"))
if verbose:
print('Successfully loaded model')
print()
# Predict cell types
predictions = loaded_model.predict(X_aligned)
# Get prediction probabilities
probabilities = loaded_model.predict_proba(X_aligned)
# Define get_key function for dictionaries
def get_key(d, value):
for k, v in d.items():
if v == value:
return k
# Add predictions and probabilities to adata
for i in range(len(dict_multi)):
prediction_i = [get_key(dict_multi[i], prediction) for prediction in predictions[i].astype(dtype=int)]
adata.obs[prefix + 'celltype_l' + f'{i+1}'] = prediction_i
probabilities_i = probabilities[i]
probabilities__i = []
for j in range(len(probabilities_i)):
probabilities__i.append(max(probabilities_i[j]))
adata.obs['prob_celltype_l' + f'{i+1}'] = probabilities__i
if verbose:
print(f'Successfully added predicted celltype_l{i+1} and cell type probabilities')
return adata
def get_default_tune_params():
"""
Get default ranges of tuned hyperparameters.
For integer parameters ('nc', 'nb', 'nh', 'ed_nh_ratio', 'ff_hd', 'classifier_hd', 'batch_size', 'patience', 'epochs', 'pretrain_epochs') a list is used where the first value is the minimum, the second is the maximum, and the third is the step.
For float parameters ('dropout', 'lr', 'weight_decay', 'aug_probability', 'prob', 'noise_std', 'dropout_aug', 'alpha') a list is used where the first value is the minimum, the second is the maximum.
For categorical parameters ('use_augmentation', 'from_unsupervised') a list [True, False] is used.
ed_nh_ratio - Used for "ed" calculation: ed = nh * ed_nh_ratio.
"""
return {
"nc": [2, 16, 2],
"nb": [1, 8, 1],
"nh": [2, 16, 2],
"ed_nh_ratio": [8, 32, 4],
"ff_hd": [128, 1024, 128],
"classifier_hd": [128, 1024, 128],
"dropout": [0.0, 0.5],
"lr": [1e-5, 1e-2],
"weight_decay": [1e-6, 1e-2],
"batch_size": [64, 2048, 64],
"patience": [5, 30, 5],
"epochs": [50, 200, 5],
"use_augmentation": [True, False],
"aug_probability": [0.1, 1.0],
"prob": [0.05, 0.4],
"noise_std": [0.0, 0.4],
"dropout_aug": [0.0, 0.4],
"alpha": [0.0, 0.4],
'from_unsupervised': [True, False],
'pretrain_epochs': [10, 75, 5],
}
# Function for hyperparameters tuning
[docs]
def hyperparameter_tuning(
adata,
celltype_keys,
path = '',
layer = None,
model_name = "scAdam_model_tuning",
storage = "scadam_model_tuning.db",
study_name = "study",
load_if_exists = True,
eval_metric = ['balanced_accuracy'],
strategy = 'linear_offset',
device = "auto",
tune_params = "auto",
adaptive_loss = True,
num_trials = 100,
n_splits = 5,
epochs = None,
patience = None,
batch_size = None,
use_augmentation = None,
aug_probability = None,
prob = None,
noise_std = None,
dropout_aug = None,
alpha = None,
nc = None,
nb = None,
nh = None,
ed_nh_ratio = None,
ff_hd = None,
classifier_hd = None,
dropout = None,
lr = None,
weight_decay = None,
from_unsupervised = None,
pretrain_epochs = None,
pretrain_data = None,
random_state = 0,
verbose = True
):
"""
Hyperparameter tuning for scAdam model with k-fold cross validation using Optuna.
Parameters
----------
adata: AnnData
Dataset with cell type annotations in adata.obs
path: str, path object
Path to create a folder with best hyperparameters, dictionary of cell annotations and genes used for hyperparameters optimization.
celltype_keys: list
List of cell type annotations in adata.obs.
Example: ['lineage', 'cell type', 'cell state']
layer: str (default: None)
If specified, use adata.layers[layer] for expression values instead of adata.X.
model_name: str (default: 'scAdam_model_tuning')
Name of a folder to save tuned hyperparameters.
storage: str (default: 'scadam_model_tuning.db')
Database URL. If this argument is set to None, in-memory (RAM) storage is used, and the study will not be persistent.
We don't recommend to use in-memory (RAM) storage to save optimization progress.
study_name: str (default: 'study')
Study’s name. If this argument is set to None, a unique name is generated automatically.
load_if_exists: bool (default: True)
Flag to control the behavior to handle a conflict of study names.
In the case where a study named study_name already exists in the storage,
a DuplicatedStudyError is raised if load_if_exists is set to False.
Otherwise, the creation of the study is skipped, and the existing one is returned.
If the value is True, allows hyperparameter tuning to continue if interrupted (keyboard interrupt, or OS update).
eval_metric: str or list (default: ['balanced_accuracy'])
Available evaluation metrics:'accuracy', 'balanced_accuracy', 'f1_score'.
The last metric is used as the target and for early stopping.
num_trials: int (default: 100)
The number of trials to get optimized hyperparameters for model training.
n_splits: int (default: 5)
The number of data splits (folds) per trial. The data is divided into n_splits parts,
where each part in turn is validation data, and the rest is training data.
The number of folds determines the test_size.
If n_splits = 5, then test_size = 0.2.
If n_splits = 4, then test_size = 0.25.
adaptive_loss: bool (default: True)
If True, enables adaptive weighting of hierarchical loss levels:
each level’s loss is tracked over training and its contribution to
the total loss is increased if this level remains hard (high loss)
and decreased if it becomes easy (low loss). This helps the model
focus more on poorly performing levels of the hierarchy instead of
weighting all levels equally. If False, all levels are summed with
a fixed weight of 1.0.
strategy: str (default: 'linear_offset')
Weighting strategy for different cell type annotation levels.
The following weighting strategies are available: linear, exponential, linear_offset, equal, last.
linear: linear increase in weight from level to level.
exponential: exponential increase in weight from level to level
linear_offset: linear increase in weight from level to level with offset.
equal: equal weight for all cell type annotation levels.
last: uses only last cell type annotation for model evaluation.
device: str (default: 'auto')
Type of device to use in training model ('cpu', 'cuda'). Set 'auto' for automatic selection.
tune_params: dict or 'auto' (default: 'auto')
Dict specifying search spaces or "auto" to use built‑in defaults.
Ranges and step for scAdam model and training parameters.
Default tuning parameters are available using 'scparadise.scadam.get_default_tune_params'.
The differences between setting parameters are available in '?scparadise.scadam.get_default_tune_params'.
For a description of the parameters, see the 'scparadise.scadam.train' function.
batch_size: int or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
epochs: int or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
patience: int or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
nc: int or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
nb: int or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
nh: int or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
ed_nh_ratio: int or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
ff_hd: int or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
classifier_hd: int or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
dropout: float or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
lr: float or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
weight_decay: float or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
use_augmentation: bool or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
aug_probability: float or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
prob: float or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
noise_std: float or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
dropout_aug: float or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
alpha: float or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
from_unsupervised: bool or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
Supervised model training included in function.
pretrain_epochs: int or None (default: None)
If a value is specified, then the tuning of this parameter will not be performed.
pretrain_data: AnnData
Anndata with the same expression matrix for unsupervised pretraining.
Additional Anndata may not contain cell type annotations.
random_state: int (default: 0)
Controls the data shuffling, splitting to folds and model training.
Pass an int for reproducible output across multiple function calls.
verbose: bool (default: True)
Show progress bar for each trail during hyperparameter tuning.
"""
# Set random state (for reproducibility)
np.random.seed(random_state)
torch.manual_seed(random_state)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(random_state)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Device selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device
# Set default best_score
if os.path.isfile(os.path.join(path, model_name, 'best_score.txt').replace("\\","/")):
with open(os.path.join(path, model_name, 'best_score.txt').replace("\\","/")) as best_score:
best_score = best_score.read()
best_score = json.loads(best_score)
else:
best_score = 0
# Create folder to save tuning results of scAdam model
os.makedirs(os.path.join(path, model_name).replace("\\","/"), exist_ok = True)
# Get evaluation metric
if isinstance(eval_metric, str):
eval_metric = [eval_metric]
if verbose:
print(f"Device: {device}")
if verbose and strategy != 'last':
level_weights = celltype_level_weights(len(celltype_keys), strategy)
# Create dataset
dataset = scRNAseqDataset(adata, celltype_keys, layer=layer)
gn = dataset.gn
ncl = [dataset.celltype_encoders[key]['n_classes'] for key in celltype_keys]
if verbose:
print(f"Number of features: {gn}")
print(f"Number of cells: {len(dataset)}")
print(f"Label hierarchy: {' → '.join(celltype_keys)}")
print(f"Annotation levels weights using strategy '{strategy}':")
for key, n_celltypes, weight in zip(celltype_keys, ncl, level_weights):
print(f" {key}: {n_celltypes} cell types, {round(weight, 3)} relative weight")
print(f"Start model optimization using optuna...")
# Train/val split
indices = np.arange(len(dataset))
# Default tuning hyperparameters
default_tune_params = {
"nc": [2, 16, 2],
"nb": [1, 8, 1],
"nh": [2, 16, 2],
"ed_nh_ratio": [8, 32, 4],
"ff_hd": [128, 1024, 128],
"classifier_hd": [128, 1024, 128],
"dropout": [0.0, 0.5],
"lr": [1e-5, 1e-2],
"weight_decay": [1e-6, 1e-2],
"batch_size": [64, 2048, 64],
"patience": [5, 30, 5],
"epochs": [50, 200, 5],
"use_augmentation": [True, False],
"aug_probability": [0.1, 1.0],
"prob": [0.05, 0.4],
"noise_std": [0.0, 0.4],
"dropout_aug": [0.0, 0.4],
"alpha": [0.0, 0.4],
'from_unsupervised': [True, False],
'pretrain_epochs': [10, 75, 5],
}
if tune_params == "auto":
tune_params = default_tune_params
# Function for training a single fold
def train_fold(
params,
train_idx,
val_idx,
epochs,
device,
fold_id
):
# Create dataloaders
train_loader = DataLoader(
Subset(dataset, train_idx),
batch_size=params["batch_size"],
shuffle=True,
num_workers=0,
drop_last=False,
)
val_loader = DataLoader(
Subset(dataset, val_idx),
batch_size=params["batch_size"],
shuffle=False,
num_workers=0,
drop_last=False,
)
# Initialize model
model = scAdamTransformer(
gn=gn,
ncl=ncl,
ed=params["ed_nh_ratio"] * params["nh"],
nc=params["nc"],
nb=params["nb"],
nh=params["nh"],
ff_hd=params["ff_hd"],
classifier_hd=params["classifier_hd"],
dropout=params["dropout"],
).to(device)
if params['from_unsupervised']:
if pretrain_data is not None:
pretrain_adata = pretrain_data
else:
pretrain_adata = adata
model = pretrain_unsupervised(
pretrain_adata,
model,
layer=layer,
batch_size=params["batch_size"],
epochs=params["pretrain_epochs"],
lr=params["lr"],
prob=params["prob"],
device=device,
random_state=random_state,
verbose = False
)
# Augmentation
augmenter = None
if params["use_augmentation"]:
augmenter = dust.Augmentation(
prob=params["prob"],
noise_std=params["noise_std"],
dropout_prob=params["dropout_aug"],
alpha=params["alpha"],
)
# Loss function
criterion = HierarchicalLoss(
num_levels=len(celltype_keys),
celltype_keys=celltype_keys,
alpha=0.25,
gamma=2.0,
adaptive=adaptive_loss
).to(device)
# Optimizer
optimizer = torch.optim.AdamW(
model.parameters(),
lr=params["lr"],
betas=(0.9, 0.95),
weight_decay=params["weight_decay"],
)
# Learning rate scheduler
def lr_lambda(epoch):
return 0.5 * (1.0 + np.cos(np.pi * epoch / max(params["epochs"], 1)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# Early stopping
early_stopping = dust.EarlyStopping(
patience=params["patience"],
delta=1e-4,
mode="max",
verbose=False
)
# Metric functions
available_metric_functions = {
'accuracy': accuracy_score,
'balanced_accuracy': balanced_accuracy_score,
'f1_score': lambda y_true, y_pred: f1_score(y_true, y_pred, average='weighted', zero_division=0)
}
# Create dictionary of used for training metric functions
metric_functions = {}
for metric in eval_metric:
if metric not in available_metric_functions:
raise ValueError(f"Unknown metric: {metric}")
else:
metric_functions[metric] = available_metric_functions[metric]
# Create training history dictionary
history = {
'train_loss': [],
'val_loss': [],
'train_metrics': {key: {metric: [] for metric in eval_metric}
for key in celltype_keys},
'val_metrics': {key: {metric: [] for metric in eval_metric}
for key in celltype_keys},
'lr': []
}
best_val = None
# Training loop
for epoch in range(params["epochs"]):
model.train()
train_loss = 0.0
train_recon_loss = 0.0
train_preds = {key: [] for key in celltype_keys}
train_targets = {key: [] for key in celltype_keys}
for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device)
batch_y = {key: val.to(device) for key, val in batch_y.items()}
# Augmantation block
if use_augmentation and np.random.random() < aug_probability:
batch_x = augmenter(batch_x)
# Get predictions using model
predictions = model(batch_x)
# Calculate model loss
loss, level_losses = criterion(predictions, batch_y)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
# Predictions and targets
with torch.no_grad():
for level_idx, key in enumerate(celltype_keys):
level_key = f'level_{level_idx}'
preds = predictions[level_key]['logits'].argmax(dim=1)
train_preds[key].append(preds.cpu())
train_targets[key].append(batch_y[key].cpu())
train_loss /= len(train_loader)
history['train_loss'].append(train_loss)
# Compute training metrics
for key in celltype_keys:
train_preds[key] = torch.cat(train_preds[key]).numpy()
train_targets[key] = torch.cat(train_targets[key]).numpy()
for metric_name, metric_func in metric_functions.items():
score = metric_func(train_targets[key], train_preds[key])
history['train_metrics'][key][metric_name].append(score)
# Validation
model.eval()
val_loss = 0.0
val_preds = {key: [] for key in celltype_keys}
val_targets = {key: [] for key in celltype_keys}
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x = batch_x.to(device)
batch_y = {key: val.to(device) for key, val in batch_y.items()}
# Forward pass
predictions = model(batch_x)
# Compute loss
loss, level_losses = criterion(predictions, batch_y)
val_loss += loss.item()
# Store predictions and targets
for level_idx, key in enumerate(celltype_keys):
level_key = f'level_{level_idx}'
preds = predictions[level_key]['logits'].argmax(dim=1)
val_preds[key].append(preds.cpu())
val_targets[key].append(batch_y[key].cpu())
val_loss /= len(val_loader)
history['val_loss'].append(val_loss)
# Validation metrics
for key in celltype_keys:
val_preds[key] = torch.cat(val_preds[key]).numpy()
val_targets[key] = torch.cat(val_targets[key]).numpy()
for metric_name, metric_func in metric_functions.items():
score = metric_func(val_targets[key], val_preds[key])
history['val_metrics'][key][metric_name].append(score)
# Learning rate
current_lr = optimizer.param_groups[0]['lr']
history['lr'].append(current_lr)
# Calculate validation metrics
if strategy == 'last':
fold_score = history['val_metrics'][celltype_keys[-1]][eval_metric][-1]
else:
fold_score = weighted_metric(
history=history,
celltype_keys=celltype_keys,
metric=eval_metric[-1],
strategy=strategy
)
if best_val is None or fold_score > best_val:
best_val = fold_score
# early stop on chosen metric
if early_stopping(fold_score, model):
break
scheduler.step()
# Restore best weights after early stopping
early_stopping.load_bm(model)
if verbose:
print(f"Fold {fold_id} finished with {eval_metric[-1]} value = {best_val:.6f}", flush=True)
return float(best_val)
# Function for processing integer hyperparameters
def int_param(name, fixed_value, trial):
if fixed_value is not None:
return trial.suggest_int(name, fixed_value, fixed_value)
if name in tune_params:
lo, hi, step = tune_params[name][0], tune_params[name][1], tune_params[name][2]
return trial.suggest_int(name, lo, hi, step=step)
# Function for processing float hyperparameters
def float_param(name, fixed_value, trial, log=False):
if fixed_value is not None:
return trial.suggest_float(name, fixed_value, fixed_value, log=log)
if name in tune_params:
lo, hi = tune_params[name][0], tune_params[name][1]
return trial.suggest_float(name, lo, hi, log=log)
# Function for processing categorial hyperparameters
def categorical_param(name, fixed_value, trial):
if fixed_value is not None:
return trial.suggest_categorical(name, [fixed_value, fixed_value])
if name in tune_params:
return trial.suggest_categorical(name, tune_params[name])
# Function to get hyperparameters in a trial
def suggest_params(trial):
# helper to suggest values either from tune_params or from fixed user-specified ones
params = {}
# Model initialization parameters
params["nc"] = int_param("nc", nc, trial)
params["nb"] = int_param("nb", nb, trial)
params["nh"] = int_param("nh", nh, trial)
params["ed_nh_ratio"] = int_param("ed_nh_ratio", ed_nh_ratio, trial)
params["ff_hd"] = int_param("ff_hd", ff_hd, trial)
params["classifier_hd"] = int_param("classifier_hd", classifier_hd, trial)
params["dropout"] = float_param("dropout", dropout, trial, log=False)
# Optimizer parameters
params["lr"] = float_param("lr", lr, trial, log=True)
params["weight_decay"] = float_param("weight_decay", weight_decay, trial, log=True)
# Training
params["batch_size"] = int_param("batch_size", batch_size, trial)
params["patience"] = int_param("patience", patience, trial)
params["epochs"] = int_param("epochs", epochs, trial)
# Augmentation parameters
params["use_augmentation"] = categorical_param("use_augmentation", use_augmentation, trial)
params["aug_probability"] = float_param("aug_probability", aug_probability, trial, log=False)
params["prob"] = float_param("prob", prob, trial, log=False)
params["noise_std"] = float_param("noise_std", noise_std, trial, log=False)
params["dropout_aug"] = float_param("dropout_aug", dropout_aug, trial, log=False)
params["alpha"] = float_param("alpha", alpha, trial, log=False)
# Pretraining
params["from_unsupervised"] = categorical_param("from_unsupervised", from_unsupervised, trial)
params["pretrain_epochs"] = int_param("pretrain_epochs", pretrain_epochs, trial)
return params
# Function for define objective and params
def objective(trial, best_score = best_score):
# Get trial params
params = suggest_params(trial)
# N splits of data for K fold cross validation
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
split_iter = skf.split(indices, adata.obs[celltype_keys[-1]]) # Using last (most detailed) annotation level
fold_scores = []
for fold_id, (train_idx, val_idx) in enumerate(split_iter, start=1):
score = train_fold(params, train_idx, val_idx, epochs, device, fold_id)
fold_scores.append(score)
# pruning hook
trial.report(score, step=fold_id)
if trial.should_prune():
raise optuna.TrialPruned()
# Get average score between folds
score = float(np.mean(fold_scores))
if score < best_score:
best_score = score
# Write best_params to model folder
with open(os.path.join(path, model_name, 'best_params.txt').replace("\\","/"), 'w') as f:
f.write(json.dumps(params))
with open(os.path.join(path, model_name, 'best_score.txt').replace("\\","/"), 'w') as f:
f.write(json.dumps(best_score))
return score
# Study storage
storage_url = None
if storage is not None:
# optuna expects URL-like string or sqlite
if storage.startswith("sqlite:///") or storage.startswith("postgresql://") or storage.startswith("mysql://"):
storage_url = storage
else:
storage_url = "sqlite:///" + os.path.join(path, model_name, storage).replace("\\","/")
# Create optuna study
study = optuna.create_study(
direction = "maximize",
study_name = study_name,
storage = storage_url,
load_if_exists = load_if_exists,
pruner = optuna.pruners.HyperbandPruner()
)
# Set default parameters
params_default = {
"nc": 16, "nb": 5, "nh": 8, "ed_nh_ratio": 32, "ff_hd": 512, "classifier_hd": 256,
"dropout": 0.3, "lr": 1e-4, "weight_decay": 1e-4, "batch_size": 128, "patience": 10,
"epochs": 200, "use_augmentation": True, "aug_probability": 0.5, "prob": 0.15,
"noise_std": 0.1, "dropout_aug": 0.1, "alpha": 0.2, "from_unsupervised": True,
"pretrain_epochs": 50
}
# Enqueue a trial which uses the default parameters
if not study.trials:
study.enqueue_trial(params_default)
# Restart optimization
trials = study.get_trials(deepcopy=False)
if len(trials) > 0:
last = trials[-1]
if last.state in (optuna.trial.TrialState.FAIL, optuna.trial.TrialState.RUNNING):
if len(last.params) > 0:
if verbose:
print(f"Re-enqueue last interrupted trial: Trial number {last.number} (will run as new trial).", flush=True)
study.enqueue_trial(last.params)
# Study optimization
study.optimize(objective, n_trials=num_trials, n_jobs=1)
best_params = dict(study.best_params)
best_score = float(study.best_value)
# Save best parameters and score
with open(os.path.join(path, model_name, 'best_params.txt').replace("\\","/"), "w") as f:
f.write(json.dumps(best_params))
with open(os.path.join(path, model_name, 'best_score.txt').replace("\\","/"), "w") as f:
f.write(json.dumps({"best_value": best_score, "eval_metric": eval_metric[-1]}))
if verbose:
print(f"Best value ({eval_metric[-1]}) = {best_score}")
print(f"Best hyperparameters saved to: {os.path.join(path, model_name, 'best_params.txt')}")
return best_params
# Function for training model using parameters tuned by scparadise.scadam.hyperparameter_tuning
[docs]
def train_tuned(
adata,
celltype_keys,
layer = None,
path = '',
path_tuned = '',
model_name = 'scAdam_model_tuned',
test_size = 0.2,
eval_metric = ['accuracy', 'balanced_accuracy'],
strategy = 'linear_offset',
batch_size = None,
epochs = None,
patience = None,
nc = None,
nb = None,
nh = None,
ed_nh_ratio = None,
ff_hd = None,
classifier_hd = None,
dropout = None,
lr = None,
weight_decay = None,
adaptive_loss = True,
use_augmentation = None,
aug_probability = None,
prob = None,
noise_std = None,
dropout_aug = None,
alpha = None,
from_unsupervised = None,
pretrain_epochs = None,
pretrain_data = None,
device = 'auto',
random_state = 0,
return_model = False,
unknown_detection = True,
verbose = True
):
"""
Train custom scAdam model with tuned hyperparameters.
The function automatically uses the configured hyperparameters.
However, you can change any hyperparameter by passing it via the corresponding parameter.
Parameters
----------
adata : AnnData
Dataset with cell type annotations in adata.obs
path: str, path object
Path to create a model folder containing the training history, cell annotation dictionary, and genes used for training.
path_tuned: str, path object
Path to folder with tuned parameters by 'scparadise.scadam.hyperparameter_tuning' function.
model_name: str (default: 'scAdam_model_tuned')
Name of a folder to save model.
celltype_keys: list
List of cell type annotations in adata.obs.
Example: ['lineage', 'cell type', 'cell state']
layer: str (default: None)
If specified, use adata.layers[layer] for expression values instead of adata.X.
test_size: float or int (default: 0.2)
If float, should be between 0.0 and 1.0 and represent the proportion
of the dataset to include in the test split. If int, represents the
absolute number of test cells.
epochs: int (default: None)
Maximum number of epochs for scAdam model training.
If specified, the specified value is used.
eval_metric: str or list (default: ['accuracy', 'balanced_accuracy'])
Available evaluation metrics:'accuracy', 'balanced_accuracy', 'f1_score'.
The last metric is used as the target and for early stopping.
batch_size: int, (default: None)
Number of examples per batch.
If specified, the specified value is used.
patience: int (default: None)
Number of consecutive epochs without improvement before performing early stopping.
If patience is set to 0, then no early stopping will be performed.
Note that if patience is enabled, then best weights from best epoch will automatically be loaded at the end of the training.
If specified, the specified value is used.
use_augmentation: bool (default: None)
Use data augmentation or not.
If specified, the specified value is used.
aug_probability: float (default: None)
The probability of applying augmentation to a batch.
If specified, the specified value is used.
prob: float (default: None)
Gene masking probability.
If specified, the specified value is used.
noise_std: float (default: None)
Gaussian noise standard deviation.
If specified, the specified value is used.
dropout_aug: float (default: None)
Dropout probability for simulating technical noise.
If specified, the specified value is used.
alpha: float (default: None)
Alpha parameter for mixup augmentation.
If specified, the specified value is used.
nc: int (default: None)
Number of chunks for genes from adata.
If specified, the specified value is used.
nb: int (default None)
Number of blocks in scAdam model.
If specified, the specified value is used.
nh: int (default: None)
Number of heads in scAdam model attention mechanism.
If specified, the specified value is used.
ed_nh_ratio: int (default: None)
Used for calculating embedding dimensionality ('ed') from 'nh'.
If specified, the specified value is used.
ff_hd: int (default: None)
Number of nodes in each scAdam model layer in feed forward network.
If specified, the specified value is used.
classifier_hd: int (default: None)
Number of nodes in scAdam classifier.
If specified, the specified value is used.
dropout: float (default: None)
Portion of neurons that temporarily ignored during training (prevents overfitting).
If specified, the specified value is used.
lr: float (default: None)
Determines the step size at each iteration while moving toward a minimum of a loss function.
If specified, the specified value is used.
weight_decay: float (default: None)
Weight decay coefficient.
If specified, the specified value is used.
from_unsupervised: bool (default: None)
Use a previously self supervised model as starting weights.
Supervised model training included in function.
pretrain_epochs: int (default: None)
Number of pretraining epochs.
pretrain_data: AnnData (default: None)
Anndata with the same expression matrix for unsupervised pretraining.
Additional Anndata may not contain cell type annotations.
device: str (default: 'auto')
Type of device to use in training model ('cpu', 'cuda'). Set 'auto' for automatic selection.
random_state: int (default: 0)
Controls the data shuffling, splitting to folds and model training.
Pass an int for reproducible output across multiple function calls.
verbose: bool (default: True)
Show progress bar for each epoch during training.
return_model: bool (default: False)
Return model after training or not.
"""
# Create new directory with model and list of genes
if not os.path.exists(os.path.join(path, model_name).replace("\\","/")):
os.makedirs(os.path.join(path, model_name).replace("\\","/"))
# load parameters for Adam model training
with open(os.path.join(path_tuned, 'best_params.txt')) as params:
params = params.read()
params = json.loads(params)
print('Successfully loaded tuned hyperparameters!')
# Dictionary of given parameters for a function
params_given = {
'epochs': epochs,
'batch_size': batch_size,
'patience': patience,
'use_augmentation': use_augmentation,
'aug_probability': aug_probability,
'prob': prob,
'noise_std': noise_std,
'dropout_aug': dropout_aug,
'alpha': alpha,
'nc': nc,
'nb': nb,
'nh': nh,
'ed_nh_ratio': ed_nh_ratio,
'ff_hd': ff_hd,
'classifier_hd': classifier_hd,
'dropout': dropout,
'lr': lr,
'weight_decay': weight_decay,
'from_unsupervised': from_unsupervised,
'pretrain_epochs': pretrain_epochs
}
# Dictionary of default parameters
params_default = {
'epochs': 200,
'batch_size': 128,
'patience': 10,
'use_augmentation': True,
'aug_probability': 0.5,
'prob': 0.15,
'noise_std': 0.1,
'dropout_aug': 0.1,
'alpha': 0.2,
'nc': 4,
'nb': 4,
'nh': 8,
'ed_nh_ratio': 32,
'ff_hd': 512,
'classifier_hd': 512,
'dropout': 0.3,
'lr': 1e-4,
'weight_decay': 1e-4,
'from_unsupervised': True,
'pretrain_epochs': 50
}
# Replace param in loaded parameters with a given value
for i in params_given.keys():
if params_given[i] is not None:
params[i] = params_given[i]
# Check params for None
for i in params.keys():
if params[i] is None:
params[i] = params_default[i]
# Train model with loaded parameters (corrected if given)
model = train(
adata = adata,
celltype_keys = celltype_keys,
layer = layer,
path = path,
model_name = model_name,
test_size = test_size,
strategy = strategy,
epochs = params['epochs'],
eval_metric = eval_metric,
batch_size = params['batch_size'],
patience = params['patience'],
use_augmentation = params['use_augmentation'],
aug_probability = params['aug_probability'],
prob = params['prob'],
noise_std = params['noise_std'],
dropout_aug = params['dropout_aug'],
alpha = params['alpha'],
nc = params['nc'],
nb = params['nb'],
nh = params['nh'],
ed_nh_ratio = params['ed_nh_ratio'],
ff_hd = params['ff_hd'],
classifier_hd = params['classifier_hd'],
dropout = params['dropout'],
lr = params['lr'],
weight_decay = params['weight_decay'],
from_unsupervised = params['from_unsupervised'],
pretrain_epochs = params['pretrain_epochs'],
pretrain_data = pretrain_data,
return_model = True,
unknown_detection = unknown_detection,
device = device,
random_state = random_state,
verbose = verbose
)
if return_model:
return model
# Function for creation of dataset for warm-start training
class scRNAseqDataset_warm_start(Dataset):
"""
Warm-start dataset:
- Aligns new_adata genes to model.var_names preserving order, missing genes -> zeros)
- Encodes labels using pretrained encoders (no refit), so classifier output dims stay valid
"""
def __init__(
self,
adata,
celltype_keys,
model_var_names,
celltype_encoders,
layer=None,
allow_unseen_labels=False,
):
self.adata = adata
self.celltype_keys = list(celltype_keys)
self.layer = layer
self.obs_names = adata.obs_names.tolist()
self.var_names = list(model_var_names)
self.gn = len(self.var_names)
# align expression matrix to model genes
if layer is not None and layer in adata.layers:
X = adata.layers[layer]
else:
X = adata.X
X = X.toarray() if hasattr(X, 'toarray') else X
new_var_names = adata.var_names.tolist()
idx_map = {g: i for i, g in enumerate(new_var_names)}
n_cells = adata.n_obs
n_model_genes = self.gn
X_aligned = np.zeros((n_cells, n_model_genes), dtype=np.float32)
matched = 0
for j, g in enumerate(self.var_names):
i = idx_map.get(g, None)
if i is None:
continue
matched += 1
X_aligned[:, j] = X[:, i]
self.matched_genes = matched
self.match_fraction = matched / max(n_model_genes, 1)
# min-max per gene as in scRNAseqDataset
if X_aligned.min(0).any():
X_aligned = np.maximum(X_aligned, 0)
X_norm = (X_aligned - X_aligned.min(0)) / (np.ptp(X_aligned, axis=0) + 1e-10)
self.X = torch.FloatTensor(X_norm)
assert not torch.isnan(self.X).any(), "NaN in adata data!"
assert not torch.isinf(self.X).any(), "Inf in adata data!"
# Labels encoding using pretrained models encoders
self.labels = {}
self.celltype_encoders = celltype_encoders
for key in self.celltype_keys:
if key not in adata.obs.columns:
raise ValueError(f"'{key}' not found in adata.obs")
enc = celltype_encoders[key]["celltype_encoder"]
y_str = adata.obs[key].astype(str).values
known = set(enc.classes_.tolist())
unseen = sorted(set(y_str.tolist()) - known)
if len(unseen) > 0 and not allow_unseen_labels:
raise ValueError(
f"Unseen labels in '{key}': {unseen[:10]}{'...' if len(unseen) > 10 else ''}. "
f"Warm-start requires the same label space as the pretrained model."
)
if len(unseen) > 0 and allow_unseen_labels:
fallback = enc.classes_[0]
y_str = np.array([v if v in known else fallback for v in y_str], dtype=object)
y_enc = enc.transform(y_str)
self.labels[key] = torch.LongTensor(y_enc)
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
x = self.X[idx]
y = {key: self.labels[key][idx] for key in self.celltype_keys}
return x, y
# Function for warm-start scAdam model fine-tuning
[docs]
def warm_start(
adata,
path_model,
celltype_keys=None,
layer=None,
path='',
model_name="scAdam_model_warm_start",
test_size=0.2,
eval_metric=["accuracy", "balanced_accuracy"],
strategy="linear_offset",
batch_size=128,
epochs=100,
patience=10,
lr=5e-5,
weight_decay=1e-4,
use_augmentation=True,
aug_probability=0.5,
prob=0.15,
noise_std=0.1,
dropout_aug=0.1,
alpha=0.2,
adaptive_loss=True,
freeze_transformer=False,
allow_unseen_labels=False,
unknown_detection=True,
device="auto",
random_state=0,
return_model=False,
verbose=True
):
"""
Warm-start fine-tuning of an existing scAdam model on new data.
Warm-start training is a technique in machine learning that involves initializing a model with parameters or states learned from a previously trained model.
adata: AnnData
New dataset with cell type annotations in adata.obs
path_model: str, path object
Path to a model folder containing pretrained scAdam model.
path: str, path object
Path to create a model folder containing the training history, cell annotation dictionary, and genes used for scAdam model warm start training.
celltype_keys: list
List of cell type annotations in adata.obs.
Example: ['lineage', 'cell type', 'cell state']
layer: str (default: None)
If specified, use adata.layers[layer] for expression values instead of adata.X.
model_name: str (default: 'scAdam_model_warm_start')
Name of a folder to save model.
test_size: float or int (default: 0.2)
If float, should be between 0.0 and 1.0 and represent the proportion
of the dataset to include in the test split. If int, represents the
absolute number of test cells.
batch_size: int, (default: 128)
Number of examples per batch.
epochs: int (default: 150)
Maximum number of epochs for scAdam model training
patience: int (default: 10)
Number of consecutive epochs without improvement before performing early stopping.
If patience is set to 0, then no early stopping will be performed.
Note that if patience is enabled, then best weights from best epoch will automatically be loaded at the end of the training.
eval_metric: str or list (default: ['accuracy', 'balanced_accuracy'])
Available evaluation metrics:'accuracy', 'balanced_accuracy', 'f1_score'.
The last metric is used as the target and for early stopping.
freeze_transformer: bool (default = False)
If True, freezes the transformer backbone (gene embedding + transformer blocks) and trains only the classifier head.
If False, fine-tunes the full model.
allow_unseen_labels: bool (default = False)
If False, raise an error if new data contains labels not present in the pretrained `LabelEncoder` for any level.
If True, unseen labels are mapped to a fallback known label to keep the label space unchanged.
unknown_detection: bool (default: True)
Train unknown cell detector - identifies unknown cells when model used for prediction on a new data.
verbose: bool (default: True)
Show progress bar for each epoch during training.
return_model: bool (default: False)
Return model after training or not.
Returns
-------
Saves fine-tuned scAdam model for cell type annotation.
"""
# Set random state (for reproducibility)
np.random.seed(random_state)
torch.manual_seed(random_state)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(random_state)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Device selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device
# load model
model = load_model(path_model, device=device, verbose=verbose)
model = model.to(device)
if celltype_keys is None:
if not hasattr(model, "celltype_keys") or model.celltype_keys is None:
raise ValueError("celltype_keys is None and model has no celltype_keys metadata.")
celltype_keys = list(model.celltype_keys)
if verbose:
print(f"Device: {device}")
if verbose and strategy != 'last':
level_weights = celltype_level_weights(len(celltype_keys), strategy)
# Align genes and create datasest
dataset = scRNAseqDataset_warm_start(
adata,
celltype_keys=celltype_keys,
model_var_names=model.var_names,
celltype_encoders=model.celltype_encoders,
layer=layer,
allow_unseen_labels=allow_unseen_labels,
)
ncl = [dataset.celltype_encoders[key]['n_classes'] for key in celltype_keys]
if verbose:
print("Gene alignment:")
print(f" Number of features: {dataset.gn}")
print(f" Matched features: {dataset.matched_genes} ({100*dataset.match_fraction:.1f}%)")
if dataset.matched_genes < 0.80 * dataset.gn:
warnings.warn(
f"Only {100*dataset.match_fraction:.1f} genes matched. Fine-tuning may be unstable; "
"Ensure consistent feature naming and data preprocessing (filtering features, reference genome, etc.)"
)
if verbose:
print(f"Label hierarchy: {' → '.join(celltype_keys)}")
print(f"Annotation levels weights using strategy '{strategy}':")
for key, n_celltypes, weight in zip(celltype_keys, ncl, level_weights):
print(f" {key}: {n_celltypes} cell types, {round(weight, 3)} relative weight")
# Sanity check: output dims must match checkpoint ncl
ncl_expected = list(model.ncl)
ncl_now = [len(model.celltype_encoders[k]["celltype_encoder"].classes_) for k in celltype_keys]
if ncl_now != ncl_expected:
raise ValueError(
f"Mismatch cell type numbers: model cell types number = {ncl_expected}, given cell types number = {ncl_now}. "
"This should not happen unless metadata is inconsistent."
)
# Train/val split
indices = np.arange(len(dataset))
train_idx, val_idx = train_test_split(
indices,
test_size=test_size,
stratify=adata.obs[celltype_keys[-1]], # Using last (most detailed) annotation level
random_state=random_state
)
if verbose:
print(f"\nDataset split:")
print(f'Train dataset contains: {len(train_idx)} cells, it is {round(100*(len(train_idx)/(len(train_idx) + len(val_idx))), ndigits=2)} % of input dataset')
print(f'Validation dataset contains: {len(val_idx)} cells, it is {round(100*(len(val_idx)/(len(train_idx) + len(val_idx))), ndigits=2)} % of input dataset')
# Create dataloaders
train_loader = DataLoader(
Subset(dataset, train_idx),
batch_size=batch_size,
shuffle=True,
num_workers=0,
drop_last=False
)
val_loader = DataLoader(
Subset(dataset, val_idx),
batch_size=batch_size,
shuffle=False,
num_workers=0,
drop_last=False
)
# freeze backbone (optional)
if freeze_transformer:
for p in model.gene_embedding.parameters():
p.requires_grad = False
for blk in model.blocks:
for p in blk.parameters():
p.requires_grad = False
# keep classifier trainable
for p in model.classifier.parameters():
p.requires_grad = True
# Augmentation
augmenter = None
if use_augmentation:
augmenter = dust.Augmentation(
prob = prob,
noise_std = noise_std,
dropout_prob = dropout_aug,
alpha = alpha
)
# loss function
criterion = HierarchicalLoss(
num_levels=len(celltype_keys),
celltype_keys=celltype_keys,
alpha=0.25,
gamma=2.0,
adaptive=adaptive_loss
).to(device)
# Optimizer over trainable params only
params_to_train = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(
params_to_train,
lr=lr,
betas=(0.9, 0.95),
weight_decay=weight_decay
)
# Learning rate scheduler
def lr_lambda(epoch):
return 0.5 * (1.0 + np.cos(np.pi * (epoch / epochs)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# Early stopping
early_stopping = dust.EarlyStopping(
patience=patience,
delta=1e-4,
mode='max',
verbose=verbose
)
# Available metric functions
available_metric_functions = {
'accuracy': accuracy_score,
'balanced_accuracy': balanced_accuracy_score,
'f1_score': lambda y_true, y_pred: f1_score(y_true, y_pred, average='weighted', zero_division=0)
}
# Create dictionary of used for training metric functions
metric_functions = {}
for metric in eval_metric:
if metric not in available_metric_functions:
raise ValueError(f"Unknown metric: {metric}")
else:
metric_functions[metric] = available_metric_functions[metric]
# Create training history dictionary
history = {
'train_loss': [],
'val_loss': [],
'train_metrics': {key: {metric: [] for metric in eval_metric}
for key in celltype_keys},
'val_metrics': {key: {metric: [] for metric in eval_metric}
for key in celltype_keys},
'lr': []
}
# Training Loop
for epoch in tqdm(range(epochs), desc="Warm-start model fine-tuning", colour="blue", disable=not verbose):
# Model training
model.train()
train_loss = 0.0
train_preds = {key: [] for key in celltype_keys}
train_targets = {key: [] for key in celltype_keys}
for batch_x, batch_y in train_loader:
batch_x = batch_x.to(device)
batch_y = {k: v.to(device) for k, v in batch_y.items()}
# Augmentation only if random < aug_probability
if use_augmentation and np.random.random() < aug_probability:
batch_x = augmenter(batch_x)
# Get predictions using model
predictions = model(batch_x)
# Calculate model loss
loss, _ = criterion(predictions, batch_y)
# Backward
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(params_to_train, max_norm=1.0)
optimizer.step()
train_loss += loss.item()
# Predictions and targets
with torch.no_grad():
for level_idx, key in enumerate(celltype_keys):
level_key = f'level_{level_idx}'
preds = predictions[level_key]['logits'].argmax(dim=1)
train_preds[key].append(preds.cpu())
train_targets[key].append(batch_y[key].cpu())
train_loss /= len(train_loader)
history['train_loss'].append(train_loss)
# Compute training metrics
for key in celltype_keys:
train_preds[key] = torch.cat(train_preds[key]).numpy()
train_targets[key] = torch.cat(train_targets[key]).numpy()
for metric_name, metric_func in metric_functions.items():
score = metric_func(train_targets[key], train_preds[key])
history['train_metrics'][key][metric_name].append(score)
# Model validation
model.eval()
val_loss = 0.0
val_preds = {key: [] for key in celltype_keys}
val_targets = {key: [] for key in celltype_keys}
with torch.no_grad():
for batch_x, batch_y in val_loader:
batch_x = batch_x.to(device)
batch_y = {k: v.to(device) for k, v in batch_y.items()}
# Get predictions using model
predictions = model(batch_x)
# Calculate model loss
loss, _ = criterion(predictions, batch_y)
val_loss += loss.item()
# Store predictions and targets
for level_idx, key in enumerate(celltype_keys):
level_key = f"level_{level_idx}"
preds = predictions[level_key]["logits"].argmax(dim=1)
val_preds[key].append(preds.cpu())
val_targets[key].append(batch_y[key].cpu())
val_loss /= len(val_loader)
history['val_loss'].append(val_loss)
# Validation metrics
for key in celltype_keys:
val_preds[key] = torch.cat(val_preds[key]).numpy()
val_targets[key] = torch.cat(val_targets[key]).numpy()
for metric_name, metric_func in metric_functions.items():
score = metric_func(val_targets[key], val_preds[key])
history['val_metrics'][key][metric_name].append(score)
# Learning rate
history["lr"].append(optimizer.param_groups[0]["lr"])
# Early Stopping
if strategy == "last":
es_score = history["val_metrics"][celltype_keys[-1]][eval_metric[-1]][-1]
else:
es_score = weighted_metric(
history=history,
celltype_keys=celltype_keys,
metric=eval_metric[-1],
strategy=strategy
)
if early_stopping(es_score, model):
break
scheduler.step()
# Load best model
early_stopping.load_bm(model)
model.eval()
# Update metadata and training history
model.celltype_keys = celltype_keys
model.history = history
# Remove loaders
del train_loader, val_loader
# Refit unknown detector on new data
if unknown_detection:
# Create detector
detector = UnknownCellDetector()
# Create full dataset loader for detector training
full_loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
drop_last=False
)
# Fit detector on full dataset
detector.fit(model, full_loader, device=device)
# Save detector
model.unknown_detector = detector
if verbose:
print("UnknownCellDetector fitted successfully!")
# Save warm start fine tuned model
save_model(model, path, model_name, verbose)
if return_model:
return model
# Function to display available models in github
[docs]
def available_models(
):
'''
Download dataframe with available pretrained scAdam models.
'''
models = pd.read_csv('https://raw.githubusercontent.com/Chechekhins/scParadise/main/scadam_available_models.csv', sep=',')
return models
# Function for downloading tuned pretrained models from github
[docs]
def download_model(
model_name='',
save_path='',
github_username=None,
github_token=None
):
"""
Download pretrained tuned model for highly accurate cell type annotation.
Parameters
----------
model_name: str
Name of the model from column 'model' from scparadise.scadam.available_models().
save_path: str, path object
Path to save trained scAdam model.
github_username: str
Your GitHub username.
github_token: str
Token for GitHub API.
"""
# Create new directory with model
save = os.path.join(save_path, model_name + '_scAdam').replace("\\", "/")
os.makedirs(save, exist_ok=True)
# read creds from args or env
github_username = github_username or os.getenv("GITHUB_USERNAME")
github_token = github_token or os.getenv("GITHUB_TOKEN") or os.getenv("GH_TOKEN")
fs_kwargs = dict(org="Chechekhins", repo="scParadise")
if github_username and github_token:
fs_kwargs.update(username=github_username, token=github_token)
fs = fsspec.filesystem("github", **fs_kwargs)
# Download content of model
remote_dir = os.path.join("models_scadam", model_name + "_scAdam").replace("\\", "/")
fs.get(fs.ls(remote_dir), save)