Managing large and complex training datasets#
Large and complex datasets are difficult to use for training models due to computational limitations (enormous amounts of RAM are required just to open such datasets).
Additionally, the time needed to select genes for model training increases.
Meanwhile, increasing the amount of data used for training does not always significantly improve model quality.
Here, we present a method to overcome computational limits for training models on large, complex datasets with many donors.
[1]:
# Python packages
import warnings
warnings.simplefilter('ignore')
import scanpy as sc
import scparadise
import numpy as np
import pandas as pd
import os
sc.set_figure_params(dpi = 120)
[2]:
# Create folder to save files
# single nuclear RNA-seq of human retina
os.makedirs('snRNAseq_human_retina')
[3]:
# Download CELLxGENE dataset (snRNA-seq of human retina):
# https://cellxgene.cziscience.com/collections/4c6eaf5c-6d57-4c76-b1e9-60df8c655f1e
!wget https://datasets.cellxgene.cziscience.com/2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad
--2025-02-08 13:49:05-- https://datasets.cellxgene.cziscience.com/2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad
Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 52.85.49.24, 52.85.49.28, 52.85.49.17, ...
Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|52.85.49.24|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 37946797973 (35G) [binary/octet-stream]
Saving to: ‘2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad’
2e910e62-7eaf-4c06- 100%[===================>] 35.34G 25.1MB/s in 27m 20s
2025-02-08 14:13:58 (22.1 MB/s) - ‘2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad’ saved [37946797973/37946797973]
Train scAdam model using fraction of dataset#
The entire dataset contains 3,177,310 cells and 36406 genes (35.34 GB). It is too large to open on a standard computer. Additionally, selecting genes for training a model on such a large dataset requires significant computational power and time.
Therefore, the scParadise team recommends that you extract a small portion of the dataset for further steps.
[4]:
# Obtain 25000 cells randomly
adata_fraction = scparadise.scnoah.get_frac(path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
fraction = 25000,
path_save = 'snRNAseq_human_retina',
celltype = 'cell_type',
random_state = 0)
[5]:
# Get raw counts from adata_fraction.raw
adata_fraction = adata_fraction.raw.to_adata()
# Replace variable names with gene names
adata_fraction.var.set_index('feature_name', inplace = True)
adata_fraction.var_names_make_unique()
# Normalize data
sc.pp.normalize_total(adata_fraction, target_sum = None)
sc.pp.log1p(adata_fraction)
adata_fraction.raw = adata_fraction
[6]:
# Find genes for model training (marker genes of cell types)
lst_genes = []
annotations = ['majorclass', 'cell_type'] # annotation levels
for annotation in annotations:
sc.tl.rank_genes_groups(adata_fraction,
groupby = annotation,
method = 't-test_overestim_var', pts = True)
# Filter marker genes of cell types
sc.tl.filter_rank_genes_groups(adata_fraction,
min_fold_change = 1.0,
min_in_group_fraction = 0.4,
key_added = 'filtered_rank_genes_groups')
# Create list of genes for model training
for i in adata_fraction.obs[annotation].unique():
df = sc.get.rank_genes_groups_df(adata_fraction, group = i, key = 'filtered_rank_genes_groups', pval_cutoff = 0.05)
df['pts_comparizon'] = df['pct_nz_group']/df['pct_nz_reference']
lst_genes.extend(df.sort_values(by = 'logfoldchanges', ascending = False).head(20)['names'].tolist())
lst_genes.extend(df.sort_values(by = 'pts_comparizon', ascending = False).head(20)['names'].tolist())
# Remove duplicates
lst_genes = np.unique(lst_genes).tolist()
[7]:
# Subset genes for model training
adata_fraction = adata_fraction[:, lst_genes]
[8]:
# Alternative way to select genes for model training
# sc.pp.highly_variable_genes(adata_fraction,
# n_top_genes = 1000,
# subset = True)
# lst_genes = adata_fraction.var_names.tolist()
[9]:
adata_balanced = scparadise.scnoah.balance(adata_fraction,
celltype_l1 = annotations[0], # majorclass
celltype_l2 = annotations[1], # cell_type
sample = 'donor_id')
Successfully undersampled cell types: retinal rod cell, GABAergic amacrine cell, Mueller cell, OFF midget ganglion cell, ON midget ganglion cell, flat midget bipolar cell, glycinergic amacrine cell, retinal cone cell, invaginating midget bipolar cell
Successfully oversampled cell types: rod bipolar cell, H1 horizontal cell, diffuse bipolar 2 cell, amacrine cell, retinal bipolar neuron, diffuse bipolar 1 cell, diffuse bipolar 4 cell, diffuse bipolar 3b cell, retinal ganglion cell, giant bipolar cell, diffuse bipolar 3a cell, starburst amacrine cell, diffuse bipolar 6 cell, OFFx cell, astrocyte, H2 horizontal cell, OFF parasol ganglion cell, S cone cell, ON parasol ganglion cell, microglial cell, ON-blue cone bipolar cell, retinal pigment epithelial cell
[10]:
# Train scadam model using adata_fraction dataset
scparadise.scadam.train(adata_balanced,
path = 'snRNAseq_human_retina', # path to save model
model_name = 'model_scAdam', # folder name with model
celltype_l1 = 'celltype_l1', # previously: majorclass
celltype_l2 = 'celltype_l2', # previously: cell_type
eval_metric = ['balanced_accuracy', 'accuracy'])
Successfully saved genes names for training model
Successfully saved dictionary of dataset annotations
Train dataset contains: 22515 cells, it is 90.0 % of input dataset
Test dataset contains: 2502 cells, it is 10.0 % of input dataset
Accelerator: cuda
Start training
epoch 0 | loss: 2.81661 | train_balanced_accuracy: 0.07509 | train_accuracy: 0.23478 | valid_balanced_accuracy: 0.07484 | valid_accuracy: 0.23441 | 0:00:01s
epoch 1 | loss: 2.15053 | train_balanced_accuracy: 0.29301 | train_accuracy: 0.46274 | valid_balanced_accuracy: 0.29203 | valid_accuracy: 0.46203 | 0:00:02s
epoch 2 | loss: 1.40403 | train_balanced_accuracy: 0.53809 | train_accuracy: 0.6624 | valid_balanced_accuracy: 0.54454 | valid_accuracy: 0.66587 | 0:00:03s
epoch 3 | loss: 0.89481 | train_balanced_accuracy: 0.75518 | train_accuracy: 0.80697 | valid_balanced_accuracy: 0.76594 | valid_accuracy: 0.81095 | 0:00:04s
epoch 4 | loss: 0.66516 | train_balanced_accuracy: 0.81854 | train_accuracy: 0.85328 | valid_balanced_accuracy: 0.8175 | valid_accuracy: 0.84972 | 0:00:06s
epoch 5 | loss: 0.57448 | train_balanced_accuracy: 0.89062 | train_accuracy: 0.90517 | valid_balanced_accuracy: 0.88043 | valid_accuracy: 0.89728 | 0:00:07s
epoch 6 | loss: 0.50396 | train_balanced_accuracy: 0.92367 | train_accuracy: 0.9328 | valid_balanced_accuracy: 0.91688 | valid_accuracy: 0.92666 | 0:00:08s
epoch 7 | loss: 0.44609 | train_balanced_accuracy: 0.93663 | train_accuracy: 0.94126 | valid_balanced_accuracy: 0.92773 | valid_accuracy: 0.93385 | 0:00:09s
epoch 8 | loss: 0.38922 | train_balanced_accuracy: 0.94813 | train_accuracy: 0.95143 | valid_balanced_accuracy: 0.93811 | valid_accuracy: 0.94305 | 0:00:10s
epoch 9 | loss: 0.38246 | train_balanced_accuracy: 0.95886 | train_accuracy: 0.96103 | valid_balanced_accuracy: 0.94843 | valid_accuracy: 0.95324 | 0:00:11s
epoch 10 | loss: 0.35179 | train_balanced_accuracy: 0.96238 | train_accuracy: 0.96513 | valid_balanced_accuracy: 0.95295 | valid_accuracy: 0.95663 | 0:00:12s
epoch 11 | loss: 0.33521 | train_balanced_accuracy: 0.96765 | train_accuracy: 0.96915 | valid_balanced_accuracy: 0.95609 | valid_accuracy: 0.96023 | 0:00:14s
epoch 12 | loss: 0.32833 | train_balanced_accuracy: 0.9687 | train_accuracy: 0.96975 | valid_balanced_accuracy: 0.9598 | valid_accuracy: 0.96343 | 0:00:15s
epoch 13 | loss: 0.30299 | train_balanced_accuracy: 0.97468 | train_accuracy: 0.97524 | valid_balanced_accuracy: 0.96538 | valid_accuracy: 0.96783 | 0:00:16s
epoch 14 | loss: 0.29462 | train_balanced_accuracy: 0.97203 | train_accuracy: 0.97264 | valid_balanced_accuracy: 0.96418 | valid_accuracy: 0.96583 | 0:00:17s
epoch 15 | loss: 0.28393 | train_balanced_accuracy: 0.97724 | train_accuracy: 0.97742 | valid_balanced_accuracy: 0.97102 | valid_accuracy: 0.97182 | 0:00:18s
epoch 16 | loss: 0.28427 | train_balanced_accuracy: 0.98081 | train_accuracy: 0.98148 | valid_balanced_accuracy: 0.97726 | valid_accuracy: 0.97862 | 0:00:19s
epoch 17 | loss: 0.27928 | train_balanced_accuracy: 0.98125 | train_accuracy: 0.98123 | valid_balanced_accuracy: 0.97695 | valid_accuracy: 0.97722 | 0:00:20s
epoch 18 | loss: 0.2769 | train_balanced_accuracy: 0.98287 | train_accuracy: 0.98314 | valid_balanced_accuracy: 0.97864 | valid_accuracy: 0.97962 | 0:00:21s
epoch 19 | loss: 0.27064 | train_balanced_accuracy: 0.98377 | train_accuracy: 0.98417 | valid_balanced_accuracy: 0.97769 | valid_accuracy: 0.97922 | 0:00:22s
epoch 20 | loss: 0.26606 | train_balanced_accuracy: 0.98612 | train_accuracy: 0.98619 | valid_balanced_accuracy: 0.98071 | valid_accuracy: 0.98201 | 0:00:24s
epoch 21 | loss: 0.25256 | train_balanced_accuracy: 0.98787 | train_accuracy: 0.98794 | valid_balanced_accuracy: 0.97919 | valid_accuracy: 0.98062 | 0:00:25s
epoch 22 | loss: 0.25767 | train_balanced_accuracy: 0.98873 | train_accuracy: 0.98903 | valid_balanced_accuracy: 0.98121 | valid_accuracy: 0.98161 | 0:00:26s
epoch 23 | loss: 0.24242 | train_balanced_accuracy: 0.98856 | train_accuracy: 0.98867 | valid_balanced_accuracy: 0.98525 | valid_accuracy: 0.98581 | 0:00:27s
epoch 24 | loss: 0.23697 | train_balanced_accuracy: 0.9894 | train_accuracy: 0.98943 | valid_balanced_accuracy: 0.98418 | valid_accuracy: 0.98501 | 0:00:28s
epoch 25 | loss: 0.23557 | train_balanced_accuracy: 0.98993 | train_accuracy: 0.98998 | valid_balanced_accuracy: 0.98402 | valid_accuracy: 0.98541 | 0:00:29s
epoch 26 | loss: 0.22787 | train_balanced_accuracy: 0.98896 | train_accuracy: 0.98901 | valid_balanced_accuracy: 0.98383 | valid_accuracy: 0.98481 | 0:00:30s
epoch 27 | loss: 0.23513 | train_balanced_accuracy: 0.99035 | train_accuracy: 0.99038 | valid_balanced_accuracy: 0.98575 | valid_accuracy: 0.98621 | 0:00:32s
epoch 28 | loss: 0.23008 | train_balanced_accuracy: 0.98933 | train_accuracy: 0.9893 | valid_balanced_accuracy: 0.98308 | valid_accuracy: 0.98341 | 0:00:33s
epoch 29 | loss: 0.23302 | train_balanced_accuracy: 0.99037 | train_accuracy: 0.9903 | valid_balanced_accuracy: 0.98438 | valid_accuracy: 0.98521 | 0:00:34s
epoch 30 | loss: 0.22559 | train_balanced_accuracy: 0.98943 | train_accuracy: 0.98958 | valid_balanced_accuracy: 0.98176 | valid_accuracy: 0.98321 | 0:00:35s
epoch 31 | loss: 0.22367 | train_balanced_accuracy: 0.9913 | train_accuracy: 0.99132 | valid_balanced_accuracy: 0.98507 | valid_accuracy: 0.98541 | 0:00:36s
epoch 32 | loss: 0.21816 | train_balanced_accuracy: 0.99195 | train_accuracy: 0.99201 | valid_balanced_accuracy: 0.98679 | valid_accuracy: 0.98721 | 0:00:37s
epoch 33 | loss: 0.21323 | train_balanced_accuracy: 0.99217 | train_accuracy: 0.99216 | valid_balanced_accuracy: 0.98691 | valid_accuracy: 0.98721 | 0:00:38s
epoch 34 | loss: 0.21449 | train_balanced_accuracy: 0.99312 | train_accuracy: 0.99318 | valid_balanced_accuracy: 0.9879 | valid_accuracy: 0.98841 | 0:00:39s
epoch 35 | loss: 0.20823 | train_balanced_accuracy: 0.9933 | train_accuracy: 0.99329 | valid_balanced_accuracy: 0.98836 | valid_accuracy: 0.98821 | 0:00:40s
epoch 36 | loss: 0.21154 | train_balanced_accuracy: 0.99371 | train_accuracy: 0.99374 | valid_balanced_accuracy: 0.98619 | valid_accuracy: 0.98681 | 0:00:42s
epoch 37 | loss: 0.21443 | train_balanced_accuracy: 0.99321 | train_accuracy: 0.9932 | valid_balanced_accuracy: 0.98757 | valid_accuracy: 0.98741 | 0:00:43s
epoch 38 | loss: 0.20554 | train_balanced_accuracy: 0.99325 | train_accuracy: 0.9934 | valid_balanced_accuracy: 0.98595 | valid_accuracy: 0.98661 | 0:00:44s
epoch 39 | loss: 0.21128 | train_balanced_accuracy: 0.99406 | train_accuracy: 0.99405 | valid_balanced_accuracy: 0.98567 | valid_accuracy: 0.98681 | 0:00:45s
epoch 40 | loss: 0.20781 | train_balanced_accuracy: 0.99394 | train_accuracy: 0.994 | valid_balanced_accuracy: 0.98512 | valid_accuracy: 0.98641 | 0:00:46s
epoch 41 | loss: 0.19599 | train_balanced_accuracy: 0.99318 | train_accuracy: 0.99323 | valid_balanced_accuracy: 0.98457 | valid_accuracy: 0.98581 | 0:00:47s
epoch 42 | loss: 0.20263 | train_balanced_accuracy: 0.9929 | train_accuracy: 0.99289 | valid_balanced_accuracy: 0.98527 | valid_accuracy: 0.98641 | 0:00:48s
epoch 43 | loss: 0.19245 | train_balanced_accuracy: 0.99362 | train_accuracy: 0.99363 | valid_balanced_accuracy: 0.98647 | valid_accuracy: 0.98681 | 0:00:49s
epoch 44 | loss: 0.19874 | train_balanced_accuracy: 0.99437 | train_accuracy: 0.99436 | valid_balanced_accuracy: 0.98734 | valid_accuracy: 0.98821 | 0:00:51s
Early stopping occurred at epoch 44 with best_epoch = 34 and best_valid_accuracy = 0.98841
Successfully saved training history and parameters
Successfully saved model at snRNAseq_human_retina/model_scAdam/model.zip
Evaluation of model quality#
For model evaluation, we use another subset of 25,000 cells generated using a different random state.
[11]:
# Get test dataset for model quality evaluation
adata_test = scparadise.scnoah.get_frac(path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
fraction = 25000,
path_save = 'snRNAseq_human_retina',
celltype = 'cell_type',
random_state = 42)
[12]:
# Check common cells between test and training datasets
lst_train = adata_fraction.obs_names.tolist()
lst_test = adata_test.obs_names.tolist()
lst_train.extend(lst_test)
lst_train = np.unique(lst_train)
percent = round((2 * len(lst_test) - len(lst_train))/len(lst_test)*100, 5)
print(f"There are {percent} % common cells ({2 * len(lst_test) - len(lst_train)} cells) between the test and training datasets")
There are 0.836 % common cells (209 cells) between the test and training datasets
Less than 1% of cells are the same between the test dataset and the training dataset. This number of similar cells can be ignored, and we can proceed with testing the model’s quality.
[13]:
# Apply the same preprocessing steps to the test dataset as used for training
# Get raw counts from adata_fraction.raw
adata_test = adata_test.raw.to_adata()
# Replace variable names with gene names
adata_test.var.set_index('feature_name', inplace = True)
adata_test.var_names_make_unique()
# Normalize data
sc.pp.normalize_total(adata_test, target_sum = None)
sc.pp.log1p(adata_test)
adata_test.raw = adata_test
[14]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
path_model = 'snRNAseq_human_retina/model_scAdam')
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Successfully loaded model
Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
[15]:
## Check model quality
df_l1 = scparadise.scnoah.report_classif_full(adata_test,
celltype = 'majorclass',
pred_celltype = 'pred_celltype_l1')
df_l1
[15]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| AC | 0.9998 | 0.9964 | 1.0 | 0.9981 | 0.9982 | 0.996 | 4496 |
| Astrocyte | 1.0000 | 0.982 | 1.0 | 0.9909 | 0.991 | 0.9802 | 111 |
| BC | 0.9969 | 0.9994 | 0.9991 | 0.9982 | 0.9993 | 0.9986 | 5437 |
| Cone | 1.0000 | 0.999 | 1.0 | 0.9995 | 0.9995 | 0.9989 | 1000 |
| HC | 0.9984 | 1.0 | 1.0 | 0.9992 | 1.0 | 1.0 | 634 |
| MG | 0.9994 | 0.9983 | 1.0 | 0.9989 | 0.9991 | 0.9981 | 1744 |
| Microglia | 1.0000 | 0.9744 | 1.0 | 0.987 | 0.9871 | 0.9719 | 39 |
| RGC | 0.9981 | 0.9997 | 0.9997 | 0.9989 | 0.9997 | 0.9994 | 3144 |
| RPE | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 7 |
| Rod | 0.9998 | 0.9999 | 0.9999 | 0.9998 | 0.9999 | 0.9998 | 8388 |
| macro avg | 0.9992 | 0.9949 | 0.9999 | 0.997 | 0.9974 | 0.9943 | |
| weighted avg | 0.9989 | 0.9989 | 0.9997 | 0.9989 | 0.9993 | 0.9985 | |
| Accuracy | 0.9989 | ||||||
| Balanced accuracy | 0.9949 |
[16]:
## Check model quality
df_l2 = scparadise.scnoah.report_classif_full(adata_test,
celltype = 'cell_type',
pred_celltype = 'pred_celltype_l2')
df_l2
[16]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| GABAergic amacrine cell | 0.9940 | 0.9881 | 0.9992 | 0.991 | 0.9936 | 0.9862 | 2855 |
| H1 horizontal cell | 0.9871 | 0.9907 | 0.9997 | 0.9889 | 0.9952 | 0.9896 | 540 |
| H2 horizontal cell | 0.9355 | 0.9255 | 0.9998 | 0.9305 | 0.9619 | 0.9184 | 94 |
| Mueller cell | 0.9994 | 0.9977 | 1.0 | 0.9986 | 0.9988 | 0.9974 | 1744 |
| OFF midget ganglion cell | 0.9146 | 0.896 | 0.9944 | 0.9052 | 0.9439 | 0.8822 | 1577 |
| OFF parasol ganglion cell | 0.9157 | 0.962 | 0.9997 | 0.9383 | 0.9807 | 0.9581 | 79 |
| OFFx cell | 0.9449 | 0.9836 | 0.9997 | 0.9639 | 0.9916 | 0.9817 | 122 |
| ON midget ganglion cell | 0.9259 | 0.9003 | 0.9964 | 0.9129 | 0.9471 | 0.8884 | 1193 |
| ON parasol ganglion cell | 0.8889 | 0.9796 | 0.9998 | 0.932 | 0.9896 | 0.9774 | 49 |
| ON-blue cone bipolar cell | 0.8750 | 0.913 | 0.9999 | 0.8936 | 0.9555 | 0.905 | 23 |
| S cone cell | 0.9054 | 1.0 | 0.9997 | 0.9504 | 0.9999 | 0.9997 | 67 |
| amacrine cell | 0.9731 | 0.9644 | 0.9995 | 0.9688 | 0.9818 | 0.9606 | 450 |
| astrocyte | 1.0000 | 0.982 | 1.0 | 0.9909 | 0.991 | 0.9802 | 111 |
| diffuse bipolar 1 cell | 0.9949 | 0.9874 | 0.9999 | 0.9912 | 0.9936 | 0.9861 | 397 |
| diffuse bipolar 2 cell | 0.9944 | 0.9869 | 0.9999 | 0.9906 | 0.9934 | 0.9855 | 536 |
| diffuse bipolar 3a cell | 0.9882 | 0.9767 | 0.9999 | 0.9825 | 0.9883 | 0.9744 | 172 |
| diffuse bipolar 3b cell | 0.9685 | 0.9964 | 0.9996 | 0.9823 | 0.998 | 0.9957 | 278 |
| diffuse bipolar 4 cell | 0.9871 | 0.9922 | 0.9998 | 0.9897 | 0.996 | 0.9913 | 387 |
| diffuse bipolar 6 cell | 0.9474 | 0.9863 | 0.9997 | 0.9664 | 0.993 | 0.9847 | 146 |
| flat midget bipolar cell | 0.9947 | 0.9929 | 0.9997 | 0.9938 | 0.9963 | 0.992 | 1134 |
| giant bipolar cell | 0.9608 | 0.9849 | 0.9997 | 0.9727 | 0.9923 | 0.9832 | 199 |
| glycinergic amacrine cell | 0.9769 | 0.9816 | 0.999 | 0.9792 | 0.9903 | 0.9789 | 1032 |
| invaginating midget bipolar cell | 0.9940 | 0.9868 | 0.9998 | 0.9904 | 0.9933 | 0.9853 | 835 |
| microglial cell | 0.9500 | 0.9744 | 0.9999 | 0.962 | 0.9871 | 0.9718 | 39 |
| retinal bipolar neuron | 0.9927 | 0.9879 | 0.9999 | 0.9903 | 0.9939 | 0.9866 | 412 |
| retinal cone cell | 1.0000 | 0.9914 | 1.0 | 0.9957 | 0.9957 | 0.9906 | 933 |
| retinal ganglion cell | 0.4304 | 0.5407 | 0.9929 | 0.4793 | 0.7327 | 0.5125 | 246 |
| retinal pigment epithelial cell | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 7 |
| retinal rod cell | 0.9996 | 0.9999 | 0.9998 | 0.9998 | 0.9999 | 0.9997 | 8388 |
| rod bipolar cell | 0.9962 | 0.9962 | 0.9999 | 0.9962 | 0.9981 | 0.9957 | 796 |
| starburst amacrine cell | 0.9691 | 0.9874 | 0.9998 | 0.9782 | 0.9936 | 0.986 | 159 |
| macro avg | 0.9485 | 0.9624 | 0.9993 | 0.955 | 0.9795 | 0.9589 | |
| weighted avg | 0.9791 | 0.9778 | 0.9991 | 0.9784 | 0.988 | 0.9752 | |
| Accuracy | 0.9778 | ||||||
| Balanced accuracy | 0.9624 |
The model performs well except for the retinal ganglion cell.
You could try using a different random state to generate another test dataset.
Iterative warm start training (optional)#
You may use another subset of the whole dataset to increase model generalization, but this may lead to overfitting.
[17]:
# Do not change the lower bound of the range to exclude 0, which was used for the primary training of the model
for i in range(1, 4):
# Obtain 25000 cells randomly
adata_fraction = scparadise.scnoah.get_frac(path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
fraction = 25000,
path_save = 'snRNAseq_human_retina',
celltype = 'cell_type',
random_state = i)
# Get raw counts from adata_fraction.raw
adata_fraction = adata_fraction.raw.to_adata()
# Replace variable names with gene names
adata_fraction.var.set_index('feature_name', inplace=True)
adata_fraction.var_names_make_unique()
# Normalize data
sc.pp.normalize_total(adata_fraction, target_sum = None)
sc.pp.log1p(adata_fraction)
adata_fraction.raw = adata_fraction
# Subset genes for model training
adata_fraction = adata_fraction[:, lst_genes]
# Balance dataset
adata_balanced = scparadise.scnoah.balance(adata_fraction,
celltype_l1 = annotations[0], # majorclass
celltype_l2 = annotations[1], # cell_type
sample = 'donor_id')
adata_balanced.raw = adata_balanced
# Warm start requires second training dataset and path to pretrained model
scparadise.scadam.warm_start(adata_balanced,
path = 'snRNAseq_human_retina', # path to save model
model_name = 'model_scAdam', # folder name with pretrained model
celltype_l1 = 'celltype_l1', # previously: majorclass
celltype_l2 = 'celltype_l2', # previously: cell_type
eval_metric = ['balanced_accuracy', 'accuracy'])
Successfully undersampled cell types: retinal rod cell, GABAergic amacrine cell, Mueller cell, OFF midget ganglion cell, ON midget ganglion cell, flat midget bipolar cell, glycinergic amacrine cell, retinal cone cell, invaginating midget bipolar cell
Successfully oversampled cell types: rod bipolar cell, H1 horizontal cell, diffuse bipolar 2 cell, amacrine cell, retinal bipolar neuron, diffuse bipolar 1 cell, diffuse bipolar 4 cell, diffuse bipolar 3b cell, retinal ganglion cell, giant bipolar cell, diffuse bipolar 3a cell, starburst amacrine cell, diffuse bipolar 6 cell, OFFx cell, astrocyte, H2 horizontal cell, OFF parasol ganglion cell, S cone cell, ON parasol ganglion cell, microglial cell, ON-blue cone bipolar cell, retinal pigment epithelial cell
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Train dataset contains: 22515 cells, it is 90.0 % of input dataset
Test dataset contains: 2502 cells, it is 10.0 % of input dataset
Successfully loaded parameters
Accelerator: cuda
Start training
epoch 0 | loss: 0.27731 | train_balanced_accuracy: 0.98601 | train_accuracy: 0.98616 | valid_balanced_accuracy: 0.98477 | valid_accuracy: 0.98481 | 0:00:01s
epoch 1 | loss: 0.25079 | train_balanced_accuracy: 0.98945 | train_accuracy: 0.98938 | valid_balanced_accuracy: 0.98653 | valid_accuracy: 0.98661 | 0:00:02s
epoch 2 | loss: 0.23321 | train_balanced_accuracy: 0.99028 | train_accuracy: 0.99036 | valid_balanced_accuracy: 0.98633 | valid_accuracy: 0.98641 | 0:00:03s
epoch 3 | loss: 0.22734 | train_balanced_accuracy: 0.99085 | train_accuracy: 0.99107 | valid_balanced_accuracy: 0.98531 | valid_accuracy: 0.98581 | 0:00:05s
epoch 4 | loss: 0.21247 | train_balanced_accuracy: 0.99147 | train_accuracy: 0.99143 | valid_balanced_accuracy: 0.98704 | valid_accuracy: 0.98701 | 0:00:06s
epoch 5 | loss: 0.22125 | train_balanced_accuracy: 0.99225 | train_accuracy: 0.99225 | valid_balanced_accuracy: 0.98551 | valid_accuracy: 0.98601 | 0:00:07s
epoch 6 | loss: 0.22236 | train_balanced_accuracy: 0.9914 | train_accuracy: 0.99138 | valid_balanced_accuracy: 0.98537 | valid_accuracy: 0.98541 | 0:00:08s
epoch 7 | loss: 0.22278 | train_balanced_accuracy: 0.99242 | train_accuracy: 0.99254 | valid_balanced_accuracy: 0.98903 | valid_accuracy: 0.98901 | 0:00:09s
epoch 8 | loss: 0.20157 | train_balanced_accuracy: 0.99396 | train_accuracy: 0.994 | valid_balanced_accuracy: 0.98829 | valid_accuracy: 0.98821 | 0:00:11s
epoch 9 | loss: 0.21051 | train_balanced_accuracy: 0.99327 | train_accuracy: 0.99327 | valid_balanced_accuracy: 0.98776 | valid_accuracy: 0.98781 | 0:00:12s
epoch 10 | loss: 0.2045 | train_balanced_accuracy: 0.99525 | train_accuracy: 0.99531 | valid_balanced_accuracy: 0.98876 | valid_accuracy: 0.98881 | 0:00:14s
epoch 11 | loss: 0.20367 | train_balanced_accuracy: 0.99413 | train_accuracy: 0.99412 | valid_balanced_accuracy: 0.98828 | valid_accuracy: 0.98821 | 0:00:15s
epoch 12 | loss: 0.20131 | train_balanced_accuracy: 0.99447 | train_accuracy: 0.99465 | valid_balanced_accuracy: 0.98744 | valid_accuracy: 0.98741 | 0:00:16s
epoch 13 | loss: 0.19265 | train_balanced_accuracy: 0.99464 | train_accuracy: 0.99465 | valid_balanced_accuracy: 0.98832 | valid_accuracy: 0.98841 | 0:00:17s
epoch 14 | loss: 0.18869 | train_balanced_accuracy: 0.99439 | train_accuracy: 0.9944 | valid_balanced_accuracy: 0.98776 | valid_accuracy: 0.98761 | 0:00:19s
epoch 15 | loss: 0.19078 | train_balanced_accuracy: 0.99424 | train_accuracy: 0.99425 | valid_balanced_accuracy: 0.98777 | valid_accuracy: 0.98781 | 0:00:20s
epoch 16 | loss: 0.19392 | train_balanced_accuracy: 0.99505 | train_accuracy: 0.99505 | valid_balanced_accuracy: 0.98756 | valid_accuracy: 0.98761 | 0:00:21s
epoch 17 | loss: 0.19063 | train_balanced_accuracy: 0.99464 | train_accuracy: 0.99474 | valid_balanced_accuracy: 0.98805 | valid_accuracy: 0.98821 | 0:00:22s
Early stopping occurred at epoch 17 with best_epoch = 7 and best_valid_accuracy = 0.98901
Successfully saved training history and parameters
Successfully saved model at snRNAseq_human_retina/model_scAdam/model.zip
Successfully undersampled cell types: retinal rod cell, GABAergic amacrine cell, Mueller cell, OFF midget ganglion cell, ON midget ganglion cell, flat midget bipolar cell, glycinergic amacrine cell, retinal cone cell, invaginating midget bipolar cell
Successfully oversampled cell types: rod bipolar cell, H1 horizontal cell, diffuse bipolar 2 cell, amacrine cell, retinal bipolar neuron, diffuse bipolar 1 cell, diffuse bipolar 4 cell, diffuse bipolar 3b cell, retinal ganglion cell, giant bipolar cell, diffuse bipolar 3a cell, starburst amacrine cell, diffuse bipolar 6 cell, OFFx cell, astrocyte, H2 horizontal cell, OFF parasol ganglion cell, S cone cell, ON parasol ganglion cell, microglial cell, ON-blue cone bipolar cell, retinal pigment epithelial cell
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Train dataset contains: 22515 cells, it is 90.0 % of input dataset
Test dataset contains: 2502 cells, it is 10.0 % of input dataset
Successfully loaded parameters
Accelerator: cuda
Start training
epoch 0 | loss: 0.25691 | train_balanced_accuracy: 0.98874 | train_accuracy: 0.98885 | valid_balanced_accuracy: 0.98564 | valid_accuracy: 0.98581 | 0:00:01s
epoch 1 | loss: 0.22516 | train_balanced_accuracy: 0.99112 | train_accuracy: 0.99116 | valid_balanced_accuracy: 0.98905 | valid_accuracy: 0.98901 | 0:00:02s
epoch 2 | loss: 0.21182 | train_balanced_accuracy: 0.99143 | train_accuracy: 0.99145 | valid_balanced_accuracy: 0.98918 | valid_accuracy: 0.98921 | 0:00:03s
epoch 3 | loss: 0.21638 | train_balanced_accuracy: 0.99258 | train_accuracy: 0.99272 | valid_balanced_accuracy: 0.99106 | valid_accuracy: 0.99121 | 0:00:04s
epoch 4 | loss: 0.20085 | train_balanced_accuracy: 0.99276 | train_accuracy: 0.99285 | valid_balanced_accuracy: 0.98915 | valid_accuracy: 0.98941 | 0:00:06s
epoch 5 | loss: 0.21065 | train_balanced_accuracy: 0.99298 | train_accuracy: 0.99309 | valid_balanced_accuracy: 0.98885 | valid_accuracy: 0.98901 | 0:00:07s
epoch 6 | loss: 0.20995 | train_balanced_accuracy: 0.99409 | train_accuracy: 0.99412 | valid_balanced_accuracy: 0.99084 | valid_accuracy: 0.99081 | 0:00:08s
epoch 7 | loss: 0.20724 | train_balanced_accuracy: 0.99507 | train_accuracy: 0.99507 | valid_balanced_accuracy: 0.98924 | valid_accuracy: 0.98981 | 0:00:09s
epoch 8 | loss: 0.19699 | train_balanced_accuracy: 0.99554 | train_accuracy: 0.99556 | valid_balanced_accuracy: 0.99105 | valid_accuracy: 0.99121 | 0:00:10s
epoch 9 | loss: 0.19933 | train_balanced_accuracy: 0.99525 | train_accuracy: 0.99531 | valid_balanced_accuracy: 0.99116 | valid_accuracy: 0.99121 | 0:00:12s
epoch 10 | loss: 0.18906 | train_balanced_accuracy: 0.99584 | train_accuracy: 0.99587 | valid_balanced_accuracy: 0.99185 | valid_accuracy: 0.99201 | 0:00:13s
epoch 11 | loss: 0.18892 | train_balanced_accuracy: 0.9958 | train_accuracy: 0.99583 | valid_balanced_accuracy: 0.99345 | valid_accuracy: 0.99361 | 0:00:14s
epoch 12 | loss: 0.19035 | train_balanced_accuracy: 0.99567 | train_accuracy: 0.99567 | valid_balanced_accuracy: 0.99225 | valid_accuracy: 0.99261 | 0:00:15s
epoch 13 | loss: 0.18244 | train_balanced_accuracy: 0.996 | train_accuracy: 0.99607 | valid_balanced_accuracy: 0.99154 | valid_accuracy: 0.99201 | 0:00:16s
epoch 14 | loss: 0.18461 | train_balanced_accuracy: 0.99528 | train_accuracy: 0.99531 | valid_balanced_accuracy: 0.98973 | valid_accuracy: 0.99041 | 0:00:18s
epoch 15 | loss: 0.18248 | train_balanced_accuracy: 0.9962 | train_accuracy: 0.99622 | valid_balanced_accuracy: 0.99272 | valid_accuracy: 0.99341 | 0:00:19s
epoch 16 | loss: 0.17748 | train_balanced_accuracy: 0.99662 | train_accuracy: 0.9966 | valid_balanced_accuracy: 0.99143 | valid_accuracy: 0.99201 | 0:00:20s
epoch 17 | loss: 0.18126 | train_balanced_accuracy: 0.99735 | train_accuracy: 0.99736 | valid_balanced_accuracy: 0.99074 | valid_accuracy: 0.99121 | 0:00:21s
epoch 18 | loss: 0.18514 | train_balanced_accuracy: 0.99689 | train_accuracy: 0.99691 | valid_balanced_accuracy: 0.99216 | valid_accuracy: 0.99221 | 0:00:22s
epoch 19 | loss: 0.18461 | train_balanced_accuracy: 0.99648 | train_accuracy: 0.99649 | valid_balanced_accuracy: 0.99345 | valid_accuracy: 0.9938 | 0:00:24s
epoch 20 | loss: 0.17847 | train_balanced_accuracy: 0.99637 | train_accuracy: 0.99638 | valid_balanced_accuracy: 0.9924 | valid_accuracy: 0.99241 | 0:00:25s
epoch 21 | loss: 0.17701 | train_balanced_accuracy: 0.99728 | train_accuracy: 0.99729 | valid_balanced_accuracy: 0.99336 | valid_accuracy: 0.99341 | 0:00:26s
epoch 22 | loss: 0.17557 | train_balanced_accuracy: 0.99737 | train_accuracy: 0.99738 | valid_balanced_accuracy: 0.99236 | valid_accuracy: 0.99241 | 0:00:27s
epoch 23 | loss: 0.17332 | train_balanced_accuracy: 0.99723 | train_accuracy: 0.99727 | valid_balanced_accuracy: 0.99256 | valid_accuracy: 0.99261 | 0:00:28s
epoch 24 | loss: 0.18068 | train_balanced_accuracy: 0.9969 | train_accuracy: 0.99696 | valid_balanced_accuracy: 0.99246 | valid_accuracy: 0.99281 | 0:00:29s
epoch 25 | loss: 0.17254 | train_balanced_accuracy: 0.9975 | train_accuracy: 0.99756 | valid_balanced_accuracy: 0.99234 | valid_accuracy: 0.99281 | 0:00:31s
epoch 26 | loss: 0.16806 | train_balanced_accuracy: 0.99776 | train_accuracy: 0.99776 | valid_balanced_accuracy: 0.99225 | valid_accuracy: 0.99241 | 0:00:32s
epoch 27 | loss: 0.17744 | train_balanced_accuracy: 0.99814 | train_accuracy: 0.99813 | valid_balanced_accuracy: 0.99426 | valid_accuracy: 0.9942 | 0:00:33s
epoch 28 | loss: 0.17337 | train_balanced_accuracy: 0.99796 | train_accuracy: 0.99802 | valid_balanced_accuracy: 0.99416 | valid_accuracy: 0.9942 | 0:00:34s
epoch 29 | loss: 0.17079 | train_balanced_accuracy: 0.99802 | train_accuracy: 0.99809 | valid_balanced_accuracy: 0.99396 | valid_accuracy: 0.994 | 0:00:35s
epoch 30 | loss: 0.17553 | train_balanced_accuracy: 0.99818 | train_accuracy: 0.99818 | valid_balanced_accuracy: 0.99407 | valid_accuracy: 0.994 | 0:00:37s
epoch 31 | loss: 0.16556 | train_balanced_accuracy: 0.99836 | train_accuracy: 0.99833 | valid_balanced_accuracy: 0.99316 | valid_accuracy: 0.99321 | 0:00:38s
epoch 32 | loss: 0.16672 | train_balanced_accuracy: 0.99814 | train_accuracy: 0.99813 | valid_balanced_accuracy: 0.99247 | valid_accuracy: 0.99241 | 0:00:39s
epoch 33 | loss: 0.15766 | train_balanced_accuracy: 0.99798 | train_accuracy: 0.99798 | valid_balanced_accuracy: 0.99436 | valid_accuracy: 0.9944 | 0:00:40s
epoch 34 | loss: 0.16557 | train_balanced_accuracy: 0.99752 | train_accuracy: 0.99751 | valid_balanced_accuracy: 0.99455 | valid_accuracy: 0.9946 | 0:00:41s
epoch 35 | loss: 0.16067 | train_balanced_accuracy: 0.99746 | train_accuracy: 0.99742 | valid_balanced_accuracy: 0.99307 | valid_accuracy: 0.99301 | 0:00:42s
epoch 36 | loss: 0.16737 | train_balanced_accuracy: 0.99772 | train_accuracy: 0.99776 | valid_balanced_accuracy: 0.99256 | valid_accuracy: 0.99261 | 0:00:44s
epoch 37 | loss: 0.16903 | train_balanced_accuracy: 0.99801 | train_accuracy: 0.99811 | valid_balanced_accuracy: 0.99305 | valid_accuracy: 0.99321 | 0:00:45s
epoch 38 | loss: 0.166 | train_balanced_accuracy: 0.99818 | train_accuracy: 0.99818 | valid_balanced_accuracy: 0.99307 | valid_accuracy: 0.99301 | 0:00:46s
epoch 39 | loss: 0.16403 | train_balanced_accuracy: 0.99852 | train_accuracy: 0.99851 | valid_balanced_accuracy: 0.99426 | valid_accuracy: 0.9942 | 0:00:47s
epoch 40 | loss: 0.16122 | train_balanced_accuracy: 0.9984 | train_accuracy: 0.9984 | valid_balanced_accuracy: 0.99446 | valid_accuracy: 0.9944 | 0:00:48s
epoch 41 | loss: 0.16069 | train_balanced_accuracy: 0.99677 | train_accuracy: 0.99676 | valid_balanced_accuracy: 0.99356 | valid_accuracy: 0.99361 | 0:00:49s
epoch 42 | loss: 0.1639 | train_balanced_accuracy: 0.99764 | train_accuracy: 0.99765 | valid_balanced_accuracy: 0.99216 | valid_accuracy: 0.99221 | 0:00:51s
epoch 43 | loss: 0.15643 | train_balanced_accuracy: 0.99848 | train_accuracy: 0.99847 | valid_balanced_accuracy: 0.99407 | valid_accuracy: 0.994 | 0:00:52s
epoch 44 | loss: 0.16107 | train_balanced_accuracy: 0.99853 | train_accuracy: 0.99853 | valid_balanced_accuracy: 0.99427 | valid_accuracy: 0.9942 | 0:00:53s
Early stopping occurred at epoch 44 with best_epoch = 34 and best_valid_accuracy = 0.9946
Successfully saved training history and parameters
Successfully saved model at snRNAseq_human_retina/model_scAdam/model.zip
Successfully undersampled cell types: retinal rod cell, GABAergic amacrine cell, Mueller cell, OFF midget ganglion cell, ON midget ganglion cell, flat midget bipolar cell, glycinergic amacrine cell, retinal cone cell, invaginating midget bipolar cell
Successfully oversampled cell types: rod bipolar cell, H1 horizontal cell, diffuse bipolar 2 cell, amacrine cell, retinal bipolar neuron, diffuse bipolar 1 cell, diffuse bipolar 4 cell, diffuse bipolar 3b cell, retinal ganglion cell, giant bipolar cell, diffuse bipolar 3a cell, starburst amacrine cell, diffuse bipolar 6 cell, OFFx cell, astrocyte, H2 horizontal cell, OFF parasol ganglion cell, S cone cell, ON parasol ganglion cell, microglial cell, ON-blue cone bipolar cell, retinal pigment epithelial cell
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Train dataset contains: 22515 cells, it is 90.0 % of input dataset
Test dataset contains: 2502 cells, it is 10.0 % of input dataset
Successfully loaded parameters
Accelerator: cuda
Start training
epoch 0 | loss: 0.23292 | train_balanced_accuracy: 0.98984 | train_accuracy: 0.99007 | valid_balanced_accuracy: 0.98956 | valid_accuracy: 0.98981 | 0:00:01s
epoch 1 | loss: 0.20013 | train_balanced_accuracy: 0.99118 | train_accuracy: 0.99114 | valid_balanced_accuracy: 0.98888 | valid_accuracy: 0.98881 | 0:00:02s
epoch 2 | loss: 0.19169 | train_balanced_accuracy: 0.9925 | train_accuracy: 0.99258 | valid_balanced_accuracy: 0.99034 | valid_accuracy: 0.99101 | 0:00:03s
epoch 3 | loss: 0.19631 | train_balanced_accuracy: 0.99199 | train_accuracy: 0.99212 | valid_balanced_accuracy: 0.98835 | valid_accuracy: 0.98901 | 0:00:04s
epoch 4 | loss: 0.18296 | train_balanced_accuracy: 0.99278 | train_accuracy: 0.99274 | valid_balanced_accuracy: 0.99016 | valid_accuracy: 0.99041 | 0:00:06s
epoch 5 | loss: 0.18494 | train_balanced_accuracy: 0.99338 | train_accuracy: 0.99336 | valid_balanced_accuracy: 0.99138 | valid_accuracy: 0.99121 | 0:00:07s
epoch 6 | loss: 0.19315 | train_balanced_accuracy: 0.99372 | train_accuracy: 0.99369 | valid_balanced_accuracy: 0.99142 | valid_accuracy: 0.99121 | 0:00:08s
epoch 7 | loss: 0.18335 | train_balanced_accuracy: 0.99505 | train_accuracy: 0.99505 | valid_balanced_accuracy: 0.99206 | valid_accuracy: 0.99281 | 0:00:09s
epoch 8 | loss: 0.17319 | train_balanced_accuracy: 0.99473 | train_accuracy: 0.99469 | valid_balanced_accuracy: 0.99158 | valid_accuracy: 0.99141 | 0:00:11s
epoch 9 | loss: 0.18107 | train_balanced_accuracy: 0.99536 | train_accuracy: 0.99534 | valid_balanced_accuracy: 0.99324 | valid_accuracy: 0.99361 | 0:00:12s
epoch 10 | loss: 0.17673 | train_balanced_accuracy: 0.99518 | train_accuracy: 0.9952 | valid_balanced_accuracy: 0.99076 | valid_accuracy: 0.99101 | 0:00:13s
epoch 11 | loss: 0.17895 | train_balanced_accuracy: 0.99539 | train_accuracy: 0.99536 | valid_balanced_accuracy: 0.99329 | valid_accuracy: 0.99361 | 0:00:14s
epoch 12 | loss: 0.17441 | train_balanced_accuracy: 0.99641 | train_accuracy: 0.99638 | valid_balanced_accuracy: 0.99197 | valid_accuracy: 0.99181 | 0:00:16s
epoch 13 | loss: 0.16971 | train_balanced_accuracy: 0.99671 | train_accuracy: 0.99669 | valid_balanced_accuracy: 0.9918 | valid_accuracy: 0.99221 | 0:00:17s
epoch 14 | loss: 0.16967 | train_balanced_accuracy: 0.99609 | train_accuracy: 0.99609 | valid_balanced_accuracy: 0.99242 | valid_accuracy: 0.99241 | 0:00:18s
epoch 15 | loss: 0.17293 | train_balanced_accuracy: 0.99586 | train_accuracy: 0.99585 | valid_balanced_accuracy: 0.9936 | valid_accuracy: 0.994 | 0:00:19s
epoch 16 | loss: 0.16655 | train_balanced_accuracy: 0.99631 | train_accuracy: 0.99629 | valid_balanced_accuracy: 0.99203 | valid_accuracy: 0.99201 | 0:00:20s
epoch 17 | loss: 0.17146 | train_balanced_accuracy: 0.99697 | train_accuracy: 0.99694 | valid_balanced_accuracy: 0.99322 | valid_accuracy: 0.99321 | 0:00:22s
epoch 18 | loss: 0.16555 | train_balanced_accuracy: 0.99702 | train_accuracy: 0.997 | valid_balanced_accuracy: 0.99322 | valid_accuracy: 0.99321 | 0:00:23s
epoch 19 | loss: 0.17055 | train_balanced_accuracy: 0.99705 | train_accuracy: 0.99702 | valid_balanced_accuracy: 0.99422 | valid_accuracy: 0.9942 | 0:00:24s
epoch 20 | loss: 0.17196 | train_balanced_accuracy: 0.99715 | train_accuracy: 0.99714 | valid_balanced_accuracy: 0.99342 | valid_accuracy: 0.99341 | 0:00:25s
epoch 21 | loss: 0.17013 | train_balanced_accuracy: 0.99717 | train_accuracy: 0.99714 | valid_balanced_accuracy: 0.99221 | valid_accuracy: 0.99261 | 0:00:26s
epoch 22 | loss: 0.16264 | train_balanced_accuracy: 0.99747 | train_accuracy: 0.99745 | valid_balanced_accuracy: 0.99161 | valid_accuracy: 0.99201 | 0:00:28s
epoch 23 | loss: 0.16108 | train_balanced_accuracy: 0.99721 | train_accuracy: 0.99718 | valid_balanced_accuracy: 0.99337 | valid_accuracy: 0.99321 | 0:00:29s
epoch 24 | loss: 0.17589 | train_balanced_accuracy: 0.99789 | train_accuracy: 0.99787 | valid_balanced_accuracy: 0.99201 | valid_accuracy: 0.99241 | 0:00:30s
epoch 25 | loss: 0.16501 | train_balanced_accuracy: 0.99748 | train_accuracy: 0.99747 | valid_balanced_accuracy: 0.99196 | valid_accuracy: 0.99221 | 0:00:31s
epoch 26 | loss: 0.15835 | train_balanced_accuracy: 0.99729 | train_accuracy: 0.99725 | valid_balanced_accuracy: 0.99171 | valid_accuracy: 0.99181 | 0:00:33s
epoch 27 | loss: 0.16955 | train_balanced_accuracy: 0.99746 | train_accuracy: 0.99745 | valid_balanced_accuracy: 0.99203 | valid_accuracy: 0.99201 | 0:00:34s
epoch 28 | loss: 0.16966 | train_balanced_accuracy: 0.99749 | train_accuracy: 0.99747 | valid_balanced_accuracy: 0.99188 | valid_accuracy: 0.99221 | 0:00:35s
epoch 29 | loss: 0.16197 | train_balanced_accuracy: 0.99788 | train_accuracy: 0.99787 | valid_balanced_accuracy: 0.99305 | valid_accuracy: 0.99341 | 0:00:36s
Early stopping occurred at epoch 29 with best_epoch = 19 and best_valid_accuracy = 0.9942
Successfully saved training history and parameters
Successfully saved model at snRNAseq_human_retina/model_scAdam/model.zip
[18]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
path_model = 'snRNAseq_human_retina/model_scAdam')
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Successfully loaded model
Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
[19]:
## Check model quality
df_warm_start_l1 = scparadise.scnoah.report_classif_full(adata_test,
celltype='majorclass',
pred_celltype='pred_celltype_l1')
df_warm_start_l1
[19]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| AC | 0.9993 | 0.9949 | 0.9999 | 0.9971 | 0.9974 | 0.9942 | 4496 |
| Astrocyte | 0.9909 | 0.982 | 1.0 | 0.9864 | 0.9909 | 0.9802 | 111 |
| BC | 0.9971 | 0.9996 | 0.9992 | 0.9983 | 0.9994 | 0.9989 | 5437 |
| Cone | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1000 |
| HC | 0.9953 | 1.0 | 0.9999 | 0.9976 | 0.9999 | 0.9999 | 634 |
| MG | 1.0000 | 0.9989 | 1.0 | 0.9994 | 0.9994 | 0.9987 | 1744 |
| Microglia | 0.9500 | 0.9744 | 0.9999 | 0.962 | 0.9871 | 0.9718 | 39 |
| RGC | 0.9978 | 0.9997 | 0.9997 | 0.9987 | 0.9997 | 0.9994 | 3144 |
| RPE | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 7 |
| Rod | 0.9996 | 0.9995 | 0.9998 | 0.9996 | 0.9997 | 0.9993 | 8388 |
| macro avg | 0.9930 | 0.9949 | 0.9998 | 0.9939 | 0.9973 | 0.9942 | |
| weighted avg | 0.9986 | 0.9986 | 0.9997 | 0.9986 | 0.9991 | 0.9982 | |
| Accuracy | 0.9986 | ||||||
| Balanced accuracy | 0.9949 |
[20]:
## Check model quality
df_warm_start_l2 = scparadise.scnoah.report_classif_full(adata_test,
celltype='cell_type',
pred_celltype='pred_celltype_l2')
df_warm_start_l2
[20]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| GABAergic amacrine cell | 0.9965 | 0.9832 | 0.9995 | 0.9898 | 0.9913 | 0.9811 | 2855 |
| H1 horizontal cell | 0.9981 | 0.9944 | 1.0 | 0.9963 | 0.9972 | 0.9939 | 540 |
| H2 horizontal cell | 0.9490 | 0.9894 | 0.9998 | 0.9688 | 0.9946 | 0.9881 | 94 |
| Mueller cell | 1.0000 | 0.9977 | 1.0 | 0.9989 | 0.9989 | 0.9975 | 1744 |
| OFF midget ganglion cell | 0.9604 | 0.877 | 0.9976 | 0.9168 | 0.9353 | 0.8643 | 1577 |
| OFF parasol ganglion cell | 0.9630 | 0.9873 | 0.9999 | 0.975 | 0.9936 | 0.986 | 79 |
| OFFx cell | 0.9918 | 0.9918 | 1.0 | 0.9918 | 0.9959 | 0.991 | 122 |
| ON midget ganglion cell | 0.9280 | 0.9405 | 0.9963 | 0.9342 | 0.968 | 0.9318 | 1193 |
| ON parasol ganglion cell | 0.9800 | 1.0 | 1.0 | 0.9899 | 1.0 | 1.0 | 49 |
| ON-blue cone bipolar cell | 0.8800 | 0.9565 | 0.9999 | 0.9167 | 0.978 | 0.9523 | 23 |
| S cone cell | 0.9437 | 1.0 | 0.9998 | 0.971 | 0.9999 | 0.9999 | 67 |
| amacrine cell | 0.9586 | 0.9778 | 0.9992 | 0.9681 | 0.9884 | 0.9749 | 450 |
| astrocyte | 1.0000 | 0.982 | 1.0 | 0.9909 | 0.991 | 0.9802 | 111 |
| diffuse bipolar 1 cell | 0.9900 | 1.0 | 0.9998 | 0.995 | 0.9999 | 0.9999 | 397 |
| diffuse bipolar 2 cell | 0.9926 | 0.9944 | 0.9998 | 0.9935 | 0.9971 | 0.9937 | 536 |
| diffuse bipolar 3a cell | 1.0000 | 0.9884 | 1.0 | 0.9942 | 0.9942 | 0.9872 | 172 |
| diffuse bipolar 3b cell | 0.9858 | 1.0 | 0.9998 | 0.9929 | 0.9999 | 0.9999 | 278 |
| diffuse bipolar 4 cell | 0.9948 | 0.9897 | 0.9999 | 0.9922 | 0.9948 | 0.9886 | 387 |
| diffuse bipolar 6 cell | 0.9796 | 0.9863 | 0.9999 | 0.9829 | 0.9931 | 0.9848 | 146 |
| flat midget bipolar cell | 0.9965 | 0.9947 | 0.9998 | 0.9956 | 0.9973 | 0.994 | 1134 |
| giant bipolar cell | 0.9850 | 0.9899 | 0.9999 | 0.9875 | 0.9949 | 0.9888 | 199 |
| glycinergic amacrine cell | 0.9715 | 0.9893 | 0.9987 | 0.9803 | 0.994 | 0.9872 | 1032 |
| invaginating midget bipolar cell | 0.9976 | 0.994 | 0.9999 | 0.9958 | 0.997 | 0.9933 | 835 |
| microglial cell | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 39 |
| retinal bipolar neuron | 0.9903 | 0.9927 | 0.9998 | 0.9915 | 0.9963 | 0.9919 | 412 |
| retinal cone cell | 1.0000 | 0.9957 | 1.0 | 0.9979 | 0.9979 | 0.9953 | 933 |
| retinal ganglion cell | 0.4879 | 0.7358 | 0.9923 | 0.5867 | 0.8545 | 0.7114 | 246 |
| retinal pigment epithelial cell | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 7 |
| retinal rod cell | 0.9995 | 0.9998 | 0.9998 | 0.9996 | 0.9998 | 0.9995 | 8388 |
| rod bipolar cell | 0.9962 | 0.995 | 0.9999 | 0.9956 | 0.9974 | 0.9944 | 796 |
| starburst amacrine cell | 0.9812 | 0.9874 | 0.9999 | 0.9843 | 0.9936 | 0.9861 | 159 |
| macro avg | 0.9644 | 0.9778 | 0.9994 | 0.9701 | 0.9882 | 0.9754 | |
| weighted avg | 0.9844 | 0.982 | 0.9994 | 0.9828 | 0.9904 | 0.9798 | |
| Accuracy | 0.9820 | ||||||
| Balanced accuracy | 0.9778 |
[22]:
pd.set_option('display.max_rows', 100)
df_l2.compare(df_warm_start_l2, keep_equal=True, align_axis = 0, result_names=('default', 'warm start'))
[22]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | ||
|---|---|---|---|---|---|---|---|
| GABAergic amacrine cell | default | 0.9940 | 0.9881 | 0.9992 | 0.991 | 0.9936 | 0.9862 |
| warm start | 0.9965 | 0.9832 | 0.9995 | 0.9898 | 0.9913 | 0.9811 | |
| H1 horizontal cell | default | 0.9871 | 0.9907 | 0.9997 | 0.9889 | 0.9952 | 0.9896 |
| warm start | 0.9981 | 0.9944 | 1.0 | 0.9963 | 0.9972 | 0.9939 | |
| H2 horizontal cell | default | 0.9355 | 0.9255 | 0.9998 | 0.9305 | 0.9619 | 0.9184 |
| warm start | 0.9490 | 0.9894 | 0.9998 | 0.9688 | 0.9946 | 0.9881 | |
| Mueller cell | default | 0.9994 | 0.9977 | 1.0 | 0.9986 | 0.9988 | 0.9974 |
| warm start | 1.0000 | 0.9977 | 1.0 | 0.9989 | 0.9989 | 0.9975 | |
| OFF midget ganglion cell | default | 0.9146 | 0.896 | 0.9944 | 0.9052 | 0.9439 | 0.8822 |
| warm start | 0.9604 | 0.877 | 0.9976 | 0.9168 | 0.9353 | 0.8643 | |
| OFF parasol ganglion cell | default | 0.9157 | 0.962 | 0.9997 | 0.9383 | 0.9807 | 0.9581 |
| warm start | 0.9630 | 0.9873 | 0.9999 | 0.975 | 0.9936 | 0.986 | |
| OFFx cell | default | 0.9449 | 0.9836 | 0.9997 | 0.9639 | 0.9916 | 0.9817 |
| warm start | 0.9918 | 0.9918 | 1.0 | 0.9918 | 0.9959 | 0.991 | |
| ON midget ganglion cell | default | 0.9259 | 0.9003 | 0.9964 | 0.9129 | 0.9471 | 0.8884 |
| warm start | 0.9280 | 0.9405 | 0.9963 | 0.9342 | 0.968 | 0.9318 | |
| ON parasol ganglion cell | default | 0.8889 | 0.9796 | 0.9998 | 0.932 | 0.9896 | 0.9774 |
| warm start | 0.9800 | 1.0 | 1.0 | 0.9899 | 1.0 | 1.0 | |
| ON-blue cone bipolar cell | default | 0.8750 | 0.913 | 0.9999 | 0.8936 | 0.9555 | 0.905 |
| warm start | 0.8800 | 0.9565 | 0.9999 | 0.9167 | 0.978 | 0.9523 | |
| S cone cell | default | 0.9054 | 1.0 | 0.9997 | 0.9504 | 0.9999 | 0.9997 |
| warm start | 0.9437 | 1.0 | 0.9998 | 0.971 | 0.9999 | 0.9999 | |
| amacrine cell | default | 0.9731 | 0.9644 | 0.9995 | 0.9688 | 0.9818 | 0.9606 |
| warm start | 0.9586 | 0.9778 | 0.9992 | 0.9681 | 0.9884 | 0.9749 | |
| diffuse bipolar 1 cell | default | 0.9949 | 0.9874 | 0.9999 | 0.9912 | 0.9936 | 0.9861 |
| warm start | 0.9900 | 1.0 | 0.9998 | 0.995 | 0.9999 | 0.9999 | |
| diffuse bipolar 2 cell | default | 0.9944 | 0.9869 | 0.9999 | 0.9906 | 0.9934 | 0.9855 |
| warm start | 0.9926 | 0.9944 | 0.9998 | 0.9935 | 0.9971 | 0.9937 | |
| diffuse bipolar 3a cell | default | 0.9882 | 0.9767 | 0.9999 | 0.9825 | 0.9883 | 0.9744 |
| warm start | 1.0000 | 0.9884 | 1.0 | 0.9942 | 0.9942 | 0.9872 | |
| diffuse bipolar 3b cell | default | 0.9685 | 0.9964 | 0.9996 | 0.9823 | 0.998 | 0.9957 |
| warm start | 0.9858 | 1.0 | 0.9998 | 0.9929 | 0.9999 | 0.9999 | |
| diffuse bipolar 4 cell | default | 0.9871 | 0.9922 | 0.9998 | 0.9897 | 0.996 | 0.9913 |
| warm start | 0.9948 | 0.9897 | 0.9999 | 0.9922 | 0.9948 | 0.9886 | |
| diffuse bipolar 6 cell | default | 0.9474 | 0.9863 | 0.9997 | 0.9664 | 0.993 | 0.9847 |
| warm start | 0.9796 | 0.9863 | 0.9999 | 0.9829 | 0.9931 | 0.9848 | |
| flat midget bipolar cell | default | 0.9947 | 0.9929 | 0.9997 | 0.9938 | 0.9963 | 0.992 |
| warm start | 0.9965 | 0.9947 | 0.9998 | 0.9956 | 0.9973 | 0.994 | |
| giant bipolar cell | default | 0.9608 | 0.9849 | 0.9997 | 0.9727 | 0.9923 | 0.9832 |
| warm start | 0.9850 | 0.9899 | 0.9999 | 0.9875 | 0.9949 | 0.9888 | |
| glycinergic amacrine cell | default | 0.9769 | 0.9816 | 0.999 | 0.9792 | 0.9903 | 0.9789 |
| warm start | 0.9715 | 0.9893 | 0.9987 | 0.9803 | 0.994 | 0.9872 | |
| invaginating midget bipolar cell | default | 0.9940 | 0.9868 | 0.9998 | 0.9904 | 0.9933 | 0.9853 |
| warm start | 0.9976 | 0.994 | 0.9999 | 0.9958 | 0.997 | 0.9933 | |
| microglial cell | default | 0.9500 | 0.9744 | 0.9999 | 0.962 | 0.9871 | 0.9718 |
| warm start | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | |
| retinal bipolar neuron | default | 0.9927 | 0.9879 | 0.9999 | 0.9903 | 0.9939 | 0.9866 |
| warm start | 0.9903 | 0.9927 | 0.9998 | 0.9915 | 0.9963 | 0.9919 | |
| retinal cone cell | default | 1.0000 | 0.9914 | 1.0 | 0.9957 | 0.9957 | 0.9906 |
| warm start | 1.0000 | 0.9957 | 1.0 | 0.9979 | 0.9979 | 0.9953 | |
| retinal ganglion cell | default | 0.4304 | 0.5407 | 0.9929 | 0.4793 | 0.7327 | 0.5125 |
| warm start | 0.4879 | 0.7358 | 0.9923 | 0.5867 | 0.8545 | 0.7114 | |
| retinal rod cell | default | 0.9996 | 0.9999 | 0.9998 | 0.9998 | 0.9999 | 0.9997 |
| warm start | 0.9995 | 0.9998 | 0.9998 | 0.9996 | 0.9998 | 0.9995 | |
| rod bipolar cell | default | 0.9962 | 0.9962 | 0.9999 | 0.9962 | 0.9981 | 0.9957 |
| warm start | 0.9962 | 0.995 | 0.9999 | 0.9956 | 0.9974 | 0.9944 | |
| starburst amacrine cell | default | 0.9691 | 0.9874 | 0.9998 | 0.9782 | 0.9936 | 0.986 |
| warm start | 0.9812 | 0.9874 | 0.9999 | 0.9843 | 0.9936 | 0.9861 | |
| macro avg | default | 0.9485 | 0.9624 | 0.9993 | 0.955 | 0.9795 | 0.9589 |
| warm start | 0.9644 | 0.9778 | 0.9994 | 0.9701 | 0.9882 | 0.9754 | |
| weighted avg | default | 0.9791 | 0.9778 | 0.9991 | 0.9784 | 0.988 | 0.9752 |
| warm start | 0.9844 | 0.982 | 0.9994 | 0.9828 | 0.9904 | 0.9798 | |
| Accuracy | default | 0.9778 | |||||
| warm start | 0.9820 | ||||||
| Balanced accuracy | default | 0.9624 | |||||
| warm start | 0.9778 |
Iterative warm start training led to an increase in all model quality metrics (rows - macro average, weighted average, accuracy, and balanced accuracy). Additionally, the model’s sensitivity increased by 19.5% and precision by 5.5% for the retinal ganglion cell.
[21]:
import session_info
session_info.show()
[21]:
Click to view session information
----- anndata 0.10.8 numpy 1.25.0 pandas 2.2.3 scanpy 1.10.3 scparadise 0.4.0_beta session_info 1.0.0 -----
Click to view modules imported as dependencies
PIL 10.4.0 anyio NA arrow 1.3.0 asciitree NA asttokens NA attr 24.2.0 attrs 24.2.0 awkward 2.7.1 awkward_cpp NA babel 2.16.0 backports NA certifi 2024.08.30 cffi 1.17.1 charset_normalizer 3.3.2 cloudpickle 3.1.0 colorlog NA comm 0.2.2 cycler 0.12.1 cython_runtime NA dask 2024.8.0 dateutil 2.9.0.post0 debugpy 1.8.6 decorator 5.1.1 defusedxml 0.7.1 exceptiongroup 1.2.2 executing 2.1.0 fastjsonschema NA fqdn NA fsspec 2023.6.0 h5py 3.12.1 idna 3.10 igraph 0.11.6 imblearn 0.12.3 importlib_metadata NA importlib_resources NA ipykernel 6.29.5 isoduration NA jaraco NA jedi 0.19.1 jinja2 3.1.4 joblib 1.4.2 json5 0.9.25 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema_specifications NA jupyter_events 0.10.0 jupyter_server 2.14.2 jupyterlab_server 2.27.3 kiwisolver 1.4.7 legacy_api_wrap NA leidenalg 0.10.2 llvmlite 0.43.0 markupsafe 2.1.5 matplotlib 3.9.2 matplotlib_inline 0.1.7 more_itertools 10.5.0 mpl_toolkits NA mpmath 1.3.0 msgpack 1.1.0 mudata 0.2.4 muon 0.1.6 natsort 8.4.0 nbformat 5.10.4 numba 0.60.0 numcodecs 0.12.1 optuna 4.0.0 overrides NA packaging 24.1 parso 0.8.4 patsy 0.5.6 pexpect 4.9.0 platformdirs 4.3.6 plotly 5.24.1 prometheus_client NA prompt_toolkit 3.0.48 psutil 6.0.0 ptyprocess 0.7.0 pure_eval 0.2.3 pyarrow 18.1.0 pycparser 2.22 pydev_ipython NA pydevconsole NA pydevd 3.1.0 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pydot 3.0.3 pygments 2.18.0 pynndescent 0.5.13 pyparsing 3.1.4 pythonjsonlogger NA pytorch_tabnet NA pytz 2024.2 referencing NA requests 2.32.3 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rich NA rpds NA scipy 1.13.1 seaborn 0.13.2 send2trash NA setuptools 75.1.0 setuptools_scm NA shap 0.46.0 six 1.16.0 sklearn 1.5.2 slicer NA sniffio 1.3.1 stack_data 0.6.3 statsmodels 0.14.3 sympy 1.13.3 tblib 3.0.0 texttable 1.7.0 threadpoolctl 3.5.0 tlz 0.12.1 tomli 2.0.1 toolz 0.12.1 torch 2.4.1+cu121 torchgen NA tornado 6.4.1 tqdm 4.66.5 traitlets 5.14.3 triton 3.0.0 typing_extensions NA umap 0.5.6 uri_template NA urllib3 1.26.20 wcwidth 0.2.13 webcolors 24.8.0 websocket 1.8.0 yaml 6.0.2 zarr 2.18.2 zipp NA zmq 26.2.0 zoneinfo NA
----- IPython 8.18.1 jupyter_client 8.6.3 jupyter_core 5.7.2 jupyterlab 4.2.5 ----- Python 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0] Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35 ----- Session information updated at 2025-02-08 17:59