scparadise.scadam.train#
- scparadise.scadam.train(adata, celltype_keys, layer=None, path='', model_name='scAdam_model', test_size=0.2, eval_metric=['accuracy', 'balanced_accuracy'], strategy='linear_offset', batch_size=128, epochs=200, patience=10, nc=16, nb=5, nh=8, ff_hd=512, ed_nh_ratio=32, classifier_hd=256, dropout=0.3, lr=0.0001, weight_decay=0.0001, use_augmentation=True, aug_probability=0.5, prob=0.15, noise_std=0.1, dropout_aug=0.1, alpha=0.2, from_unsupervised=True, pretrain_epochs=50, pretrain_data=None, adaptive_loss=True, device='auto', random_state=0, return_model=False, unknown_detection=True, verbose=True)[source]#
scAdam model training with unsupervised pretraining. scAdam is a model for automatic cell type annotation. This function creates a model that can be used for cell type annotation.
- Parameters:
adata (AnnData) – Dataset with cell type annotations in adata.obs
path (str, path object) – Path to create a model folder containing the training history, cell annotation dictionary, and genes used for training.
celltype_keys (list) – List of cell type annotations in adata.obs. Example: [‘lineage’, ‘cell type’, ‘cell state’]
layer (str (default: None)) – If specified, use adata.layers[layer] for expression values instead of adata.X.
model_name (str (default: 'scAdam_model')) – Name of a folder to save model.
test_size (float or int, (default: 0.2)) – If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test cells.
batch_size (int, (default: 128)) – Number of examples per batch.
epochs (int (default: 150)) – Maximum number of epochs for scAdam model training
patience (int (default: 10)) – Number of consecutive epochs without improvement before performing early stopping. If patience is set to 0, then no early stopping will be performed. Note that if patience is enabled, then best weights from best epoch will automatically be loaded at the end of the training.
eval_metric (str or list (default: ['accuracy', 'balanced_accuracy'])) – Available evaluation metrics:’accuracy’, ‘balanced_accuracy’, ‘f1_score’. The last metric is used as the target and for early stopping.
strategy (str (default: 'linear_offset')) – Weighting strategy for different cell type annotation levels. The following weighting strategies are available: linear, exponential, linear_offset, equal, last. linear: linear increase in weight from level to level. exponential: exponential increase in weight from level to level linear_offset: linear increase in weight from level to level with offset. equal: equal weight for all cell type annotation levels. last: uses only last cell type annotation for model evaluation.
nc (int (default: 16)) – Number of chunks for genes from adata.
nb (int (default: 5)) – Number of blocks in scAdam model.
nh (int (default: 8)) – Number of heads in scAdam model attention mechanism.
ed_nh_ratio (int (default: 32)) – Used for calculating embedding dimensionality (‘ed’) from ‘nh’. Default ed = nh * ed_nh_ratio = 8 * 32 = 256.
ff_hd (int (default: 512)) – Number of nodes in each scAdam model layer in feed forward network.
classifier_hd (int (default: 256)) – Number of nodes in each scAdam classifier.
dropout (float (default: 0.3)) – Portion of neurons that temporarily ignored during training (prevents overfitting).
lr (float (default: 1e-4)) – Determines the step size at each iteration while moving toward a minimum of a loss function.
weight_decay (float (default: 1e-4)) – Weight decay coefficient.
adaptive_loss (bool (default: True)) – If True, enables adaptive weighting of hierarchical loss levels: each level’s loss is tracked over training and its contribution to the total loss is increased if this level remains hard (high loss) and decreased if it becomes easy (low loss). This helps the model focus more on poorly performing levels of the hierarchy instead of weighting all levels equally. If False, all levels are summed with a fixed weight of 1.0.
use_augmentation (bool (default: True)) – Use data augmentation or not.
aug_probability (float (default: 0.5)) – The probability of applying augmentation to a batch
prob (float (default: 0.15)) – Gene masking probability.
noise_std (float (default: 0.1)) – Gaussian noise standard deviation.
dropout_aug (float (default: 0.1)) – Dropout probability for simulating technical noise.
alpha (float (default: 0.2)) – Alpha parameter for mixup augmentation.
from_unsupervised (bool (default: True)) – Use a previously self supervised model as starting weights. Supervised model training included in function.
pretrain_epochs (int (default: 50)) – Number of pretraining epochs.
pretrain_data (AnnData) – Anndata with the same expression matrix for unsupervised pretraining. Additional Anndata may not contain cell type annotations.
device (str (default: 'auto')) – Type of device to use in training model (‘cpu’, ‘cuda’). Set ‘auto’ for automatic selection.
random_state (int (default: 0)) – Controls the data shuffling, splitting to folds and model training. Pass an int for reproducible output across multiple function calls.
verbose (bool (default: True)) – Show progress bar for each epoch during training.
return_model (bool (default: False)) – Return model after training or not.
unknown_detection (bool (default: True)) – Train unknown cell detector - identifies unknown cells when model used for prediction on a new data.
- Returns:
Saves scAdam model for cell type annotation.