scparadise.scadam.train_tuned

Contents

scparadise.scadam.train_tuned#

scparadise.scadam.train_tuned(adata, path='', path_tuned='', celltype_l1=None, celltype_l2=None, celltype_l3=None, celltype_l4=None, celltype_l5=None, model_name='model_annotation_tuned', accelerator='auto', random_state=0, test_size=0.1, optimizer_fn=<class 'torch.optim.adamw.AdamW'>, scheduler_fn=<class 'torch.optim.lr_scheduler.StepLR'>, loss_fn=CrossEntropyLoss(), step_size=10, gamma_scheduler=0.95, verbose=True, eval_metric=['accuracy'], drop_last=True, return_model=False)[source]#

Train the scAdam model using parameters tuned by the ‘scparadise.scadam.hyperparameter_tuning’ function.

Parameters:
  • adata (AnnData) – Annotated data matrix.

  • path (str, path object) – Path to create a folder with model, training history, dictionary of cell annotations and genes used for training.

  • path_tuned (str, path object) – Path to folder with tuned parameters by scparadise.scadam.hyperparameter_tuning function.

  • celltype_l1 (str, (default: None)) – First level of cell annotation. Key in adata.obs dataframe.

  • celltype_l2 (str, (default: None)) – Second level of cell annotation. Key in adata.obs dataframe.

  • celltype_l3 (str, (default: None)) – Third level of cell annotation. Key in adata.obs dataframe.

  • celltype_l4 (str, (default: None)) – Forth level of cell annotation. Key in adata.obs dataframe.

  • celltype_l5 (str, (default: None)) – Fifth level of cell annotation. Key in adata.obs dataframe.

  • model_name (str, (default: 'model_annotation_tuned')) – Name of a folder to save model.

  • accelerator (str, (default: 'auto')) – Type of accelerator to use in training model (‘cpu’, ‘cuda’). Set ‘auto’ for automatic selection.

  • random_state (int, (default: 0)) – Controls the data shuffling, splitting to folds and model training. Pass an int for reproducible output across multiple function calls.

  • test_size (float or int, (default: 0.1)) – If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split. If int, represents the absolute number of test cells.

  • optimizer_fn (func, (default: torch.optim.AdamW)) – Pytorch Optimizer function.

  • scheduler_fn (func, (default: torch.optim.lr_scheduler.StepLR)) – Pytorch Scheduler to change learning rates during training.

  • loss_fn (torch.loss function (default: torch.nn.CrossEntropyLoss)) – Loss function for training.

  • step_size (int, (default: 10)) – Scheduler learning rate decay.

  • gamma_scheduler (float, (default: 0.95)) – Multiplicative factor of scheduler learning rate decay. step_size and gamma_scheduler are used in dictionary of parameters to apply to the scheduler_fn.

  • verbose (int (0 or 1), bool (True or False), (default: True)) – Show progress bar for each epoch during training. Set to 1 or ‘True’ to see every epoch progress, 0 or ‘False’ to get None.

  • eval_metric (list, (default: 'accuracy')) – List of evaluation metrics (‘accuracy’, ‘balanced_accuracy’, ‘logloss’). The last metric is used for early stopping.

  • drop_last (bool, (default: True)) – Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller.

  • return_model (bool, (default: False)) – Return model after training or not.