scAdam model optimization#

Model optimization refers to the process of improving the performance and efficiency of machine learning models. scAdam models support two types of optimization:

  1. Warm start: Warm starting is a technique in machine learning that involves initializing a model with weights from a previously trained model on the same or a similar task. This method allows the training process to begin from a more advantageous point on the loss surface, leveraging prior knowledge to improve efficiency and performance.

  2. Hyperparameter tuning: This involves adjusting the hyperparameters of the model, such as learning rates, batch sizes, and other configuration settings, to optimize performance on a specific dataset.

[1]:
# Python packages
import warnings
warnings.simplefilter('ignore')

import scanpy as sc
import scparadise
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

sc.set_figure_params(dpi = 120)

Recommendations about training dataset#

We recommend shifted logarithm data normalization method:

sc.pp.normalize_total(adata, target_sum=None)

sc.pp.log1p(adata)

But you can use any other method of data normalzation (Use the same normalization method for test dataset)

Training dataset sould contain all genes that you want to use for model training in adata_train.X We recommend to remove all non-marker genes from adata_train.X (removeing of such useless genes may increase performance and model quality metrics and also increase training speed)

Data preparation#

Here we download and preprocess dataset from cellxgene:

Single cell RNA sequencing of oropharyngeal squamous cell carcinoma https://cellxgene.cziscience.com/collections/3c34e6f1-6827-47dd-8e19-9edcd461893f

[2]:
!wget https://datasets.cellxgene.cziscience.com/915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad
--2025-01-15 22:51:33--  https://datasets.cellxgene.cziscience.com/915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad
Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 52.85.49.17, 52.85.49.24, 52.85.49.125, ...
Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|52.85.49.17|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 981941987 (936M) [binary/octet-stream]
Saving to: ‘915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad’

915069db-1df2-49a1- 100%[===================>] 936.45M  3.21MB/s    in 4m 10s

2025-01-15 22:55:43 (3.75 MB/s) - ‘915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad’ saved [981941987/981941987]

[3]:
# Load prepared for training anndata object
adata = sc.read_h5ad('915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad')
[4]:
# Get raw counts from adata.raw
adata = adata.raw.to_adata()
[5]:
# Convert var_names from ENSG codes to gene names
adata.var.set_index('feature_name', inplace=True)
adata.var_names_make_unique()
[6]:
# Select datasets for model training and testing
# It is necessary for all cell types to be present in both the training datasets and the testing datasets
df = scparadise.scnoah.cell_counter(adata, celltype='cell_type', sample='donor_id')
df
[6]:
HN481 HN482 HN483 HN485 HN487 HN488 HN489 HN490 HN492 HN494
epithelial cell 6747 758 934 575 2061 1335 2533 2955 2619 780
plasma cell 868 1051 498 609 106 596 1005 975 642 579
CD8-positive, alpha-beta T cell 613 1129 1204 740 279 941 1004 472 1408 911
fibroblast 512 164 256 163 371 118 856 204 234 560
myeloid cell 415 311 159 361 995 136 929 394 203 2998
regulatory T cell 370 316 623 351 135 69 881 385 321 616
endothelial cell 283 174 144 111 122 152 311 240 135 138
CD4-positive, alpha-beta T cell 282 1272 2629 911 87 224 1310 1076 789 1784
mural cell 267 170 98 156 195 173 286 233 87 128
B cell 213 1837 3369 764 30 2053 1338 2297 942 1832
mast cell 92 56 91 56 20 39 33 29 74 27
natural killer cell 88 135 205 65 40 120 465 94 58 128
mature alpha-beta T cell 44 31 13 7 8 7 82 30 38 42
endothelial cell of lymphatic vessel 29 23 25 10 8 9 59 6 28 8
plasmacytoid dendritic cell 19 31 43 25 13 22 87 24 20 163
[7]:
# Create adata_train1, adata_train2 and adata_test datasets
adata_test = adata[adata.obs['donor_id'].isin(['HN481', 'HN482'])].copy()
adata_train1 = adata[adata.obs['donor_id'].isin(['HN483' 'HN485', 'HN488', 'HN489'])].copy()
adata_train2 = adata[adata.obs['donor_id'].isin(['HN490', 'HN487', 'HN492', 'HN494'])].copy()
#del adata
[8]:
# Normalize data, find highly variable features
for i in [adata_train1, adata_train2, adata_test]:
    i.layers['counts'] = i.X.copy()
    sc.pp.normalize_total(i, target_sum=None)
    sc.pp.log1p(i)
    i.raw = i
    sc.pp.highly_variable_genes(i,
                                layer='counts',
                                flavor='seurat_v3',
                                n_top_genes=1200,
                                subset=True)

Default model training#

Use ‘balanced_accuracy’ as the evaluation metric in this case due to the imbalance in the number of cells among cell types.

[9]:
# Train scadam model using adata_train1 dataset
scparadise.scadam.train(adata_train1,
                        path='', # path to save model
                        model_name='model_scadam', # folder name with model
                        celltype_l1='cell_type',
                        eval_metric=['accuracy','balanced_accuracy']) # If you are using an imbalanced training dataset, we recommend using balanced_accuracy for early stopping
Successfully saved genes names for training model

Successfully saved dictionary of dataset annotations

Train dataset contains: 15455 cells, it is 90.0 % of input dataset
Test dataset contains: 1718 cells, it is 10.0 % of input dataset

Accelerator: cuda
Start training
epoch 0  | loss: 2.55518 | train_accuracy: 0.34455 | train_balanced_accuracy: 0.1102  | valid_accuracy: 0.34517 | valid_balanced_accuracy: 0.11042 |  0:00:01s
epoch 1  | loss: 2.00387 | train_accuracy: 0.42077 | train_balanced_accuracy: 0.13547 | valid_accuracy: 0.41909 | valid_balanced_accuracy: 0.13511 |  0:00:02s
epoch 2  | loss: 1.65271 | train_accuracy: 0.58932 | train_balanced_accuracy: 0.29457 | valid_accuracy: 0.57625 | valid_balanced_accuracy: 0.28318 |  0:00:03s
epoch 3  | loss: 1.28917 | train_accuracy: 0.60679 | train_balanced_accuracy: 0.31877 | valid_accuracy: 0.60419 | valid_balanced_accuracy: 0.31184 |  0:00:04s
epoch 4  | loss: 1.09993 | train_accuracy: 0.68515 | train_balanced_accuracy: 0.41748 | valid_accuracy: 0.68393 | valid_balanced_accuracy: 0.4168  |  0:00:04s
epoch 5  | loss: 0.96115 | train_accuracy: 0.7144  | train_balanced_accuracy: 0.4634  | valid_accuracy: 0.71478 | valid_balanced_accuracy: 0.46791 |  0:00:05s
epoch 6  | loss: 0.84477 | train_accuracy: 0.75865 | train_balanced_accuracy: 0.50364 | valid_accuracy: 0.76368 | valid_balanced_accuracy: 0.51123 |  0:00:06s
epoch 7  | loss: 0.74068 | train_accuracy: 0.82472 | train_balanced_accuracy: 0.55661 | valid_accuracy: 0.82363 | valid_balanced_accuracy: 0.55338 |  0:00:07s
epoch 8  | loss: 0.66158 | train_accuracy: 0.86005 | train_balanced_accuracy: 0.61093 | valid_accuracy: 0.86205 | valid_balanced_accuracy: 0.61738 |  0:00:08s
epoch 9  | loss: 0.6103  | train_accuracy: 0.86962 | train_balanced_accuracy: 0.63885 | valid_accuracy: 0.87078 | valid_balanced_accuracy: 0.62554 |  0:00:09s
epoch 10 | loss: 0.56654 | train_accuracy: 0.89097 | train_balanced_accuracy: 0.70924 | valid_accuracy: 0.88882 | valid_balanced_accuracy: 0.67687 |  0:00:10s
epoch 11 | loss: 0.50903 | train_accuracy: 0.90191 | train_balanced_accuracy: 0.72043 | valid_accuracy: 0.90047 | valid_balanced_accuracy: 0.70244 |  0:00:11s
epoch 12 | loss: 0.48274 | train_accuracy: 0.91239 | train_balanced_accuracy: 0.77949 | valid_accuracy: 0.91967 | valid_balanced_accuracy: 0.78529 |  0:00:12s
epoch 13 | loss: 0.44446 | train_accuracy: 0.91252 | train_balanced_accuracy: 0.7516  | valid_accuracy: 0.91502 | valid_balanced_accuracy: 0.73566 |  0:00:13s
epoch 14 | loss: 0.42764 | train_accuracy: 0.92333 | train_balanced_accuracy: 0.8083  | valid_accuracy: 0.92549 | valid_balanced_accuracy: 0.7785  |  0:00:14s
epoch 15 | loss: 0.40217 | train_accuracy: 0.9386  | train_balanced_accuracy: 0.88841 | valid_accuracy: 0.94121 | valid_balanced_accuracy: 0.88421 |  0:00:15s
epoch 16 | loss: 0.38251 | train_accuracy: 0.9397  | train_balanced_accuracy: 0.89722 | valid_accuracy: 0.94005 | valid_balanced_accuracy: 0.897   |  0:00:16s
epoch 17 | loss: 0.37016 | train_accuracy: 0.9474  | train_balanced_accuracy: 0.91188 | valid_accuracy: 0.94237 | valid_balanced_accuracy: 0.89974 |  0:00:17s
epoch 18 | loss: 0.36254 | train_accuracy: 0.95244 | train_balanced_accuracy: 0.93253 | valid_accuracy: 0.94703 | valid_balanced_accuracy: 0.92078 |  0:00:18s
epoch 19 | loss: 0.35562 | train_accuracy: 0.94998 | train_balanced_accuracy: 0.92859 | valid_accuracy: 0.94179 | valid_balanced_accuracy: 0.90541 |  0:00:19s
epoch 20 | loss: 0.33606 | train_accuracy: 0.95917 | train_balanced_accuracy: 0.94093 | valid_accuracy: 0.94645 | valid_balanced_accuracy: 0.92007 |  0:00:20s
epoch 21 | loss: 0.32085 | train_accuracy: 0.9615  | train_balanced_accuracy: 0.95026 | valid_accuracy: 0.95343 | valid_balanced_accuracy: 0.94456 |  0:00:21s
epoch 22 | loss: 0.32289 | train_accuracy: 0.96357 | train_balanced_accuracy: 0.95312 | valid_accuracy: 0.9546  | valid_balanced_accuracy: 0.93717 |  0:00:22s
epoch 23 | loss: 0.31242 | train_accuracy: 0.96189 | train_balanced_accuracy: 0.95122 | valid_accuracy: 0.9546  | valid_balanced_accuracy: 0.93809 |  0:00:23s
epoch 24 | loss: 0.30792 | train_accuracy: 0.96364 | train_balanced_accuracy: 0.95162 | valid_accuracy: 0.95518 | valid_balanced_accuracy: 0.94025 |  0:00:24s
epoch 25 | loss: 0.3059  | train_accuracy: 0.96687 | train_balanced_accuracy: 0.9607  | valid_accuracy: 0.95867 | valid_balanced_accuracy: 0.943   |  0:00:25s
epoch 26 | loss: 0.29367 | train_accuracy: 0.96901 | train_balanced_accuracy: 0.96614 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.93963 |  0:00:26s
epoch 27 | loss: 0.29061 | train_accuracy: 0.96855 | train_balanced_accuracy: 0.96414 | valid_accuracy: 0.95984 | valid_balanced_accuracy: 0.94939 |  0:00:27s
epoch 28 | loss: 0.29253 | train_accuracy: 0.9692  | train_balanced_accuracy: 0.96904 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.95159 |  0:00:28s
epoch 29 | loss: 0.29325 | train_accuracy: 0.97257 | train_balanced_accuracy: 0.97368 | valid_accuracy: 0.95867 | valid_balanced_accuracy: 0.94992 |  0:00:29s
epoch 30 | loss: 0.28561 | train_accuracy: 0.97114 | train_balanced_accuracy: 0.96873 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.95207 |  0:00:30s
epoch 31 | loss: 0.28216 | train_accuracy: 0.97321 | train_balanced_accuracy: 0.97322 | valid_accuracy: 0.96217 | valid_balanced_accuracy: 0.96296 |  0:00:31s
epoch 32 | loss: 0.27522 | train_accuracy: 0.97185 | train_balanced_accuracy: 0.97129 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.95845 |  0:00:32s
epoch 33 | loss: 0.27716 | train_accuracy: 0.97444 | train_balanced_accuracy: 0.97277 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.95505 |  0:00:33s
epoch 34 | loss: 0.26944 | train_accuracy: 0.97587 | train_balanced_accuracy: 0.9787  | valid_accuracy: 0.96158 | valid_balanced_accuracy: 0.96142 |  0:00:34s
epoch 35 | loss: 0.26409 | train_accuracy: 0.97289 | train_balanced_accuracy: 0.97552 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.9599  |  0:00:35s
epoch 36 | loss: 0.26652 | train_accuracy: 0.97554 | train_balanced_accuracy: 0.97432 | valid_accuracy: 0.95751 | valid_balanced_accuracy: 0.95575 |  0:00:36s
epoch 37 | loss: 0.26731 | train_accuracy: 0.97574 | train_balanced_accuracy: 0.97458 | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.95274 |  0:00:37s
epoch 38 | loss: 0.25573 | train_accuracy: 0.97703 | train_balanced_accuracy: 0.97853 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.96639 |  0:00:38s
epoch 39 | loss: 0.26136 | train_accuracy: 0.97975 | train_balanced_accuracy: 0.97911 | valid_accuracy: 0.95984 | valid_balanced_accuracy: 0.95097 |  0:00:39s
epoch 40 | loss: 0.25477 | train_accuracy: 0.97981 | train_balanced_accuracy: 0.9822  | valid_accuracy: 0.961   | valid_balanced_accuracy: 0.96062 |  0:00:40s
epoch 41 | loss: 0.2501  | train_accuracy: 0.97832 | train_balanced_accuracy: 0.97982 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.95776 |  0:00:41s
epoch 42 | loss: 0.24979 | train_accuracy: 0.98188 | train_balanced_accuracy: 0.982   | valid_accuracy: 0.961   | valid_balanced_accuracy: 0.95845 |  0:00:42s
epoch 43 | loss: 0.2472  | train_accuracy: 0.98175 | train_balanced_accuracy: 0.98184 | valid_accuracy: 0.96333 | valid_balanced_accuracy: 0.96133 |  0:00:43s
epoch 44 | loss: 0.23966 | train_accuracy: 0.98156 | train_balanced_accuracy: 0.98064 | valid_accuracy: 0.96158 | valid_balanced_accuracy: 0.95964 |  0:00:43s
epoch 45 | loss: 0.25145 | train_accuracy: 0.98078 | train_balanced_accuracy: 0.98265 | valid_accuracy: 0.95634 | valid_balanced_accuracy: 0.96661 |  0:00:44s
epoch 46 | loss: 0.24429 | train_accuracy: 0.97988 | train_balanced_accuracy: 0.97911 | valid_accuracy: 0.95343 | valid_balanced_accuracy: 0.95564 |  0:00:45s
epoch 47 | loss: 0.25503 | train_accuracy: 0.98234 | train_balanced_accuracy: 0.98221 | valid_accuracy: 0.95634 | valid_balanced_accuracy: 0.96131 |  0:00:46s
epoch 48 | loss: 0.24714 | train_accuracy: 0.98382 | train_balanced_accuracy: 0.98409 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.95744 |  0:00:47s
epoch 49 | loss: 0.23765 | train_accuracy: 0.97988 | train_balanced_accuracy: 0.97965 | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.95392 |  0:00:48s
epoch 50 | loss: 0.23015 | train_accuracy: 0.98305 | train_balanced_accuracy: 0.98415 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.96781 |  0:00:49s
epoch 51 | loss: 0.23351 | train_accuracy: 0.97755 | train_balanced_accuracy: 0.97853 | valid_accuracy: 0.9546  | valid_balanced_accuracy: 0.95501 |  0:00:50s
epoch 52 | loss: 0.22766 | train_accuracy: 0.98531 | train_balanced_accuracy: 0.98407 | valid_accuracy: 0.95576 | valid_balanced_accuracy: 0.95683 |  0:00:51s
epoch 53 | loss: 0.23931 | train_accuracy: 0.98499 | train_balanced_accuracy: 0.98428 | valid_accuracy: 0.95402 | valid_balanced_accuracy: 0.94877 |  0:00:52s
epoch 54 | loss: 0.23413 | train_accuracy: 0.98479 | train_balanced_accuracy: 0.98549 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.96727 |  0:00:53s
epoch 55 | loss: 0.24018 | train_accuracy: 0.98505 | train_balanced_accuracy: 0.98458 | valid_accuracy: 0.95751 | valid_balanced_accuracy: 0.96351 |  0:00:54s
epoch 56 | loss: 0.22409 | train_accuracy: 0.97988 | train_balanced_accuracy: 0.9813  | valid_accuracy: 0.95343 | valid_balanced_accuracy: 0.96193 |  0:00:55s
epoch 57 | loss: 0.2251  | train_accuracy: 0.98434 | train_balanced_accuracy: 0.98721 | valid_accuracy: 0.95402 | valid_balanced_accuracy: 0.96176 |  0:00:56s
epoch 58 | loss: 0.22227 | train_accuracy: 0.9868  | train_balanced_accuracy: 0.98474 | valid_accuracy: 0.95984 | valid_balanced_accuracy: 0.96604 |  0:00:57s
epoch 59 | loss: 0.22204 | train_accuracy: 0.98098 | train_balanced_accuracy: 0.98357 | valid_accuracy: 0.95402 | valid_balanced_accuracy: 0.96383 |  0:00:58s
epoch 60 | loss: 0.22922 | train_accuracy: 0.98369 | train_balanced_accuracy: 0.98709 | valid_accuracy: 0.95052 | valid_balanced_accuracy: 0.95798 |  0:00:59s

Early stopping occurred at epoch 60 with best_epoch = 50 and best_valid_balanced_accuracy = 0.96781

Successfully saved training history and parameters
Successfully saved model at model_scadam/model.zip
[10]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
                                       path_model = 'model_scadam')
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities

Evaluation of prediction quality#

The model cannot predict mature alpha-beta T cell (precision = 0.0665).

Also model cannot predict 50% of epithelial cell (recall/sensitivity = 0.4994).

There is also low sensitivity for regulatory T cell and low precision for CD8-positive and CD4-positive alpha-beta T cells and endothelial cell.

[11]:
## Check model quality
df = scparadise.scnoah.report_classif_full(adata_test,
                                           celltype='cell_type',
                                           pred_celltype='pred_celltype_l1')
df
[11]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
B cell 0.8170 0.98 0.9723 0.8911 0.9761 0.9536 2050
CD4-positive, alpha-beta T cell 0.5112 0.8481 0.9248 0.6379 0.8856 0.7783 1554
CD8-positive, alpha-beta T cell 0.6327 0.8829 0.9461 0.7371 0.9139 0.83 1742
endothelial cell 0.7873 0.9803 0.9932 0.8733 0.9867 0.9724 457
endothelial cell of lymphatic vessel 0.8475 0.9615 0.9995 0.9009 0.9803 0.9574 52
epithelial cell 0.9907 0.4994 0.9968 0.6641 0.7055 0.473 7505
fibroblast 0.9379 0.8935 0.9977 0.9152 0.9442 0.8822 676
mast cell 0.9797 0.9797 0.9998 0.9797 0.9897 0.9776 148
mature alpha-beta T cell 0.0665 0.84 0.9514 0.1232 0.894 0.7903 75
mural cell 0.8173 0.9519 0.9948 0.8795 0.9731 0.9429 437
myeloid cell 0.8784 0.9449 0.9946 0.9104 0.9694 0.9351 726
natural killer cell 0.5929 0.9731 0.9918 0.7368 0.9824 0.9633 223
plasma cell 0.9810 0.9937 0.9977 0.9873 0.9957 0.9911 1919
plasmacytoid dendritic cell 0.9231 0.96 0.9998 0.9412 0.9797 0.956 50
regulatory T cell 0.4762 0.7143 0.9694 0.5714 0.8321 0.6748 686
macro avg 0.7493 0.8936 0.982 0.7833 0.9339 0.8719
weighted avg 0.8512 0.7479 0.9818 0.7567 0.8468 0.7198
Accuracy 0.7479
Balanced accuracy 0.8936
[12]:
# Order cell type colors
celltype = np.unique(adata_test.obs['cell_type']).tolist()
adata_test.obs['cell_type'] = pd.Categorical(
    values=adata_test.obs['cell_type'], categories=celltype, ordered=True
)
adata_test.obs['pred_celltype_l1'] = pd.Categorical(
    values=adata_test.obs['pred_celltype_l1'], categories=celltype, ordered=True
)
adata_test = scparadise.scnoah.pred_status(adata_test, celltype='cell_type', pred_celltype='pred_celltype_l1')
[13]:
# Visualise predicted cell types levels, prediction probabilities and prediction status
sc.pl.embedding(adata_test,
                color=[
                    'cell_type',
                    'pred_celltype_l1',
                    'prob_celltype_l1',
                    'pred_status'
                ],
                basis = 'X_umap',
                frameon = False,
                add_outline = True,
                legend_loc = 'right margin',
                legend_fontsize = 7,
                legend_fontoutline = 1,
                ncols=2,
                wspace = 0.7,
                hspace = 0.1)
../../../_images/tutorials_notebooks_scAdam_scAdam_model_optimization_18_0.png

Warm start model training#

Warning! For warm start scAdam model training, a dataset is required that contains all the same cell types that were present in the original training. The cell types must be named exactly as they were during the initial training of the model.

  1. There should be no new additional cell types or levels of annotation.

  2. Additionally, none of the cell types used for the initial training of the model should be missing.

[14]:
# Warm start requires second training dataset and path to pretrained model
scparadise.scadam.warm_start(adata_train2,
                             path='', # path to save model
                             model_name='model_scadam', # folder name with pretrained model
                             celltype_l1='cell_type',
                             eval_metric=['accuracy','balanced_accuracy']) # If you are using an imbalanced training dataset, we recommend using balanced_accuracy for early stopping
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Train dataset contains: 28958 cells, it is 90.0 % of input dataset
Test dataset contains: 3218 cells, it is 10.0 % of input dataset

Successfully loaded parameters

Accelerator: cuda
Start training
epoch 0  | loss: 0.39464 | train_accuracy: 0.95114 | train_balanced_accuracy: 0.9084  | valid_accuracy: 0.95059 | valid_balanced_accuracy: 0.91401 |  0:00:01s
epoch 1  | loss: 0.31958 | train_accuracy: 0.95701 | train_balanced_accuracy: 0.94684 | valid_accuracy: 0.95401 | valid_balanced_accuracy: 0.94443 |  0:00:03s
epoch 2  | loss: 0.30485 | train_accuracy: 0.95849 | train_balanced_accuracy: 0.95216 | valid_accuracy: 0.94624 | valid_balanced_accuracy: 0.9447  |  0:00:05s
epoch 3  | loss: 0.29525 | train_accuracy: 0.96074 | train_balanced_accuracy: 0.94526 | valid_accuracy: 0.94779 | valid_balanced_accuracy: 0.93132 |  0:00:07s
epoch 4  | loss: 0.27713 | train_accuracy: 0.96564 | train_balanced_accuracy: 0.95579 | valid_accuracy: 0.95494 | valid_balanced_accuracy: 0.9461  |  0:00:09s
epoch 5  | loss: 0.27371 | train_accuracy: 0.96706 | train_balanced_accuracy: 0.958   | valid_accuracy: 0.95339 | valid_balanced_accuracy: 0.94601 |  0:00:11s
epoch 6  | loss: 0.27291 | train_accuracy: 0.96785 | train_balanced_accuracy: 0.95893 | valid_accuracy: 0.95183 | valid_balanced_accuracy: 0.93815 |  0:00:13s
epoch 7  | loss: 0.26626 | train_accuracy: 0.96913 | train_balanced_accuracy: 0.96329 | valid_accuracy: 0.95339 | valid_balanced_accuracy: 0.94487 |  0:00:15s
epoch 8  | loss: 0.25926 | train_accuracy: 0.96681 | train_balanced_accuracy: 0.95935 | valid_accuracy: 0.94935 | valid_balanced_accuracy: 0.93925 |  0:00:16s
epoch 9  | loss: 0.25558 | train_accuracy: 0.97134 | train_balanced_accuracy: 0.9652  | valid_accuracy: 0.95339 | valid_balanced_accuracy: 0.94177 |  0:00:18s
epoch 10 | loss: 0.25424 | train_accuracy: 0.97237 | train_balanced_accuracy: 0.96595 | valid_accuracy: 0.9509  | valid_balanced_accuracy: 0.93876 |  0:00:20s
epoch 11 | loss: 0.25194 | train_accuracy: 0.96809 | train_balanced_accuracy: 0.97495 | valid_accuracy: 0.94096 | valid_balanced_accuracy: 0.93964 |  0:00:22s
epoch 12 | loss: 0.24886 | train_accuracy: 0.97241 | train_balanced_accuracy: 0.96691 | valid_accuracy: 0.9537  | valid_balanced_accuracy: 0.94268 |  0:00:24s
epoch 13 | loss: 0.24931 | train_accuracy: 0.97344 | train_balanced_accuracy: 0.96327 | valid_accuracy: 0.95059 | valid_balanced_accuracy: 0.92561 |  0:00:25s
epoch 14 | loss: 0.24412 | train_accuracy: 0.97248 | train_balanced_accuracy: 0.97067 | valid_accuracy: 0.95028 | valid_balanced_accuracy: 0.9415  |  0:00:27s

Early stopping occurred at epoch 14 with best_epoch = 4 and best_valid_balanced_accuracy = 0.9461

Successfully saved training history and parameters
Successfully saved model at model_scadam/model.zip
[15]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
                                       path_model = 'model_scadam')
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities

Evaluation of prediction results#

Warm start leads to an increase in the model’s accuracy, balanced accuracy and precision.

Below are examples of improved model performance for certain cell types.

  1. mature alpha-beta T cell (precision - 0.6036 vs 0.0665)

  2. regulatory T cell (precision - 0.7766 vs 0.4762)

  3. natural killer cell (precision - 0.8991 vs 0.5929)

  4. fibroblast (sensitivity - 0.9896 vs 0.8935)

  5. epithelial cell (sensitivity - 0.9824 vs 0.4994)

  6. CD4-positive, alpha-beta T cell (precision - 0.8644 vs 0.5112)

  7. CD8-positive, alpha-beta T cell (precision - 0.8985 vs 0.6327)

Full comparison is available below.

[16]:
## Check model quality
df_warm_start = scparadise.scnoah.report_classif_full(adata_test,
                                                      celltype='cell_type',
                                                      pred_celltype='pred_celltype_l1')
df_warm_start
[16]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
B cell 0.9738 0.978 0.9967 0.9759 0.9873 0.973 2050
CD4-positive, alpha-beta T cell 0.8644 0.8411 0.9878 0.8526 0.9115 0.8186 1554
CD8-positive, alpha-beta T cell 0.8985 0.9248 0.989 0.9115 0.9564 0.9088 1742
endothelial cell 0.9868 0.9803 0.9997 0.9835 0.9899 0.9781 457
endothelial cell of lymphatic vessel 0.9592 0.9038 0.9999 0.9307 0.9507 0.8951 52
epithelial cell 0.9879 0.9824 0.9917 0.9852 0.987 0.9733 7505
fibroblast 0.9585 0.9896 0.9984 0.9738 0.994 0.9872 676
mast cell 0.9931 0.973 0.9999 0.9829 0.9864 0.9703 148
mature alpha-beta T cell 0.6036 0.8933 0.9976 0.7204 0.944 0.8819 75
mural cell 0.9761 0.9336 0.9994 0.9544 0.966 0.927 437
myeloid cell 0.9832 0.9656 0.9993 0.9743 0.9823 0.9616 726
natural killer cell 0.8991 0.9193 0.9987 0.9091 0.9582 0.9108 223
plasma cell 0.9830 0.9917 0.998 0.9873 0.9948 0.989 1919
plasmacytoid dendritic cell 0.9074 0.98 0.9997 0.9423 0.9898 0.9778 50
regulatory T cell 0.7766 0.7551 0.9915 0.7657 0.8653 0.731 686
macro avg 0.9167 0.9341 0.9965 0.9233 0.9642 0.9256
weighted avg 0.9544 0.9538 0.9935 0.9539 0.973 0.9443
Accuracy 0.9538
Balanced accuracy 0.9341
[17]:
# Order cell type colors
celltype = np.unique(adata_test.obs['cell_type']).tolist()
adata_test.obs['cell_type'] = pd.Categorical(
    values=adata_test.obs['cell_type'], categories=celltype, ordered=True
)
adata_test.obs['pred_celltype_l1'] = pd.Categorical(
    values=adata_test.obs['pred_celltype_l1'], categories=celltype, ordered=True
)
# Add prediction status. Label cells as correct or incorrect based on the comparison between ground truth cell types and predictions.
adata_test = scparadise.scnoah.pred_status(adata_test, celltype='cell_type', pred_celltype='pred_celltype_l1')
[18]:
# Visualise predicted cell types levels, prediction probabilities and prediction status
sc.pl.embedding(adata_test,
                color=[
                    'cell_type',
                    'pred_celltype_l1',
                    'prob_celltype_l1',
                    'pred_status'
                ],
                basis = 'X_umap',
                frameon = False,
                add_outline = True,
                legend_loc = 'right margin',
                legend_fontsize = 7,
                legend_fontoutline = 1,
                ncols=2,
                wspace = 0.7,
                hspace = 0.1)
../../../_images/tutorials_notebooks_scAdam_scAdam_model_optimization_26_0.png
[19]:
# Compare prediction results
# 'untuned' row represents untuned model
# 'warm start' row represents warm start trained model
df.compare(df_warm_start, keep_equal=True, align_axis = 0, result_names=('untuned', 'warm start'))
[19]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy
B cell untuned 0.8170 0.98 0.9723 0.8911 0.9761 0.9536
warm start 0.9738 0.978 0.9967 0.9759 0.9873 0.973
CD4-positive, alpha-beta T cell untuned 0.5112 0.8481 0.9248 0.6379 0.8856 0.7783
warm start 0.8644 0.8411 0.9878 0.8526 0.9115 0.8186
CD8-positive, alpha-beta T cell untuned 0.6327 0.8829 0.9461 0.7371 0.9139 0.83
warm start 0.8985 0.9248 0.989 0.9115 0.9564 0.9088
endothelial cell untuned 0.7873 0.9803 0.9932 0.8733 0.9867 0.9724
warm start 0.9868 0.9803 0.9997 0.9835 0.9899 0.9781
endothelial cell of lymphatic vessel untuned 0.8475 0.9615 0.9995 0.9009 0.9803 0.9574
warm start 0.9592 0.9038 0.9999 0.9307 0.9507 0.8951
epithelial cell untuned 0.9907 0.4994 0.9968 0.6641 0.7055 0.473
warm start 0.9879 0.9824 0.9917 0.9852 0.987 0.9733
fibroblast untuned 0.9379 0.8935 0.9977 0.9152 0.9442 0.8822
warm start 0.9585 0.9896 0.9984 0.9738 0.994 0.9872
mast cell untuned 0.9797 0.9797 0.9998 0.9797 0.9897 0.9776
warm start 0.9931 0.973 0.9999 0.9829 0.9864 0.9703
mature alpha-beta T cell untuned 0.0665 0.84 0.9514 0.1232 0.894 0.7903
warm start 0.6036 0.8933 0.9976 0.7204 0.944 0.8819
mural cell untuned 0.8173 0.9519 0.9948 0.8795 0.9731 0.9429
warm start 0.9761 0.9336 0.9994 0.9544 0.966 0.927
myeloid cell untuned 0.8784 0.9449 0.9946 0.9104 0.9694 0.9351
warm start 0.9832 0.9656 0.9993 0.9743 0.9823 0.9616
natural killer cell untuned 0.5929 0.9731 0.9918 0.7368 0.9824 0.9633
warm start 0.8991 0.9193 0.9987 0.9091 0.9582 0.9108
plasma cell untuned 0.9810 0.9937 0.9977 0.9873 0.9957 0.9911
warm start 0.9830 0.9917 0.998 0.9873 0.9948 0.989
plasmacytoid dendritic cell untuned 0.9231 0.96 0.9998 0.9412 0.9797 0.956
warm start 0.9074 0.98 0.9997 0.9423 0.9898 0.9778
regulatory T cell untuned 0.4762 0.7143 0.9694 0.5714 0.8321 0.6748
warm start 0.7766 0.7551 0.9915 0.7657 0.8653 0.731
macro avg untuned 0.7493 0.8936 0.982 0.7833 0.9339 0.8719
warm start 0.9167 0.9341 0.9965 0.9233 0.9642 0.9256
weighted avg untuned 0.8512 0.7479 0.9818 0.7567 0.8468 0.7198
warm start 0.9544 0.9538 0.9935 0.9539 0.973 0.9443
Accuracy untuned 0.7479
warm start 0.9538
Balanced accuracy untuned 0.8936
warm start 0.9341

Hyperparameters tuning#

Warning! Hyperparameter tuning is a very time-consuming process and may be interrupted for various reasons (such as system shutdowns or errors in CUDA). However, this is not a problem, as restarting the hyperparameter tuning will allow it to continue from the last unfinished trial. An example of hyperparameter tuning represents such a case. The fourth line states: “Using an existing study with name ‘model_scadam_hp_tune’ instead of creating a new one.” This indicates a restart of hyperparameter tuning (in this case, from the 8th trial).

[20]:
scparadise.scadam.hyperparameter_tuning(adata_train1,
                                        path='',
                                        model_name='model_scadam_hp_tune', # Folder to save hyperparameter tuning results
                                        celltype_l1='cell_type',
                                        num_trials=50, # The number of attempts to find the optimal hyperparameters for the model (recommended - minimum 100)
                                        eval_metric=['accuracy','balanced_accuracy'])
Successfully saved genes names for training model

Successfully saved dictionary of dataset annotations

Accelerator: cuda

[I 2024-12-17 16:56:48,662] Using an existing study with name 'model_scadam_hp_tune' instead of creating a new one.
Fold 1:

Early stopping occurred at epoch 19 with best_epoch = 4 and best_valid_balanced_accuracy = 0.12159
[I 2024-12-17 16:57:06,116] Trial 8 pruned.
Fold 1:
Stop training because you reached max_epochs = 50 with best_epoch = 42 and best_valid_balanced_accuracy = 0.30773
[I 2024-12-17 16:57:48,788] Trial 9 pruned.
Fold 1:
Stop training because you reached max_epochs = 45 with best_epoch = 40 and best_valid_balanced_accuracy = 0.95197

Fold 2:
Stop training because you reached max_epochs = 45 with best_epoch = 32 and best_valid_balanced_accuracy = 0.95653

Fold 3:
Stop training because you reached max_epochs = 45 with best_epoch = 33 and best_valid_balanced_accuracy = 0.95143

Fold 4:
Stop training because you reached max_epochs = 45 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95292
[I 2024-12-17 17:00:38,489] Trial 10 finished with value: 0.9532140393068294 and parameters: {'n_d': 124, 'n_a': 92, 'n_steps': 4, 'n_shared': 5, 'cat_emb_dim': 7, 'n_independent': 1, 'gamma': 1.6667667154456676, 'momentum': 0.27154876915108217, 'lr': 0.10527024228081307, 'mask_type': 'sparsemax', 'lambda_sparse': 0.0009586155312153562, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 10 with value: 0.9532140393068294.

Fold 1:

Early stopping occurred at epoch 9 with best_epoch = 4 and best_valid_balanced_accuracy = 0.09804
[I 2024-12-17 17:01:06,583] Trial 11 pruned.
Fold 1:

Early stopping occurred at epoch 12 with best_epoch = 2 and best_valid_balanced_accuracy = 0.10296
[I 2024-12-17 17:01:21,539] Trial 12 pruned.
Fold 1:

Early stopping occurred at epoch 35 with best_epoch = 25 and best_valid_balanced_accuracy = 0.94294

Fold 2:

Early stopping occurred at epoch 43 with best_epoch = 33 and best_valid_balanced_accuracy = 0.95294

Fold 3:

Early stopping occurred at epoch 40 with best_epoch = 30 and best_valid_balanced_accuracy = 0.94829

Fold 4:

Early stopping occurred at epoch 38 with best_epoch = 28 and best_valid_balanced_accuracy = 0.95632
[I 2024-12-17 17:03:00,448] Trial 13 finished with value: 0.9501192387079038 and parameters: {'n_d': 48, 'n_a': 44, 'n_steps': 1, 'n_shared': 3, 'cat_emb_dim': 7, 'n_independent': 3, 'gamma': 1.3155465245255713, 'momentum': 0.015628332308846832, 'lr': 0.022300516115261078, 'mask_type': 'entmax', 'lambda_sparse': 0.001323640069167409, 'patience': 10, 'max_epochs': 75, 'virtual_batch_size': 128, 'batch_size': 768}. Best is trial 10 with value: 0.9532140393068294.

Fold 1:

Early stopping occurred at epoch 22 with best_epoch = 12 and best_valid_balanced_accuracy = 0.9479

Fold 2:

Early stopping occurred at epoch 23 with best_epoch = 13 and best_valid_balanced_accuracy = 0.9517

Fold 3:

Early stopping occurred at epoch 30 with best_epoch = 20 and best_valid_balanced_accuracy = 0.95081
[I 2024-12-17 17:04:23,766] Trial 14 pruned.
Fold 1:

Early stopping occurred at epoch 42 with best_epoch = 32 and best_valid_balanced_accuracy = 0.94818

Fold 2:

Early stopping occurred at epoch 40 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95425

Fold 3:

Early stopping occurred at epoch 28 with best_epoch = 18 and best_valid_balanced_accuracy = 0.9401
[I 2024-12-17 17:05:54,421] Trial 15 pruned.
Fold 1:

Early stopping occurred at epoch 20 with best_epoch = 15 and best_valid_balanced_accuracy = 0.75417
[I 2024-12-17 17:06:21,883] Trial 16 pruned.
Fold 1:

Early stopping occurred at epoch 58 with best_epoch = 38 and best_valid_balanced_accuracy = 0.9491

Fold 2:
Stop training because you reached max_epochs = 65 with best_epoch = 51 and best_valid_balanced_accuracy = 0.95384

Fold 3:
Stop training because you reached max_epochs = 65 with best_epoch = 49 and best_valid_balanced_accuracy = 0.95196

Fold 4:

Early stopping occurred at epoch 50 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95086
[I 2024-12-17 17:09:16,394] Trial 17 finished with value: 0.9514360808085562 and parameters: {'n_d': 84, 'n_a': 112, 'n_steps': 2, 'n_shared': 7, 'cat_emb_dim': 6, 'n_independent': 1, 'gamma': 1.5030605109294553, 'momentum': 0.20760649764410705, 'lr': 0.05660643128255937, 'mask_type': 'entmax', 'lambda_sparse': 0.008471890017944976, 'patience': 20, 'max_epochs': 65, 'virtual_batch_size': 128, 'batch_size': 1024}. Best is trial 10 with value: 0.9532140393068294.

Fold 1:
Stop training because you reached max_epochs = 30 with best_epoch = 26 and best_valid_balanced_accuracy = 0.88263

Fold 2:
Stop training because you reached max_epochs = 30 with best_epoch = 28 and best_valid_balanced_accuracy = 0.92126

Fold 3:
Stop training because you reached max_epochs = 30 with best_epoch = 25 and best_valid_balanced_accuracy = 0.88974
[I 2024-12-17 17:10:09,304] Trial 18 pruned.
Fold 1:
Stop training because you reached max_epochs = 60 with best_epoch = 51 and best_valid_balanced_accuracy = 0.94394

Fold 2:
Stop training because you reached max_epochs = 60 with best_epoch = 42 and best_valid_balanced_accuracy = 0.94609

Fold 3:
Stop training because you reached max_epochs = 60 with best_epoch = 59 and best_valid_balanced_accuracy = 0.94985
[I 2024-12-17 17:12:35,400] Trial 19 pruned.
Fold 1:
Stop training because you reached max_epochs = 15 with best_epoch = 12 and best_valid_balanced_accuracy = 0.25899
[I 2024-12-17 17:12:56,643] Trial 20 pruned.
Fold 1:
Stop training because you reached max_epochs = 80 with best_epoch = 63 and best_valid_balanced_accuracy = 0.95048

Fold 2:
Stop training because you reached max_epochs = 80 with best_epoch = 73 and best_valid_balanced_accuracy = 0.95068

Fold 3:

Early stopping occurred at epoch 79 with best_epoch = 59 and best_valid_balanced_accuracy = 0.95031
[I 2024-12-17 17:18:31,330] Trial 21 pruned.
Fold 1:

Early stopping occurred at epoch 30 with best_epoch = 20 and best_valid_balanced_accuracy = 0.94582

Fold 2:

Early stopping occurred at epoch 36 with best_epoch = 26 and best_valid_balanced_accuracy = 0.95586

Fold 3:

Early stopping occurred at epoch 35 with best_epoch = 25 and best_valid_balanced_accuracy = 0.94643
[I 2024-12-17 17:19:54,289] Trial 22 pruned.
Fold 1:
Stop training because you reached max_epochs = 40 with best_epoch = 35 and best_valid_balanced_accuracy = 0.9449

Fold 2:
Stop training because you reached max_epochs = 40 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95382

Fold 3:
Stop training because you reached max_epochs = 40 with best_epoch = 26 and best_valid_balanced_accuracy = 0.95427

Fold 4:
Stop training because you reached max_epochs = 40 with best_epoch = 38 and best_valid_balanced_accuracy = 0.95487
[I 2024-12-17 17:21:47,902] Trial 23 finished with value: 0.9519649530929056 and parameters: {'n_d': 28, 'n_a': 104, 'n_steps': 2, 'n_shared': 4, 'cat_emb_dim': 6, 'n_independent': 1, 'gamma': 1.178399586222898, 'momentum': 0.11624408126387431, 'lr': 0.05461797882680373, 'mask_type': 'entmax', 'lambda_sparse': 0.004533711957980013, 'patience': 15, 'max_epochs': 40, 'virtual_batch_size': 128, 'batch_size': 768}. Best is trial 10 with value: 0.9532140393068294.

Fold 1:
Stop training because you reached max_epochs = 40 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95087

Fold 2:
Stop training because you reached max_epochs = 40 with best_epoch = 32 and best_valid_balanced_accuracy = 0.95126

Fold 3:
Stop training because you reached max_epochs = 40 with best_epoch = 27 and best_valid_balanced_accuracy = 0.95497

Fold 4:
Stop training because you reached max_epochs = 40 with best_epoch = 33 and best_valid_balanced_accuracy = 0.95485
[I 2024-12-17 17:23:37,508] Trial 24 finished with value: 0.9529842642572143 and parameters: {'n_d': 20, 'n_a': 104, 'n_steps': 2, 'n_shared': 4, 'cat_emb_dim': 6, 'n_independent': 1, 'gamma': 1.004354856556712, 'momentum': 0.09745170594845298, 'lr': 0.1361115286436767, 'mask_type': 'entmax', 'lambda_sparse': 0.004700932119213813, 'patience': 15, 'max_epochs': 40, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 10 with value: 0.9532140393068294.

Fold 1:
Stop training because you reached max_epochs = 40 with best_epoch = 29 and best_valid_balanced_accuracy = 0.95132

Fold 2:
Stop training because you reached max_epochs = 40 with best_epoch = 36 and best_valid_balanced_accuracy = 0.95143

Fold 3:
Stop training because you reached max_epochs = 40 with best_epoch = 32 and best_valid_balanced_accuracy = 0.94893
[I 2024-12-17 17:25:50,067] Trial 25 pruned.
Fold 1:
Stop training because you reached max_epochs = 25 with best_epoch = 24 and best_valid_balanced_accuracy = 0.93793
[I 2024-12-17 17:26:07,067] Trial 26 pruned.
Fold 1:
Stop training because you reached max_epochs = 40 with best_epoch = 34 and best_valid_balanced_accuracy = 0.91319
[I 2024-12-17 17:26:39,334] Trial 27 pruned.
Fold 1:
Stop training because you reached max_epochs = 55 with best_epoch = 53 and best_valid_balanced_accuracy = 0.89312
[I 2024-12-17 17:27:09,718] Trial 28 pruned.
Fold 1:
Stop training because you reached max_epochs = 20 with best_epoch = 17 and best_valid_balanced_accuracy = 0.69984
[I 2024-12-17 17:27:22,652] Trial 29 pruned.
Fold 1:
Stop training because you reached max_epochs = 35 with best_epoch = 27 and best_valid_balanced_accuracy = 0.95036

Fold 2:
Stop training because you reached max_epochs = 35 with best_epoch = 22 and best_valid_balanced_accuracy = 0.9522

Fold 3:
Stop training because you reached max_epochs = 35 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95624

Fold 4:

Early stopping occurred at epoch 34 with best_epoch = 19 and best_valid_balanced_accuracy = 0.95711
[I 2024-12-17 17:29:03,542] Trial 30 finished with value: 0.953979711233727 and parameters: {'n_d': 8, 'n_a': 104, 'n_steps': 3, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 1, 'gamma': 1.1086534148049974, 'momentum': 0.03604587542483932, 'lr': 0.04304648047194131, 'mask_type': 'entmax', 'lambda_sparse': 0.0056379815045209765, 'patience': 15, 'max_epochs': 35, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 30 with value: 0.953979711233727.

Fold 1:

Early stopping occurred at epoch 29 with best_epoch = 19 and best_valid_balanced_accuracy = 0.94137
[I 2024-12-17 17:29:44,127] Trial 31 pruned.
Fold 1:

Early stopping occurred at epoch 44 with best_epoch = 29 and best_valid_balanced_accuracy = 0.95196

Fold 2:

Early stopping occurred at epoch 38 with best_epoch = 23 and best_valid_balanced_accuracy = 0.95521

Fold 3:

Early stopping occurred at epoch 35 with best_epoch = 20 and best_valid_balanced_accuracy = 0.95698

Fold 4:

Early stopping occurred at epoch 33 with best_epoch = 18 and best_valid_balanced_accuracy = 0.95593
[I 2024-12-17 17:31:43,742] Trial 32 finished with value: 0.9550171718117837 and parameters: {'n_d': 16, 'n_a': 100, 'n_steps': 3, 'n_shared': 2, 'cat_emb_dim': 8, 'n_independent': 1, 'gamma': 1.0958861456214881, 'momentum': 0.03574116393845864, 'lr': 0.043387929856948185, 'mask_type': 'entmax', 'lambda_sparse': 0.006794720166487761, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 32 with value: 0.9550171718117837.

Fold 1:

Early stopping occurred at epoch 32 with best_epoch = 17 and best_valid_balanced_accuracy = 0.94463
[I 2024-12-17 17:32:09,558] Trial 33 pruned.
Fold 1:
Stop training because you reached max_epochs = 45 with best_epoch = 32 and best_valid_balanced_accuracy = 0.95222

Fold 2:
Stop training because you reached max_epochs = 45 with best_epoch = 39 and best_valid_balanced_accuracy = 0.95251

Fold 3:
Stop training because you reached max_epochs = 45 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95336

Fold 4:

Early stopping occurred at epoch 38 with best_epoch = 23 and best_valid_balanced_accuracy = 0.95417
[I 2024-12-17 17:34:43,722] Trial 34 finished with value: 0.9530666488089797 and parameters: {'n_d': 16, 'n_a': 84, 'n_steps': 4, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2276108102899952, 'momentum': 0.012593582181250802, 'lr': 0.08059647626559789, 'mask_type': 'entmax', 'lambda_sparse': 0.002167146005271099, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 32 with value: 0.9550171718117837.

Fold 1:

Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_balanced_accuracy = 0.95276

Fold 2:

Early stopping occurred at epoch 34 with best_epoch = 19 and best_valid_balanced_accuracy = 0.95563

Fold 3:
Stop training because you reached max_epochs = 45 with best_epoch = 37 and best_valid_balanced_accuracy = 0.95338

Fold 4:

Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_balanced_accuracy = 0.95118
[I 2024-12-17 17:37:52,876] Trial 35 finished with value: 0.9532372247143821 and parameters: {'n_d': 16, 'n_a': 64, 'n_steps': 6, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2400761918697902, 'momentum': 0.03197771261456092, 'lr': 0.07586645313201903, 'mask_type': 'entmax', 'lambda_sparse': 0.0019824653467942583, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 512}. Best is trial 32 with value: 0.9550171718117837.

Fold 1:
Stop training because you reached max_epochs = 50 with best_epoch = 48 and best_valid_balanced_accuracy = 0.84368
[I 2024-12-17 17:38:20,704] Trial 36 pruned.
Fold 1:
Stop training because you reached max_epochs = 45 with best_epoch = 35 and best_valid_balanced_accuracy = 0.92418
[I 2024-12-17 17:39:19,414] Trial 37 pruned.
Fold 1:
Stop training because you reached max_epochs = 5 with best_epoch = 2 and best_valid_balanced_accuracy = 0.13015
[I 2024-12-17 17:39:24,766] Trial 38 pruned.
Fold 1:
Stop training because you reached max_epochs = 30 with best_epoch = 27 and best_valid_balanced_accuracy = 0.25774

Fold 2:
Stop training because you reached max_epochs = 30 with best_epoch = 25 and best_valid_balanced_accuracy = 0.2464

Fold 3:
Stop training because you reached max_epochs = 30 with best_epoch = 29 and best_valid_balanced_accuracy = 0.24999
[I 2024-12-17 17:41:38,906] Trial 39 pruned.
Fold 1:
Stop training because you reached max_epochs = 30 with best_epoch = 28 and best_valid_balanced_accuracy = 0.90695
[I 2024-12-17 17:42:12,114] Trial 40 pruned.
Fold 1:
Stop training because you reached max_epochs = 55 with best_epoch = 53 and best_valid_balanced_accuracy = 0.87599
[I 2024-12-17 17:42:35,563] Trial 41 pruned.
Fold 1:

Early stopping occurred at epoch 37 with best_epoch = 22 and best_valid_balanced_accuracy = 0.95393

Fold 2:
Stop training because you reached max_epochs = 45 with best_epoch = 37 and best_valid_balanced_accuracy = 0.95927

Fold 3:

Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_balanced_accuracy = 0.95694

Fold 4:

Early stopping occurred at epoch 34 with best_epoch = 19 and best_valid_balanced_accuracy = 0.95263
[I 2024-12-17 17:44:57,599] Trial 42 finished with value: 0.9556933956291884 and parameters: {'n_d': 12, 'n_a': 88, 'n_steps': 4, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2378640594404067, 'momentum': 0.011566565604250411, 'lr': 0.07971165451794174, 'mask_type': 'entmax', 'lambda_sparse': 0.0019751637085813345, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 42 with value: 0.9556933956291884.

Fold 1:

Early stopping occurred at epoch 38 with best_epoch = 23 and best_valid_balanced_accuracy = 0.94737

Fold 2:

Early stopping occurred at epoch 34 with best_epoch = 19 and best_valid_balanced_accuracy = 0.9551

Fold 3:

Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_balanced_accuracy = 0.94724
[I 2024-12-17 17:46:47,117] Trial 43 pruned.
Fold 1:
Stop training because you reached max_epochs = 35 with best_epoch = 34 and best_valid_balanced_accuracy = 0.94709

Fold 2:

Early stopping occurred at epoch 25 with best_epoch = 10 and best_valid_balanced_accuracy = 0.95678

Fold 3:
Stop training because you reached max_epochs = 35 with best_epoch = 32 and best_valid_balanced_accuracy = 0.95282
[I 2024-12-17 17:48:47,095] Trial 44 pruned.
Fold 1:

Early stopping occurred at epoch 34 with best_epoch = 19 and best_valid_balanced_accuracy = 0.94692

Fold 2:

Early stopping occurred at epoch 42 with best_epoch = 27 and best_valid_balanced_accuracy = 0.95551

Fold 3:

Early stopping occurred at epoch 32 with best_epoch = 17 and best_valid_balanced_accuracy = 0.95563

Fold 4:

Early stopping occurred at epoch 40 with best_epoch = 25 and best_valid_balanced_accuracy = 0.95928
[I 2024-12-17 17:50:56,243] Trial 45 finished with value: 0.9543346665963542 and parameters: {'n_d': 12, 'n_a': 96, 'n_steps': 3, 'n_shared': 1, 'cat_emb_dim': 7, 'n_independent': 3, 'gamma': 1.3202643716808666, 'momentum': 0.013752839994914707, 'lr': 0.11122640167419975, 'mask_type': 'entmax', 'lambda_sparse': 0.02401249896211509, 'patience': 15, 'max_epochs': 50, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 42 with value: 0.9556933956291884.

Fold 1:

Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_balanced_accuracy = 0.94524
[I 2024-12-17 17:51:33,343] Trial 46 pruned.
Fold 1:
Stop training because you reached max_epochs = 60 with best_epoch = 50 and best_valid_balanced_accuracy = 0.95008

Fold 2:
Stop training because you reached max_epochs = 60 with best_epoch = 57 and best_valid_balanced_accuracy = 0.94405

Fold 3:

Early stopping occurred at epoch 54 with best_epoch = 39 and best_valid_balanced_accuracy = 0.95121
[I 2024-12-17 17:55:42,846] Trial 47 pruned.
Fold 1:
Stop training because you reached max_epochs = 35 with best_epoch = 34 and best_valid_balanced_accuracy = 0.93133
[I 2024-12-17 17:57:03,560] Trial 48 pruned.
Fold 1:
Stop training because you reached max_epochs = 60 with best_epoch = 59 and best_valid_balanced_accuracy = 0.94001
[I 2024-12-17 17:57:33,156] Trial 49 pruned.
Successfully saved best hyperparameters

Best hyperparameters: {'n_d': 12, 'n_a': 88, 'n_steps': 4, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2378640594404067, 'momentum': 0.011566565604250411, 'lr': 0.07971165451794174, 'mask_type': 'entmax', 'lambda_sparse': 0.0019751637085813345, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 640}
[20]:
{'n_d': 12,
 'n_a': 88,
 'n_steps': 4,
 'n_shared': 1,
 'cat_emb_dim': 8,
 'n_independent': 2,
 'gamma': 1.2378640594404067,
 'momentum': 0.011566565604250411,
 'lr': 0.07971165451794174,
 'mask_type': 'entmax',
 'lambda_sparse': 0.0019751637085813345,
 'patience': 15,
 'max_epochs': 45,
 'virtual_batch_size': 128,
 'batch_size': 640}
[23]:
# Train model using optimal parameters from model_scadam_hp_tune folder
scparadise.scadam.train_tuned(adata_train1,
                              path='', # path to save model
                              path_tuned='model_scadam_hp_tune', # path to a folder with tuned hyperparameters
                              model_name='model_scadam_tuned', # folder name with model
                              celltype_l1='cell_type',
                              eval_metric=['accuracy','balanced_accuracy']) # If you are using an imbalanced training dataset, we recommend using balanced_accuracy for early stopping
Successfully saved genes names for training model

Successfully saved dictionary of dataset annotations

Train dataset contains: 15455 cells, it is 90.0 % of input dataset
Test dataset contains: 1718 cells, it is 10.0 % of input dataset

Accelerator: cuda
Start training with following hyperparameters: {'n_d': 12, 'n_a': 88, 'n_steps': 4, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2378640594404067, 'momentum': 0.011566565604250411, 'optimizer_params': {'lr': 0.07971165451794174}, 'mask_type': 'entmax', 'lambda_sparse': 0.0019751637085813345, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 640, 'device_name': 'cuda'}

epoch 0  | loss: 1.97705 | train_accuracy: 0.54274 | train_balanced_accuracy: 0.32096 | valid_accuracy: 0.54657 | valid_balanced_accuracy: 0.32691 |  0:00:01s
epoch 1  | loss: 0.93955 | train_accuracy: 0.72902 | train_balanced_accuracy: 0.53245 | valid_accuracy: 0.73108 | valid_balanced_accuracy: 0.54764 |  0:00:02s
epoch 2  | loss: 0.6556  | train_accuracy: 0.86069 | train_balanced_accuracy: 0.72309 | valid_accuracy: 0.85739 | valid_balanced_accuracy: 0.70182 |  0:00:03s
epoch 3  | loss: 0.51348 | train_accuracy: 0.88282 | train_balanced_accuracy: 0.68273 | valid_accuracy: 0.88708 | valid_balanced_accuracy: 0.67382 |  0:00:04s
epoch 4  | loss: 0.44289 | train_accuracy: 0.91388 | train_balanced_accuracy: 0.76913 | valid_accuracy: 0.91502 | valid_balanced_accuracy: 0.76522 |  0:00:06s
epoch 5  | loss: 0.39766 | train_accuracy: 0.87784 | train_balanced_accuracy: 0.71463 | valid_accuracy: 0.87602 | valid_balanced_accuracy: 0.72104 |  0:00:07s
epoch 6  | loss: 0.38373 | train_accuracy: 0.9287  | train_balanced_accuracy: 0.79359 | valid_accuracy: 0.92841 | valid_balanced_accuracy: 0.78694 |  0:00:08s
epoch 7  | loss: 0.34919 | train_accuracy: 0.94196 | train_balanced_accuracy: 0.8213  | valid_accuracy: 0.94529 | valid_balanced_accuracy: 0.8094  |  0:00:09s
epoch 8  | loss: 0.34294 | train_accuracy: 0.94746 | train_balanced_accuracy: 0.89272 | valid_accuracy: 0.94005 | valid_balanced_accuracy: 0.82458 |  0:00:11s
epoch 9  | loss: 0.31135 | train_accuracy: 0.95807 | train_balanced_accuracy: 0.9349  | valid_accuracy: 0.9482  | valid_balanced_accuracy: 0.89005 |  0:00:12s
epoch 10 | loss: 0.302   | train_accuracy: 0.95516 | train_balanced_accuracy: 0.91947 | valid_accuracy: 0.94412 | valid_balanced_accuracy: 0.87415 |  0:00:13s
epoch 11 | loss: 0.29094 | train_accuracy: 0.95613 | train_balanced_accuracy: 0.93584 | valid_accuracy: 0.9546  | valid_balanced_accuracy: 0.91292 |  0:00:14s
epoch 12 | loss: 0.28451 | train_accuracy: 0.96694 | train_balanced_accuracy: 0.9558  | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.93755 |  0:00:16s
epoch 13 | loss: 0.2717  | train_accuracy: 0.96881 | train_balanced_accuracy: 0.95787 | valid_accuracy: 0.9546  | valid_balanced_accuracy: 0.91791 |  0:00:17s
epoch 14 | loss: 0.271   | train_accuracy: 0.96894 | train_balanced_accuracy: 0.96448 | valid_accuracy: 0.95343 | valid_balanced_accuracy: 0.95216 |  0:00:18s
epoch 15 | loss: 0.27228 | train_accuracy: 0.96972 | train_balanced_accuracy: 0.97213 | valid_accuracy: 0.961   | valid_balanced_accuracy: 0.96207 |  0:00:19s
epoch 16 | loss: 0.24422 | train_accuracy: 0.96868 | train_balanced_accuracy: 0.97584 | valid_accuracy: 0.95343 | valid_balanced_accuracy: 0.95022 |  0:00:21s
epoch 17 | loss: 0.26597 | train_accuracy: 0.9725  | train_balanced_accuracy: 0.96771 | valid_accuracy: 0.95984 | valid_balanced_accuracy: 0.95716 |  0:00:22s
epoch 18 | loss: 0.25958 | train_accuracy: 0.97457 | train_balanced_accuracy: 0.97918 | valid_accuracy: 0.9546  | valid_balanced_accuracy: 0.95112 |  0:00:23s
epoch 19 | loss: 0.26462 | train_accuracy: 0.97095 | train_balanced_accuracy: 0.97068 | valid_accuracy: 0.94994 | valid_balanced_accuracy: 0.93543 |  0:00:24s
epoch 20 | loss: 0.25516 | train_accuracy: 0.98149 | train_balanced_accuracy: 0.98106 | valid_accuracy: 0.95984 | valid_balanced_accuracy: 0.95633 |  0:00:26s
epoch 21 | loss: 0.24411 | train_accuracy: 0.9758  | train_balanced_accuracy: 0.97992 | valid_accuracy: 0.95751 | valid_balanced_accuracy: 0.96093 |  0:00:27s
epoch 22 | loss: 0.24723 | train_accuracy: 0.98104 | train_balanced_accuracy: 0.97788 | valid_accuracy: 0.96275 | valid_balanced_accuracy: 0.94997 |  0:00:28s
epoch 23 | loss: 0.2381  | train_accuracy: 0.97373 | train_balanced_accuracy: 0.97587 | valid_accuracy: 0.94936 | valid_balanced_accuracy: 0.94231 |  0:00:29s
epoch 24 | loss: 0.23338 | train_accuracy: 0.98337 | train_balanced_accuracy: 0.98278 | valid_accuracy: 0.96508 | valid_balanced_accuracy: 0.95644 |  0:00:30s
epoch 25 | loss: 0.24744 | train_accuracy: 0.98382 | train_balanced_accuracy: 0.98654 | valid_accuracy: 0.96042 | valid_balanced_accuracy: 0.96381 |  0:00:31s
epoch 26 | loss: 0.22563 | train_accuracy: 0.98674 | train_balanced_accuracy: 0.98769 | valid_accuracy: 0.96275 | valid_balanced_accuracy: 0.96354 |  0:00:33s
epoch 27 | loss: 0.22583 | train_accuracy: 0.98266 | train_balanced_accuracy: 0.9851  | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.9587  |  0:00:34s
epoch 28 | loss: 0.22655 | train_accuracy: 0.98589 | train_balanced_accuracy: 0.9853  | valid_accuracy: 0.96158 | valid_balanced_accuracy: 0.95896 |  0:00:35s
epoch 29 | loss: 0.21171 | train_accuracy: 0.98784 | train_balanced_accuracy: 0.99048 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.95705 |  0:00:36s
epoch 30 | loss: 0.20867 | train_accuracy: 0.98602 | train_balanced_accuracy: 0.9838  | valid_accuracy: 0.96449 | valid_balanced_accuracy: 0.95301 |  0:00:38s
epoch 31 | loss: 0.22263 | train_accuracy: 0.98958 | train_balanced_accuracy: 0.99077 | valid_accuracy: 0.96333 | valid_balanced_accuracy: 0.96619 |  0:00:39s
epoch 32 | loss: 0.21809 | train_accuracy: 0.99081 | train_balanced_accuracy: 0.99026 | valid_accuracy: 0.96042 | valid_balanced_accuracy: 0.9611  |  0:00:40s
epoch 33 | loss: 0.20374 | train_accuracy: 0.99114 | train_balanced_accuracy: 0.99207 | valid_accuracy: 0.961   | valid_balanced_accuracy: 0.95457 |  0:00:41s
epoch 34 | loss: 0.21158 | train_accuracy: 0.98848 | train_balanced_accuracy: 0.99016 | valid_accuracy: 0.95576 | valid_balanced_accuracy: 0.96417 |  0:00:42s
epoch 35 | loss: 0.21034 | train_accuracy: 0.99159 | train_balanced_accuracy: 0.99133 | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.95491 |  0:00:44s
epoch 36 | loss: 0.20458 | train_accuracy: 0.98589 | train_balanced_accuracy: 0.99006 | valid_accuracy: 0.95518 | valid_balanced_accuracy: 0.96309 |  0:00:45s
epoch 37 | loss: 0.2034  | train_accuracy: 0.99178 | train_balanced_accuracy: 0.99211 | valid_accuracy: 0.95751 | valid_balanced_accuracy: 0.95942 |  0:00:46s
epoch 38 | loss: 0.19649 | train_accuracy: 0.99185 | train_balanced_accuracy: 0.99273 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.94929 |  0:00:48s
epoch 39 | loss: 0.19649 | train_accuracy: 0.99159 | train_balanced_accuracy: 0.99244 | valid_accuracy: 0.95402 | valid_balanced_accuracy: 0.94112 |  0:00:49s
epoch 40 | loss: 0.1909  | train_accuracy: 0.99152 | train_balanced_accuracy: 0.99123 | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.95582 |  0:00:50s
epoch 41 | loss: 0.19694 | train_accuracy: 0.99398 | train_balanced_accuracy: 0.98818 | valid_accuracy: 0.95867 | valid_balanced_accuracy: 0.96127 |  0:00:51s
epoch 42 | loss: 0.19382 | train_accuracy: 0.96635 | train_balanced_accuracy: 0.98592 | valid_accuracy: 0.92841 | valid_balanced_accuracy: 0.94994 |  0:00:52s
epoch 43 | loss: 0.19076 | train_accuracy: 0.99534 | train_balanced_accuracy: 0.99435 | valid_accuracy: 0.95634 | valid_balanced_accuracy: 0.94817 |  0:00:54s
epoch 44 | loss: 0.18032 | train_accuracy: 0.99159 | train_balanced_accuracy: 0.99387 | valid_accuracy: 0.9546  | valid_balanced_accuracy: 0.96242 |  0:00:55s
Stop training because you reached max_epochs = 45 with best_epoch = 31 and best_valid_balanced_accuracy = 0.96619

Successfully saved training history and parameters
Successfully saved model at model_scadam_tuned/model.zip
[24]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
                                       path_model = 'model_scadam_tuned')
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities

Evaluation of prediction results#

Hyperparameters tuning leads to an increase in the model’s overall accuracy and precision.

Below are examples of improved model performance for certain cell types.

  1. mature alpha-beta T cell (precision - 0.2783 vs 0.0665)

  2. regulatory T cell (precision - 0.6378 vs 0.4762)

  3. natural killer cell (precision - 0.7616 vs 0.5929)

  4. endothelial cell (precision - 0.9934 vs 0.7873)

  5. fibroblast (sensitivity - 0.9911 vs 0.8935)

  6. epithelial cell (sensitivity - 0.5518 vs 0.4994)

  7. mural cell (precision - 0.9779 vs 0.8173)

However, a decrease in prediction accuracy may be observed for some cell types.

  1. B cell (precision - 0.7263 vs 0.8170)

  2. endothelial cell of lymphatic vessel (precision - 0.5532 vs 0.8475)

[25]:
## Check model quality
df_hp_tuned = scparadise.scnoah.report_classif_full(adata_test,
                                                    celltype='cell_type',
                                                    pred_celltype='pred_celltype_l1')
df_hp_tuned
[25]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
B cell 0.7263 0.9863 0.9531 0.8366 0.9696 0.9432 2050
CD4-positive, alpha-beta T cell 0.4301 0.8616 0.8941 0.5738 0.8777 0.7679 1554
CD8-positive, alpha-beta T cell 0.6414 0.8852 0.9479 0.7438 0.916 0.8338 1742
endothelial cell 0.9934 0.9891 0.9998 0.9912 0.9944 0.9878 457
endothelial cell of lymphatic vessel 0.5532 1.0 0.9977 0.7123 0.9988 0.9979 52
epithelial cell 0.9897 0.5518 0.996 0.7085 0.7413 0.5252 7505
fibroblast 0.9585 0.9911 0.9984 0.9745 0.9947 0.9888 676
mast cell 0.9793 0.9595 0.9998 0.9693 0.9794 0.9554 148
mature alpha-beta T cell 0.2783 0.8533 0.9909 0.4197 0.9195 0.8339 75
mural cell 0.9779 0.9108 0.9995 0.9431 0.9541 0.9022 437
myeloid cell 0.9283 0.9628 0.9969 0.9452 0.9797 0.9566 726
natural killer cell 0.7616 0.9596 0.9963 0.8492 0.9778 0.9526 223
plasma cell 0.9881 0.9922 0.9986 0.9901 0.9954 0.9902 1919
plasmacytoid dendritic cell 0.9608 0.98 0.9999 0.9703 0.9899 0.9779 50
regulatory T cell 0.6378 0.7187 0.9841 0.6758 0.841 0.6885 686
macro avg 0.7870 0.9068 0.9835 0.8202 0.942 0.8868
weighted avg 0.8553 0.7749 0.9781 0.7783 0.8627 0.7442
Accuracy 0.7749
Balanced accuracy 0.9068
[26]:
# Order cell type colors
celltype = np.unique(adata_test.obs['cell_type']).tolist()
adata_test.obs['cell_type'] = pd.Categorical(
    values=adata_test.obs['cell_type'], categories=celltype, ordered=True
)
adata_test.obs['pred_celltype_l1'] = pd.Categorical(
    values=adata_test.obs['pred_celltype_l1'], categories=celltype, ordered=True
)
# Add prediction status. Label cells as correct or incorrect based on the comparison between ground truth cell types and predictions.
adata_test = scparadise.scnoah.pred_status(adata_test, celltype='cell_type', pred_celltype='pred_celltype_l1')
[27]:
# Visualise predicted cell types levels, prediction probabilities and prediction status
sc.pl.embedding(adata_test,
                color=[
                    'cell_type',
                    'pred_celltype_l1',
                    'prob_celltype_l1',
                    'pred_status'
                ],
                basis = 'X_umap',
                frameon = False,
                add_outline = True,
                legend_loc = 'right margin',
                legend_fontsize = 7,
                legend_fontoutline = 1,
                ncols=2,
                wspace = 0.7,
                hspace = 0.1)
../../../_images/tutorials_notebooks_scAdam_scAdam_model_optimization_36_0.png
[28]:
# Compare prediction results
# 'untuned' row represents untuned model
# 'hp tuned' row represents tuned model
df.compare(df_hp_tuned, keep_equal=True, align_axis = 0, result_names=('untuned', 'hp tuned'))
[28]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy
B cell untuned 0.8170 0.98 0.9723 0.8911 0.9761 0.9536
hp tuned 0.7263 0.9863 0.9531 0.8366 0.9696 0.9432
CD4-positive, alpha-beta T cell untuned 0.5112 0.8481 0.9248 0.6379 0.8856 0.7783
hp tuned 0.4301 0.8616 0.8941 0.5738 0.8777 0.7679
CD8-positive, alpha-beta T cell untuned 0.6327 0.8829 0.9461 0.7371 0.9139 0.83
hp tuned 0.6414 0.8852 0.9479 0.7438 0.916 0.8338
endothelial cell untuned 0.7873 0.9803 0.9932 0.8733 0.9867 0.9724
hp tuned 0.9934 0.9891 0.9998 0.9912 0.9944 0.9878
endothelial cell of lymphatic vessel untuned 0.8475 0.9615 0.9995 0.9009 0.9803 0.9574
hp tuned 0.5532 1.0 0.9977 0.7123 0.9988 0.9979
epithelial cell untuned 0.9907 0.4994 0.9968 0.6641 0.7055 0.473
hp tuned 0.9897 0.5518 0.996 0.7085 0.7413 0.5252
fibroblast untuned 0.9379 0.8935 0.9977 0.9152 0.9442 0.8822
hp tuned 0.9585 0.9911 0.9984 0.9745 0.9947 0.9888
mast cell untuned 0.9797 0.9797 0.9998 0.9797 0.9897 0.9776
hp tuned 0.9793 0.9595 0.9998 0.9693 0.9794 0.9554
mature alpha-beta T cell untuned 0.0665 0.84 0.9514 0.1232 0.894 0.7903
hp tuned 0.2783 0.8533 0.9909 0.4197 0.9195 0.8339
mural cell untuned 0.8173 0.9519 0.9948 0.8795 0.9731 0.9429
hp tuned 0.9779 0.9108 0.9995 0.9431 0.9541 0.9022
myeloid cell untuned 0.8784 0.9449 0.9946 0.9104 0.9694 0.9351
hp tuned 0.9283 0.9628 0.9969 0.9452 0.9797 0.9566
natural killer cell untuned 0.5929 0.9731 0.9918 0.7368 0.9824 0.9633
hp tuned 0.7616 0.9596 0.9963 0.8492 0.9778 0.9526
plasma cell untuned 0.9810 0.9937 0.9977 0.9873 0.9957 0.9911
hp tuned 0.9881 0.9922 0.9986 0.9901 0.9954 0.9902
plasmacytoid dendritic cell untuned 0.9231 0.96 0.9998 0.9412 0.9797 0.956
hp tuned 0.9608 0.98 0.9999 0.9703 0.9899 0.9779
regulatory T cell untuned 0.4762 0.7143 0.9694 0.5714 0.8321 0.6748
hp tuned 0.6378 0.7187 0.9841 0.6758 0.841 0.6885
macro avg untuned 0.7493 0.8936 0.982 0.7833 0.9339 0.8719
hp tuned 0.7870 0.9068 0.9835 0.8202 0.942 0.8868
weighted avg untuned 0.8512 0.7479 0.9818 0.7567 0.8468 0.7198
hp tuned 0.8553 0.7749 0.9781 0.7783 0.8627 0.7442
Accuracy untuned 0.7479
hp tuned 0.7749
Balanced accuracy untuned 0.8936
hp tuned 0.9068
[29]:
# Compare prediction results
# 'warm start' row represents warm start trained model
# 'hp tuned' row represents tuned model
df_warm_start.compare(df_hp_tuned, keep_equal=True, align_axis = 0, result_names=('warm start', 'hp tuned'))
[29]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy
B cell warm start 0.9738 0.978 0.9967 0.9759 0.9873 0.973
hp tuned 0.7263 0.9863 0.9531 0.8366 0.9696 0.9432
CD4-positive, alpha-beta T cell warm start 0.8644 0.8411 0.9878 0.8526 0.9115 0.8186
hp tuned 0.4301 0.8616 0.8941 0.5738 0.8777 0.7679
CD8-positive, alpha-beta T cell warm start 0.8985 0.9248 0.989 0.9115 0.9564 0.9088
hp tuned 0.6414 0.8852 0.9479 0.7438 0.916 0.8338
endothelial cell warm start 0.9868 0.9803 0.9997 0.9835 0.9899 0.9781
hp tuned 0.9934 0.9891 0.9998 0.9912 0.9944 0.9878
endothelial cell of lymphatic vessel warm start 0.9592 0.9038 0.9999 0.9307 0.9507 0.8951
hp tuned 0.5532 1.0 0.9977 0.7123 0.9988 0.9979
epithelial cell warm start 0.9879 0.9824 0.9917 0.9852 0.987 0.9733
hp tuned 0.9897 0.5518 0.996 0.7085 0.7413 0.5252
fibroblast warm start 0.9585 0.9896 0.9984 0.9738 0.994 0.9872
hp tuned 0.9585 0.9911 0.9984 0.9745 0.9947 0.9888
mast cell warm start 0.9931 0.973 0.9999 0.9829 0.9864 0.9703
hp tuned 0.9793 0.9595 0.9998 0.9693 0.9794 0.9554
mature alpha-beta T cell warm start 0.6036 0.8933 0.9976 0.7204 0.944 0.8819
hp tuned 0.2783 0.8533 0.9909 0.4197 0.9195 0.8339
mural cell warm start 0.9761 0.9336 0.9994 0.9544 0.966 0.927
hp tuned 0.9779 0.9108 0.9995 0.9431 0.9541 0.9022
myeloid cell warm start 0.9832 0.9656 0.9993 0.9743 0.9823 0.9616
hp tuned 0.9283 0.9628 0.9969 0.9452 0.9797 0.9566
natural killer cell warm start 0.8991 0.9193 0.9987 0.9091 0.9582 0.9108
hp tuned 0.7616 0.9596 0.9963 0.8492 0.9778 0.9526
plasma cell warm start 0.9830 0.9917 0.998 0.9873 0.9948 0.989
hp tuned 0.9881 0.9922 0.9986 0.9901 0.9954 0.9902
plasmacytoid dendritic cell warm start 0.9074 0.98 0.9997 0.9423 0.9898 0.9778
hp tuned 0.9608 0.98 0.9999 0.9703 0.9899 0.9779
regulatory T cell warm start 0.7766 0.7551 0.9915 0.7657 0.8653 0.731
hp tuned 0.6378 0.7187 0.9841 0.6758 0.841 0.6885
macro avg warm start 0.9167 0.9341 0.9965 0.9233 0.9642 0.9256
hp tuned 0.7870 0.9068 0.9835 0.8202 0.942 0.8868
weighted avg warm start 0.9544 0.9538 0.9935 0.9539 0.973 0.9443
hp tuned 0.8553 0.7749 0.9781 0.7783 0.8627 0.7442
Accuracy warm start 0.9538
hp tuned 0.7749
Balanced accuracy warm start 0.9341
hp tuned 0.9068

Recommendation#

We recommend warm start training instead of hyperparameter tuning if possible. However, in the absence of an additional training dataset, hyperparameter tuning may help improve model performance.

[30]:
import session_info
session_info.show()
[30]:
Click to view session information
-----
anndata             0.10.8
matplotlib          3.9.2
numpy               1.25.0
pandas              2.2.3
scanpy              1.10.3
scparadise          0.3.1_beta
session_info        1.0.0
-----
Click to view modules imported as dependencies
PIL                         10.4.0
anyio                       NA
arrow                       1.3.0
asciitree                   NA
asttokens                   NA
attr                        24.2.0
attrs                       24.2.0
awkward                     2.7.1
awkward_cpp                 NA
babel                       2.16.0
backports                   NA
certifi                     2024.08.30
cffi                        1.17.1
charset_normalizer          3.3.2
cloudpickle                 3.1.0
colorlog                    NA
comm                        0.2.2
cycler                      0.12.1
cython_runtime              NA
dask                        2024.8.0
dateutil                    2.9.0.post0
debugpy                     1.8.6
decorator                   5.1.1
defusedxml                  0.7.1
exceptiongroup              1.2.2
executing                   2.1.0
fastjsonschema              NA
fqdn                        NA
fsspec                      2023.6.0
h5py                        3.12.1
idna                        3.10
igraph                      0.11.6
imblearn                    0.12.3
importlib_metadata          NA
importlib_resources         NA
ipykernel                   6.29.5
isoduration                 NA
jaraco                      NA
jedi                        0.19.1
jinja2                      3.1.4
joblib                      1.4.2
json5                       0.9.25
jsonpointer                 3.0.0
jsonschema                  4.23.0
jsonschema_specifications   NA
jupyter_events              0.10.0
jupyter_server              2.14.2
jupyterlab_server           2.27.3
kiwisolver                  1.4.7
legacy_api_wrap             NA
leidenalg                   0.10.2
llvmlite                    0.43.0
markupsafe                  2.1.5
matplotlib_inline           0.1.7
more_itertools              10.5.0
mpl_toolkits                NA
mpmath                      1.3.0
msgpack                     1.1.0
mudata                      0.2.4
muon                        0.1.6
natsort                     8.4.0
nbformat                    5.10.4
numba                       0.60.0
numcodecs                   0.12.1
optuna                      4.0.0
overrides                   NA
packaging                   24.1
parso                       0.8.4
patsy                       0.5.6
pexpect                     4.9.0
platformdirs                4.3.6
plotly                      5.24.1
prometheus_client           NA
prompt_toolkit              3.0.48
psutil                      6.0.0
ptyprocess                  0.7.0
pure_eval                   0.2.3
pyarrow                     18.1.0
pycparser                   2.22
pydev_ipython               NA
pydevconsole                NA
pydevd                      3.1.0
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pydot                       3.0.3
pygments                    2.18.0
pynndescent                 0.5.13
pyparsing                   3.1.4
pythonjsonlogger            NA
pytorch_tabnet              NA
pytz                        2024.2
referencing                 NA
requests                    2.32.3
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rich                        NA
rpds                        NA
scipy                       1.13.1
seaborn                     0.13.2
send2trash                  NA
setuptools                  75.1.0
setuptools_scm              NA
shap                        0.46.0
six                         1.16.0
sklearn                     1.5.2
skmisc                      0.3.1
slicer                      NA
sniffio                     1.3.1
stack_data                  0.6.3
statsmodels                 0.14.3
sympy                       1.13.3
tblib                       3.0.0
texttable                   1.7.0
threadpoolctl               3.5.0
tlz                         0.12.1
tomli                       2.0.1
toolz                       0.12.1
torch                       2.4.1+cu121
torchgen                    NA
tornado                     6.4.1
tqdm                        4.66.5
traitlets                   5.14.3
triton                      3.0.0
typing_extensions           NA
umap                        0.5.6
uri_template                NA
urllib3                     1.26.20
wcwidth                     0.2.13
webcolors                   24.8.0
websocket                   1.8.0
yaml                        6.0.2
zarr                        2.18.2
zipp                        NA
zmq                         26.2.0
zoneinfo                    NA
-----
IPython             8.18.1
jupyter_client      8.6.3
jupyter_core        5.7.2
jupyterlab          4.2.5
-----
Python 3.9.19 (main, May  6 2024, 19:43:03) [GCC 11.2.0]
Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
-----
Session information updated at 2025-01-15 23:33