scAdam model train#

[1]:
# Python packages
import warnings
warnings.simplefilter('ignore')

import scanpy as sc
import scparadise
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

sc.set_figure_params(dpi = 120)

Recommendations about training dataset#

We recommend shifted logarithm data normalization method: sc.pp.normalize_total(adata, target_sum=None) sc.pp.log1p(adata) But you can use any other method of data normalzation (Use the same normalization method for test dataset) Training dataset sould contain all genes that you want to use for model training in adata_train.X We recommend to remove all non-marker genes from adata_train.X (removeing of such useless genes increase performance and model quality metrics)

Data preparation#

Here we download and preprocess dataset from cellxgene:

Single cell RNA sequencing of oropharyngeal squamous cell carcinoma https://cellxgene.cziscience.com/collections/3c34e6f1-6827-47dd-8e19-9edcd461893f

[2]:
!wget https://datasets.cellxgene.cziscience.com/915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad
--2025-01-16 12:42:52--  https://datasets.cellxgene.cziscience.com/915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad
Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 52.85.49.125, 52.85.49.28, 52.85.49.24, ...
Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|52.85.49.125|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 981941987 (936M) [binary/octet-stream]
Saving to: ‘915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad’

915069db-1df2-49a1- 100%[===================>] 936.45M  4.91MB/s    in 2m 16s

2025-01-16 12:45:08 (6.89 MB/s) - ‘915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad’ saved [981941987/981941987]

[3]:
# Load prepared for training anndata object
adata = sc.read_h5ad('915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad')
[4]:
# Get raw counts from adata.raw
adata = adata.raw.to_adata()
[5]:
# Convert var_names from ENSG codes to gene names
adata.var.set_index('feature_name', inplace=True)
adata.var_names_make_unique()
[6]:
# Create celltype_l1 and celltype_l2 annotation levels
adata.obs['celltype_l2'] = adata.obs['cell_type'].copy()
# Check current cluster name
cluster_list = adata.obs['celltype_l2'].unique()
# Make cluster anottation dictionary
annotation = {"epithelial":['epithelial cell'],
              "B":['B cell', 'plasma cell'],
              "T":['CD4-positive, alpha-beta T cell', 'CD8-positive, alpha-beta T cell', 'mature alpha-beta T cell', 'regulatory T cell'],
              "NK":['natural killer cell'],
              "myeloid": ['myeloid cell', 'plasmacytoid dendritic cell', 'mast cell'],
              "endothelial":['endothelial cell', 'endothelial cell of lymphatic vessel'],
              "stromal": ['fibroblast', 'mural cell']}

# Change dictionary format
annotation_rev = {}
for i in cluster_list:
    for k in annotation:
        if i in annotation[k]:
            annotation_rev[i] = k

adata.obs["celltype_l1"] = [annotation_rev[i] for i in adata.obs['celltype_l2']]
[7]:
# Check annotations
sc.pl.embedding(adata,
                basis = 'X_umap',
                color = ['celltype_l1',
                         'celltype_l2'],
                frameon = False,
                ncols = 1)
../../../_images/tutorials_notebooks_scAdam_scAdam_train_9_0.png
[8]:
# Create adata_train1, adata_train2 and adata_test datasets
adata_test = adata[adata.obs['donor_id'].isin(['HN481', 'HN482'])].copy()
adata_train = adata[adata.obs['donor_id'].isin(['HN483' 'HN485', 'HN488', 'HN489', 'HN490', 'HN487', 'HN492', 'HN494'])].copy()
del adata
[9]:
# Normalize data, find highly variable features
for i in [adata_train, adata_test]:
    i.layers['counts'] = i.X.copy()
    sc.pp.normalize_total(i, target_sum=None)
    sc.pp.log1p(i)
    i.raw = i
    sc.pp.highly_variable_genes(i,
                                layer='counts',
                                flavor='seurat_v3',
                                n_top_genes=1200,
                                subset=True)

Balanced training#

We recommend balance dataset based on most detailed annotation level (celltype_l2 in this case). Balancing the training dataset increases the sensitivity (balanced accuracy), f1-score and geometric mean of the model but leads to slightly decrease in precision

[10]:
# Balance dataset based on most detailed annotation level
adata_balanced = scparadise.scnoah.balance(adata_train,
                                           sample = 'donor_id',
                                           celltype_l1 = 'celltype_l1',
                                           celltype_l2 = 'celltype_l2')
Successfully undersampled cell types: epithelial cell, B cell, myeloid cell, CD4-positive, alpha-beta T cell, CD8-positive, alpha-beta T cell, plasma cell

Successfully oversampled cell types: regulatory T cell, fibroblast, mural cell, endothelial cell, natural killer cell, plasmacytoid dendritic cell, mast cell, mature alpha-beta T cell, endothelial cell of lymphatic vessel
[11]:
# Train scadam model using adata_balanced dataset
scparadise.scadam.train(adata_balanced,
                        path = '', # path to save model
                        model_name = 'model_scadam_balanced', # folder name with model
                        celltype_l1 = 'celltype_l1', # First (less detailed) annotation level
                        celltype_l2 = 'celltype_l2', # Second (most detailed) annotation level
                        eval_metric = ['balanced_accuracy','accuracy'])
Successfully saved genes names for training model

Successfully saved dictionary of dataset annotations

Train dataset contains: 44415 cells, it is 90.0 % of input dataset
Test dataset contains: 4935 cells, it is 10.0 % of input dataset

Accelerator: cuda
Start training
epoch 0  | loss: 2.21299 | train_balanced_accuracy: 0.49625 | train_accuracy: 0.55639 | valid_balanced_accuracy: 0.48656 | valid_accuracy: 0.5461  |  0:00:02s
epoch 1  | loss: 1.14552 | train_balanced_accuracy: 0.72033 | train_accuracy: 0.74447 | valid_balanced_accuracy: 0.70999 | valid_accuracy: 0.73587 |  0:00:05s
epoch 2  | loss: 0.74196 | train_balanced_accuracy: 0.84973 | train_accuracy: 0.8624  | valid_balanced_accuracy: 0.83588 | valid_accuracy: 0.84985 |  0:00:08s
epoch 3  | loss: 0.55414 | train_balanced_accuracy: 0.90071 | train_accuracy: 0.91104 | valid_balanced_accuracy: 0.89921 | valid_accuracy: 0.90942 |  0:00:11s
epoch 4  | loss: 0.46574 | train_balanced_accuracy: 0.92385 | train_accuracy: 0.93116 | valid_balanced_accuracy: 0.91397 | valid_accuracy: 0.92158 |  0:00:14s
epoch 5  | loss: 0.42424 | train_balanced_accuracy: 0.93255 | train_accuracy: 0.94035 | valid_balanced_accuracy: 0.92515 | valid_accuracy: 0.93333 |  0:00:17s
epoch 6  | loss: 0.38279 | train_balanced_accuracy: 0.95089 | train_accuracy: 0.95306 | valid_balanced_accuracy: 0.94334 | valid_accuracy: 0.94498 |  0:00:19s
epoch 7  | loss: 0.365   | train_balanced_accuracy: 0.94994 | train_accuracy: 0.95516 | valid_balanced_accuracy: 0.94515 | valid_accuracy: 0.95086 |  0:00:22s
epoch 8  | loss: 0.33252 | train_balanced_accuracy: 0.95308 | train_accuracy: 0.95708 | valid_balanced_accuracy: 0.94601 | valid_accuracy: 0.95035 |  0:00:25s
epoch 9  | loss: 0.31807 | train_balanced_accuracy: 0.96494 | train_accuracy: 0.96749 | valid_balanced_accuracy: 0.9574  | valid_accuracy: 0.96039 |  0:00:28s
epoch 10 | loss: 0.30404 | train_balanced_accuracy: 0.97066 | train_accuracy: 0.97151 | valid_balanced_accuracy: 0.96183 | valid_accuracy: 0.96353 |  0:00:31s
epoch 11 | loss: 0.29295 | train_balanced_accuracy: 0.96804 | train_accuracy: 0.97109 | valid_balanced_accuracy: 0.96141 | valid_accuracy: 0.96464 |  0:00:34s
epoch 12 | loss: 0.29203 | train_balanced_accuracy: 0.97248 | train_accuracy: 0.97422 | valid_balanced_accuracy: 0.96377 | valid_accuracy: 0.96616 |  0:00:37s
epoch 13 | loss: 0.27764 | train_balanced_accuracy: 0.97068 | train_accuracy: 0.97313 | valid_balanced_accuracy: 0.96204 | valid_accuracy: 0.96525 |  0:00:40s
epoch 14 | loss: 0.26716 | train_balanced_accuracy: 0.97676 | train_accuracy: 0.97791 | valid_balanced_accuracy: 0.96719 | valid_accuracy: 0.969   |  0:00:43s
epoch 15 | loss: 0.2658  | train_balanced_accuracy: 0.97736 | train_accuracy: 0.9776  | valid_balanced_accuracy: 0.96948 | valid_accuracy: 0.96971 |  0:00:45s
epoch 16 | loss: 0.26173 | train_balanced_accuracy: 0.97664 | train_accuracy: 0.9786  | valid_balanced_accuracy: 0.96713 | valid_accuracy: 0.9696  |  0:00:48s
epoch 17 | loss: 0.25269 | train_balanced_accuracy: 0.97566 | train_accuracy: 0.97638 | valid_balanced_accuracy: 0.96591 | valid_accuracy: 0.96717 |  0:00:51s
epoch 18 | loss: 0.25638 | train_balanced_accuracy: 0.98054 | train_accuracy: 0.98126 | valid_balanced_accuracy: 0.96883 | valid_accuracy: 0.97052 |  0:00:54s
epoch 19 | loss: 0.24683 | train_balanced_accuracy: 0.98036 | train_accuracy: 0.98016 | valid_balanced_accuracy: 0.97149 | valid_accuracy: 0.97153 |  0:00:56s
epoch 20 | loss: 0.24669 | train_balanced_accuracy: 0.98134 | train_accuracy: 0.98176 | valid_balanced_accuracy: 0.97291 | valid_accuracy: 0.97356 |  0:00:59s
epoch 21 | loss: 0.2413  | train_balanced_accuracy: 0.97987 | train_accuracy: 0.98138 | valid_balanced_accuracy: 0.96868 | valid_accuracy: 0.97143 |  0:01:02s
epoch 22 | loss: 0.23411 | train_balanced_accuracy: 0.9823  | train_accuracy: 0.98285 | valid_balanced_accuracy: 0.96949 | valid_accuracy: 0.97143 |  0:01:05s
epoch 23 | loss: 0.23256 | train_balanced_accuracy: 0.98111 | train_accuracy: 0.98152 | valid_balanced_accuracy: 0.96911 | valid_accuracy: 0.97031 |  0:01:08s
epoch 24 | loss: 0.23049 | train_balanced_accuracy: 0.97992 | train_accuracy: 0.98013 | valid_balanced_accuracy: 0.96913 | valid_accuracy: 0.97011 |  0:01:11s
epoch 25 | loss: 0.22792 | train_balanced_accuracy: 0.98018 | train_accuracy: 0.98138 | valid_balanced_accuracy: 0.96714 | valid_accuracy: 0.97042 |  0:01:14s
epoch 26 | loss: 0.22866 | train_balanced_accuracy: 0.98175 | train_accuracy: 0.98292 | valid_balanced_accuracy: 0.96839 | valid_accuracy: 0.97123 |  0:01:16s
epoch 27 | loss: 0.22948 | train_balanced_accuracy: 0.98347 | train_accuracy: 0.98354 | valid_balanced_accuracy: 0.97165 | valid_accuracy: 0.97214 |  0:01:19s
epoch 28 | loss: 0.23231 | train_balanced_accuracy: 0.98406 | train_accuracy: 0.98454 | valid_balanced_accuracy: 0.97292 | valid_accuracy: 0.97457 |  0:01:22s
epoch 29 | loss: 0.2214  | train_balanced_accuracy: 0.98382 | train_accuracy: 0.98372 | valid_balanced_accuracy: 0.97213 | valid_accuracy: 0.97244 |  0:01:25s
epoch 30 | loss: 0.21979 | train_balanced_accuracy: 0.98496 | train_accuracy: 0.9847  | valid_balanced_accuracy: 0.97303 | valid_accuracy: 0.97345 |  0:01:27s
epoch 31 | loss: 0.21397 | train_balanced_accuracy: 0.98468 | train_accuracy: 0.98523 | valid_balanced_accuracy: 0.97364 | valid_accuracy: 0.97538 |  0:01:30s
epoch 32 | loss: 0.2146  | train_balanced_accuracy: 0.98435 | train_accuracy: 0.98476 | valid_balanced_accuracy: 0.97197 | valid_accuracy: 0.97386 |  0:01:33s
epoch 33 | loss: 0.21473 | train_balanced_accuracy: 0.98328 | train_accuracy: 0.9838  | valid_balanced_accuracy: 0.96836 | valid_accuracy: 0.97042 |  0:01:36s
epoch 34 | loss: 0.21234 | train_balanced_accuracy: 0.98507 | train_accuracy: 0.98472 | valid_balanced_accuracy: 0.97059 | valid_accuracy: 0.97102 |  0:01:39s
epoch 35 | loss: 0.21444 | train_balanced_accuracy: 0.98667 | train_accuracy: 0.98633 | valid_balanced_accuracy: 0.97435 | valid_accuracy: 0.97437 |  0:01:42s
epoch 36 | loss: 0.21378 | train_balanced_accuracy: 0.98626 | train_accuracy: 0.98602 | valid_balanced_accuracy: 0.97264 | valid_accuracy: 0.97335 |  0:01:44s
epoch 37 | loss: 0.21579 | train_balanced_accuracy: 0.98637 | train_accuracy: 0.98651 | valid_balanced_accuracy: 0.97226 | valid_accuracy: 0.97366 |  0:01:47s
epoch 38 | loss: 0.21361 | train_balanced_accuracy: 0.98471 | train_accuracy: 0.98512 | valid_balanced_accuracy: 0.97023 | valid_accuracy: 0.97275 |  0:01:50s
epoch 39 | loss: 0.20431 | train_balanced_accuracy: 0.98699 | train_accuracy: 0.98694 | valid_balanced_accuracy: 0.9727  | valid_accuracy: 0.97406 |  0:01:53s
epoch 40 | loss: 0.20723 | train_balanced_accuracy: 0.98677 | train_accuracy: 0.98683 | valid_balanced_accuracy: 0.97398 | valid_accuracy: 0.97487 |  0:01:55s
epoch 41 | loss: 0.20137 | train_balanced_accuracy: 0.98697 | train_accuracy: 0.98698 | valid_balanced_accuracy: 0.97357 | valid_accuracy: 0.97447 |  0:01:58s

Early stopping occurred at epoch 41 with best_epoch = 31 and best_valid_accuracy = 0.97538

Successfully saved training history and parameters
Successfully saved model at model_scadam_balanced/model.zip

Check model quality#

[13]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
                                       path_model = 'model_scadam_balanced')
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
[14]:
# Add prediction status to adata_test
scparadise.scnoah.pred_status(adata_test,
                              celltype = 'celltype_l1',
                              pred_celltype = 'pred_celltype_l1',
                              key_added = 'pred_status_l1')
scparadise.scnoah.pred_status(adata_test,
                              celltype = 'celltype_l2',
                              pred_celltype = 'pred_celltype_l2',
                              key_added = 'pred_status_l2')
[15]:
# Order cell type colors
celltype_list = ['celltype_l1','celltype_l2']
for i in celltype_list:
    celltype = np.unique(adata_test.obs[i]).tolist()
    adata_test.obs[i] = pd.Categorical(
        values=adata_test.obs[i], categories=celltype, ordered=True
    )
    adata_test.obs['pred_' + i] = pd.Categorical(
        values=adata_test.obs['pred_' + i], categories=celltype, ordered=True
    )
  1. The left column represents observed cell type annotations.

  2. The central column represents predicted cell type annotations.

  3. The right column represents prediction probabilities.

[16]:
# Visualise predicted cell types levels and prediction probabilities
sc.pl.embedding(adata_test,
                color=[
                    'celltype_l1',
                    'pred_celltype_l1',
                    'prob_celltype_l1',
                    'celltype_l2',
                    'pred_celltype_l2',
                    'prob_celltype_l2'
                ],
                basis = 'X_umap',
                frameon = False,
                cmap = 'viridis',
                legend_loc = 'on data',
                legend_fontsize = 7,
                legend_fontoutline = 1,
                ncols = 3,
                wspace = 0.1,
                hspace = 0.1)
../../../_images/tutorials_notebooks_scAdam_scAdam_train_20_0.png
[17]:
# Visualise prediction status
sc.pl.embedding(adata_test,
                color=[
                    'pred_status_l1',
                    'pred_status_l2'
                ],
                basis = 'X_umap',
                frameon = False,
                cmap = 'viridis',
                legend_loc = 'right margin',
                legend_fontsize = 7,
                legend_fontoutline = 1,
                ncols = 3,
                wspace = 0.1,
                hspace = 0.1)
../../../_images/tutorials_notebooks_scAdam_scAdam_train_21_0.png

The probability and prediction status analysis indicates issues with the model in annotating T cell subtypes at the celltype_l2 annotation level. However, the model does not have problems with annotating other cell types, as confirmed by the quality analysis presented below.

Comparizon of observed and predicted cell type annotations#

To compare the actual and predicted cell types, we use report_classif_full and conf_matrix from the scNoah module. More information about the metrics in scparadise.scnoah.report_classif_full is available in the scParadise documentation. More information about confusion matrix (scparadise.scnoah.conf_matrix) is available here.

[18]:
# First annotation level (celltype_l1)
df_l1 = scparadise.scnoah.report_classif_full(adata_test,
                                              celltype = 'celltype_l1',
                                              pred_celltype = 'pred_celltype_l1',
                                              ndigits = 3)
df_l1
[18]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
B 0.982 0.984 0.995 0.983 0.989 0.978 3969
NK 0.835 0.951 0.998 0.889 0.974 0.944 223
T 0.936 0.972 0.981 0.954 0.977 0.953 4057
endothelial 0.994 0.992 1.0 0.993 0.996 0.991 509
epithelial 0.989 0.951 0.992 0.97 0.972 0.94 7505
myeloid 0.915 0.976 0.995 0.945 0.986 0.97 924
stromal 0.971 0.997 0.998 0.984 0.998 0.995 1113
macro avg 0.946 0.975 0.994 0.96 0.984 0.967
weighted avg 0.969 0.968 0.991 0.968 0.98 0.958
Accuracy 0.968
Balanced accuracy 0.975
[19]:
sns.set(font_scale = 0.7)
plt.figure(figsize = (3, 3))
scparadise.scnoah.conf_matrix(adata_test,
                              celltype = 'celltype_l1',
                              pred_celltype = 'pred_celltype_l1',
                              annot_kws = {"size":5},
                              linewidths = 0.1, linecolor = 'black',
                              fmt =  ".2f",
                              ndigits_metrics = 4,
                              vmin = 0, vmax = 1)
../../../_images/tutorials_notebooks_scAdam_scAdam_train_25_0.png
[20]:
# Second annotation level (celltype_l2)
df_l2 = scparadise.scnoah.report_classif_full(adata_test,
                                              celltype = 'celltype_l2',
                                              pred_celltype = 'pred_celltype_l2',
                                              ndigits = 3)
df_l2
[20]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
B cell 0.965 0.969 0.996 0.967 0.982 0.962 2050
CD4-positive, alpha-beta T cell 0.892 0.719 0.992 0.796 0.845 0.694 1554
CD8-positive, alpha-beta T cell 0.886 0.943 0.987 0.914 0.965 0.927 1742
endothelial cell 0.989 0.991 1.0 0.99 0.995 0.99 457
endothelial cell of lymphatic vessel 1.000 0.981 1.0 0.99 0.99 0.979 52
epithelial cell 0.989 0.954 0.993 0.971 0.973 0.944 7505
fibroblast 0.928 0.979 0.997 0.953 0.988 0.975 676
mast cell 0.986 0.98 1.0 0.983 0.99 0.978 148
mature alpha-beta T cell 0.629 0.88 0.998 0.733 0.937 0.868 75
mural cell 0.968 0.973 0.999 0.97 0.986 0.969 437
myeloid cell 0.890 0.971 0.995 0.929 0.983 0.964 726
natural killer cell 0.839 0.955 0.998 0.893 0.976 0.949 223
plasma cell 0.983 0.988 0.998 0.985 0.993 0.985 1919
plasmacytoid dendritic cell 0.875 0.98 1.0 0.925 0.99 0.978 50
regulatory T cell 0.596 0.824 0.978 0.692 0.898 0.793 686
macro avg 0.894 0.939 0.995 0.913 0.966 0.93
weighted avg 0.943 0.937 0.993 0.938 0.964 0.926
Accuracy 0.937
Balanced accuracy 0.939
[21]:
sns.set(font_scale = 0.7)
plt.figure(figsize = (5, 3))
scparadise.scnoah.conf_matrix(adata_test,
                              celltype = 'celltype_l2',
                              pred_celltype = 'pred_celltype_l2',
                              annot_kws = {"size":5},
                              linewidths = 0.1, linecolor = 'black',
                              fmt =  ".2f",
                              ndigits_metrics = 4,
                              vmin = 0, vmax = 1)
../../../_images/tutorials_notebooks_scAdam_scAdam_train_27_0.png
[22]:
# Save anndata with predicted annotations
adata_test.write_h5ad('adata_test_balanced_model.h5ad')

Imbalanced training#

Use ‘balanced_accuracy’ as evaluation metric in case of imbalanced learning

[23]:
# Train scadam model using adata_balanced dataset
scparadise.scadam.train(adata_train,
                        path = '', # path to save model
                        model_name = 'model_scadam_imbalanced', # folder name with model
                        celltype_l1 = 'celltype_l1', # First (less detailed) annotation level
                        celltype_l2 = 'celltype_l2', # Second (most detailed) annotation level
                        eval_metric = ['accuracy','balanced_accuracy']) # If you are using an imbalanced training dataset, we recommend using balance_accuracy for early stopping
Successfully saved genes names for training model

Successfully saved dictionary of dataset annotations

Train dataset contains: 44414 cells, it is 90.0 % of input dataset
Test dataset contains: 4935 cells, it is 10.0 % of input dataset

Accelerator: cuda
Start training
epoch 0  | loss: 1.89983 | train_accuracy: 0.6205  | train_balanced_accuracy: 0.34227 | valid_accuracy: 0.6157  | valid_balanced_accuracy: 0.33701 |  0:00:02s
epoch 1  | loss: 0.93062 | train_accuracy: 0.76397 | train_balanced_accuracy: 0.5379  | valid_accuracy: 0.7613  | valid_balanced_accuracy: 0.54393 |  0:00:05s
epoch 2  | loss: 0.63447 | train_accuracy: 0.89784 | train_balanced_accuracy: 0.6967  | valid_accuracy: 0.89656 | valid_balanced_accuracy: 0.6991  |  0:00:08s
epoch 3  | loss: 0.45799 | train_accuracy: 0.92658 | train_balanced_accuracy: 0.79117 | valid_accuracy: 0.92533 | valid_balanced_accuracy: 0.7936  |  0:00:11s
epoch 4  | loss: 0.38553 | train_accuracy: 0.93557 | train_balanced_accuracy: 0.82882 | valid_accuracy: 0.93222 | valid_balanced_accuracy: 0.82831 |  0:00:13s
epoch 5  | loss: 0.34978 | train_accuracy: 0.94098 | train_balanced_accuracy: 0.83302 | valid_accuracy: 0.94032 | valid_balanced_accuracy: 0.8334  |  0:00:16s
epoch 6  | loss: 0.32005 | train_accuracy: 0.94931 | train_balanced_accuracy: 0.8714  | valid_accuracy: 0.94438 | valid_balanced_accuracy: 0.8629  |  0:00:19s
epoch 7  | loss: 0.30517 | train_accuracy: 0.96018 | train_balanced_accuracy: 0.90504 | valid_accuracy: 0.95623 | valid_balanced_accuracy: 0.8828  |  0:00:22s
epoch 8  | loss: 0.29713 | train_accuracy: 0.95839 | train_balanced_accuracy: 0.91825 | valid_accuracy: 0.95319 | valid_balanced_accuracy: 0.89912 |  0:00:25s
epoch 9  | loss: 0.27967 | train_accuracy: 0.96636 | train_balanced_accuracy: 0.92368 | valid_accuracy: 0.9613  | valid_balanced_accuracy: 0.91722 |  0:00:27s
epoch 10 | loss: 0.26854 | train_accuracy: 0.96917 | train_balanced_accuracy: 0.95137 | valid_accuracy: 0.9618  | valid_balanced_accuracy: 0.94174 |  0:00:30s
epoch 11 | loss: 0.2596  | train_accuracy: 0.9695  | train_balanced_accuracy: 0.96104 | valid_accuracy: 0.9614  | valid_balanced_accuracy: 0.94668 |  0:00:33s
epoch 12 | loss: 0.2611  | train_accuracy: 0.96822 | train_balanced_accuracy: 0.95705 | valid_accuracy: 0.9614  | valid_balanced_accuracy: 0.94765 |  0:00:36s
epoch 13 | loss: 0.24682 | train_accuracy: 0.97206 | train_balanced_accuracy: 0.95925 | valid_accuracy: 0.96596 | valid_balanced_accuracy: 0.94677 |  0:00:39s
epoch 14 | loss: 0.24491 | train_accuracy: 0.97435 | train_balanced_accuracy: 0.96788 | valid_accuracy: 0.96535 | valid_balanced_accuracy: 0.95488 |  0:00:42s
epoch 15 | loss: 0.23757 | train_accuracy: 0.97454 | train_balanced_accuracy: 0.96045 | valid_accuracy: 0.96596 | valid_balanced_accuracy: 0.94597 |  0:00:45s
epoch 16 | loss: 0.23232 | train_accuracy: 0.97702 | train_balanced_accuracy: 0.96977 | valid_accuracy: 0.96657 | valid_balanced_accuracy: 0.9553  |  0:00:48s
epoch 17 | loss: 0.22917 | train_accuracy: 0.97673 | train_balanced_accuracy: 0.97094 | valid_accuracy: 0.96636 | valid_balanced_accuracy: 0.95947 |  0:00:51s
epoch 18 | loss: 0.22392 | train_accuracy: 0.97735 | train_balanced_accuracy: 0.97009 | valid_accuracy: 0.96575 | valid_balanced_accuracy: 0.95331 |  0:00:53s
epoch 19 | loss: 0.21909 | train_accuracy: 0.9789  | train_balanced_accuracy: 0.97495 | valid_accuracy: 0.96738 | valid_balanced_accuracy: 0.95752 |  0:00:56s
epoch 20 | loss: 0.22133 | train_accuracy: 0.97781 | train_balanced_accuracy: 0.9717  | valid_accuracy: 0.96525 | valid_balanced_accuracy: 0.95094 |  0:00:59s
epoch 21 | loss: 0.21602 | train_accuracy: 0.9779  | train_balanced_accuracy: 0.97577 | valid_accuracy: 0.96586 | valid_balanced_accuracy: 0.95856 |  0:01:01s
epoch 22 | loss: 0.20957 | train_accuracy: 0.97917 | train_balanced_accuracy: 0.97793 | valid_accuracy: 0.96748 | valid_balanced_accuracy: 0.96025 |  0:01:04s
epoch 23 | loss: 0.21309 | train_accuracy: 0.9806  | train_balanced_accuracy: 0.97859 | valid_accuracy: 0.96636 | valid_balanced_accuracy: 0.95804 |  0:01:07s
epoch 24 | loss: 0.2073  | train_accuracy: 0.97963 | train_balanced_accuracy: 0.97953 | valid_accuracy: 0.96768 | valid_balanced_accuracy: 0.96242 |  0:01:10s
epoch 25 | loss: 0.20919 | train_accuracy: 0.98106 | train_balanced_accuracy: 0.9762  | valid_accuracy: 0.96839 | valid_balanced_accuracy: 0.95242 |  0:01:13s
epoch 26 | loss: 0.20778 | train_accuracy: 0.97898 | train_balanced_accuracy: 0.9773  | valid_accuracy: 0.96575 | valid_balanced_accuracy: 0.95438 |  0:01:16s
epoch 27 | loss: 0.20592 | train_accuracy: 0.98023 | train_balanced_accuracy: 0.97647 | valid_accuracy: 0.96798 | valid_balanced_accuracy: 0.95622 |  0:01:19s
epoch 28 | loss: 0.20345 | train_accuracy: 0.98064 | train_balanced_accuracy: 0.97579 | valid_accuracy: 0.96758 | valid_balanced_accuracy: 0.95307 |  0:01:21s
epoch 29 | loss: 0.20338 | train_accuracy: 0.98195 | train_balanced_accuracy: 0.97994 | valid_accuracy: 0.9693  | valid_balanced_accuracy: 0.96008 |  0:01:24s
epoch 30 | loss: 0.19875 | train_accuracy: 0.98133 | train_balanced_accuracy: 0.97684 | valid_accuracy: 0.96727 | valid_balanced_accuracy: 0.95396 |  0:01:27s
epoch 31 | loss: 0.19886 | train_accuracy: 0.98129 | train_balanced_accuracy: 0.98089 | valid_accuracy: 0.96768 | valid_balanced_accuracy: 0.9582  |  0:01:29s
epoch 32 | loss: 0.19664 | train_accuracy: 0.98207 | train_balanced_accuracy: 0.98338 | valid_accuracy: 0.96778 | valid_balanced_accuracy: 0.96199 |  0:01:32s
epoch 33 | loss: 0.19808 | train_accuracy: 0.98324 | train_balanced_accuracy: 0.98151 | valid_accuracy: 0.96768 | valid_balanced_accuracy: 0.95489 |  0:01:35s
epoch 34 | loss: 0.19552 | train_accuracy: 0.98102 | train_balanced_accuracy: 0.98076 | valid_accuracy: 0.96555 | valid_balanced_accuracy: 0.95149 |  0:01:38s

Early stopping occurred at epoch 34 with best_epoch = 24 and best_valid_balanced_accuracy = 0.96242

Successfully saved training history and parameters
Successfully saved model at model_scadam_imbalanced/model.zip

Check model quality#

[24]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
                                       path_model = 'model_scadam_imbalanced')
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
[25]:
# Add prediction status to adata_test
scparadise.scnoah.pred_status(adata_test,
                              celltype = 'celltype_l1',
                              pred_celltype = 'pred_celltype_l1',
                              key_added = 'pred_status_l1')
scparadise.scnoah.pred_status(adata_test,
                              celltype = 'celltype_l2',
                              pred_celltype = 'pred_celltype_l2',
                              key_added = 'pred_status_l2')
[26]:
# Order cell type colors
celltype_list = ['celltype_l1','celltype_l2']
for i in celltype_list:
    celltype = np.unique(adata_test.obs[i]).tolist()
    adata_test.obs[i] = pd.Categorical(
        values=adata_test.obs[i], categories=celltype, ordered=True
    )
    adata_test.obs['pred_' + i] = pd.Categorical(
        values=adata_test.obs['pred_' + i], categories=celltype, ordered=True
    )
[27]:
# Visualise predicted cell types levels and prediction probabilities
sc.pl.embedding(adata_test,
                color=[
                    'celltype_l1',
                    'pred_celltype_l1',
                    'prob_celltype_l1',
                    'pred_status_l1',
                    'celltype_l2',
                    'pred_celltype_l2',
                    'prob_celltype_l2',
                    'pred_status_l2'
                ],
                basis = 'X_umap',
                frameon = False,
                cmap = 'viridis',
                legend_loc = 'on data',
                legend_fontsize = 7,
                legend_fontoutline = 1,
                ncols = 4,
                wspace = 0,
                hspace = 0.1)
../../../_images/tutorials_notebooks_scAdam_scAdam_train_35_0.png
[28]:
# Visualise prediction status
sc.pl.embedding(adata_test,
                color=[
                    'pred_status_l1',
                    'pred_status_l2'
                ],
                basis = 'X_umap',
                frameon = False,
                cmap = 'viridis',
                legend_loc = 'right margin',
                legend_fontsize = 7,
                legend_fontoutline = 1,
                ncols = 3,
                wspace = 0.1,
                hspace = 0.1)
../../../_images/tutorials_notebooks_scAdam_scAdam_train_36_0.png

The probability and prediction status analysis indicates issues with the model in annotating T cell subtypes at the celltype_l2 annotation level. However, the model does not have problems with annotating other cell types, as confirmed by the quality analysis presented below.

Comparizon of observed and predicted cell type annotations#

To compare the actual and predicted cell types, we use report_classif_full and conf_matrix from the scNoah module. More information about the metrics in scparadise.scnoah.report_classif_full is available in the scParadise documentation. More information about confusion matrix (scparadise.scnoah.conf_matrix) is available here.

[29]:
# First annotation level (celltype_l1)
df_l1 = scparadise.scnoah.report_classif_full(adata_test,
                                              celltype='celltype_l1',
                                              pred_celltype='pred_celltype_l1')
df_l1
[29]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
B 0.9881 0.9829 0.9967 0.9855 0.9898 0.9783 3969
NK 0.7220 0.9552 0.9955 0.8224 0.9751 0.947 223
T 0.9732 0.9675 0.9924 0.9703 0.9799 0.9577 4057
endothelial 0.9980 0.9784 0.9999 0.9881 0.9891 0.9762 509
epithelial 0.9888 0.9859 0.9922 0.9873 0.989 0.9776 7505
myeloid 0.9825 0.9697 0.9991 0.976 0.9843 0.966 924
stromal 0.9832 0.9982 0.9989 0.9906 0.9985 0.997 1113
macro avg 0.9480 0.9768 0.9964 0.96 0.9865 0.9714
weighted avg 0.9815 0.9805 0.9942 0.9808 0.9873 0.9735
Accuracy 0.9805
Balanced accuracy 0.9768
[30]:
sns.set(font_scale = 0.7)
plt.figure(figsize = (3, 3))
scparadise.scnoah.conf_matrix(adata_test,
                              celltype = 'celltype_l1',
                              pred_celltype = 'pred_celltype_l1',
                              annot_kws = {"size":5},
                              linewidths = 0.1, linecolor = 'black',
                              fmt =  ".2f",
                              ndigits_metrics = 4,
                              vmin = 0, vmax = 1)
../../../_images/tutorials_notebooks_scAdam_scAdam_train_40_0.png
[31]:
# Second annotation level (celltype_l2)
df_l2 = scparadise.scnoah.report_classif_full(adata_test,
                                              celltype='celltype_l2',
                                              pred_celltype='pred_celltype_l2')
df_l2
[31]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
B cell 0.9819 0.9766 0.9977 0.9792 0.9871 0.9723 2050
CD4-positive, alpha-beta T cell 0.8686 0.8166 0.9885 0.8418 0.8985 0.7934 1554
CD8-positive, alpha-beta T cell 0.9166 0.9208 0.9912 0.9187 0.9553 0.9062 1742
endothelial cell 0.9956 0.9847 0.9999 0.9901 0.9923 0.9831 457
endothelial cell of lymphatic vessel 0.9273 0.9808 0.9998 0.9533 0.9902 0.9787 52
epithelial cell 0.9884 0.986 0.9919 0.9872 0.989 0.9775 7505
fibroblast 0.9844 0.9364 0.9994 0.9598 0.9674 0.93 676
mast cell 0.9728 0.9662 0.9998 0.9695 0.9829 0.9628 148
mature alpha-beta T cell 0.7901 0.8533 0.9991 0.8205 0.9233 0.8401 75
mural cell 0.8979 0.9863 0.9973 0.94 0.9917 0.9825 437
myeloid cell 0.9777 0.9656 0.9991 0.9716 0.9822 0.9615 726
natural killer cell 0.7245 0.9552 0.9955 0.824 0.9751 0.947 223
plasma cell 0.9854 0.9859 0.9983 0.9857 0.9921 0.983 1919
plasmacytoid dendritic cell 0.8704 0.94 0.9996 0.9038 0.9694 0.934 50
regulatory T cell 0.7528 0.7901 0.9899 0.771 0.8844 0.7665 686
macro avg 0.9090 0.9363 0.9965 0.9211 0.9654 0.9279
weighted avg 0.9543 0.9531 0.9939 0.9533 0.9728 0.9439
Accuracy 0.9531
Balanced accuracy 0.9363
[32]:
sns.set(font_scale = 0.7)
plt.figure(figsize = (5, 3))
scparadise.scnoah.conf_matrix(adata_test,
                              celltype = 'celltype_l2',
                              pred_celltype = 'pred_celltype_l2',
                              annot_kws = {"size":5},
                              linewidths = 0.1, linecolor = 'black',
                              fmt =  ".2f",
                              ndigits_metrics = 4,
                              vmin = 0, vmax = 1)
../../../_images/tutorials_notebooks_scAdam_scAdam_train_42_0.png
[33]:
# Save anndata with predicted annotations
adata_test.write_h5ad('adata_test_imbalanced_model.h5ad')
[34]:
import session_info
session_info.show()
[34]:
Click to view session information
-----
anndata             0.10.8
matplotlib          3.9.2
numpy               1.25.0
pandas              2.2.3
scanpy              1.10.3
scparadise          0.3.2_beta
seaborn             0.13.2
session_info        1.0.0
-----
Click to view modules imported as dependencies
PIL                         10.4.0
anyio                       NA
arrow                       1.3.0
asciitree                   NA
asttokens                   NA
attr                        24.2.0
attrs                       24.2.0
awkward                     2.7.1
awkward_cpp                 NA
babel                       2.16.0
backports                   NA
certifi                     2024.08.30
cffi                        1.17.1
charset_normalizer          3.3.2
cloudpickle                 3.1.0
colorlog                    NA
comm                        0.2.2
cycler                      0.12.1
cython_runtime              NA
dask                        2024.8.0
dateutil                    2.9.0.post0
debugpy                     1.8.6
decorator                   5.1.1
defusedxml                  0.7.1
exceptiongroup              1.2.2
executing                   2.1.0
fastjsonschema              NA
fqdn                        NA
fsspec                      2023.6.0
h5py                        3.12.1
idna                        3.10
igraph                      0.11.6
imblearn                    0.12.3
importlib_metadata          NA
importlib_resources         NA
ipykernel                   6.29.5
isoduration                 NA
jaraco                      NA
jedi                        0.19.1
jinja2                      3.1.4
joblib                      1.4.2
json5                       0.9.25
jsonpointer                 3.0.0
jsonschema                  4.23.0
jsonschema_specifications   NA
jupyter_events              0.10.0
jupyter_server              2.14.2
jupyterlab_server           2.27.3
kiwisolver                  1.4.7
legacy_api_wrap             NA
leidenalg                   0.10.2
llvmlite                    0.43.0
markupsafe                  2.1.5
matplotlib_inline           0.1.7
more_itertools              10.5.0
mpl_toolkits                NA
mpmath                      1.3.0
msgpack                     1.1.0
mudata                      0.2.4
muon                        0.1.6
natsort                     8.4.0
nbformat                    5.10.4
numba                       0.60.0
numcodecs                   0.12.1
optuna                      4.0.0
overrides                   NA
packaging                   24.1
parso                       0.8.4
patsy                       0.5.6
pexpect                     4.9.0
platformdirs                4.3.6
plotly                      5.24.1
prometheus_client           NA
prompt_toolkit              3.0.48
psutil                      6.0.0
ptyprocess                  0.7.0
pure_eval                   0.2.3
pyarrow                     18.1.0
pycparser                   2.22
pydev_ipython               NA
pydevconsole                NA
pydevd                      3.1.0
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pydot                       3.0.3
pygments                    2.18.0
pynndescent                 0.5.13
pyparsing                   3.1.4
pythonjsonlogger            NA
pytorch_tabnet              NA
pytz                        2024.2
referencing                 NA
requests                    2.32.3
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rich                        NA
rpds                        NA
scipy                       1.13.1
send2trash                  NA
setuptools                  75.1.0
setuptools_scm              NA
shap                        0.46.0
six                         1.16.0
sklearn                     1.5.2
skmisc                      0.3.1
slicer                      NA
sniffio                     1.3.1
stack_data                  0.6.3
statsmodels                 0.14.3
sympy                       1.13.3
tblib                       3.0.0
texttable                   1.7.0
threadpoolctl               3.5.0
tlz                         0.12.1
tomli                       2.0.1
toolz                       0.12.1
torch                       2.4.1+cu121
torchgen                    NA
tornado                     6.4.1
tqdm                        4.66.5
traitlets                   5.14.3
triton                      3.0.0
typing_extensions           NA
umap                        0.5.6
uri_template                NA
urllib3                     1.26.20
wcwidth                     0.2.13
webcolors                   24.8.0
websocket                   1.8.0
yaml                        6.0.2
zarr                        2.18.2
zipp                        NA
zmq                         26.2.0
zoneinfo                    NA
-----
IPython             8.18.1
jupyter_client      8.6.3
jupyter_core        5.7.2
jupyterlab          4.2.5
-----
Python 3.9.19 (main, May  6 2024, 19:43:03) [GCC 11.2.0]
Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
-----
Session information updated at 2025-01-16 13:27