scAdam model optimization#
Model optimization refers to the process of improving the performance and efficiency of machine learning models. scAdam models support two types of optimization:
Warm start: Warm starting is a technique in machine learning that involves initializing a model with weights from a previously trained model on the same or a similar task. This method allows the training process to begin from a more advantageous point on the loss surface, leveraging prior knowledge to improve efficiency and performance.
Hyperparameter tuning: This involves adjusting the hyperparameters of the model, such as learning rates, batch sizes, and other configuration settings, to optimize performance on a specific dataset.
[1]:
# Python packages
import warnings
warnings.simplefilter('ignore')
import scanpy as sc
import scparadise
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
sc.set_figure_params(dpi = 120)
Recommendations about training dataset#
We recommend shifted logarithm data normalization method:
sc.pp.normalize_total(adata, target_sum=None)
sc.pp.log1p(adata)
But you can use any other method of data normalzation (Use the same normalization method for test dataset)
Training dataset sould contain all genes that you want to use for model training in adata_train.X We recommend to remove all non-marker genes from adata_train.X (removeing of such useless genes may increase performance and model quality metrics and also increase training speed)
Data preparation#
Here we download and preprocess dataset from cellxgene:
Single cell RNA sequencing of oropharyngeal squamous cell carcinoma https://cellxgene.cziscience.com/collections/3c34e6f1-6827-47dd-8e19-9edcd461893f
[2]:
!wget https://datasets.cellxgene.cziscience.com/915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad
--2025-01-15 22:51:33-- https://datasets.cellxgene.cziscience.com/915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad
Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 52.85.49.17, 52.85.49.24, 52.85.49.125, ...
Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|52.85.49.17|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 981941987 (936M) [binary/octet-stream]
Saving to: ‘915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad’
915069db-1df2-49a1- 100%[===================>] 936.45M 3.21MB/s in 4m 10s
2025-01-15 22:55:43 (3.75 MB/s) - ‘915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad’ saved [981941987/981941987]
[3]:
# Load prepared for training anndata object
adata = sc.read_h5ad('915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad')
[4]:
# Get raw counts from adata.raw
adata = adata.raw.to_adata()
[5]:
# Convert var_names from ENSG codes to gene names
adata.var.set_index('feature_name', inplace=True)
adata.var_names_make_unique()
[6]:
# Select datasets for model training and testing
# It is necessary for all cell types to be present in both the training datasets and the testing datasets
df = scparadise.scnoah.cell_counter(adata, celltype='cell_type', sample='donor_id')
df
[6]:
| HN481 | HN482 | HN483 | HN485 | HN487 | HN488 | HN489 | HN490 | HN492 | HN494 | |
|---|---|---|---|---|---|---|---|---|---|---|
| epithelial cell | 6747 | 758 | 934 | 575 | 2061 | 1335 | 2533 | 2955 | 2619 | 780 |
| plasma cell | 868 | 1051 | 498 | 609 | 106 | 596 | 1005 | 975 | 642 | 579 |
| CD8-positive, alpha-beta T cell | 613 | 1129 | 1204 | 740 | 279 | 941 | 1004 | 472 | 1408 | 911 |
| fibroblast | 512 | 164 | 256 | 163 | 371 | 118 | 856 | 204 | 234 | 560 |
| myeloid cell | 415 | 311 | 159 | 361 | 995 | 136 | 929 | 394 | 203 | 2998 |
| regulatory T cell | 370 | 316 | 623 | 351 | 135 | 69 | 881 | 385 | 321 | 616 |
| endothelial cell | 283 | 174 | 144 | 111 | 122 | 152 | 311 | 240 | 135 | 138 |
| CD4-positive, alpha-beta T cell | 282 | 1272 | 2629 | 911 | 87 | 224 | 1310 | 1076 | 789 | 1784 |
| mural cell | 267 | 170 | 98 | 156 | 195 | 173 | 286 | 233 | 87 | 128 |
| B cell | 213 | 1837 | 3369 | 764 | 30 | 2053 | 1338 | 2297 | 942 | 1832 |
| mast cell | 92 | 56 | 91 | 56 | 20 | 39 | 33 | 29 | 74 | 27 |
| natural killer cell | 88 | 135 | 205 | 65 | 40 | 120 | 465 | 94 | 58 | 128 |
| mature alpha-beta T cell | 44 | 31 | 13 | 7 | 8 | 7 | 82 | 30 | 38 | 42 |
| endothelial cell of lymphatic vessel | 29 | 23 | 25 | 10 | 8 | 9 | 59 | 6 | 28 | 8 |
| plasmacytoid dendritic cell | 19 | 31 | 43 | 25 | 13 | 22 | 87 | 24 | 20 | 163 |
[7]:
# Create adata_train1, adata_train2 and adata_test datasets
adata_test = adata[adata.obs['donor_id'].isin(['HN481', 'HN482'])].copy()
adata_train1 = adata[adata.obs['donor_id'].isin(['HN483' 'HN485', 'HN488', 'HN489'])].copy()
adata_train2 = adata[adata.obs['donor_id'].isin(['HN490', 'HN487', 'HN492', 'HN494'])].copy()
#del adata
[8]:
# Normalize data, find highly variable features
for i in [adata_train1, adata_train2, adata_test]:
i.layers['counts'] = i.X.copy()
sc.pp.normalize_total(i, target_sum=None)
sc.pp.log1p(i)
i.raw = i
sc.pp.highly_variable_genes(i,
layer='counts',
flavor='seurat_v3',
n_top_genes=1200,
subset=True)
Default model training#
Use ‘balanced_accuracy’ as the evaluation metric in this case due to the imbalance in the number of cells among cell types.
[9]:
# Train scadam model using adata_train1 dataset
scparadise.scadam.train(adata_train1,
path='', # path to save model
model_name='model_scadam', # folder name with model
celltype_l1='cell_type',
eval_metric=['accuracy','balanced_accuracy']) # If you are using an imbalanced training dataset, we recommend using balanced_accuracy for early stopping
Successfully saved genes names for training model
Successfully saved dictionary of dataset annotations
Train dataset contains: 15455 cells, it is 90.0 % of input dataset
Test dataset contains: 1718 cells, it is 10.0 % of input dataset
Accelerator: cuda
Start training
epoch 0 | loss: 2.55518 | train_accuracy: 0.34455 | train_balanced_accuracy: 0.1102 | valid_accuracy: 0.34517 | valid_balanced_accuracy: 0.11042 | 0:00:01s
epoch 1 | loss: 2.00387 | train_accuracy: 0.42077 | train_balanced_accuracy: 0.13547 | valid_accuracy: 0.41909 | valid_balanced_accuracy: 0.13511 | 0:00:02s
epoch 2 | loss: 1.65271 | train_accuracy: 0.58932 | train_balanced_accuracy: 0.29457 | valid_accuracy: 0.57625 | valid_balanced_accuracy: 0.28318 | 0:00:03s
epoch 3 | loss: 1.28917 | train_accuracy: 0.60679 | train_balanced_accuracy: 0.31877 | valid_accuracy: 0.60419 | valid_balanced_accuracy: 0.31184 | 0:00:04s
epoch 4 | loss: 1.09993 | train_accuracy: 0.68515 | train_balanced_accuracy: 0.41748 | valid_accuracy: 0.68393 | valid_balanced_accuracy: 0.4168 | 0:00:04s
epoch 5 | loss: 0.96115 | train_accuracy: 0.7144 | train_balanced_accuracy: 0.4634 | valid_accuracy: 0.71478 | valid_balanced_accuracy: 0.46791 | 0:00:05s
epoch 6 | loss: 0.84477 | train_accuracy: 0.75865 | train_balanced_accuracy: 0.50364 | valid_accuracy: 0.76368 | valid_balanced_accuracy: 0.51123 | 0:00:06s
epoch 7 | loss: 0.74068 | train_accuracy: 0.82472 | train_balanced_accuracy: 0.55661 | valid_accuracy: 0.82363 | valid_balanced_accuracy: 0.55338 | 0:00:07s
epoch 8 | loss: 0.66158 | train_accuracy: 0.86005 | train_balanced_accuracy: 0.61093 | valid_accuracy: 0.86205 | valid_balanced_accuracy: 0.61738 | 0:00:08s
epoch 9 | loss: 0.6103 | train_accuracy: 0.86962 | train_balanced_accuracy: 0.63885 | valid_accuracy: 0.87078 | valid_balanced_accuracy: 0.62554 | 0:00:09s
epoch 10 | loss: 0.56654 | train_accuracy: 0.89097 | train_balanced_accuracy: 0.70924 | valid_accuracy: 0.88882 | valid_balanced_accuracy: 0.67687 | 0:00:10s
epoch 11 | loss: 0.50903 | train_accuracy: 0.90191 | train_balanced_accuracy: 0.72043 | valid_accuracy: 0.90047 | valid_balanced_accuracy: 0.70244 | 0:00:11s
epoch 12 | loss: 0.48274 | train_accuracy: 0.91239 | train_balanced_accuracy: 0.77949 | valid_accuracy: 0.91967 | valid_balanced_accuracy: 0.78529 | 0:00:12s
epoch 13 | loss: 0.44446 | train_accuracy: 0.91252 | train_balanced_accuracy: 0.7516 | valid_accuracy: 0.91502 | valid_balanced_accuracy: 0.73566 | 0:00:13s
epoch 14 | loss: 0.42764 | train_accuracy: 0.92333 | train_balanced_accuracy: 0.8083 | valid_accuracy: 0.92549 | valid_balanced_accuracy: 0.7785 | 0:00:14s
epoch 15 | loss: 0.40217 | train_accuracy: 0.9386 | train_balanced_accuracy: 0.88841 | valid_accuracy: 0.94121 | valid_balanced_accuracy: 0.88421 | 0:00:15s
epoch 16 | loss: 0.38251 | train_accuracy: 0.9397 | train_balanced_accuracy: 0.89722 | valid_accuracy: 0.94005 | valid_balanced_accuracy: 0.897 | 0:00:16s
epoch 17 | loss: 0.37016 | train_accuracy: 0.9474 | train_balanced_accuracy: 0.91188 | valid_accuracy: 0.94237 | valid_balanced_accuracy: 0.89974 | 0:00:17s
epoch 18 | loss: 0.36254 | train_accuracy: 0.95244 | train_balanced_accuracy: 0.93253 | valid_accuracy: 0.94703 | valid_balanced_accuracy: 0.92078 | 0:00:18s
epoch 19 | loss: 0.35562 | train_accuracy: 0.94998 | train_balanced_accuracy: 0.92859 | valid_accuracy: 0.94179 | valid_balanced_accuracy: 0.90541 | 0:00:19s
epoch 20 | loss: 0.33606 | train_accuracy: 0.95917 | train_balanced_accuracy: 0.94093 | valid_accuracy: 0.94645 | valid_balanced_accuracy: 0.92007 | 0:00:20s
epoch 21 | loss: 0.32085 | train_accuracy: 0.9615 | train_balanced_accuracy: 0.95026 | valid_accuracy: 0.95343 | valid_balanced_accuracy: 0.94456 | 0:00:21s
epoch 22 | loss: 0.32289 | train_accuracy: 0.96357 | train_balanced_accuracy: 0.95312 | valid_accuracy: 0.9546 | valid_balanced_accuracy: 0.93717 | 0:00:22s
epoch 23 | loss: 0.31242 | train_accuracy: 0.96189 | train_balanced_accuracy: 0.95122 | valid_accuracy: 0.9546 | valid_balanced_accuracy: 0.93809 | 0:00:23s
epoch 24 | loss: 0.30792 | train_accuracy: 0.96364 | train_balanced_accuracy: 0.95162 | valid_accuracy: 0.95518 | valid_balanced_accuracy: 0.94025 | 0:00:24s
epoch 25 | loss: 0.3059 | train_accuracy: 0.96687 | train_balanced_accuracy: 0.9607 | valid_accuracy: 0.95867 | valid_balanced_accuracy: 0.943 | 0:00:25s
epoch 26 | loss: 0.29367 | train_accuracy: 0.96901 | train_balanced_accuracy: 0.96614 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.93963 | 0:00:26s
epoch 27 | loss: 0.29061 | train_accuracy: 0.96855 | train_balanced_accuracy: 0.96414 | valid_accuracy: 0.95984 | valid_balanced_accuracy: 0.94939 | 0:00:27s
epoch 28 | loss: 0.29253 | train_accuracy: 0.9692 | train_balanced_accuracy: 0.96904 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.95159 | 0:00:28s
epoch 29 | loss: 0.29325 | train_accuracy: 0.97257 | train_balanced_accuracy: 0.97368 | valid_accuracy: 0.95867 | valid_balanced_accuracy: 0.94992 | 0:00:29s
epoch 30 | loss: 0.28561 | train_accuracy: 0.97114 | train_balanced_accuracy: 0.96873 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.95207 | 0:00:30s
epoch 31 | loss: 0.28216 | train_accuracy: 0.97321 | train_balanced_accuracy: 0.97322 | valid_accuracy: 0.96217 | valid_balanced_accuracy: 0.96296 | 0:00:31s
epoch 32 | loss: 0.27522 | train_accuracy: 0.97185 | train_balanced_accuracy: 0.97129 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.95845 | 0:00:32s
epoch 33 | loss: 0.27716 | train_accuracy: 0.97444 | train_balanced_accuracy: 0.97277 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.95505 | 0:00:33s
epoch 34 | loss: 0.26944 | train_accuracy: 0.97587 | train_balanced_accuracy: 0.9787 | valid_accuracy: 0.96158 | valid_balanced_accuracy: 0.96142 | 0:00:34s
epoch 35 | loss: 0.26409 | train_accuracy: 0.97289 | train_balanced_accuracy: 0.97552 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.9599 | 0:00:35s
epoch 36 | loss: 0.26652 | train_accuracy: 0.97554 | train_balanced_accuracy: 0.97432 | valid_accuracy: 0.95751 | valid_balanced_accuracy: 0.95575 | 0:00:36s
epoch 37 | loss: 0.26731 | train_accuracy: 0.97574 | train_balanced_accuracy: 0.97458 | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.95274 | 0:00:37s
epoch 38 | loss: 0.25573 | train_accuracy: 0.97703 | train_balanced_accuracy: 0.97853 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.96639 | 0:00:38s
epoch 39 | loss: 0.26136 | train_accuracy: 0.97975 | train_balanced_accuracy: 0.97911 | valid_accuracy: 0.95984 | valid_balanced_accuracy: 0.95097 | 0:00:39s
epoch 40 | loss: 0.25477 | train_accuracy: 0.97981 | train_balanced_accuracy: 0.9822 | valid_accuracy: 0.961 | valid_balanced_accuracy: 0.96062 | 0:00:40s
epoch 41 | loss: 0.2501 | train_accuracy: 0.97832 | train_balanced_accuracy: 0.97982 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.95776 | 0:00:41s
epoch 42 | loss: 0.24979 | train_accuracy: 0.98188 | train_balanced_accuracy: 0.982 | valid_accuracy: 0.961 | valid_balanced_accuracy: 0.95845 | 0:00:42s
epoch 43 | loss: 0.2472 | train_accuracy: 0.98175 | train_balanced_accuracy: 0.98184 | valid_accuracy: 0.96333 | valid_balanced_accuracy: 0.96133 | 0:00:43s
epoch 44 | loss: 0.23966 | train_accuracy: 0.98156 | train_balanced_accuracy: 0.98064 | valid_accuracy: 0.96158 | valid_balanced_accuracy: 0.95964 | 0:00:43s
epoch 45 | loss: 0.25145 | train_accuracy: 0.98078 | train_balanced_accuracy: 0.98265 | valid_accuracy: 0.95634 | valid_balanced_accuracy: 0.96661 | 0:00:44s
epoch 46 | loss: 0.24429 | train_accuracy: 0.97988 | train_balanced_accuracy: 0.97911 | valid_accuracy: 0.95343 | valid_balanced_accuracy: 0.95564 | 0:00:45s
epoch 47 | loss: 0.25503 | train_accuracy: 0.98234 | train_balanced_accuracy: 0.98221 | valid_accuracy: 0.95634 | valid_balanced_accuracy: 0.96131 | 0:00:46s
epoch 48 | loss: 0.24714 | train_accuracy: 0.98382 | train_balanced_accuracy: 0.98409 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.95744 | 0:00:47s
epoch 49 | loss: 0.23765 | train_accuracy: 0.97988 | train_balanced_accuracy: 0.97965 | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.95392 | 0:00:48s
epoch 50 | loss: 0.23015 | train_accuracy: 0.98305 | train_balanced_accuracy: 0.98415 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.96781 | 0:00:49s
epoch 51 | loss: 0.23351 | train_accuracy: 0.97755 | train_balanced_accuracy: 0.97853 | valid_accuracy: 0.9546 | valid_balanced_accuracy: 0.95501 | 0:00:50s
epoch 52 | loss: 0.22766 | train_accuracy: 0.98531 | train_balanced_accuracy: 0.98407 | valid_accuracy: 0.95576 | valid_balanced_accuracy: 0.95683 | 0:00:51s
epoch 53 | loss: 0.23931 | train_accuracy: 0.98499 | train_balanced_accuracy: 0.98428 | valid_accuracy: 0.95402 | valid_balanced_accuracy: 0.94877 | 0:00:52s
epoch 54 | loss: 0.23413 | train_accuracy: 0.98479 | train_balanced_accuracy: 0.98549 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.96727 | 0:00:53s
epoch 55 | loss: 0.24018 | train_accuracy: 0.98505 | train_balanced_accuracy: 0.98458 | valid_accuracy: 0.95751 | valid_balanced_accuracy: 0.96351 | 0:00:54s
epoch 56 | loss: 0.22409 | train_accuracy: 0.97988 | train_balanced_accuracy: 0.9813 | valid_accuracy: 0.95343 | valid_balanced_accuracy: 0.96193 | 0:00:55s
epoch 57 | loss: 0.2251 | train_accuracy: 0.98434 | train_balanced_accuracy: 0.98721 | valid_accuracy: 0.95402 | valid_balanced_accuracy: 0.96176 | 0:00:56s
epoch 58 | loss: 0.22227 | train_accuracy: 0.9868 | train_balanced_accuracy: 0.98474 | valid_accuracy: 0.95984 | valid_balanced_accuracy: 0.96604 | 0:00:57s
epoch 59 | loss: 0.22204 | train_accuracy: 0.98098 | train_balanced_accuracy: 0.98357 | valid_accuracy: 0.95402 | valid_balanced_accuracy: 0.96383 | 0:00:58s
epoch 60 | loss: 0.22922 | train_accuracy: 0.98369 | train_balanced_accuracy: 0.98709 | valid_accuracy: 0.95052 | valid_balanced_accuracy: 0.95798 | 0:00:59s
Early stopping occurred at epoch 60 with best_epoch = 50 and best_valid_balanced_accuracy = 0.96781
Successfully saved training history and parameters
Successfully saved model at model_scadam/model.zip
[10]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
path_model = 'model_scadam')
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Successfully loaded model
Successfully added predicted celltype_l1 and cell type probabilities
Evaluation of prediction quality#
The model cannot predict mature alpha-beta T cell (precision = 0.0665).
Also model cannot predict 50% of epithelial cell (recall/sensitivity = 0.4994).
There is also low sensitivity for regulatory T cell and low precision for CD8-positive and CD4-positive alpha-beta T cells and endothelial cell.
[11]:
## Check model quality
df = scparadise.scnoah.report_classif_full(adata_test,
celltype='cell_type',
pred_celltype='pred_celltype_l1')
df
[11]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| B cell | 0.8170 | 0.98 | 0.9723 | 0.8911 | 0.9761 | 0.9536 | 2050 |
| CD4-positive, alpha-beta T cell | 0.5112 | 0.8481 | 0.9248 | 0.6379 | 0.8856 | 0.7783 | 1554 |
| CD8-positive, alpha-beta T cell | 0.6327 | 0.8829 | 0.9461 | 0.7371 | 0.9139 | 0.83 | 1742 |
| endothelial cell | 0.7873 | 0.9803 | 0.9932 | 0.8733 | 0.9867 | 0.9724 | 457 |
| endothelial cell of lymphatic vessel | 0.8475 | 0.9615 | 0.9995 | 0.9009 | 0.9803 | 0.9574 | 52 |
| epithelial cell | 0.9907 | 0.4994 | 0.9968 | 0.6641 | 0.7055 | 0.473 | 7505 |
| fibroblast | 0.9379 | 0.8935 | 0.9977 | 0.9152 | 0.9442 | 0.8822 | 676 |
| mast cell | 0.9797 | 0.9797 | 0.9998 | 0.9797 | 0.9897 | 0.9776 | 148 |
| mature alpha-beta T cell | 0.0665 | 0.84 | 0.9514 | 0.1232 | 0.894 | 0.7903 | 75 |
| mural cell | 0.8173 | 0.9519 | 0.9948 | 0.8795 | 0.9731 | 0.9429 | 437 |
| myeloid cell | 0.8784 | 0.9449 | 0.9946 | 0.9104 | 0.9694 | 0.9351 | 726 |
| natural killer cell | 0.5929 | 0.9731 | 0.9918 | 0.7368 | 0.9824 | 0.9633 | 223 |
| plasma cell | 0.9810 | 0.9937 | 0.9977 | 0.9873 | 0.9957 | 0.9911 | 1919 |
| plasmacytoid dendritic cell | 0.9231 | 0.96 | 0.9998 | 0.9412 | 0.9797 | 0.956 | 50 |
| regulatory T cell | 0.4762 | 0.7143 | 0.9694 | 0.5714 | 0.8321 | 0.6748 | 686 |
| macro avg | 0.7493 | 0.8936 | 0.982 | 0.7833 | 0.9339 | 0.8719 | |
| weighted avg | 0.8512 | 0.7479 | 0.9818 | 0.7567 | 0.8468 | 0.7198 | |
| Accuracy | 0.7479 | ||||||
| Balanced accuracy | 0.8936 |
[12]:
# Order cell type colors
celltype = np.unique(adata_test.obs['cell_type']).tolist()
adata_test.obs['cell_type'] = pd.Categorical(
values=adata_test.obs['cell_type'], categories=celltype, ordered=True
)
adata_test.obs['pred_celltype_l1'] = pd.Categorical(
values=adata_test.obs['pred_celltype_l1'], categories=celltype, ordered=True
)
adata_test = scparadise.scnoah.pred_status(adata_test, celltype='cell_type', pred_celltype='pred_celltype_l1')
[13]:
# Visualise predicted cell types levels, prediction probabilities and prediction status
sc.pl.embedding(adata_test,
color=[
'cell_type',
'pred_celltype_l1',
'prob_celltype_l1',
'pred_status'
],
basis = 'X_umap',
frameon = False,
add_outline = True,
legend_loc = 'right margin',
legend_fontsize = 7,
legend_fontoutline = 1,
ncols=2,
wspace = 0.7,
hspace = 0.1)
Warm start model training#
Warning! For warm start scAdam model training, a dataset is required that contains all the same cell types that were present in the original training. The cell types must be named exactly as they were during the initial training of the model.
There should be no new additional cell types or levels of annotation.
Additionally, none of the cell types used for the initial training of the model should be missing.
[14]:
# Warm start requires second training dataset and path to pretrained model
scparadise.scadam.warm_start(adata_train2,
path='', # path to save model
model_name='model_scadam', # folder name with pretrained model
celltype_l1='cell_type',
eval_metric=['accuracy','balanced_accuracy']) # If you are using an imbalanced training dataset, we recommend using balanced_accuracy for early stopping
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Train dataset contains: 28958 cells, it is 90.0 % of input dataset
Test dataset contains: 3218 cells, it is 10.0 % of input dataset
Successfully loaded parameters
Accelerator: cuda
Start training
epoch 0 | loss: 0.39464 | train_accuracy: 0.95114 | train_balanced_accuracy: 0.9084 | valid_accuracy: 0.95059 | valid_balanced_accuracy: 0.91401 | 0:00:01s
epoch 1 | loss: 0.31958 | train_accuracy: 0.95701 | train_balanced_accuracy: 0.94684 | valid_accuracy: 0.95401 | valid_balanced_accuracy: 0.94443 | 0:00:03s
epoch 2 | loss: 0.30485 | train_accuracy: 0.95849 | train_balanced_accuracy: 0.95216 | valid_accuracy: 0.94624 | valid_balanced_accuracy: 0.9447 | 0:00:05s
epoch 3 | loss: 0.29525 | train_accuracy: 0.96074 | train_balanced_accuracy: 0.94526 | valid_accuracy: 0.94779 | valid_balanced_accuracy: 0.93132 | 0:00:07s
epoch 4 | loss: 0.27713 | train_accuracy: 0.96564 | train_balanced_accuracy: 0.95579 | valid_accuracy: 0.95494 | valid_balanced_accuracy: 0.9461 | 0:00:09s
epoch 5 | loss: 0.27371 | train_accuracy: 0.96706 | train_balanced_accuracy: 0.958 | valid_accuracy: 0.95339 | valid_balanced_accuracy: 0.94601 | 0:00:11s
epoch 6 | loss: 0.27291 | train_accuracy: 0.96785 | train_balanced_accuracy: 0.95893 | valid_accuracy: 0.95183 | valid_balanced_accuracy: 0.93815 | 0:00:13s
epoch 7 | loss: 0.26626 | train_accuracy: 0.96913 | train_balanced_accuracy: 0.96329 | valid_accuracy: 0.95339 | valid_balanced_accuracy: 0.94487 | 0:00:15s
epoch 8 | loss: 0.25926 | train_accuracy: 0.96681 | train_balanced_accuracy: 0.95935 | valid_accuracy: 0.94935 | valid_balanced_accuracy: 0.93925 | 0:00:16s
epoch 9 | loss: 0.25558 | train_accuracy: 0.97134 | train_balanced_accuracy: 0.9652 | valid_accuracy: 0.95339 | valid_balanced_accuracy: 0.94177 | 0:00:18s
epoch 10 | loss: 0.25424 | train_accuracy: 0.97237 | train_balanced_accuracy: 0.96595 | valid_accuracy: 0.9509 | valid_balanced_accuracy: 0.93876 | 0:00:20s
epoch 11 | loss: 0.25194 | train_accuracy: 0.96809 | train_balanced_accuracy: 0.97495 | valid_accuracy: 0.94096 | valid_balanced_accuracy: 0.93964 | 0:00:22s
epoch 12 | loss: 0.24886 | train_accuracy: 0.97241 | train_balanced_accuracy: 0.96691 | valid_accuracy: 0.9537 | valid_balanced_accuracy: 0.94268 | 0:00:24s
epoch 13 | loss: 0.24931 | train_accuracy: 0.97344 | train_balanced_accuracy: 0.96327 | valid_accuracy: 0.95059 | valid_balanced_accuracy: 0.92561 | 0:00:25s
epoch 14 | loss: 0.24412 | train_accuracy: 0.97248 | train_balanced_accuracy: 0.97067 | valid_accuracy: 0.95028 | valid_balanced_accuracy: 0.9415 | 0:00:27s
Early stopping occurred at epoch 14 with best_epoch = 4 and best_valid_balanced_accuracy = 0.9461
Successfully saved training history and parameters
Successfully saved model at model_scadam/model.zip
[15]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
path_model = 'model_scadam')
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Successfully loaded model
Successfully added predicted celltype_l1 and cell type probabilities
Evaluation of prediction results#
Warm start leads to an increase in the model’s accuracy, balanced accuracy and precision.
Below are examples of improved model performance for certain cell types.
mature alpha-beta T cell(precision - 0.6036 vs 0.0665)regulatory T cell(precision - 0.7766 vs 0.4762)natural killer cell(precision - 0.8991 vs 0.5929)fibroblast(sensitivity - 0.9896 vs 0.8935)epithelial cell(sensitivity - 0.9824 vs 0.4994)CD4-positive, alpha-beta T cell(precision - 0.8644 vs 0.5112)CD8-positive, alpha-beta T cell(precision - 0.8985 vs 0.6327)
Full comparison is available below.
[16]:
## Check model quality
df_warm_start = scparadise.scnoah.report_classif_full(adata_test,
celltype='cell_type',
pred_celltype='pred_celltype_l1')
df_warm_start
[16]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| B cell | 0.9738 | 0.978 | 0.9967 | 0.9759 | 0.9873 | 0.973 | 2050 |
| CD4-positive, alpha-beta T cell | 0.8644 | 0.8411 | 0.9878 | 0.8526 | 0.9115 | 0.8186 | 1554 |
| CD8-positive, alpha-beta T cell | 0.8985 | 0.9248 | 0.989 | 0.9115 | 0.9564 | 0.9088 | 1742 |
| endothelial cell | 0.9868 | 0.9803 | 0.9997 | 0.9835 | 0.9899 | 0.9781 | 457 |
| endothelial cell of lymphatic vessel | 0.9592 | 0.9038 | 0.9999 | 0.9307 | 0.9507 | 0.8951 | 52 |
| epithelial cell | 0.9879 | 0.9824 | 0.9917 | 0.9852 | 0.987 | 0.9733 | 7505 |
| fibroblast | 0.9585 | 0.9896 | 0.9984 | 0.9738 | 0.994 | 0.9872 | 676 |
| mast cell | 0.9931 | 0.973 | 0.9999 | 0.9829 | 0.9864 | 0.9703 | 148 |
| mature alpha-beta T cell | 0.6036 | 0.8933 | 0.9976 | 0.7204 | 0.944 | 0.8819 | 75 |
| mural cell | 0.9761 | 0.9336 | 0.9994 | 0.9544 | 0.966 | 0.927 | 437 |
| myeloid cell | 0.9832 | 0.9656 | 0.9993 | 0.9743 | 0.9823 | 0.9616 | 726 |
| natural killer cell | 0.8991 | 0.9193 | 0.9987 | 0.9091 | 0.9582 | 0.9108 | 223 |
| plasma cell | 0.9830 | 0.9917 | 0.998 | 0.9873 | 0.9948 | 0.989 | 1919 |
| plasmacytoid dendritic cell | 0.9074 | 0.98 | 0.9997 | 0.9423 | 0.9898 | 0.9778 | 50 |
| regulatory T cell | 0.7766 | 0.7551 | 0.9915 | 0.7657 | 0.8653 | 0.731 | 686 |
| macro avg | 0.9167 | 0.9341 | 0.9965 | 0.9233 | 0.9642 | 0.9256 | |
| weighted avg | 0.9544 | 0.9538 | 0.9935 | 0.9539 | 0.973 | 0.9443 | |
| Accuracy | 0.9538 | ||||||
| Balanced accuracy | 0.9341 |
[17]:
# Order cell type colors
celltype = np.unique(adata_test.obs['cell_type']).tolist()
adata_test.obs['cell_type'] = pd.Categorical(
values=adata_test.obs['cell_type'], categories=celltype, ordered=True
)
adata_test.obs['pred_celltype_l1'] = pd.Categorical(
values=adata_test.obs['pred_celltype_l1'], categories=celltype, ordered=True
)
# Add prediction status. Label cells as correct or incorrect based on the comparison between ground truth cell types and predictions.
adata_test = scparadise.scnoah.pred_status(adata_test, celltype='cell_type', pred_celltype='pred_celltype_l1')
[18]:
# Visualise predicted cell types levels, prediction probabilities and prediction status
sc.pl.embedding(adata_test,
color=[
'cell_type',
'pred_celltype_l1',
'prob_celltype_l1',
'pred_status'
],
basis = 'X_umap',
frameon = False,
add_outline = True,
legend_loc = 'right margin',
legend_fontsize = 7,
legend_fontoutline = 1,
ncols=2,
wspace = 0.7,
hspace = 0.1)
[19]:
# Compare prediction results
# 'untuned' row represents untuned model
# 'warm start' row represents warm start trained model
df.compare(df_warm_start, keep_equal=True, align_axis = 0, result_names=('untuned', 'warm start'))
[19]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | ||
|---|---|---|---|---|---|---|---|
| B cell | untuned | 0.8170 | 0.98 | 0.9723 | 0.8911 | 0.9761 | 0.9536 |
| warm start | 0.9738 | 0.978 | 0.9967 | 0.9759 | 0.9873 | 0.973 | |
| CD4-positive, alpha-beta T cell | untuned | 0.5112 | 0.8481 | 0.9248 | 0.6379 | 0.8856 | 0.7783 |
| warm start | 0.8644 | 0.8411 | 0.9878 | 0.8526 | 0.9115 | 0.8186 | |
| CD8-positive, alpha-beta T cell | untuned | 0.6327 | 0.8829 | 0.9461 | 0.7371 | 0.9139 | 0.83 |
| warm start | 0.8985 | 0.9248 | 0.989 | 0.9115 | 0.9564 | 0.9088 | |
| endothelial cell | untuned | 0.7873 | 0.9803 | 0.9932 | 0.8733 | 0.9867 | 0.9724 |
| warm start | 0.9868 | 0.9803 | 0.9997 | 0.9835 | 0.9899 | 0.9781 | |
| endothelial cell of lymphatic vessel | untuned | 0.8475 | 0.9615 | 0.9995 | 0.9009 | 0.9803 | 0.9574 |
| warm start | 0.9592 | 0.9038 | 0.9999 | 0.9307 | 0.9507 | 0.8951 | |
| epithelial cell | untuned | 0.9907 | 0.4994 | 0.9968 | 0.6641 | 0.7055 | 0.473 |
| warm start | 0.9879 | 0.9824 | 0.9917 | 0.9852 | 0.987 | 0.9733 | |
| fibroblast | untuned | 0.9379 | 0.8935 | 0.9977 | 0.9152 | 0.9442 | 0.8822 |
| warm start | 0.9585 | 0.9896 | 0.9984 | 0.9738 | 0.994 | 0.9872 | |
| mast cell | untuned | 0.9797 | 0.9797 | 0.9998 | 0.9797 | 0.9897 | 0.9776 |
| warm start | 0.9931 | 0.973 | 0.9999 | 0.9829 | 0.9864 | 0.9703 | |
| mature alpha-beta T cell | untuned | 0.0665 | 0.84 | 0.9514 | 0.1232 | 0.894 | 0.7903 |
| warm start | 0.6036 | 0.8933 | 0.9976 | 0.7204 | 0.944 | 0.8819 | |
| mural cell | untuned | 0.8173 | 0.9519 | 0.9948 | 0.8795 | 0.9731 | 0.9429 |
| warm start | 0.9761 | 0.9336 | 0.9994 | 0.9544 | 0.966 | 0.927 | |
| myeloid cell | untuned | 0.8784 | 0.9449 | 0.9946 | 0.9104 | 0.9694 | 0.9351 |
| warm start | 0.9832 | 0.9656 | 0.9993 | 0.9743 | 0.9823 | 0.9616 | |
| natural killer cell | untuned | 0.5929 | 0.9731 | 0.9918 | 0.7368 | 0.9824 | 0.9633 |
| warm start | 0.8991 | 0.9193 | 0.9987 | 0.9091 | 0.9582 | 0.9108 | |
| plasma cell | untuned | 0.9810 | 0.9937 | 0.9977 | 0.9873 | 0.9957 | 0.9911 |
| warm start | 0.9830 | 0.9917 | 0.998 | 0.9873 | 0.9948 | 0.989 | |
| plasmacytoid dendritic cell | untuned | 0.9231 | 0.96 | 0.9998 | 0.9412 | 0.9797 | 0.956 |
| warm start | 0.9074 | 0.98 | 0.9997 | 0.9423 | 0.9898 | 0.9778 | |
| regulatory T cell | untuned | 0.4762 | 0.7143 | 0.9694 | 0.5714 | 0.8321 | 0.6748 |
| warm start | 0.7766 | 0.7551 | 0.9915 | 0.7657 | 0.8653 | 0.731 | |
| macro avg | untuned | 0.7493 | 0.8936 | 0.982 | 0.7833 | 0.9339 | 0.8719 |
| warm start | 0.9167 | 0.9341 | 0.9965 | 0.9233 | 0.9642 | 0.9256 | |
| weighted avg | untuned | 0.8512 | 0.7479 | 0.9818 | 0.7567 | 0.8468 | 0.7198 |
| warm start | 0.9544 | 0.9538 | 0.9935 | 0.9539 | 0.973 | 0.9443 | |
| Accuracy | untuned | 0.7479 | |||||
| warm start | 0.9538 | ||||||
| Balanced accuracy | untuned | 0.8936 | |||||
| warm start | 0.9341 |
Hyperparameters tuning#
Warning! Hyperparameter tuning is a very time-consuming process and may be interrupted for various reasons (such as system shutdowns or errors in CUDA). However, this is not a problem, as restarting the hyperparameter tuning will allow it to continue from the last unfinished trial. An example of hyperparameter tuning represents such a case. The fourth line states: “Using an existing study with name ‘model_scadam_hp_tune’ instead of creating a new one.” This indicates a restart of hyperparameter tuning (in this case, from the 8th trial).
[20]:
scparadise.scadam.hyperparameter_tuning(adata_train1,
path='',
model_name='model_scadam_hp_tune', # Folder to save hyperparameter tuning results
celltype_l1='cell_type',
num_trials=50, # The number of attempts to find the optimal hyperparameters for the model (recommended - minimum 100)
eval_metric=['accuracy','balanced_accuracy'])
Successfully saved genes names for training model
Successfully saved dictionary of dataset annotations
Accelerator: cuda
[I 2024-12-17 16:56:48,662] Using an existing study with name 'model_scadam_hp_tune' instead of creating a new one.
Fold 1:
Early stopping occurred at epoch 19 with best_epoch = 4 and best_valid_balanced_accuracy = 0.12159
[I 2024-12-17 16:57:06,116] Trial 8 pruned.
Fold 1:
Stop training because you reached max_epochs = 50 with best_epoch = 42 and best_valid_balanced_accuracy = 0.30773
[I 2024-12-17 16:57:48,788] Trial 9 pruned.
Fold 1:
Stop training because you reached max_epochs = 45 with best_epoch = 40 and best_valid_balanced_accuracy = 0.95197
Fold 2:
Stop training because you reached max_epochs = 45 with best_epoch = 32 and best_valid_balanced_accuracy = 0.95653
Fold 3:
Stop training because you reached max_epochs = 45 with best_epoch = 33 and best_valid_balanced_accuracy = 0.95143
Fold 4:
Stop training because you reached max_epochs = 45 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95292
[I 2024-12-17 17:00:38,489] Trial 10 finished with value: 0.9532140393068294 and parameters: {'n_d': 124, 'n_a': 92, 'n_steps': 4, 'n_shared': 5, 'cat_emb_dim': 7, 'n_independent': 1, 'gamma': 1.6667667154456676, 'momentum': 0.27154876915108217, 'lr': 0.10527024228081307, 'mask_type': 'sparsemax', 'lambda_sparse': 0.0009586155312153562, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 10 with value: 0.9532140393068294.
Fold 1:
Early stopping occurred at epoch 9 with best_epoch = 4 and best_valid_balanced_accuracy = 0.09804
[I 2024-12-17 17:01:06,583] Trial 11 pruned.
Fold 1:
Early stopping occurred at epoch 12 with best_epoch = 2 and best_valid_balanced_accuracy = 0.10296
[I 2024-12-17 17:01:21,539] Trial 12 pruned.
Fold 1:
Early stopping occurred at epoch 35 with best_epoch = 25 and best_valid_balanced_accuracy = 0.94294
Fold 2:
Early stopping occurred at epoch 43 with best_epoch = 33 and best_valid_balanced_accuracy = 0.95294
Fold 3:
Early stopping occurred at epoch 40 with best_epoch = 30 and best_valid_balanced_accuracy = 0.94829
Fold 4:
Early stopping occurred at epoch 38 with best_epoch = 28 and best_valid_balanced_accuracy = 0.95632
[I 2024-12-17 17:03:00,448] Trial 13 finished with value: 0.9501192387079038 and parameters: {'n_d': 48, 'n_a': 44, 'n_steps': 1, 'n_shared': 3, 'cat_emb_dim': 7, 'n_independent': 3, 'gamma': 1.3155465245255713, 'momentum': 0.015628332308846832, 'lr': 0.022300516115261078, 'mask_type': 'entmax', 'lambda_sparse': 0.001323640069167409, 'patience': 10, 'max_epochs': 75, 'virtual_batch_size': 128, 'batch_size': 768}. Best is trial 10 with value: 0.9532140393068294.
Fold 1:
Early stopping occurred at epoch 22 with best_epoch = 12 and best_valid_balanced_accuracy = 0.9479
Fold 2:
Early stopping occurred at epoch 23 with best_epoch = 13 and best_valid_balanced_accuracy = 0.9517
Fold 3:
Early stopping occurred at epoch 30 with best_epoch = 20 and best_valid_balanced_accuracy = 0.95081
[I 2024-12-17 17:04:23,766] Trial 14 pruned.
Fold 1:
Early stopping occurred at epoch 42 with best_epoch = 32 and best_valid_balanced_accuracy = 0.94818
Fold 2:
Early stopping occurred at epoch 40 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95425
Fold 3:
Early stopping occurred at epoch 28 with best_epoch = 18 and best_valid_balanced_accuracy = 0.9401
[I 2024-12-17 17:05:54,421] Trial 15 pruned.
Fold 1:
Early stopping occurred at epoch 20 with best_epoch = 15 and best_valid_balanced_accuracy = 0.75417
[I 2024-12-17 17:06:21,883] Trial 16 pruned.
Fold 1:
Early stopping occurred at epoch 58 with best_epoch = 38 and best_valid_balanced_accuracy = 0.9491
Fold 2:
Stop training because you reached max_epochs = 65 with best_epoch = 51 and best_valid_balanced_accuracy = 0.95384
Fold 3:
Stop training because you reached max_epochs = 65 with best_epoch = 49 and best_valid_balanced_accuracy = 0.95196
Fold 4:
Early stopping occurred at epoch 50 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95086
[I 2024-12-17 17:09:16,394] Trial 17 finished with value: 0.9514360808085562 and parameters: {'n_d': 84, 'n_a': 112, 'n_steps': 2, 'n_shared': 7, 'cat_emb_dim': 6, 'n_independent': 1, 'gamma': 1.5030605109294553, 'momentum': 0.20760649764410705, 'lr': 0.05660643128255937, 'mask_type': 'entmax', 'lambda_sparse': 0.008471890017944976, 'patience': 20, 'max_epochs': 65, 'virtual_batch_size': 128, 'batch_size': 1024}. Best is trial 10 with value: 0.9532140393068294.
Fold 1:
Stop training because you reached max_epochs = 30 with best_epoch = 26 and best_valid_balanced_accuracy = 0.88263
Fold 2:
Stop training because you reached max_epochs = 30 with best_epoch = 28 and best_valid_balanced_accuracy = 0.92126
Fold 3:
Stop training because you reached max_epochs = 30 with best_epoch = 25 and best_valid_balanced_accuracy = 0.88974
[I 2024-12-17 17:10:09,304] Trial 18 pruned.
Fold 1:
Stop training because you reached max_epochs = 60 with best_epoch = 51 and best_valid_balanced_accuracy = 0.94394
Fold 2:
Stop training because you reached max_epochs = 60 with best_epoch = 42 and best_valid_balanced_accuracy = 0.94609
Fold 3:
Stop training because you reached max_epochs = 60 with best_epoch = 59 and best_valid_balanced_accuracy = 0.94985
[I 2024-12-17 17:12:35,400] Trial 19 pruned.
Fold 1:
Stop training because you reached max_epochs = 15 with best_epoch = 12 and best_valid_balanced_accuracy = 0.25899
[I 2024-12-17 17:12:56,643] Trial 20 pruned.
Fold 1:
Stop training because you reached max_epochs = 80 with best_epoch = 63 and best_valid_balanced_accuracy = 0.95048
Fold 2:
Stop training because you reached max_epochs = 80 with best_epoch = 73 and best_valid_balanced_accuracy = 0.95068
Fold 3:
Early stopping occurred at epoch 79 with best_epoch = 59 and best_valid_balanced_accuracy = 0.95031
[I 2024-12-17 17:18:31,330] Trial 21 pruned.
Fold 1:
Early stopping occurred at epoch 30 with best_epoch = 20 and best_valid_balanced_accuracy = 0.94582
Fold 2:
Early stopping occurred at epoch 36 with best_epoch = 26 and best_valid_balanced_accuracy = 0.95586
Fold 3:
Early stopping occurred at epoch 35 with best_epoch = 25 and best_valid_balanced_accuracy = 0.94643
[I 2024-12-17 17:19:54,289] Trial 22 pruned.
Fold 1:
Stop training because you reached max_epochs = 40 with best_epoch = 35 and best_valid_balanced_accuracy = 0.9449
Fold 2:
Stop training because you reached max_epochs = 40 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95382
Fold 3:
Stop training because you reached max_epochs = 40 with best_epoch = 26 and best_valid_balanced_accuracy = 0.95427
Fold 4:
Stop training because you reached max_epochs = 40 with best_epoch = 38 and best_valid_balanced_accuracy = 0.95487
[I 2024-12-17 17:21:47,902] Trial 23 finished with value: 0.9519649530929056 and parameters: {'n_d': 28, 'n_a': 104, 'n_steps': 2, 'n_shared': 4, 'cat_emb_dim': 6, 'n_independent': 1, 'gamma': 1.178399586222898, 'momentum': 0.11624408126387431, 'lr': 0.05461797882680373, 'mask_type': 'entmax', 'lambda_sparse': 0.004533711957980013, 'patience': 15, 'max_epochs': 40, 'virtual_batch_size': 128, 'batch_size': 768}. Best is trial 10 with value: 0.9532140393068294.
Fold 1:
Stop training because you reached max_epochs = 40 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95087
Fold 2:
Stop training because you reached max_epochs = 40 with best_epoch = 32 and best_valid_balanced_accuracy = 0.95126
Fold 3:
Stop training because you reached max_epochs = 40 with best_epoch = 27 and best_valid_balanced_accuracy = 0.95497
Fold 4:
Stop training because you reached max_epochs = 40 with best_epoch = 33 and best_valid_balanced_accuracy = 0.95485
[I 2024-12-17 17:23:37,508] Trial 24 finished with value: 0.9529842642572143 and parameters: {'n_d': 20, 'n_a': 104, 'n_steps': 2, 'n_shared': 4, 'cat_emb_dim': 6, 'n_independent': 1, 'gamma': 1.004354856556712, 'momentum': 0.09745170594845298, 'lr': 0.1361115286436767, 'mask_type': 'entmax', 'lambda_sparse': 0.004700932119213813, 'patience': 15, 'max_epochs': 40, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 10 with value: 0.9532140393068294.
Fold 1:
Stop training because you reached max_epochs = 40 with best_epoch = 29 and best_valid_balanced_accuracy = 0.95132
Fold 2:
Stop training because you reached max_epochs = 40 with best_epoch = 36 and best_valid_balanced_accuracy = 0.95143
Fold 3:
Stop training because you reached max_epochs = 40 with best_epoch = 32 and best_valid_balanced_accuracy = 0.94893
[I 2024-12-17 17:25:50,067] Trial 25 pruned.
Fold 1:
Stop training because you reached max_epochs = 25 with best_epoch = 24 and best_valid_balanced_accuracy = 0.93793
[I 2024-12-17 17:26:07,067] Trial 26 pruned.
Fold 1:
Stop training because you reached max_epochs = 40 with best_epoch = 34 and best_valid_balanced_accuracy = 0.91319
[I 2024-12-17 17:26:39,334] Trial 27 pruned.
Fold 1:
Stop training because you reached max_epochs = 55 with best_epoch = 53 and best_valid_balanced_accuracy = 0.89312
[I 2024-12-17 17:27:09,718] Trial 28 pruned.
Fold 1:
Stop training because you reached max_epochs = 20 with best_epoch = 17 and best_valid_balanced_accuracy = 0.69984
[I 2024-12-17 17:27:22,652] Trial 29 pruned.
Fold 1:
Stop training because you reached max_epochs = 35 with best_epoch = 27 and best_valid_balanced_accuracy = 0.95036
Fold 2:
Stop training because you reached max_epochs = 35 with best_epoch = 22 and best_valid_balanced_accuracy = 0.9522
Fold 3:
Stop training because you reached max_epochs = 35 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95624
Fold 4:
Early stopping occurred at epoch 34 with best_epoch = 19 and best_valid_balanced_accuracy = 0.95711
[I 2024-12-17 17:29:03,542] Trial 30 finished with value: 0.953979711233727 and parameters: {'n_d': 8, 'n_a': 104, 'n_steps': 3, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 1, 'gamma': 1.1086534148049974, 'momentum': 0.03604587542483932, 'lr': 0.04304648047194131, 'mask_type': 'entmax', 'lambda_sparse': 0.0056379815045209765, 'patience': 15, 'max_epochs': 35, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 30 with value: 0.953979711233727.
Fold 1:
Early stopping occurred at epoch 29 with best_epoch = 19 and best_valid_balanced_accuracy = 0.94137
[I 2024-12-17 17:29:44,127] Trial 31 pruned.
Fold 1:
Early stopping occurred at epoch 44 with best_epoch = 29 and best_valid_balanced_accuracy = 0.95196
Fold 2:
Early stopping occurred at epoch 38 with best_epoch = 23 and best_valid_balanced_accuracy = 0.95521
Fold 3:
Early stopping occurred at epoch 35 with best_epoch = 20 and best_valid_balanced_accuracy = 0.95698
Fold 4:
Early stopping occurred at epoch 33 with best_epoch = 18 and best_valid_balanced_accuracy = 0.95593
[I 2024-12-17 17:31:43,742] Trial 32 finished with value: 0.9550171718117837 and parameters: {'n_d': 16, 'n_a': 100, 'n_steps': 3, 'n_shared': 2, 'cat_emb_dim': 8, 'n_independent': 1, 'gamma': 1.0958861456214881, 'momentum': 0.03574116393845864, 'lr': 0.043387929856948185, 'mask_type': 'entmax', 'lambda_sparse': 0.006794720166487761, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 32 with value: 0.9550171718117837.
Fold 1:
Early stopping occurred at epoch 32 with best_epoch = 17 and best_valid_balanced_accuracy = 0.94463
[I 2024-12-17 17:32:09,558] Trial 33 pruned.
Fold 1:
Stop training because you reached max_epochs = 45 with best_epoch = 32 and best_valid_balanced_accuracy = 0.95222
Fold 2:
Stop training because you reached max_epochs = 45 with best_epoch = 39 and best_valid_balanced_accuracy = 0.95251
Fold 3:
Stop training because you reached max_epochs = 45 with best_epoch = 30 and best_valid_balanced_accuracy = 0.95336
Fold 4:
Early stopping occurred at epoch 38 with best_epoch = 23 and best_valid_balanced_accuracy = 0.95417
[I 2024-12-17 17:34:43,722] Trial 34 finished with value: 0.9530666488089797 and parameters: {'n_d': 16, 'n_a': 84, 'n_steps': 4, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2276108102899952, 'momentum': 0.012593582181250802, 'lr': 0.08059647626559789, 'mask_type': 'entmax', 'lambda_sparse': 0.002167146005271099, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 32 with value: 0.9550171718117837.
Fold 1:
Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_balanced_accuracy = 0.95276
Fold 2:
Early stopping occurred at epoch 34 with best_epoch = 19 and best_valid_balanced_accuracy = 0.95563
Fold 3:
Stop training because you reached max_epochs = 45 with best_epoch = 37 and best_valid_balanced_accuracy = 0.95338
Fold 4:
Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_balanced_accuracy = 0.95118
[I 2024-12-17 17:37:52,876] Trial 35 finished with value: 0.9532372247143821 and parameters: {'n_d': 16, 'n_a': 64, 'n_steps': 6, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2400761918697902, 'momentum': 0.03197771261456092, 'lr': 0.07586645313201903, 'mask_type': 'entmax', 'lambda_sparse': 0.0019824653467942583, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 512}. Best is trial 32 with value: 0.9550171718117837.
Fold 1:
Stop training because you reached max_epochs = 50 with best_epoch = 48 and best_valid_balanced_accuracy = 0.84368
[I 2024-12-17 17:38:20,704] Trial 36 pruned.
Fold 1:
Stop training because you reached max_epochs = 45 with best_epoch = 35 and best_valid_balanced_accuracy = 0.92418
[I 2024-12-17 17:39:19,414] Trial 37 pruned.
Fold 1:
Stop training because you reached max_epochs = 5 with best_epoch = 2 and best_valid_balanced_accuracy = 0.13015
[I 2024-12-17 17:39:24,766] Trial 38 pruned.
Fold 1:
Stop training because you reached max_epochs = 30 with best_epoch = 27 and best_valid_balanced_accuracy = 0.25774
Fold 2:
Stop training because you reached max_epochs = 30 with best_epoch = 25 and best_valid_balanced_accuracy = 0.2464
Fold 3:
Stop training because you reached max_epochs = 30 with best_epoch = 29 and best_valid_balanced_accuracy = 0.24999
[I 2024-12-17 17:41:38,906] Trial 39 pruned.
Fold 1:
Stop training because you reached max_epochs = 30 with best_epoch = 28 and best_valid_balanced_accuracy = 0.90695
[I 2024-12-17 17:42:12,114] Trial 40 pruned.
Fold 1:
Stop training because you reached max_epochs = 55 with best_epoch = 53 and best_valid_balanced_accuracy = 0.87599
[I 2024-12-17 17:42:35,563] Trial 41 pruned.
Fold 1:
Early stopping occurred at epoch 37 with best_epoch = 22 and best_valid_balanced_accuracy = 0.95393
Fold 2:
Stop training because you reached max_epochs = 45 with best_epoch = 37 and best_valid_balanced_accuracy = 0.95927
Fold 3:
Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_balanced_accuracy = 0.95694
Fold 4:
Early stopping occurred at epoch 34 with best_epoch = 19 and best_valid_balanced_accuracy = 0.95263
[I 2024-12-17 17:44:57,599] Trial 42 finished with value: 0.9556933956291884 and parameters: {'n_d': 12, 'n_a': 88, 'n_steps': 4, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2378640594404067, 'momentum': 0.011566565604250411, 'lr': 0.07971165451794174, 'mask_type': 'entmax', 'lambda_sparse': 0.0019751637085813345, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 42 with value: 0.9556933956291884.
Fold 1:
Early stopping occurred at epoch 38 with best_epoch = 23 and best_valid_balanced_accuracy = 0.94737
Fold 2:
Early stopping occurred at epoch 34 with best_epoch = 19 and best_valid_balanced_accuracy = 0.9551
Fold 3:
Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_balanced_accuracy = 0.94724
[I 2024-12-17 17:46:47,117] Trial 43 pruned.
Fold 1:
Stop training because you reached max_epochs = 35 with best_epoch = 34 and best_valid_balanced_accuracy = 0.94709
Fold 2:
Early stopping occurred at epoch 25 with best_epoch = 10 and best_valid_balanced_accuracy = 0.95678
Fold 3:
Stop training because you reached max_epochs = 35 with best_epoch = 32 and best_valid_balanced_accuracy = 0.95282
[I 2024-12-17 17:48:47,095] Trial 44 pruned.
Fold 1:
Early stopping occurred at epoch 34 with best_epoch = 19 and best_valid_balanced_accuracy = 0.94692
Fold 2:
Early stopping occurred at epoch 42 with best_epoch = 27 and best_valid_balanced_accuracy = 0.95551
Fold 3:
Early stopping occurred at epoch 32 with best_epoch = 17 and best_valid_balanced_accuracy = 0.95563
Fold 4:
Early stopping occurred at epoch 40 with best_epoch = 25 and best_valid_balanced_accuracy = 0.95928
[I 2024-12-17 17:50:56,243] Trial 45 finished with value: 0.9543346665963542 and parameters: {'n_d': 12, 'n_a': 96, 'n_steps': 3, 'n_shared': 1, 'cat_emb_dim': 7, 'n_independent': 3, 'gamma': 1.3202643716808666, 'momentum': 0.013752839994914707, 'lr': 0.11122640167419975, 'mask_type': 'entmax', 'lambda_sparse': 0.02401249896211509, 'patience': 15, 'max_epochs': 50, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 42 with value: 0.9556933956291884.
Fold 1:
Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_balanced_accuracy = 0.94524
[I 2024-12-17 17:51:33,343] Trial 46 pruned.
Fold 1:
Stop training because you reached max_epochs = 60 with best_epoch = 50 and best_valid_balanced_accuracy = 0.95008
Fold 2:
Stop training because you reached max_epochs = 60 with best_epoch = 57 and best_valid_balanced_accuracy = 0.94405
Fold 3:
Early stopping occurred at epoch 54 with best_epoch = 39 and best_valid_balanced_accuracy = 0.95121
[I 2024-12-17 17:55:42,846] Trial 47 pruned.
Fold 1:
Stop training because you reached max_epochs = 35 with best_epoch = 34 and best_valid_balanced_accuracy = 0.93133
[I 2024-12-17 17:57:03,560] Trial 48 pruned.
Fold 1:
Stop training because you reached max_epochs = 60 with best_epoch = 59 and best_valid_balanced_accuracy = 0.94001
[I 2024-12-17 17:57:33,156] Trial 49 pruned.
Successfully saved best hyperparameters
Best hyperparameters: {'n_d': 12, 'n_a': 88, 'n_steps': 4, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2378640594404067, 'momentum': 0.011566565604250411, 'lr': 0.07971165451794174, 'mask_type': 'entmax', 'lambda_sparse': 0.0019751637085813345, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 640}
[20]:
{'n_d': 12,
'n_a': 88,
'n_steps': 4,
'n_shared': 1,
'cat_emb_dim': 8,
'n_independent': 2,
'gamma': 1.2378640594404067,
'momentum': 0.011566565604250411,
'lr': 0.07971165451794174,
'mask_type': 'entmax',
'lambda_sparse': 0.0019751637085813345,
'patience': 15,
'max_epochs': 45,
'virtual_batch_size': 128,
'batch_size': 640}
[23]:
# Train model using optimal parameters from model_scadam_hp_tune folder
scparadise.scadam.train_tuned(adata_train1,
path='', # path to save model
path_tuned='model_scadam_hp_tune', # path to a folder with tuned hyperparameters
model_name='model_scadam_tuned', # folder name with model
celltype_l1='cell_type',
eval_metric=['accuracy','balanced_accuracy']) # If you are using an imbalanced training dataset, we recommend using balanced_accuracy for early stopping
Successfully saved genes names for training model
Successfully saved dictionary of dataset annotations
Train dataset contains: 15455 cells, it is 90.0 % of input dataset
Test dataset contains: 1718 cells, it is 10.0 % of input dataset
Accelerator: cuda
Start training with following hyperparameters: {'n_d': 12, 'n_a': 88, 'n_steps': 4, 'n_shared': 1, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2378640594404067, 'momentum': 0.011566565604250411, 'optimizer_params': {'lr': 0.07971165451794174}, 'mask_type': 'entmax', 'lambda_sparse': 0.0019751637085813345, 'patience': 15, 'max_epochs': 45, 'virtual_batch_size': 128, 'batch_size': 640, 'device_name': 'cuda'}
epoch 0 | loss: 1.97705 | train_accuracy: 0.54274 | train_balanced_accuracy: 0.32096 | valid_accuracy: 0.54657 | valid_balanced_accuracy: 0.32691 | 0:00:01s
epoch 1 | loss: 0.93955 | train_accuracy: 0.72902 | train_balanced_accuracy: 0.53245 | valid_accuracy: 0.73108 | valid_balanced_accuracy: 0.54764 | 0:00:02s
epoch 2 | loss: 0.6556 | train_accuracy: 0.86069 | train_balanced_accuracy: 0.72309 | valid_accuracy: 0.85739 | valid_balanced_accuracy: 0.70182 | 0:00:03s
epoch 3 | loss: 0.51348 | train_accuracy: 0.88282 | train_balanced_accuracy: 0.68273 | valid_accuracy: 0.88708 | valid_balanced_accuracy: 0.67382 | 0:00:04s
epoch 4 | loss: 0.44289 | train_accuracy: 0.91388 | train_balanced_accuracy: 0.76913 | valid_accuracy: 0.91502 | valid_balanced_accuracy: 0.76522 | 0:00:06s
epoch 5 | loss: 0.39766 | train_accuracy: 0.87784 | train_balanced_accuracy: 0.71463 | valid_accuracy: 0.87602 | valid_balanced_accuracy: 0.72104 | 0:00:07s
epoch 6 | loss: 0.38373 | train_accuracy: 0.9287 | train_balanced_accuracy: 0.79359 | valid_accuracy: 0.92841 | valid_balanced_accuracy: 0.78694 | 0:00:08s
epoch 7 | loss: 0.34919 | train_accuracy: 0.94196 | train_balanced_accuracy: 0.8213 | valid_accuracy: 0.94529 | valid_balanced_accuracy: 0.8094 | 0:00:09s
epoch 8 | loss: 0.34294 | train_accuracy: 0.94746 | train_balanced_accuracy: 0.89272 | valid_accuracy: 0.94005 | valid_balanced_accuracy: 0.82458 | 0:00:11s
epoch 9 | loss: 0.31135 | train_accuracy: 0.95807 | train_balanced_accuracy: 0.9349 | valid_accuracy: 0.9482 | valid_balanced_accuracy: 0.89005 | 0:00:12s
epoch 10 | loss: 0.302 | train_accuracy: 0.95516 | train_balanced_accuracy: 0.91947 | valid_accuracy: 0.94412 | valid_balanced_accuracy: 0.87415 | 0:00:13s
epoch 11 | loss: 0.29094 | train_accuracy: 0.95613 | train_balanced_accuracy: 0.93584 | valid_accuracy: 0.9546 | valid_balanced_accuracy: 0.91292 | 0:00:14s
epoch 12 | loss: 0.28451 | train_accuracy: 0.96694 | train_balanced_accuracy: 0.9558 | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.93755 | 0:00:16s
epoch 13 | loss: 0.2717 | train_accuracy: 0.96881 | train_balanced_accuracy: 0.95787 | valid_accuracy: 0.9546 | valid_balanced_accuracy: 0.91791 | 0:00:17s
epoch 14 | loss: 0.271 | train_accuracy: 0.96894 | train_balanced_accuracy: 0.96448 | valid_accuracy: 0.95343 | valid_balanced_accuracy: 0.95216 | 0:00:18s
epoch 15 | loss: 0.27228 | train_accuracy: 0.96972 | train_balanced_accuracy: 0.97213 | valid_accuracy: 0.961 | valid_balanced_accuracy: 0.96207 | 0:00:19s
epoch 16 | loss: 0.24422 | train_accuracy: 0.96868 | train_balanced_accuracy: 0.97584 | valid_accuracy: 0.95343 | valid_balanced_accuracy: 0.95022 | 0:00:21s
epoch 17 | loss: 0.26597 | train_accuracy: 0.9725 | train_balanced_accuracy: 0.96771 | valid_accuracy: 0.95984 | valid_balanced_accuracy: 0.95716 | 0:00:22s
epoch 18 | loss: 0.25958 | train_accuracy: 0.97457 | train_balanced_accuracy: 0.97918 | valid_accuracy: 0.9546 | valid_balanced_accuracy: 0.95112 | 0:00:23s
epoch 19 | loss: 0.26462 | train_accuracy: 0.97095 | train_balanced_accuracy: 0.97068 | valid_accuracy: 0.94994 | valid_balanced_accuracy: 0.93543 | 0:00:24s
epoch 20 | loss: 0.25516 | train_accuracy: 0.98149 | train_balanced_accuracy: 0.98106 | valid_accuracy: 0.95984 | valid_balanced_accuracy: 0.95633 | 0:00:26s
epoch 21 | loss: 0.24411 | train_accuracy: 0.9758 | train_balanced_accuracy: 0.97992 | valid_accuracy: 0.95751 | valid_balanced_accuracy: 0.96093 | 0:00:27s
epoch 22 | loss: 0.24723 | train_accuracy: 0.98104 | train_balanced_accuracy: 0.97788 | valid_accuracy: 0.96275 | valid_balanced_accuracy: 0.94997 | 0:00:28s
epoch 23 | loss: 0.2381 | train_accuracy: 0.97373 | train_balanced_accuracy: 0.97587 | valid_accuracy: 0.94936 | valid_balanced_accuracy: 0.94231 | 0:00:29s
epoch 24 | loss: 0.23338 | train_accuracy: 0.98337 | train_balanced_accuracy: 0.98278 | valid_accuracy: 0.96508 | valid_balanced_accuracy: 0.95644 | 0:00:30s
epoch 25 | loss: 0.24744 | train_accuracy: 0.98382 | train_balanced_accuracy: 0.98654 | valid_accuracy: 0.96042 | valid_balanced_accuracy: 0.96381 | 0:00:31s
epoch 26 | loss: 0.22563 | train_accuracy: 0.98674 | train_balanced_accuracy: 0.98769 | valid_accuracy: 0.96275 | valid_balanced_accuracy: 0.96354 | 0:00:33s
epoch 27 | loss: 0.22583 | train_accuracy: 0.98266 | train_balanced_accuracy: 0.9851 | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.9587 | 0:00:34s
epoch 28 | loss: 0.22655 | train_accuracy: 0.98589 | train_balanced_accuracy: 0.9853 | valid_accuracy: 0.96158 | valid_balanced_accuracy: 0.95896 | 0:00:35s
epoch 29 | loss: 0.21171 | train_accuracy: 0.98784 | train_balanced_accuracy: 0.99048 | valid_accuracy: 0.95925 | valid_balanced_accuracy: 0.95705 | 0:00:36s
epoch 30 | loss: 0.20867 | train_accuracy: 0.98602 | train_balanced_accuracy: 0.9838 | valid_accuracy: 0.96449 | valid_balanced_accuracy: 0.95301 | 0:00:38s
epoch 31 | loss: 0.22263 | train_accuracy: 0.98958 | train_balanced_accuracy: 0.99077 | valid_accuracy: 0.96333 | valid_balanced_accuracy: 0.96619 | 0:00:39s
epoch 32 | loss: 0.21809 | train_accuracy: 0.99081 | train_balanced_accuracy: 0.99026 | valid_accuracy: 0.96042 | valid_balanced_accuracy: 0.9611 | 0:00:40s
epoch 33 | loss: 0.20374 | train_accuracy: 0.99114 | train_balanced_accuracy: 0.99207 | valid_accuracy: 0.961 | valid_balanced_accuracy: 0.95457 | 0:00:41s
epoch 34 | loss: 0.21158 | train_accuracy: 0.98848 | train_balanced_accuracy: 0.99016 | valid_accuracy: 0.95576 | valid_balanced_accuracy: 0.96417 | 0:00:42s
epoch 35 | loss: 0.21034 | train_accuracy: 0.99159 | train_balanced_accuracy: 0.99133 | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.95491 | 0:00:44s
epoch 36 | loss: 0.20458 | train_accuracy: 0.98589 | train_balanced_accuracy: 0.99006 | valid_accuracy: 0.95518 | valid_balanced_accuracy: 0.96309 | 0:00:45s
epoch 37 | loss: 0.2034 | train_accuracy: 0.99178 | train_balanced_accuracy: 0.99211 | valid_accuracy: 0.95751 | valid_balanced_accuracy: 0.95942 | 0:00:46s
epoch 38 | loss: 0.19649 | train_accuracy: 0.99185 | train_balanced_accuracy: 0.99273 | valid_accuracy: 0.95809 | valid_balanced_accuracy: 0.94929 | 0:00:48s
epoch 39 | loss: 0.19649 | train_accuracy: 0.99159 | train_balanced_accuracy: 0.99244 | valid_accuracy: 0.95402 | valid_balanced_accuracy: 0.94112 | 0:00:49s
epoch 40 | loss: 0.1909 | train_accuracy: 0.99152 | train_balanced_accuracy: 0.99123 | valid_accuracy: 0.95693 | valid_balanced_accuracy: 0.95582 | 0:00:50s
epoch 41 | loss: 0.19694 | train_accuracy: 0.99398 | train_balanced_accuracy: 0.98818 | valid_accuracy: 0.95867 | valid_balanced_accuracy: 0.96127 | 0:00:51s
epoch 42 | loss: 0.19382 | train_accuracy: 0.96635 | train_balanced_accuracy: 0.98592 | valid_accuracy: 0.92841 | valid_balanced_accuracy: 0.94994 | 0:00:52s
epoch 43 | loss: 0.19076 | train_accuracy: 0.99534 | train_balanced_accuracy: 0.99435 | valid_accuracy: 0.95634 | valid_balanced_accuracy: 0.94817 | 0:00:54s
epoch 44 | loss: 0.18032 | train_accuracy: 0.99159 | train_balanced_accuracy: 0.99387 | valid_accuracy: 0.9546 | valid_balanced_accuracy: 0.96242 | 0:00:55s
Stop training because you reached max_epochs = 45 with best_epoch = 31 and best_valid_balanced_accuracy = 0.96619
Successfully saved training history and parameters
Successfully saved model at model_scadam_tuned/model.zip
[24]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
path_model = 'model_scadam_tuned')
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Successfully loaded model
Successfully added predicted celltype_l1 and cell type probabilities
Evaluation of prediction results#
Hyperparameters tuning leads to an increase in the model’s overall accuracy and precision.
Below are examples of improved model performance for certain cell types.
mature alpha-beta T cell(precision - 0.2783 vs 0.0665)regulatory T cell(precision - 0.6378 vs 0.4762)natural killer cell(precision - 0.7616 vs 0.5929)endothelial cell(precision - 0.9934 vs 0.7873)fibroblast(sensitivity - 0.9911 vs 0.8935)epithelial cell(sensitivity - 0.5518 vs 0.4994)mural cell(precision - 0.9779 vs 0.8173)
However, a decrease in prediction accuracy may be observed for some cell types.
B cell(precision - 0.7263 vs 0.8170)endothelial cell of lymphatic vessel(precision - 0.5532 vs 0.8475)
[25]:
## Check model quality
df_hp_tuned = scparadise.scnoah.report_classif_full(adata_test,
celltype='cell_type',
pred_celltype='pred_celltype_l1')
df_hp_tuned
[25]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| B cell | 0.7263 | 0.9863 | 0.9531 | 0.8366 | 0.9696 | 0.9432 | 2050 |
| CD4-positive, alpha-beta T cell | 0.4301 | 0.8616 | 0.8941 | 0.5738 | 0.8777 | 0.7679 | 1554 |
| CD8-positive, alpha-beta T cell | 0.6414 | 0.8852 | 0.9479 | 0.7438 | 0.916 | 0.8338 | 1742 |
| endothelial cell | 0.9934 | 0.9891 | 0.9998 | 0.9912 | 0.9944 | 0.9878 | 457 |
| endothelial cell of lymphatic vessel | 0.5532 | 1.0 | 0.9977 | 0.7123 | 0.9988 | 0.9979 | 52 |
| epithelial cell | 0.9897 | 0.5518 | 0.996 | 0.7085 | 0.7413 | 0.5252 | 7505 |
| fibroblast | 0.9585 | 0.9911 | 0.9984 | 0.9745 | 0.9947 | 0.9888 | 676 |
| mast cell | 0.9793 | 0.9595 | 0.9998 | 0.9693 | 0.9794 | 0.9554 | 148 |
| mature alpha-beta T cell | 0.2783 | 0.8533 | 0.9909 | 0.4197 | 0.9195 | 0.8339 | 75 |
| mural cell | 0.9779 | 0.9108 | 0.9995 | 0.9431 | 0.9541 | 0.9022 | 437 |
| myeloid cell | 0.9283 | 0.9628 | 0.9969 | 0.9452 | 0.9797 | 0.9566 | 726 |
| natural killer cell | 0.7616 | 0.9596 | 0.9963 | 0.8492 | 0.9778 | 0.9526 | 223 |
| plasma cell | 0.9881 | 0.9922 | 0.9986 | 0.9901 | 0.9954 | 0.9902 | 1919 |
| plasmacytoid dendritic cell | 0.9608 | 0.98 | 0.9999 | 0.9703 | 0.9899 | 0.9779 | 50 |
| regulatory T cell | 0.6378 | 0.7187 | 0.9841 | 0.6758 | 0.841 | 0.6885 | 686 |
| macro avg | 0.7870 | 0.9068 | 0.9835 | 0.8202 | 0.942 | 0.8868 | |
| weighted avg | 0.8553 | 0.7749 | 0.9781 | 0.7783 | 0.8627 | 0.7442 | |
| Accuracy | 0.7749 | ||||||
| Balanced accuracy | 0.9068 |
[26]:
# Order cell type colors
celltype = np.unique(adata_test.obs['cell_type']).tolist()
adata_test.obs['cell_type'] = pd.Categorical(
values=adata_test.obs['cell_type'], categories=celltype, ordered=True
)
adata_test.obs['pred_celltype_l1'] = pd.Categorical(
values=adata_test.obs['pred_celltype_l1'], categories=celltype, ordered=True
)
# Add prediction status. Label cells as correct or incorrect based on the comparison between ground truth cell types and predictions.
adata_test = scparadise.scnoah.pred_status(adata_test, celltype='cell_type', pred_celltype='pred_celltype_l1')
[27]:
# Visualise predicted cell types levels, prediction probabilities and prediction status
sc.pl.embedding(adata_test,
color=[
'cell_type',
'pred_celltype_l1',
'prob_celltype_l1',
'pred_status'
],
basis = 'X_umap',
frameon = False,
add_outline = True,
legend_loc = 'right margin',
legend_fontsize = 7,
legend_fontoutline = 1,
ncols=2,
wspace = 0.7,
hspace = 0.1)
[28]:
# Compare prediction results
# 'untuned' row represents untuned model
# 'hp tuned' row represents tuned model
df.compare(df_hp_tuned, keep_equal=True, align_axis = 0, result_names=('untuned', 'hp tuned'))
[28]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | ||
|---|---|---|---|---|---|---|---|
| B cell | untuned | 0.8170 | 0.98 | 0.9723 | 0.8911 | 0.9761 | 0.9536 |
| hp tuned | 0.7263 | 0.9863 | 0.9531 | 0.8366 | 0.9696 | 0.9432 | |
| CD4-positive, alpha-beta T cell | untuned | 0.5112 | 0.8481 | 0.9248 | 0.6379 | 0.8856 | 0.7783 |
| hp tuned | 0.4301 | 0.8616 | 0.8941 | 0.5738 | 0.8777 | 0.7679 | |
| CD8-positive, alpha-beta T cell | untuned | 0.6327 | 0.8829 | 0.9461 | 0.7371 | 0.9139 | 0.83 |
| hp tuned | 0.6414 | 0.8852 | 0.9479 | 0.7438 | 0.916 | 0.8338 | |
| endothelial cell | untuned | 0.7873 | 0.9803 | 0.9932 | 0.8733 | 0.9867 | 0.9724 |
| hp tuned | 0.9934 | 0.9891 | 0.9998 | 0.9912 | 0.9944 | 0.9878 | |
| endothelial cell of lymphatic vessel | untuned | 0.8475 | 0.9615 | 0.9995 | 0.9009 | 0.9803 | 0.9574 |
| hp tuned | 0.5532 | 1.0 | 0.9977 | 0.7123 | 0.9988 | 0.9979 | |
| epithelial cell | untuned | 0.9907 | 0.4994 | 0.9968 | 0.6641 | 0.7055 | 0.473 |
| hp tuned | 0.9897 | 0.5518 | 0.996 | 0.7085 | 0.7413 | 0.5252 | |
| fibroblast | untuned | 0.9379 | 0.8935 | 0.9977 | 0.9152 | 0.9442 | 0.8822 |
| hp tuned | 0.9585 | 0.9911 | 0.9984 | 0.9745 | 0.9947 | 0.9888 | |
| mast cell | untuned | 0.9797 | 0.9797 | 0.9998 | 0.9797 | 0.9897 | 0.9776 |
| hp tuned | 0.9793 | 0.9595 | 0.9998 | 0.9693 | 0.9794 | 0.9554 | |
| mature alpha-beta T cell | untuned | 0.0665 | 0.84 | 0.9514 | 0.1232 | 0.894 | 0.7903 |
| hp tuned | 0.2783 | 0.8533 | 0.9909 | 0.4197 | 0.9195 | 0.8339 | |
| mural cell | untuned | 0.8173 | 0.9519 | 0.9948 | 0.8795 | 0.9731 | 0.9429 |
| hp tuned | 0.9779 | 0.9108 | 0.9995 | 0.9431 | 0.9541 | 0.9022 | |
| myeloid cell | untuned | 0.8784 | 0.9449 | 0.9946 | 0.9104 | 0.9694 | 0.9351 |
| hp tuned | 0.9283 | 0.9628 | 0.9969 | 0.9452 | 0.9797 | 0.9566 | |
| natural killer cell | untuned | 0.5929 | 0.9731 | 0.9918 | 0.7368 | 0.9824 | 0.9633 |
| hp tuned | 0.7616 | 0.9596 | 0.9963 | 0.8492 | 0.9778 | 0.9526 | |
| plasma cell | untuned | 0.9810 | 0.9937 | 0.9977 | 0.9873 | 0.9957 | 0.9911 |
| hp tuned | 0.9881 | 0.9922 | 0.9986 | 0.9901 | 0.9954 | 0.9902 | |
| plasmacytoid dendritic cell | untuned | 0.9231 | 0.96 | 0.9998 | 0.9412 | 0.9797 | 0.956 |
| hp tuned | 0.9608 | 0.98 | 0.9999 | 0.9703 | 0.9899 | 0.9779 | |
| regulatory T cell | untuned | 0.4762 | 0.7143 | 0.9694 | 0.5714 | 0.8321 | 0.6748 |
| hp tuned | 0.6378 | 0.7187 | 0.9841 | 0.6758 | 0.841 | 0.6885 | |
| macro avg | untuned | 0.7493 | 0.8936 | 0.982 | 0.7833 | 0.9339 | 0.8719 |
| hp tuned | 0.7870 | 0.9068 | 0.9835 | 0.8202 | 0.942 | 0.8868 | |
| weighted avg | untuned | 0.8512 | 0.7479 | 0.9818 | 0.7567 | 0.8468 | 0.7198 |
| hp tuned | 0.8553 | 0.7749 | 0.9781 | 0.7783 | 0.8627 | 0.7442 | |
| Accuracy | untuned | 0.7479 | |||||
| hp tuned | 0.7749 | ||||||
| Balanced accuracy | untuned | 0.8936 | |||||
| hp tuned | 0.9068 |
[29]:
# Compare prediction results
# 'warm start' row represents warm start trained model
# 'hp tuned' row represents tuned model
df_warm_start.compare(df_hp_tuned, keep_equal=True, align_axis = 0, result_names=('warm start', 'hp tuned'))
[29]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | ||
|---|---|---|---|---|---|---|---|
| B cell | warm start | 0.9738 | 0.978 | 0.9967 | 0.9759 | 0.9873 | 0.973 |
| hp tuned | 0.7263 | 0.9863 | 0.9531 | 0.8366 | 0.9696 | 0.9432 | |
| CD4-positive, alpha-beta T cell | warm start | 0.8644 | 0.8411 | 0.9878 | 0.8526 | 0.9115 | 0.8186 |
| hp tuned | 0.4301 | 0.8616 | 0.8941 | 0.5738 | 0.8777 | 0.7679 | |
| CD8-positive, alpha-beta T cell | warm start | 0.8985 | 0.9248 | 0.989 | 0.9115 | 0.9564 | 0.9088 |
| hp tuned | 0.6414 | 0.8852 | 0.9479 | 0.7438 | 0.916 | 0.8338 | |
| endothelial cell | warm start | 0.9868 | 0.9803 | 0.9997 | 0.9835 | 0.9899 | 0.9781 |
| hp tuned | 0.9934 | 0.9891 | 0.9998 | 0.9912 | 0.9944 | 0.9878 | |
| endothelial cell of lymphatic vessel | warm start | 0.9592 | 0.9038 | 0.9999 | 0.9307 | 0.9507 | 0.8951 |
| hp tuned | 0.5532 | 1.0 | 0.9977 | 0.7123 | 0.9988 | 0.9979 | |
| epithelial cell | warm start | 0.9879 | 0.9824 | 0.9917 | 0.9852 | 0.987 | 0.9733 |
| hp tuned | 0.9897 | 0.5518 | 0.996 | 0.7085 | 0.7413 | 0.5252 | |
| fibroblast | warm start | 0.9585 | 0.9896 | 0.9984 | 0.9738 | 0.994 | 0.9872 |
| hp tuned | 0.9585 | 0.9911 | 0.9984 | 0.9745 | 0.9947 | 0.9888 | |
| mast cell | warm start | 0.9931 | 0.973 | 0.9999 | 0.9829 | 0.9864 | 0.9703 |
| hp tuned | 0.9793 | 0.9595 | 0.9998 | 0.9693 | 0.9794 | 0.9554 | |
| mature alpha-beta T cell | warm start | 0.6036 | 0.8933 | 0.9976 | 0.7204 | 0.944 | 0.8819 |
| hp tuned | 0.2783 | 0.8533 | 0.9909 | 0.4197 | 0.9195 | 0.8339 | |
| mural cell | warm start | 0.9761 | 0.9336 | 0.9994 | 0.9544 | 0.966 | 0.927 |
| hp tuned | 0.9779 | 0.9108 | 0.9995 | 0.9431 | 0.9541 | 0.9022 | |
| myeloid cell | warm start | 0.9832 | 0.9656 | 0.9993 | 0.9743 | 0.9823 | 0.9616 |
| hp tuned | 0.9283 | 0.9628 | 0.9969 | 0.9452 | 0.9797 | 0.9566 | |
| natural killer cell | warm start | 0.8991 | 0.9193 | 0.9987 | 0.9091 | 0.9582 | 0.9108 |
| hp tuned | 0.7616 | 0.9596 | 0.9963 | 0.8492 | 0.9778 | 0.9526 | |
| plasma cell | warm start | 0.9830 | 0.9917 | 0.998 | 0.9873 | 0.9948 | 0.989 |
| hp tuned | 0.9881 | 0.9922 | 0.9986 | 0.9901 | 0.9954 | 0.9902 | |
| plasmacytoid dendritic cell | warm start | 0.9074 | 0.98 | 0.9997 | 0.9423 | 0.9898 | 0.9778 |
| hp tuned | 0.9608 | 0.98 | 0.9999 | 0.9703 | 0.9899 | 0.9779 | |
| regulatory T cell | warm start | 0.7766 | 0.7551 | 0.9915 | 0.7657 | 0.8653 | 0.731 |
| hp tuned | 0.6378 | 0.7187 | 0.9841 | 0.6758 | 0.841 | 0.6885 | |
| macro avg | warm start | 0.9167 | 0.9341 | 0.9965 | 0.9233 | 0.9642 | 0.9256 |
| hp tuned | 0.7870 | 0.9068 | 0.9835 | 0.8202 | 0.942 | 0.8868 | |
| weighted avg | warm start | 0.9544 | 0.9538 | 0.9935 | 0.9539 | 0.973 | 0.9443 |
| hp tuned | 0.8553 | 0.7749 | 0.9781 | 0.7783 | 0.8627 | 0.7442 | |
| Accuracy | warm start | 0.9538 | |||||
| hp tuned | 0.7749 | ||||||
| Balanced accuracy | warm start | 0.9341 | |||||
| hp tuned | 0.9068 |
Recommendation#
We recommend warm start training instead of hyperparameter tuning if possible. However, in the absence of an additional training dataset, hyperparameter tuning may help improve model performance.
[30]:
import session_info
session_info.show()
[30]:
Click to view session information
----- anndata 0.10.8 matplotlib 3.9.2 numpy 1.25.0 pandas 2.2.3 scanpy 1.10.3 scparadise 0.3.1_beta session_info 1.0.0 -----
Click to view modules imported as dependencies
PIL 10.4.0 anyio NA arrow 1.3.0 asciitree NA asttokens NA attr 24.2.0 attrs 24.2.0 awkward 2.7.1 awkward_cpp NA babel 2.16.0 backports NA certifi 2024.08.30 cffi 1.17.1 charset_normalizer 3.3.2 cloudpickle 3.1.0 colorlog NA comm 0.2.2 cycler 0.12.1 cython_runtime NA dask 2024.8.0 dateutil 2.9.0.post0 debugpy 1.8.6 decorator 5.1.1 defusedxml 0.7.1 exceptiongroup 1.2.2 executing 2.1.0 fastjsonschema NA fqdn NA fsspec 2023.6.0 h5py 3.12.1 idna 3.10 igraph 0.11.6 imblearn 0.12.3 importlib_metadata NA importlib_resources NA ipykernel 6.29.5 isoduration NA jaraco NA jedi 0.19.1 jinja2 3.1.4 joblib 1.4.2 json5 0.9.25 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema_specifications NA jupyter_events 0.10.0 jupyter_server 2.14.2 jupyterlab_server 2.27.3 kiwisolver 1.4.7 legacy_api_wrap NA leidenalg 0.10.2 llvmlite 0.43.0 markupsafe 2.1.5 matplotlib_inline 0.1.7 more_itertools 10.5.0 mpl_toolkits NA mpmath 1.3.0 msgpack 1.1.0 mudata 0.2.4 muon 0.1.6 natsort 8.4.0 nbformat 5.10.4 numba 0.60.0 numcodecs 0.12.1 optuna 4.0.0 overrides NA packaging 24.1 parso 0.8.4 patsy 0.5.6 pexpect 4.9.0 platformdirs 4.3.6 plotly 5.24.1 prometheus_client NA prompt_toolkit 3.0.48 psutil 6.0.0 ptyprocess 0.7.0 pure_eval 0.2.3 pyarrow 18.1.0 pycparser 2.22 pydev_ipython NA pydevconsole NA pydevd 3.1.0 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pydot 3.0.3 pygments 2.18.0 pynndescent 0.5.13 pyparsing 3.1.4 pythonjsonlogger NA pytorch_tabnet NA pytz 2024.2 referencing NA requests 2.32.3 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rich NA rpds NA scipy 1.13.1 seaborn 0.13.2 send2trash NA setuptools 75.1.0 setuptools_scm NA shap 0.46.0 six 1.16.0 sklearn 1.5.2 skmisc 0.3.1 slicer NA sniffio 1.3.1 stack_data 0.6.3 statsmodels 0.14.3 sympy 1.13.3 tblib 3.0.0 texttable 1.7.0 threadpoolctl 3.5.0 tlz 0.12.1 tomli 2.0.1 toolz 0.12.1 torch 2.4.1+cu121 torchgen NA tornado 6.4.1 tqdm 4.66.5 traitlets 5.14.3 triton 3.0.0 typing_extensions NA umap 0.5.6 uri_template NA urllib3 1.26.20 wcwidth 0.2.13 webcolors 24.8.0 websocket 1.8.0 yaml 6.0.2 zarr 2.18.2 zipp NA zmq 26.2.0 zoneinfo NA
----- IPython 8.18.1 jupyter_client 8.6.3 jupyter_core 5.7.2 jupyterlab 4.2.5 ----- Python 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0] Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35 ----- Session information updated at 2025-01-15 23:33