scAdam model train#
[1]:
# Python packages
import warnings
warnings.simplefilter('ignore')
import scanpy as sc
import scparadise
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
sc.set_figure_params(dpi = 120)
Recommendations about training dataset#
We recommend shifted logarithm data normalization method: sc.pp.normalize_total(adata, target_sum=None) sc.pp.log1p(adata) But you can use any other method of data normalzation (Use the same normalization method for test dataset) Training dataset sould contain all genes that you want to use for model training in adata_train.X We recommend to remove all non-marker genes from adata_train.X (removeing of such useless genes increase performance and model quality metrics)
Data preparation#
Here we download and preprocess dataset from cellxgene:
Single cell RNA sequencing of oropharyngeal squamous cell carcinoma https://cellxgene.cziscience.com/collections/3c34e6f1-6827-47dd-8e19-9edcd461893f
[2]:
!wget https://datasets.cellxgene.cziscience.com/915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad
--2025-01-16 12:42:52-- https://datasets.cellxgene.cziscience.com/915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad
Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 52.85.49.125, 52.85.49.28, 52.85.49.24, ...
Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|52.85.49.125|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 981941987 (936M) [binary/octet-stream]
Saving to: ‘915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad’
915069db-1df2-49a1- 100%[===================>] 936.45M 4.91MB/s in 2m 16s
2025-01-16 12:45:08 (6.89 MB/s) - ‘915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad’ saved [981941987/981941987]
[3]:
# Load prepared for training anndata object
adata = sc.read_h5ad('915069db-1df2-49a1-9a9c-2fbd0aa13c81.h5ad')
[4]:
# Get raw counts from adata.raw
adata = adata.raw.to_adata()
[5]:
# Convert var_names from ENSG codes to gene names
adata.var.set_index('feature_name', inplace=True)
adata.var_names_make_unique()
[6]:
# Create celltype_l1 and celltype_l2 annotation levels
adata.obs['celltype_l2'] = adata.obs['cell_type'].copy()
# Check current cluster name
cluster_list = adata.obs['celltype_l2'].unique()
# Make cluster anottation dictionary
annotation = {"epithelial":['epithelial cell'],
"B":['B cell', 'plasma cell'],
"T":['CD4-positive, alpha-beta T cell', 'CD8-positive, alpha-beta T cell', 'mature alpha-beta T cell', 'regulatory T cell'],
"NK":['natural killer cell'],
"myeloid": ['myeloid cell', 'plasmacytoid dendritic cell', 'mast cell'],
"endothelial":['endothelial cell', 'endothelial cell of lymphatic vessel'],
"stromal": ['fibroblast', 'mural cell']}
# Change dictionary format
annotation_rev = {}
for i in cluster_list:
for k in annotation:
if i in annotation[k]:
annotation_rev[i] = k
adata.obs["celltype_l1"] = [annotation_rev[i] for i in adata.obs['celltype_l2']]
[7]:
# Check annotations
sc.pl.embedding(adata,
basis = 'X_umap',
color = ['celltype_l1',
'celltype_l2'],
frameon = False,
ncols = 1)
[8]:
# Create adata_train1, adata_train2 and adata_test datasets
adata_test = adata[adata.obs['donor_id'].isin(['HN481', 'HN482'])].copy()
adata_train = adata[adata.obs['donor_id'].isin(['HN483' 'HN485', 'HN488', 'HN489', 'HN490', 'HN487', 'HN492', 'HN494'])].copy()
del adata
[9]:
# Normalize data, find highly variable features
for i in [adata_train, adata_test]:
i.layers['counts'] = i.X.copy()
sc.pp.normalize_total(i, target_sum=None)
sc.pp.log1p(i)
i.raw = i
sc.pp.highly_variable_genes(i,
layer='counts',
flavor='seurat_v3',
n_top_genes=1200,
subset=True)
Balanced training#
We recommend balance dataset based on most detailed annotation level (celltype_l2 in this case). Balancing the training dataset increases the sensitivity (balanced accuracy), f1-score and geometric mean of the model but leads to slightly decrease in precision
[10]:
# Balance dataset based on most detailed annotation level
adata_balanced = scparadise.scnoah.balance(adata_train,
sample = 'donor_id',
celltype_l1 = 'celltype_l1',
celltype_l2 = 'celltype_l2')
Successfully undersampled cell types: epithelial cell, B cell, myeloid cell, CD4-positive, alpha-beta T cell, CD8-positive, alpha-beta T cell, plasma cell
Successfully oversampled cell types: regulatory T cell, fibroblast, mural cell, endothelial cell, natural killer cell, plasmacytoid dendritic cell, mast cell, mature alpha-beta T cell, endothelial cell of lymphatic vessel
[11]:
# Train scadam model using adata_balanced dataset
scparadise.scadam.train(adata_balanced,
path = '', # path to save model
model_name = 'model_scadam_balanced', # folder name with model
celltype_l1 = 'celltype_l1', # First (less detailed) annotation level
celltype_l2 = 'celltype_l2', # Second (most detailed) annotation level
eval_metric = ['balanced_accuracy','accuracy'])
Successfully saved genes names for training model
Successfully saved dictionary of dataset annotations
Train dataset contains: 44415 cells, it is 90.0 % of input dataset
Test dataset contains: 4935 cells, it is 10.0 % of input dataset
Accelerator: cuda
Start training
epoch 0 | loss: 2.21299 | train_balanced_accuracy: 0.49625 | train_accuracy: 0.55639 | valid_balanced_accuracy: 0.48656 | valid_accuracy: 0.5461 | 0:00:02s
epoch 1 | loss: 1.14552 | train_balanced_accuracy: 0.72033 | train_accuracy: 0.74447 | valid_balanced_accuracy: 0.70999 | valid_accuracy: 0.73587 | 0:00:05s
epoch 2 | loss: 0.74196 | train_balanced_accuracy: 0.84973 | train_accuracy: 0.8624 | valid_balanced_accuracy: 0.83588 | valid_accuracy: 0.84985 | 0:00:08s
epoch 3 | loss: 0.55414 | train_balanced_accuracy: 0.90071 | train_accuracy: 0.91104 | valid_balanced_accuracy: 0.89921 | valid_accuracy: 0.90942 | 0:00:11s
epoch 4 | loss: 0.46574 | train_balanced_accuracy: 0.92385 | train_accuracy: 0.93116 | valid_balanced_accuracy: 0.91397 | valid_accuracy: 0.92158 | 0:00:14s
epoch 5 | loss: 0.42424 | train_balanced_accuracy: 0.93255 | train_accuracy: 0.94035 | valid_balanced_accuracy: 0.92515 | valid_accuracy: 0.93333 | 0:00:17s
epoch 6 | loss: 0.38279 | train_balanced_accuracy: 0.95089 | train_accuracy: 0.95306 | valid_balanced_accuracy: 0.94334 | valid_accuracy: 0.94498 | 0:00:19s
epoch 7 | loss: 0.365 | train_balanced_accuracy: 0.94994 | train_accuracy: 0.95516 | valid_balanced_accuracy: 0.94515 | valid_accuracy: 0.95086 | 0:00:22s
epoch 8 | loss: 0.33252 | train_balanced_accuracy: 0.95308 | train_accuracy: 0.95708 | valid_balanced_accuracy: 0.94601 | valid_accuracy: 0.95035 | 0:00:25s
epoch 9 | loss: 0.31807 | train_balanced_accuracy: 0.96494 | train_accuracy: 0.96749 | valid_balanced_accuracy: 0.9574 | valid_accuracy: 0.96039 | 0:00:28s
epoch 10 | loss: 0.30404 | train_balanced_accuracy: 0.97066 | train_accuracy: 0.97151 | valid_balanced_accuracy: 0.96183 | valid_accuracy: 0.96353 | 0:00:31s
epoch 11 | loss: 0.29295 | train_balanced_accuracy: 0.96804 | train_accuracy: 0.97109 | valid_balanced_accuracy: 0.96141 | valid_accuracy: 0.96464 | 0:00:34s
epoch 12 | loss: 0.29203 | train_balanced_accuracy: 0.97248 | train_accuracy: 0.97422 | valid_balanced_accuracy: 0.96377 | valid_accuracy: 0.96616 | 0:00:37s
epoch 13 | loss: 0.27764 | train_balanced_accuracy: 0.97068 | train_accuracy: 0.97313 | valid_balanced_accuracy: 0.96204 | valid_accuracy: 0.96525 | 0:00:40s
epoch 14 | loss: 0.26716 | train_balanced_accuracy: 0.97676 | train_accuracy: 0.97791 | valid_balanced_accuracy: 0.96719 | valid_accuracy: 0.969 | 0:00:43s
epoch 15 | loss: 0.2658 | train_balanced_accuracy: 0.97736 | train_accuracy: 0.9776 | valid_balanced_accuracy: 0.96948 | valid_accuracy: 0.96971 | 0:00:45s
epoch 16 | loss: 0.26173 | train_balanced_accuracy: 0.97664 | train_accuracy: 0.9786 | valid_balanced_accuracy: 0.96713 | valid_accuracy: 0.9696 | 0:00:48s
epoch 17 | loss: 0.25269 | train_balanced_accuracy: 0.97566 | train_accuracy: 0.97638 | valid_balanced_accuracy: 0.96591 | valid_accuracy: 0.96717 | 0:00:51s
epoch 18 | loss: 0.25638 | train_balanced_accuracy: 0.98054 | train_accuracy: 0.98126 | valid_balanced_accuracy: 0.96883 | valid_accuracy: 0.97052 | 0:00:54s
epoch 19 | loss: 0.24683 | train_balanced_accuracy: 0.98036 | train_accuracy: 0.98016 | valid_balanced_accuracy: 0.97149 | valid_accuracy: 0.97153 | 0:00:56s
epoch 20 | loss: 0.24669 | train_balanced_accuracy: 0.98134 | train_accuracy: 0.98176 | valid_balanced_accuracy: 0.97291 | valid_accuracy: 0.97356 | 0:00:59s
epoch 21 | loss: 0.2413 | train_balanced_accuracy: 0.97987 | train_accuracy: 0.98138 | valid_balanced_accuracy: 0.96868 | valid_accuracy: 0.97143 | 0:01:02s
epoch 22 | loss: 0.23411 | train_balanced_accuracy: 0.9823 | train_accuracy: 0.98285 | valid_balanced_accuracy: 0.96949 | valid_accuracy: 0.97143 | 0:01:05s
epoch 23 | loss: 0.23256 | train_balanced_accuracy: 0.98111 | train_accuracy: 0.98152 | valid_balanced_accuracy: 0.96911 | valid_accuracy: 0.97031 | 0:01:08s
epoch 24 | loss: 0.23049 | train_balanced_accuracy: 0.97992 | train_accuracy: 0.98013 | valid_balanced_accuracy: 0.96913 | valid_accuracy: 0.97011 | 0:01:11s
epoch 25 | loss: 0.22792 | train_balanced_accuracy: 0.98018 | train_accuracy: 0.98138 | valid_balanced_accuracy: 0.96714 | valid_accuracy: 0.97042 | 0:01:14s
epoch 26 | loss: 0.22866 | train_balanced_accuracy: 0.98175 | train_accuracy: 0.98292 | valid_balanced_accuracy: 0.96839 | valid_accuracy: 0.97123 | 0:01:16s
epoch 27 | loss: 0.22948 | train_balanced_accuracy: 0.98347 | train_accuracy: 0.98354 | valid_balanced_accuracy: 0.97165 | valid_accuracy: 0.97214 | 0:01:19s
epoch 28 | loss: 0.23231 | train_balanced_accuracy: 0.98406 | train_accuracy: 0.98454 | valid_balanced_accuracy: 0.97292 | valid_accuracy: 0.97457 | 0:01:22s
epoch 29 | loss: 0.2214 | train_balanced_accuracy: 0.98382 | train_accuracy: 0.98372 | valid_balanced_accuracy: 0.97213 | valid_accuracy: 0.97244 | 0:01:25s
epoch 30 | loss: 0.21979 | train_balanced_accuracy: 0.98496 | train_accuracy: 0.9847 | valid_balanced_accuracy: 0.97303 | valid_accuracy: 0.97345 | 0:01:27s
epoch 31 | loss: 0.21397 | train_balanced_accuracy: 0.98468 | train_accuracy: 0.98523 | valid_balanced_accuracy: 0.97364 | valid_accuracy: 0.97538 | 0:01:30s
epoch 32 | loss: 0.2146 | train_balanced_accuracy: 0.98435 | train_accuracy: 0.98476 | valid_balanced_accuracy: 0.97197 | valid_accuracy: 0.97386 | 0:01:33s
epoch 33 | loss: 0.21473 | train_balanced_accuracy: 0.98328 | train_accuracy: 0.9838 | valid_balanced_accuracy: 0.96836 | valid_accuracy: 0.97042 | 0:01:36s
epoch 34 | loss: 0.21234 | train_balanced_accuracy: 0.98507 | train_accuracy: 0.98472 | valid_balanced_accuracy: 0.97059 | valid_accuracy: 0.97102 | 0:01:39s
epoch 35 | loss: 0.21444 | train_balanced_accuracy: 0.98667 | train_accuracy: 0.98633 | valid_balanced_accuracy: 0.97435 | valid_accuracy: 0.97437 | 0:01:42s
epoch 36 | loss: 0.21378 | train_balanced_accuracy: 0.98626 | train_accuracy: 0.98602 | valid_balanced_accuracy: 0.97264 | valid_accuracy: 0.97335 | 0:01:44s
epoch 37 | loss: 0.21579 | train_balanced_accuracy: 0.98637 | train_accuracy: 0.98651 | valid_balanced_accuracy: 0.97226 | valid_accuracy: 0.97366 | 0:01:47s
epoch 38 | loss: 0.21361 | train_balanced_accuracy: 0.98471 | train_accuracy: 0.98512 | valid_balanced_accuracy: 0.97023 | valid_accuracy: 0.97275 | 0:01:50s
epoch 39 | loss: 0.20431 | train_balanced_accuracy: 0.98699 | train_accuracy: 0.98694 | valid_balanced_accuracy: 0.9727 | valid_accuracy: 0.97406 | 0:01:53s
epoch 40 | loss: 0.20723 | train_balanced_accuracy: 0.98677 | train_accuracy: 0.98683 | valid_balanced_accuracy: 0.97398 | valid_accuracy: 0.97487 | 0:01:55s
epoch 41 | loss: 0.20137 | train_balanced_accuracy: 0.98697 | train_accuracy: 0.98698 | valid_balanced_accuracy: 0.97357 | valid_accuracy: 0.97447 | 0:01:58s
Early stopping occurred at epoch 41 with best_epoch = 31 and best_valid_accuracy = 0.97538
Successfully saved training history and parameters
Successfully saved model at model_scadam_balanced/model.zip
Check model quality#
[13]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
path_model = 'model_scadam_balanced')
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Successfully loaded model
Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
[14]:
# Add prediction status to adata_test
scparadise.scnoah.pred_status(adata_test,
celltype = 'celltype_l1',
pred_celltype = 'pred_celltype_l1',
key_added = 'pred_status_l1')
scparadise.scnoah.pred_status(adata_test,
celltype = 'celltype_l2',
pred_celltype = 'pred_celltype_l2',
key_added = 'pred_status_l2')
[15]:
# Order cell type colors
celltype_list = ['celltype_l1','celltype_l2']
for i in celltype_list:
celltype = np.unique(adata_test.obs[i]).tolist()
adata_test.obs[i] = pd.Categorical(
values=adata_test.obs[i], categories=celltype, ordered=True
)
adata_test.obs['pred_' + i] = pd.Categorical(
values=adata_test.obs['pred_' + i], categories=celltype, ordered=True
)
The left column represents observed cell type annotations.
The central column represents predicted cell type annotations.
The right column represents prediction probabilities.
[16]:
# Visualise predicted cell types levels and prediction probabilities
sc.pl.embedding(adata_test,
color=[
'celltype_l1',
'pred_celltype_l1',
'prob_celltype_l1',
'celltype_l2',
'pred_celltype_l2',
'prob_celltype_l2'
],
basis = 'X_umap',
frameon = False,
cmap = 'viridis',
legend_loc = 'on data',
legend_fontsize = 7,
legend_fontoutline = 1,
ncols = 3,
wspace = 0.1,
hspace = 0.1)
[17]:
# Visualise prediction status
sc.pl.embedding(adata_test,
color=[
'pred_status_l1',
'pred_status_l2'
],
basis = 'X_umap',
frameon = False,
cmap = 'viridis',
legend_loc = 'right margin',
legend_fontsize = 7,
legend_fontoutline = 1,
ncols = 3,
wspace = 0.1,
hspace = 0.1)
The probability and prediction status analysis indicates issues with the model in annotating T cell subtypes at the celltype_l2 annotation level. However, the model does not have problems with annotating other cell types, as confirmed by the quality analysis presented below.
Comparizon of observed and predicted cell type annotations#
To compare the actual and predicted cell types, we use report_classif_full and conf_matrix from the scNoah module. More information about the metrics in scparadise.scnoah.report_classif_full is available in the scParadise documentation. More information about confusion matrix (scparadise.scnoah.conf_matrix) is available
here.
[18]:
# First annotation level (celltype_l1)
df_l1 = scparadise.scnoah.report_classif_full(adata_test,
celltype = 'celltype_l1',
pred_celltype = 'pred_celltype_l1',
ndigits = 3)
df_l1
[18]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| B | 0.982 | 0.984 | 0.995 | 0.983 | 0.989 | 0.978 | 3969 |
| NK | 0.835 | 0.951 | 0.998 | 0.889 | 0.974 | 0.944 | 223 |
| T | 0.936 | 0.972 | 0.981 | 0.954 | 0.977 | 0.953 | 4057 |
| endothelial | 0.994 | 0.992 | 1.0 | 0.993 | 0.996 | 0.991 | 509 |
| epithelial | 0.989 | 0.951 | 0.992 | 0.97 | 0.972 | 0.94 | 7505 |
| myeloid | 0.915 | 0.976 | 0.995 | 0.945 | 0.986 | 0.97 | 924 |
| stromal | 0.971 | 0.997 | 0.998 | 0.984 | 0.998 | 0.995 | 1113 |
| macro avg | 0.946 | 0.975 | 0.994 | 0.96 | 0.984 | 0.967 | |
| weighted avg | 0.969 | 0.968 | 0.991 | 0.968 | 0.98 | 0.958 | |
| Accuracy | 0.968 | ||||||
| Balanced accuracy | 0.975 |
[19]:
sns.set(font_scale = 0.7)
plt.figure(figsize = (3, 3))
scparadise.scnoah.conf_matrix(adata_test,
celltype = 'celltype_l1',
pred_celltype = 'pred_celltype_l1',
annot_kws = {"size":5},
linewidths = 0.1, linecolor = 'black',
fmt = ".2f",
ndigits_metrics = 4,
vmin = 0, vmax = 1)
[20]:
# Second annotation level (celltype_l2)
df_l2 = scparadise.scnoah.report_classif_full(adata_test,
celltype = 'celltype_l2',
pred_celltype = 'pred_celltype_l2',
ndigits = 3)
df_l2
[20]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| B cell | 0.965 | 0.969 | 0.996 | 0.967 | 0.982 | 0.962 | 2050 |
| CD4-positive, alpha-beta T cell | 0.892 | 0.719 | 0.992 | 0.796 | 0.845 | 0.694 | 1554 |
| CD8-positive, alpha-beta T cell | 0.886 | 0.943 | 0.987 | 0.914 | 0.965 | 0.927 | 1742 |
| endothelial cell | 0.989 | 0.991 | 1.0 | 0.99 | 0.995 | 0.99 | 457 |
| endothelial cell of lymphatic vessel | 1.000 | 0.981 | 1.0 | 0.99 | 0.99 | 0.979 | 52 |
| epithelial cell | 0.989 | 0.954 | 0.993 | 0.971 | 0.973 | 0.944 | 7505 |
| fibroblast | 0.928 | 0.979 | 0.997 | 0.953 | 0.988 | 0.975 | 676 |
| mast cell | 0.986 | 0.98 | 1.0 | 0.983 | 0.99 | 0.978 | 148 |
| mature alpha-beta T cell | 0.629 | 0.88 | 0.998 | 0.733 | 0.937 | 0.868 | 75 |
| mural cell | 0.968 | 0.973 | 0.999 | 0.97 | 0.986 | 0.969 | 437 |
| myeloid cell | 0.890 | 0.971 | 0.995 | 0.929 | 0.983 | 0.964 | 726 |
| natural killer cell | 0.839 | 0.955 | 0.998 | 0.893 | 0.976 | 0.949 | 223 |
| plasma cell | 0.983 | 0.988 | 0.998 | 0.985 | 0.993 | 0.985 | 1919 |
| plasmacytoid dendritic cell | 0.875 | 0.98 | 1.0 | 0.925 | 0.99 | 0.978 | 50 |
| regulatory T cell | 0.596 | 0.824 | 0.978 | 0.692 | 0.898 | 0.793 | 686 |
| macro avg | 0.894 | 0.939 | 0.995 | 0.913 | 0.966 | 0.93 | |
| weighted avg | 0.943 | 0.937 | 0.993 | 0.938 | 0.964 | 0.926 | |
| Accuracy | 0.937 | ||||||
| Balanced accuracy | 0.939 |
[21]:
sns.set(font_scale = 0.7)
plt.figure(figsize = (5, 3))
scparadise.scnoah.conf_matrix(adata_test,
celltype = 'celltype_l2',
pred_celltype = 'pred_celltype_l2',
annot_kws = {"size":5},
linewidths = 0.1, linecolor = 'black',
fmt = ".2f",
ndigits_metrics = 4,
vmin = 0, vmax = 1)
[22]:
# Save anndata with predicted annotations
adata_test.write_h5ad('adata_test_balanced_model.h5ad')
Imbalanced training#
Use ‘balanced_accuracy’ as evaluation metric in case of imbalanced learning
[23]:
# Train scadam model using adata_balanced dataset
scparadise.scadam.train(adata_train,
path = '', # path to save model
model_name = 'model_scadam_imbalanced', # folder name with model
celltype_l1 = 'celltype_l1', # First (less detailed) annotation level
celltype_l2 = 'celltype_l2', # Second (most detailed) annotation level
eval_metric = ['accuracy','balanced_accuracy']) # If you are using an imbalanced training dataset, we recommend using balance_accuracy for early stopping
Successfully saved genes names for training model
Successfully saved dictionary of dataset annotations
Train dataset contains: 44414 cells, it is 90.0 % of input dataset
Test dataset contains: 4935 cells, it is 10.0 % of input dataset
Accelerator: cuda
Start training
epoch 0 | loss: 1.89983 | train_accuracy: 0.6205 | train_balanced_accuracy: 0.34227 | valid_accuracy: 0.6157 | valid_balanced_accuracy: 0.33701 | 0:00:02s
epoch 1 | loss: 0.93062 | train_accuracy: 0.76397 | train_balanced_accuracy: 0.5379 | valid_accuracy: 0.7613 | valid_balanced_accuracy: 0.54393 | 0:00:05s
epoch 2 | loss: 0.63447 | train_accuracy: 0.89784 | train_balanced_accuracy: 0.6967 | valid_accuracy: 0.89656 | valid_balanced_accuracy: 0.6991 | 0:00:08s
epoch 3 | loss: 0.45799 | train_accuracy: 0.92658 | train_balanced_accuracy: 0.79117 | valid_accuracy: 0.92533 | valid_balanced_accuracy: 0.7936 | 0:00:11s
epoch 4 | loss: 0.38553 | train_accuracy: 0.93557 | train_balanced_accuracy: 0.82882 | valid_accuracy: 0.93222 | valid_balanced_accuracy: 0.82831 | 0:00:13s
epoch 5 | loss: 0.34978 | train_accuracy: 0.94098 | train_balanced_accuracy: 0.83302 | valid_accuracy: 0.94032 | valid_balanced_accuracy: 0.8334 | 0:00:16s
epoch 6 | loss: 0.32005 | train_accuracy: 0.94931 | train_balanced_accuracy: 0.8714 | valid_accuracy: 0.94438 | valid_balanced_accuracy: 0.8629 | 0:00:19s
epoch 7 | loss: 0.30517 | train_accuracy: 0.96018 | train_balanced_accuracy: 0.90504 | valid_accuracy: 0.95623 | valid_balanced_accuracy: 0.8828 | 0:00:22s
epoch 8 | loss: 0.29713 | train_accuracy: 0.95839 | train_balanced_accuracy: 0.91825 | valid_accuracy: 0.95319 | valid_balanced_accuracy: 0.89912 | 0:00:25s
epoch 9 | loss: 0.27967 | train_accuracy: 0.96636 | train_balanced_accuracy: 0.92368 | valid_accuracy: 0.9613 | valid_balanced_accuracy: 0.91722 | 0:00:27s
epoch 10 | loss: 0.26854 | train_accuracy: 0.96917 | train_balanced_accuracy: 0.95137 | valid_accuracy: 0.9618 | valid_balanced_accuracy: 0.94174 | 0:00:30s
epoch 11 | loss: 0.2596 | train_accuracy: 0.9695 | train_balanced_accuracy: 0.96104 | valid_accuracy: 0.9614 | valid_balanced_accuracy: 0.94668 | 0:00:33s
epoch 12 | loss: 0.2611 | train_accuracy: 0.96822 | train_balanced_accuracy: 0.95705 | valid_accuracy: 0.9614 | valid_balanced_accuracy: 0.94765 | 0:00:36s
epoch 13 | loss: 0.24682 | train_accuracy: 0.97206 | train_balanced_accuracy: 0.95925 | valid_accuracy: 0.96596 | valid_balanced_accuracy: 0.94677 | 0:00:39s
epoch 14 | loss: 0.24491 | train_accuracy: 0.97435 | train_balanced_accuracy: 0.96788 | valid_accuracy: 0.96535 | valid_balanced_accuracy: 0.95488 | 0:00:42s
epoch 15 | loss: 0.23757 | train_accuracy: 0.97454 | train_balanced_accuracy: 0.96045 | valid_accuracy: 0.96596 | valid_balanced_accuracy: 0.94597 | 0:00:45s
epoch 16 | loss: 0.23232 | train_accuracy: 0.97702 | train_balanced_accuracy: 0.96977 | valid_accuracy: 0.96657 | valid_balanced_accuracy: 0.9553 | 0:00:48s
epoch 17 | loss: 0.22917 | train_accuracy: 0.97673 | train_balanced_accuracy: 0.97094 | valid_accuracy: 0.96636 | valid_balanced_accuracy: 0.95947 | 0:00:51s
epoch 18 | loss: 0.22392 | train_accuracy: 0.97735 | train_balanced_accuracy: 0.97009 | valid_accuracy: 0.96575 | valid_balanced_accuracy: 0.95331 | 0:00:53s
epoch 19 | loss: 0.21909 | train_accuracy: 0.9789 | train_balanced_accuracy: 0.97495 | valid_accuracy: 0.96738 | valid_balanced_accuracy: 0.95752 | 0:00:56s
epoch 20 | loss: 0.22133 | train_accuracy: 0.97781 | train_balanced_accuracy: 0.9717 | valid_accuracy: 0.96525 | valid_balanced_accuracy: 0.95094 | 0:00:59s
epoch 21 | loss: 0.21602 | train_accuracy: 0.9779 | train_balanced_accuracy: 0.97577 | valid_accuracy: 0.96586 | valid_balanced_accuracy: 0.95856 | 0:01:01s
epoch 22 | loss: 0.20957 | train_accuracy: 0.97917 | train_balanced_accuracy: 0.97793 | valid_accuracy: 0.96748 | valid_balanced_accuracy: 0.96025 | 0:01:04s
epoch 23 | loss: 0.21309 | train_accuracy: 0.9806 | train_balanced_accuracy: 0.97859 | valid_accuracy: 0.96636 | valid_balanced_accuracy: 0.95804 | 0:01:07s
epoch 24 | loss: 0.2073 | train_accuracy: 0.97963 | train_balanced_accuracy: 0.97953 | valid_accuracy: 0.96768 | valid_balanced_accuracy: 0.96242 | 0:01:10s
epoch 25 | loss: 0.20919 | train_accuracy: 0.98106 | train_balanced_accuracy: 0.9762 | valid_accuracy: 0.96839 | valid_balanced_accuracy: 0.95242 | 0:01:13s
epoch 26 | loss: 0.20778 | train_accuracy: 0.97898 | train_balanced_accuracy: 0.9773 | valid_accuracy: 0.96575 | valid_balanced_accuracy: 0.95438 | 0:01:16s
epoch 27 | loss: 0.20592 | train_accuracy: 0.98023 | train_balanced_accuracy: 0.97647 | valid_accuracy: 0.96798 | valid_balanced_accuracy: 0.95622 | 0:01:19s
epoch 28 | loss: 0.20345 | train_accuracy: 0.98064 | train_balanced_accuracy: 0.97579 | valid_accuracy: 0.96758 | valid_balanced_accuracy: 0.95307 | 0:01:21s
epoch 29 | loss: 0.20338 | train_accuracy: 0.98195 | train_balanced_accuracy: 0.97994 | valid_accuracy: 0.9693 | valid_balanced_accuracy: 0.96008 | 0:01:24s
epoch 30 | loss: 0.19875 | train_accuracy: 0.98133 | train_balanced_accuracy: 0.97684 | valid_accuracy: 0.96727 | valid_balanced_accuracy: 0.95396 | 0:01:27s
epoch 31 | loss: 0.19886 | train_accuracy: 0.98129 | train_balanced_accuracy: 0.98089 | valid_accuracy: 0.96768 | valid_balanced_accuracy: 0.9582 | 0:01:29s
epoch 32 | loss: 0.19664 | train_accuracy: 0.98207 | train_balanced_accuracy: 0.98338 | valid_accuracy: 0.96778 | valid_balanced_accuracy: 0.96199 | 0:01:32s
epoch 33 | loss: 0.19808 | train_accuracy: 0.98324 | train_balanced_accuracy: 0.98151 | valid_accuracy: 0.96768 | valid_balanced_accuracy: 0.95489 | 0:01:35s
epoch 34 | loss: 0.19552 | train_accuracy: 0.98102 | train_balanced_accuracy: 0.98076 | valid_accuracy: 0.96555 | valid_balanced_accuracy: 0.95149 | 0:01:38s
Early stopping occurred at epoch 34 with best_epoch = 24 and best_valid_balanced_accuracy = 0.96242
Successfully saved training history and parameters
Successfully saved model at model_scadam_imbalanced/model.zip
Check model quality#
[24]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
path_model = 'model_scadam_imbalanced')
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Successfully loaded model
Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
[25]:
# Add prediction status to adata_test
scparadise.scnoah.pred_status(adata_test,
celltype = 'celltype_l1',
pred_celltype = 'pred_celltype_l1',
key_added = 'pred_status_l1')
scparadise.scnoah.pred_status(adata_test,
celltype = 'celltype_l2',
pred_celltype = 'pred_celltype_l2',
key_added = 'pred_status_l2')
[26]:
# Order cell type colors
celltype_list = ['celltype_l1','celltype_l2']
for i in celltype_list:
celltype = np.unique(adata_test.obs[i]).tolist()
adata_test.obs[i] = pd.Categorical(
values=adata_test.obs[i], categories=celltype, ordered=True
)
adata_test.obs['pred_' + i] = pd.Categorical(
values=adata_test.obs['pred_' + i], categories=celltype, ordered=True
)
[27]:
# Visualise predicted cell types levels and prediction probabilities
sc.pl.embedding(adata_test,
color=[
'celltype_l1',
'pred_celltype_l1',
'prob_celltype_l1',
'pred_status_l1',
'celltype_l2',
'pred_celltype_l2',
'prob_celltype_l2',
'pred_status_l2'
],
basis = 'X_umap',
frameon = False,
cmap = 'viridis',
legend_loc = 'on data',
legend_fontsize = 7,
legend_fontoutline = 1,
ncols = 4,
wspace = 0,
hspace = 0.1)
[28]:
# Visualise prediction status
sc.pl.embedding(adata_test,
color=[
'pred_status_l1',
'pred_status_l2'
],
basis = 'X_umap',
frameon = False,
cmap = 'viridis',
legend_loc = 'right margin',
legend_fontsize = 7,
legend_fontoutline = 1,
ncols = 3,
wspace = 0.1,
hspace = 0.1)
The probability and prediction status analysis indicates issues with the model in annotating T cell subtypes at the celltype_l2 annotation level. However, the model does not have problems with annotating other cell types, as confirmed by the quality analysis presented below.
Comparizon of observed and predicted cell type annotations#
To compare the actual and predicted cell types, we use report_classif_full and conf_matrix from the scNoah module. More information about the metrics in scparadise.scnoah.report_classif_full is available in the scParadise documentation. More information about confusion matrix (scparadise.scnoah.conf_matrix) is available
here.
[29]:
# First annotation level (celltype_l1)
df_l1 = scparadise.scnoah.report_classif_full(adata_test,
celltype='celltype_l1',
pred_celltype='pred_celltype_l1')
df_l1
[29]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| B | 0.9881 | 0.9829 | 0.9967 | 0.9855 | 0.9898 | 0.9783 | 3969 |
| NK | 0.7220 | 0.9552 | 0.9955 | 0.8224 | 0.9751 | 0.947 | 223 |
| T | 0.9732 | 0.9675 | 0.9924 | 0.9703 | 0.9799 | 0.9577 | 4057 |
| endothelial | 0.9980 | 0.9784 | 0.9999 | 0.9881 | 0.9891 | 0.9762 | 509 |
| epithelial | 0.9888 | 0.9859 | 0.9922 | 0.9873 | 0.989 | 0.9776 | 7505 |
| myeloid | 0.9825 | 0.9697 | 0.9991 | 0.976 | 0.9843 | 0.966 | 924 |
| stromal | 0.9832 | 0.9982 | 0.9989 | 0.9906 | 0.9985 | 0.997 | 1113 |
| macro avg | 0.9480 | 0.9768 | 0.9964 | 0.96 | 0.9865 | 0.9714 | |
| weighted avg | 0.9815 | 0.9805 | 0.9942 | 0.9808 | 0.9873 | 0.9735 | |
| Accuracy | 0.9805 | ||||||
| Balanced accuracy | 0.9768 |
[30]:
sns.set(font_scale = 0.7)
plt.figure(figsize = (3, 3))
scparadise.scnoah.conf_matrix(adata_test,
celltype = 'celltype_l1',
pred_celltype = 'pred_celltype_l1',
annot_kws = {"size":5},
linewidths = 0.1, linecolor = 'black',
fmt = ".2f",
ndigits_metrics = 4,
vmin = 0, vmax = 1)
[31]:
# Second annotation level (celltype_l2)
df_l2 = scparadise.scnoah.report_classif_full(adata_test,
celltype='celltype_l2',
pred_celltype='pred_celltype_l2')
df_l2
[31]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| B cell | 0.9819 | 0.9766 | 0.9977 | 0.9792 | 0.9871 | 0.9723 | 2050 |
| CD4-positive, alpha-beta T cell | 0.8686 | 0.8166 | 0.9885 | 0.8418 | 0.8985 | 0.7934 | 1554 |
| CD8-positive, alpha-beta T cell | 0.9166 | 0.9208 | 0.9912 | 0.9187 | 0.9553 | 0.9062 | 1742 |
| endothelial cell | 0.9956 | 0.9847 | 0.9999 | 0.9901 | 0.9923 | 0.9831 | 457 |
| endothelial cell of lymphatic vessel | 0.9273 | 0.9808 | 0.9998 | 0.9533 | 0.9902 | 0.9787 | 52 |
| epithelial cell | 0.9884 | 0.986 | 0.9919 | 0.9872 | 0.989 | 0.9775 | 7505 |
| fibroblast | 0.9844 | 0.9364 | 0.9994 | 0.9598 | 0.9674 | 0.93 | 676 |
| mast cell | 0.9728 | 0.9662 | 0.9998 | 0.9695 | 0.9829 | 0.9628 | 148 |
| mature alpha-beta T cell | 0.7901 | 0.8533 | 0.9991 | 0.8205 | 0.9233 | 0.8401 | 75 |
| mural cell | 0.8979 | 0.9863 | 0.9973 | 0.94 | 0.9917 | 0.9825 | 437 |
| myeloid cell | 0.9777 | 0.9656 | 0.9991 | 0.9716 | 0.9822 | 0.9615 | 726 |
| natural killer cell | 0.7245 | 0.9552 | 0.9955 | 0.824 | 0.9751 | 0.947 | 223 |
| plasma cell | 0.9854 | 0.9859 | 0.9983 | 0.9857 | 0.9921 | 0.983 | 1919 |
| plasmacytoid dendritic cell | 0.8704 | 0.94 | 0.9996 | 0.9038 | 0.9694 | 0.934 | 50 |
| regulatory T cell | 0.7528 | 0.7901 | 0.9899 | 0.771 | 0.8844 | 0.7665 | 686 |
| macro avg | 0.9090 | 0.9363 | 0.9965 | 0.9211 | 0.9654 | 0.9279 | |
| weighted avg | 0.9543 | 0.9531 | 0.9939 | 0.9533 | 0.9728 | 0.9439 | |
| Accuracy | 0.9531 | ||||||
| Balanced accuracy | 0.9363 |
[32]:
sns.set(font_scale = 0.7)
plt.figure(figsize = (5, 3))
scparadise.scnoah.conf_matrix(adata_test,
celltype = 'celltype_l2',
pred_celltype = 'pred_celltype_l2',
annot_kws = {"size":5},
linewidths = 0.1, linecolor = 'black',
fmt = ".2f",
ndigits_metrics = 4,
vmin = 0, vmax = 1)
[33]:
# Save anndata with predicted annotations
adata_test.write_h5ad('adata_test_imbalanced_model.h5ad')
[34]:
import session_info
session_info.show()
[34]:
Click to view session information
----- anndata 0.10.8 matplotlib 3.9.2 numpy 1.25.0 pandas 2.2.3 scanpy 1.10.3 scparadise 0.3.2_beta seaborn 0.13.2 session_info 1.0.0 -----
Click to view modules imported as dependencies
PIL 10.4.0 anyio NA arrow 1.3.0 asciitree NA asttokens NA attr 24.2.0 attrs 24.2.0 awkward 2.7.1 awkward_cpp NA babel 2.16.0 backports NA certifi 2024.08.30 cffi 1.17.1 charset_normalizer 3.3.2 cloudpickle 3.1.0 colorlog NA comm 0.2.2 cycler 0.12.1 cython_runtime NA dask 2024.8.0 dateutil 2.9.0.post0 debugpy 1.8.6 decorator 5.1.1 defusedxml 0.7.1 exceptiongroup 1.2.2 executing 2.1.0 fastjsonschema NA fqdn NA fsspec 2023.6.0 h5py 3.12.1 idna 3.10 igraph 0.11.6 imblearn 0.12.3 importlib_metadata NA importlib_resources NA ipykernel 6.29.5 isoduration NA jaraco NA jedi 0.19.1 jinja2 3.1.4 joblib 1.4.2 json5 0.9.25 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema_specifications NA jupyter_events 0.10.0 jupyter_server 2.14.2 jupyterlab_server 2.27.3 kiwisolver 1.4.7 legacy_api_wrap NA leidenalg 0.10.2 llvmlite 0.43.0 markupsafe 2.1.5 matplotlib_inline 0.1.7 more_itertools 10.5.0 mpl_toolkits NA mpmath 1.3.0 msgpack 1.1.0 mudata 0.2.4 muon 0.1.6 natsort 8.4.0 nbformat 5.10.4 numba 0.60.0 numcodecs 0.12.1 optuna 4.0.0 overrides NA packaging 24.1 parso 0.8.4 patsy 0.5.6 pexpect 4.9.0 platformdirs 4.3.6 plotly 5.24.1 prometheus_client NA prompt_toolkit 3.0.48 psutil 6.0.0 ptyprocess 0.7.0 pure_eval 0.2.3 pyarrow 18.1.0 pycparser 2.22 pydev_ipython NA pydevconsole NA pydevd 3.1.0 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pydot 3.0.3 pygments 2.18.0 pynndescent 0.5.13 pyparsing 3.1.4 pythonjsonlogger NA pytorch_tabnet NA pytz 2024.2 referencing NA requests 2.32.3 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rich NA rpds NA scipy 1.13.1 send2trash NA setuptools 75.1.0 setuptools_scm NA shap 0.46.0 six 1.16.0 sklearn 1.5.2 skmisc 0.3.1 slicer NA sniffio 1.3.1 stack_data 0.6.3 statsmodels 0.14.3 sympy 1.13.3 tblib 3.0.0 texttable 1.7.0 threadpoolctl 3.5.0 tlz 0.12.1 tomli 2.0.1 toolz 0.12.1 torch 2.4.1+cu121 torchgen NA tornado 6.4.1 tqdm 4.66.5 traitlets 5.14.3 triton 3.0.0 typing_extensions NA umap 0.5.6 uri_template NA urllib3 1.26.20 wcwidth 0.2.13 webcolors 24.8.0 websocket 1.8.0 yaml 6.0.2 zarr 2.18.2 zipp NA zmq 26.2.0 zoneinfo NA
----- IPython 8.18.1 jupyter_client 8.6.3 jupyter_core 5.7.2 jupyterlab 4.2.5 ----- Python 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0] Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35 ----- Session information updated at 2025-01-16 13:27