scparadise.sceve.train_tuned#
- scparadise.sceve.train_tuned(mdata, path='', path_tuned='', rna_modality_name='rna', second_modality_name='adt', detailed_annotation=None, model_name='model_regression_tuned', accelerator='auto', random_state=0, test_size=0.1, optimizer_fn=<class 'torch.optim.adam.Adam'>, scheduler_fn=<class 'torch.optim.lr_scheduler.StepLR'>, loss_fn=MSELoss(), step_size=10, gamma_scheduler=0.95, verbose=True, eval_metric=['accuracy'], drop_last=True, return_model=False)[source]#
Train the scEve model using parameters tuned by the ‘scparadise.sceve.hyperparameter_tuning’ function.
- Parameters:
mdata (MuData) – MuData object.
rna_modality_name (str, (default: 'rna')) – Name of RNA (GEX) modality in MuData object.
second_modality_name (str, (default: 'adt')) – Name of protein (ADT) or ATAC-seq modality in MuData object.
path (str, path object) – Path to create a folder with model, training history, dictionary of cell annotations and genes used for training.
path_tuned (str, path object) – Path to folder with tuned parameters by scparadise.scadam.hyperparameter_tuning function.
detailed_annotation (str, (default: None)) – The most detailed level of cell annotation. Key in mdata.obs dataframe. If given may increase model evaluation score.
model_name (str, (default: 'model_regression_tuned')) – Name of a folder to save model.
accelerator (str, (default: 'auto')) – Type of accelerator to use in training model (‘cpu’, ‘cuda’). Set ‘auto’ for automatic selection.
random_state (int, (default: 0)) – Controls the data shuffling, splitting to folds and model training. Pass an int for reproducible output across multiple function calls.
test_size (float or int, (default: 0.1)) – If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test cells.
optimizer_fn (func, (default: torch.optim.Adam)) – Pytorch Optimizer function.
scheduler_fn (func, (default: torch.optim.lr_scheduler.StepLR)) – Pytorch Scheduler to change learning rates during training.
loss_fn (torch.loss function (default: torch.nn.MSELoss)) – Loss function for training.
step_size (int, (default: 10)) – Scheduler learning rate decay.
gamma_scheduler (float, (default: 0.95)) – Multiplicative factor of scheduler learning rate decay. step_size and gamma_scheduler are used in dictionary of parameters to apply to the scheduler_fn.
verbose (int (0 or 1), bool (True or False), (default: True)) – Show progress bar for each epoch during training. Set to 1 or ‘True’ to see every epoch progress, 0 or ‘False’ to get None.
eval_metric (list, (default: ['rmse'])) – List of evaluation metrics (‘mse’, ‘mae’, ‘rmse’, ‘rmsle’). The last metric is used for early stopping. Mean Squared Logarithmic Error (rmsle) cannot be used when targets contain negative values.
drop_last (bool, (default: True)) – Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
return_model (bool, (default: False)) – Return model after training or not.