Managing large and complex training datasets#

Large and complex datasets are difficult to use for training models due to computational limitations (enormous amounts of RAM are required just to open such datasets).

Additionally, the time needed to select genes for model training increases.

Meanwhile, increasing the amount of data used for training does not always significantly improve model quality.

Here, we present a method to overcome computational limits for training models on large, complex datasets with many donors.

[1]:
# Python packages
import warnings
warnings.simplefilter('ignore')

import scanpy as sc
import scparadise
import numpy as np
import pandas as pd
import os

sc.set_figure_params(dpi = 120)
[2]:
# Create folder to save files
# single nuclear RNA-seq of human retina
os.makedirs('snRNAseq_human_retina')
[3]:
# Download CELLxGENE dataset (snRNA-seq of human retina):
# https://cellxgene.cziscience.com/collections/4c6eaf5c-6d57-4c76-b1e9-60df8c655f1e
!wget https://datasets.cellxgene.cziscience.com/2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad
--2025-02-08 13:49:05--  https://datasets.cellxgene.cziscience.com/2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad
Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 52.85.49.24, 52.85.49.28, 52.85.49.17, ...
Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|52.85.49.24|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 37946797973 (35G) [binary/octet-stream]
Saving to: ‘2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad’

2e910e62-7eaf-4c06- 100%[===================>]  35.34G  25.1MB/s    in 27m 20s

2025-02-08 14:13:58 (22.1 MB/s) - ‘2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad’ saved [37946797973/37946797973]

Train scAdam model using fraction of dataset#

The entire dataset contains 3,177,310 cells and 36406 genes (35.34 GB). It is too large to open on a standard computer. Additionally, selecting genes for training a model on such a large dataset requires significant computational power and time.

Therefore, the scParadise team recommends that you extract a small portion of the dataset for further steps.

[4]:
# Obtain 25000 cells randomly
adata_fraction = scparadise.scnoah.get_frac(path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
                                            fraction = 25000,
                                            path_save = 'snRNAseq_human_retina',
                                            celltype = 'cell_type',
                                            random_state = 0)
[5]:
# Get raw counts from adata_fraction.raw
adata_fraction = adata_fraction.raw.to_adata()
# Replace variable names with gene names
adata_fraction.var.set_index('feature_name', inplace = True)
adata_fraction.var_names_make_unique()
# Normalize data
sc.pp.normalize_total(adata_fraction, target_sum = None)
sc.pp.log1p(adata_fraction)
adata_fraction.raw = adata_fraction
[6]:
# Find genes for model training (marker genes of cell types)
lst_genes = []
annotations = ['majorclass', 'cell_type'] # annotation levels
for annotation in annotations:
    sc.tl.rank_genes_groups(adata_fraction,
                            groupby = annotation,
                            method = 't-test_overestim_var', pts = True)
    # Filter marker genes of cell types
    sc.tl.filter_rank_genes_groups(adata_fraction,
                                   min_fold_change = 1.0,
                                   min_in_group_fraction = 0.4,
                                   key_added = 'filtered_rank_genes_groups')
    # Create list of genes for model training

    for i in adata_fraction.obs[annotation].unique():
        df = sc.get.rank_genes_groups_df(adata_fraction, group = i, key = 'filtered_rank_genes_groups', pval_cutoff = 0.05)
        df['pts_comparizon'] = df['pct_nz_group']/df['pct_nz_reference']
        lst_genes.extend(df.sort_values(by = 'logfoldchanges', ascending = False).head(20)['names'].tolist())
        lst_genes.extend(df.sort_values(by = 'pts_comparizon', ascending = False).head(20)['names'].tolist())
# Remove duplicates
lst_genes = np.unique(lst_genes).tolist()
[7]:
# Subset genes for model training
adata_fraction = adata_fraction[:, lst_genes]
[8]:
# Alternative way to select genes for model training
# sc.pp.highly_variable_genes(adata_fraction,
#                             n_top_genes = 1000,
#                             subset = True)
# lst_genes = adata_fraction.var_names.tolist()
[9]:
adata_balanced = scparadise.scnoah.balance(adata_fraction,
                                           celltype_l1 = annotations[0], # majorclass
                                           celltype_l2 = annotations[1], # cell_type
                                           sample = 'donor_id')
Successfully undersampled cell types: retinal rod cell, GABAergic amacrine cell, Mueller cell, OFF midget ganglion cell, ON midget ganglion cell, flat midget bipolar cell, glycinergic amacrine cell, retinal cone cell, invaginating midget bipolar cell

Successfully oversampled cell types: rod bipolar cell, H1 horizontal cell, diffuse bipolar 2 cell, amacrine cell, retinal bipolar neuron, diffuse bipolar 1 cell, diffuse bipolar 4 cell, diffuse bipolar 3b cell, retinal ganglion cell, giant bipolar cell, diffuse bipolar 3a cell, starburst amacrine cell, diffuse bipolar 6 cell, OFFx cell, astrocyte, H2 horizontal cell, OFF parasol ganglion cell, S cone cell, ON parasol ganglion cell, microglial cell, ON-blue cone bipolar cell, retinal pigment epithelial cell
[10]:
# Train scadam model using adata_fraction dataset
scparadise.scadam.train(adata_balanced,
                        path = 'snRNAseq_human_retina', # path to save model
                        model_name = 'model_scAdam', # folder name with model
                        celltype_l1 = 'celltype_l1', # previously: majorclass
                        celltype_l2 = 'celltype_l2', # previously: cell_type
                        eval_metric = ['balanced_accuracy', 'accuracy'])
Successfully saved genes names for training model

Successfully saved dictionary of dataset annotations

Train dataset contains: 22515 cells, it is 90.0 % of input dataset
Test dataset contains: 2502 cells, it is 10.0 % of input dataset

Accelerator: cuda
Start training
epoch 0  | loss: 2.81661 | train_balanced_accuracy: 0.07509 | train_accuracy: 0.23478 | valid_balanced_accuracy: 0.07484 | valid_accuracy: 0.23441 |  0:00:01s
epoch 1  | loss: 2.15053 | train_balanced_accuracy: 0.29301 | train_accuracy: 0.46274 | valid_balanced_accuracy: 0.29203 | valid_accuracy: 0.46203 |  0:00:02s
epoch 2  | loss: 1.40403 | train_balanced_accuracy: 0.53809 | train_accuracy: 0.6624  | valid_balanced_accuracy: 0.54454 | valid_accuracy: 0.66587 |  0:00:03s
epoch 3  | loss: 0.89481 | train_balanced_accuracy: 0.75518 | train_accuracy: 0.80697 | valid_balanced_accuracy: 0.76594 | valid_accuracy: 0.81095 |  0:00:04s
epoch 4  | loss: 0.66516 | train_balanced_accuracy: 0.81854 | train_accuracy: 0.85328 | valid_balanced_accuracy: 0.8175  | valid_accuracy: 0.84972 |  0:00:06s
epoch 5  | loss: 0.57448 | train_balanced_accuracy: 0.89062 | train_accuracy: 0.90517 | valid_balanced_accuracy: 0.88043 | valid_accuracy: 0.89728 |  0:00:07s
epoch 6  | loss: 0.50396 | train_balanced_accuracy: 0.92367 | train_accuracy: 0.9328  | valid_balanced_accuracy: 0.91688 | valid_accuracy: 0.92666 |  0:00:08s
epoch 7  | loss: 0.44609 | train_balanced_accuracy: 0.93663 | train_accuracy: 0.94126 | valid_balanced_accuracy: 0.92773 | valid_accuracy: 0.93385 |  0:00:09s
epoch 8  | loss: 0.38922 | train_balanced_accuracy: 0.94813 | train_accuracy: 0.95143 | valid_balanced_accuracy: 0.93811 | valid_accuracy: 0.94305 |  0:00:10s
epoch 9  | loss: 0.38246 | train_balanced_accuracy: 0.95886 | train_accuracy: 0.96103 | valid_balanced_accuracy: 0.94843 | valid_accuracy: 0.95324 |  0:00:11s
epoch 10 | loss: 0.35179 | train_balanced_accuracy: 0.96238 | train_accuracy: 0.96513 | valid_balanced_accuracy: 0.95295 | valid_accuracy: 0.95663 |  0:00:12s
epoch 11 | loss: 0.33521 | train_balanced_accuracy: 0.96765 | train_accuracy: 0.96915 | valid_balanced_accuracy: 0.95609 | valid_accuracy: 0.96023 |  0:00:14s
epoch 12 | loss: 0.32833 | train_balanced_accuracy: 0.9687  | train_accuracy: 0.96975 | valid_balanced_accuracy: 0.9598  | valid_accuracy: 0.96343 |  0:00:15s
epoch 13 | loss: 0.30299 | train_balanced_accuracy: 0.97468 | train_accuracy: 0.97524 | valid_balanced_accuracy: 0.96538 | valid_accuracy: 0.96783 |  0:00:16s
epoch 14 | loss: 0.29462 | train_balanced_accuracy: 0.97203 | train_accuracy: 0.97264 | valid_balanced_accuracy: 0.96418 | valid_accuracy: 0.96583 |  0:00:17s
epoch 15 | loss: 0.28393 | train_balanced_accuracy: 0.97724 | train_accuracy: 0.97742 | valid_balanced_accuracy: 0.97102 | valid_accuracy: 0.97182 |  0:00:18s
epoch 16 | loss: 0.28427 | train_balanced_accuracy: 0.98081 | train_accuracy: 0.98148 | valid_balanced_accuracy: 0.97726 | valid_accuracy: 0.97862 |  0:00:19s
epoch 17 | loss: 0.27928 | train_balanced_accuracy: 0.98125 | train_accuracy: 0.98123 | valid_balanced_accuracy: 0.97695 | valid_accuracy: 0.97722 |  0:00:20s
epoch 18 | loss: 0.2769  | train_balanced_accuracy: 0.98287 | train_accuracy: 0.98314 | valid_balanced_accuracy: 0.97864 | valid_accuracy: 0.97962 |  0:00:21s
epoch 19 | loss: 0.27064 | train_balanced_accuracy: 0.98377 | train_accuracy: 0.98417 | valid_balanced_accuracy: 0.97769 | valid_accuracy: 0.97922 |  0:00:22s
epoch 20 | loss: 0.26606 | train_balanced_accuracy: 0.98612 | train_accuracy: 0.98619 | valid_balanced_accuracy: 0.98071 | valid_accuracy: 0.98201 |  0:00:24s
epoch 21 | loss: 0.25256 | train_balanced_accuracy: 0.98787 | train_accuracy: 0.98794 | valid_balanced_accuracy: 0.97919 | valid_accuracy: 0.98062 |  0:00:25s
epoch 22 | loss: 0.25767 | train_balanced_accuracy: 0.98873 | train_accuracy: 0.98903 | valid_balanced_accuracy: 0.98121 | valid_accuracy: 0.98161 |  0:00:26s
epoch 23 | loss: 0.24242 | train_balanced_accuracy: 0.98856 | train_accuracy: 0.98867 | valid_balanced_accuracy: 0.98525 | valid_accuracy: 0.98581 |  0:00:27s
epoch 24 | loss: 0.23697 | train_balanced_accuracy: 0.9894  | train_accuracy: 0.98943 | valid_balanced_accuracy: 0.98418 | valid_accuracy: 0.98501 |  0:00:28s
epoch 25 | loss: 0.23557 | train_balanced_accuracy: 0.98993 | train_accuracy: 0.98998 | valid_balanced_accuracy: 0.98402 | valid_accuracy: 0.98541 |  0:00:29s
epoch 26 | loss: 0.22787 | train_balanced_accuracy: 0.98896 | train_accuracy: 0.98901 | valid_balanced_accuracy: 0.98383 | valid_accuracy: 0.98481 |  0:00:30s
epoch 27 | loss: 0.23513 | train_balanced_accuracy: 0.99035 | train_accuracy: 0.99038 | valid_balanced_accuracy: 0.98575 | valid_accuracy: 0.98621 |  0:00:32s
epoch 28 | loss: 0.23008 | train_balanced_accuracy: 0.98933 | train_accuracy: 0.9893  | valid_balanced_accuracy: 0.98308 | valid_accuracy: 0.98341 |  0:00:33s
epoch 29 | loss: 0.23302 | train_balanced_accuracy: 0.99037 | train_accuracy: 0.9903  | valid_balanced_accuracy: 0.98438 | valid_accuracy: 0.98521 |  0:00:34s
epoch 30 | loss: 0.22559 | train_balanced_accuracy: 0.98943 | train_accuracy: 0.98958 | valid_balanced_accuracy: 0.98176 | valid_accuracy: 0.98321 |  0:00:35s
epoch 31 | loss: 0.22367 | train_balanced_accuracy: 0.9913  | train_accuracy: 0.99132 | valid_balanced_accuracy: 0.98507 | valid_accuracy: 0.98541 |  0:00:36s
epoch 32 | loss: 0.21816 | train_balanced_accuracy: 0.99195 | train_accuracy: 0.99201 | valid_balanced_accuracy: 0.98679 | valid_accuracy: 0.98721 |  0:00:37s
epoch 33 | loss: 0.21323 | train_balanced_accuracy: 0.99217 | train_accuracy: 0.99216 | valid_balanced_accuracy: 0.98691 | valid_accuracy: 0.98721 |  0:00:38s
epoch 34 | loss: 0.21449 | train_balanced_accuracy: 0.99312 | train_accuracy: 0.99318 | valid_balanced_accuracy: 0.9879  | valid_accuracy: 0.98841 |  0:00:39s
epoch 35 | loss: 0.20823 | train_balanced_accuracy: 0.9933  | train_accuracy: 0.99329 | valid_balanced_accuracy: 0.98836 | valid_accuracy: 0.98821 |  0:00:40s
epoch 36 | loss: 0.21154 | train_balanced_accuracy: 0.99371 | train_accuracy: 0.99374 | valid_balanced_accuracy: 0.98619 | valid_accuracy: 0.98681 |  0:00:42s
epoch 37 | loss: 0.21443 | train_balanced_accuracy: 0.99321 | train_accuracy: 0.9932  | valid_balanced_accuracy: 0.98757 | valid_accuracy: 0.98741 |  0:00:43s
epoch 38 | loss: 0.20554 | train_balanced_accuracy: 0.99325 | train_accuracy: 0.9934  | valid_balanced_accuracy: 0.98595 | valid_accuracy: 0.98661 |  0:00:44s
epoch 39 | loss: 0.21128 | train_balanced_accuracy: 0.99406 | train_accuracy: 0.99405 | valid_balanced_accuracy: 0.98567 | valid_accuracy: 0.98681 |  0:00:45s
epoch 40 | loss: 0.20781 | train_balanced_accuracy: 0.99394 | train_accuracy: 0.994   | valid_balanced_accuracy: 0.98512 | valid_accuracy: 0.98641 |  0:00:46s
epoch 41 | loss: 0.19599 | train_balanced_accuracy: 0.99318 | train_accuracy: 0.99323 | valid_balanced_accuracy: 0.98457 | valid_accuracy: 0.98581 |  0:00:47s
epoch 42 | loss: 0.20263 | train_balanced_accuracy: 0.9929  | train_accuracy: 0.99289 | valid_balanced_accuracy: 0.98527 | valid_accuracy: 0.98641 |  0:00:48s
epoch 43 | loss: 0.19245 | train_balanced_accuracy: 0.99362 | train_accuracy: 0.99363 | valid_balanced_accuracy: 0.98647 | valid_accuracy: 0.98681 |  0:00:49s
epoch 44 | loss: 0.19874 | train_balanced_accuracy: 0.99437 | train_accuracy: 0.99436 | valid_balanced_accuracy: 0.98734 | valid_accuracy: 0.98821 |  0:00:51s

Early stopping occurred at epoch 44 with best_epoch = 34 and best_valid_accuracy = 0.98841

Successfully saved training history and parameters
Successfully saved model at snRNAseq_human_retina/model_scAdam/model.zip

Evaluation of model quality#

For model evaluation, we use another subset of 25,000 cells generated using a different random state.

[11]:
# Get test dataset for model quality evaluation
adata_test = scparadise.scnoah.get_frac(path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
                                        fraction = 25000,
                                        path_save = 'snRNAseq_human_retina',
                                        celltype = 'cell_type',
                                        random_state = 42)
[12]:
# Check common cells between test and training datasets
lst_train = adata_fraction.obs_names.tolist()
lst_test = adata_test.obs_names.tolist()
lst_train.extend(lst_test)
lst_train = np.unique(lst_train)
percent = round((2 * len(lst_test) - len(lst_train))/len(lst_test)*100, 5)
print(f"There are {percent} % common cells ({2 * len(lst_test) - len(lst_train)} cells) between the test and training datasets")
There are 0.836 % common cells (209 cells) between the test and training datasets

Less than 1% of cells are the same between the test dataset and the training dataset. This number of similar cells can be ignored, and we can proceed with testing the model’s quality.

[13]:
# Apply the same preprocessing steps to the test dataset as used for training
# Get raw counts from adata_fraction.raw
adata_test = adata_test.raw.to_adata()

# Replace variable names with gene names
adata_test.var.set_index('feature_name', inplace = True)
adata_test.var_names_make_unique()

# Normalize data
sc.pp.normalize_total(adata_test, target_sum = None)
sc.pp.log1p(adata_test)
adata_test.raw = adata_test
[14]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
                                       path_model = 'snRNAseq_human_retina/model_scAdam')
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
[15]:
## Check model quality
df_l1 = scparadise.scnoah.report_classif_full(adata_test,
                                              celltype = 'majorclass',
                                              pred_celltype = 'pred_celltype_l1')
df_l1
[15]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
AC 0.9998 0.9964 1.0 0.9981 0.9982 0.996 4496
Astrocyte 1.0000 0.982 1.0 0.9909 0.991 0.9802 111
BC 0.9969 0.9994 0.9991 0.9982 0.9993 0.9986 5437
Cone 1.0000 0.999 1.0 0.9995 0.9995 0.9989 1000
HC 0.9984 1.0 1.0 0.9992 1.0 1.0 634
MG 0.9994 0.9983 1.0 0.9989 0.9991 0.9981 1744
Microglia 1.0000 0.9744 1.0 0.987 0.9871 0.9719 39
RGC 0.9981 0.9997 0.9997 0.9989 0.9997 0.9994 3144
RPE 1.0000 1.0 1.0 1.0 1.0 1.0 7
Rod 0.9998 0.9999 0.9999 0.9998 0.9999 0.9998 8388
macro avg 0.9992 0.9949 0.9999 0.997 0.9974 0.9943
weighted avg 0.9989 0.9989 0.9997 0.9989 0.9993 0.9985
Accuracy 0.9989
Balanced accuracy 0.9949
[16]:
## Check model quality
df_l2 = scparadise.scnoah.report_classif_full(adata_test,
                                              celltype = 'cell_type',
                                              pred_celltype = 'pred_celltype_l2')
df_l2
[16]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
GABAergic amacrine cell 0.9940 0.9881 0.9992 0.991 0.9936 0.9862 2855
H1 horizontal cell 0.9871 0.9907 0.9997 0.9889 0.9952 0.9896 540
H2 horizontal cell 0.9355 0.9255 0.9998 0.9305 0.9619 0.9184 94
Mueller cell 0.9994 0.9977 1.0 0.9986 0.9988 0.9974 1744
OFF midget ganglion cell 0.9146 0.896 0.9944 0.9052 0.9439 0.8822 1577
OFF parasol ganglion cell 0.9157 0.962 0.9997 0.9383 0.9807 0.9581 79
OFFx cell 0.9449 0.9836 0.9997 0.9639 0.9916 0.9817 122
ON midget ganglion cell 0.9259 0.9003 0.9964 0.9129 0.9471 0.8884 1193
ON parasol ganglion cell 0.8889 0.9796 0.9998 0.932 0.9896 0.9774 49
ON-blue cone bipolar cell 0.8750 0.913 0.9999 0.8936 0.9555 0.905 23
S cone cell 0.9054 1.0 0.9997 0.9504 0.9999 0.9997 67
amacrine cell 0.9731 0.9644 0.9995 0.9688 0.9818 0.9606 450
astrocyte 1.0000 0.982 1.0 0.9909 0.991 0.9802 111
diffuse bipolar 1 cell 0.9949 0.9874 0.9999 0.9912 0.9936 0.9861 397
diffuse bipolar 2 cell 0.9944 0.9869 0.9999 0.9906 0.9934 0.9855 536
diffuse bipolar 3a cell 0.9882 0.9767 0.9999 0.9825 0.9883 0.9744 172
diffuse bipolar 3b cell 0.9685 0.9964 0.9996 0.9823 0.998 0.9957 278
diffuse bipolar 4 cell 0.9871 0.9922 0.9998 0.9897 0.996 0.9913 387
diffuse bipolar 6 cell 0.9474 0.9863 0.9997 0.9664 0.993 0.9847 146
flat midget bipolar cell 0.9947 0.9929 0.9997 0.9938 0.9963 0.992 1134
giant bipolar cell 0.9608 0.9849 0.9997 0.9727 0.9923 0.9832 199
glycinergic amacrine cell 0.9769 0.9816 0.999 0.9792 0.9903 0.9789 1032
invaginating midget bipolar cell 0.9940 0.9868 0.9998 0.9904 0.9933 0.9853 835
microglial cell 0.9500 0.9744 0.9999 0.962 0.9871 0.9718 39
retinal bipolar neuron 0.9927 0.9879 0.9999 0.9903 0.9939 0.9866 412
retinal cone cell 1.0000 0.9914 1.0 0.9957 0.9957 0.9906 933
retinal ganglion cell 0.4304 0.5407 0.9929 0.4793 0.7327 0.5125 246
retinal pigment epithelial cell 1.0000 1.0 1.0 1.0 1.0 1.0 7
retinal rod cell 0.9996 0.9999 0.9998 0.9998 0.9999 0.9997 8388
rod bipolar cell 0.9962 0.9962 0.9999 0.9962 0.9981 0.9957 796
starburst amacrine cell 0.9691 0.9874 0.9998 0.9782 0.9936 0.986 159
macro avg 0.9485 0.9624 0.9993 0.955 0.9795 0.9589
weighted avg 0.9791 0.9778 0.9991 0.9784 0.988 0.9752
Accuracy 0.9778
Balanced accuracy 0.9624

The model performs well except for the retinal ganglion cell.

You could try using a different random state to generate another test dataset.

Iterative warm start training (optional)#

You may use another subset of the whole dataset to increase model generalization, but this may lead to overfitting.

[17]:
# Do not change the lower bound of the range to exclude 0, which was used for the primary training of the model
for i in range(1, 4):
    # Obtain 25000 cells randomly
    adata_fraction = scparadise.scnoah.get_frac(path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
                                                fraction = 25000,
                                                path_save = 'snRNAseq_human_retina',
                                                celltype = 'cell_type',
                                                random_state = i)
    # Get raw counts from adata_fraction.raw
    adata_fraction = adata_fraction.raw.to_adata()

    # Replace variable names with gene names
    adata_fraction.var.set_index('feature_name', inplace=True)
    adata_fraction.var_names_make_unique()

    # Normalize data
    sc.pp.normalize_total(adata_fraction, target_sum = None)
    sc.pp.log1p(adata_fraction)
    adata_fraction.raw = adata_fraction

    # Subset genes for model training
    adata_fraction = adata_fraction[:, lst_genes]
    # Balance dataset
    adata_balanced = scparadise.scnoah.balance(adata_fraction,
                                               celltype_l1 = annotations[0], # majorclass
                                               celltype_l2 = annotations[1], # cell_type
                                               sample = 'donor_id')
    adata_balanced.raw = adata_balanced
    # Warm start requires second training dataset and path to pretrained model
    scparadise.scadam.warm_start(adata_balanced,
                                 path = 'snRNAseq_human_retina', # path to save model
                                 model_name = 'model_scAdam', # folder name with pretrained model
                                 celltype_l1 = 'celltype_l1', # previously: majorclass
                                 celltype_l2 = 'celltype_l2', # previously: cell_type
                                 eval_metric = ['balanced_accuracy', 'accuracy'])
Successfully undersampled cell types: retinal rod cell, GABAergic amacrine cell, Mueller cell, OFF midget ganglion cell, ON midget ganglion cell, flat midget bipolar cell, glycinergic amacrine cell, retinal cone cell, invaginating midget bipolar cell

Successfully oversampled cell types: rod bipolar cell, H1 horizontal cell, diffuse bipolar 2 cell, amacrine cell, retinal bipolar neuron, diffuse bipolar 1 cell, diffuse bipolar 4 cell, diffuse bipolar 3b cell, retinal ganglion cell, giant bipolar cell, diffuse bipolar 3a cell, starburst amacrine cell, diffuse bipolar 6 cell, OFFx cell, astrocyte, H2 horizontal cell, OFF parasol ganglion cell, S cone cell, ON parasol ganglion cell, microglial cell, ON-blue cone bipolar cell, retinal pigment epithelial cell
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Train dataset contains: 22515 cells, it is 90.0 % of input dataset
Test dataset contains: 2502 cells, it is 10.0 % of input dataset

Successfully loaded parameters

Accelerator: cuda
Start training
epoch 0  | loss: 0.27731 | train_balanced_accuracy: 0.98601 | train_accuracy: 0.98616 | valid_balanced_accuracy: 0.98477 | valid_accuracy: 0.98481 |  0:00:01s
epoch 1  | loss: 0.25079 | train_balanced_accuracy: 0.98945 | train_accuracy: 0.98938 | valid_balanced_accuracy: 0.98653 | valid_accuracy: 0.98661 |  0:00:02s
epoch 2  | loss: 0.23321 | train_balanced_accuracy: 0.99028 | train_accuracy: 0.99036 | valid_balanced_accuracy: 0.98633 | valid_accuracy: 0.98641 |  0:00:03s
epoch 3  | loss: 0.22734 | train_balanced_accuracy: 0.99085 | train_accuracy: 0.99107 | valid_balanced_accuracy: 0.98531 | valid_accuracy: 0.98581 |  0:00:05s
epoch 4  | loss: 0.21247 | train_balanced_accuracy: 0.99147 | train_accuracy: 0.99143 | valid_balanced_accuracy: 0.98704 | valid_accuracy: 0.98701 |  0:00:06s
epoch 5  | loss: 0.22125 | train_balanced_accuracy: 0.99225 | train_accuracy: 0.99225 | valid_balanced_accuracy: 0.98551 | valid_accuracy: 0.98601 |  0:00:07s
epoch 6  | loss: 0.22236 | train_balanced_accuracy: 0.9914  | train_accuracy: 0.99138 | valid_balanced_accuracy: 0.98537 | valid_accuracy: 0.98541 |  0:00:08s
epoch 7  | loss: 0.22278 | train_balanced_accuracy: 0.99242 | train_accuracy: 0.99254 | valid_balanced_accuracy: 0.98903 | valid_accuracy: 0.98901 |  0:00:09s
epoch 8  | loss: 0.20157 | train_balanced_accuracy: 0.99396 | train_accuracy: 0.994   | valid_balanced_accuracy: 0.98829 | valid_accuracy: 0.98821 |  0:00:11s
epoch 9  | loss: 0.21051 | train_balanced_accuracy: 0.99327 | train_accuracy: 0.99327 | valid_balanced_accuracy: 0.98776 | valid_accuracy: 0.98781 |  0:00:12s
epoch 10 | loss: 0.2045  | train_balanced_accuracy: 0.99525 | train_accuracy: 0.99531 | valid_balanced_accuracy: 0.98876 | valid_accuracy: 0.98881 |  0:00:14s
epoch 11 | loss: 0.20367 | train_balanced_accuracy: 0.99413 | train_accuracy: 0.99412 | valid_balanced_accuracy: 0.98828 | valid_accuracy: 0.98821 |  0:00:15s
epoch 12 | loss: 0.20131 | train_balanced_accuracy: 0.99447 | train_accuracy: 0.99465 | valid_balanced_accuracy: 0.98744 | valid_accuracy: 0.98741 |  0:00:16s
epoch 13 | loss: 0.19265 | train_balanced_accuracy: 0.99464 | train_accuracy: 0.99465 | valid_balanced_accuracy: 0.98832 | valid_accuracy: 0.98841 |  0:00:17s
epoch 14 | loss: 0.18869 | train_balanced_accuracy: 0.99439 | train_accuracy: 0.9944  | valid_balanced_accuracy: 0.98776 | valid_accuracy: 0.98761 |  0:00:19s
epoch 15 | loss: 0.19078 | train_balanced_accuracy: 0.99424 | train_accuracy: 0.99425 | valid_balanced_accuracy: 0.98777 | valid_accuracy: 0.98781 |  0:00:20s
epoch 16 | loss: 0.19392 | train_balanced_accuracy: 0.99505 | train_accuracy: 0.99505 | valid_balanced_accuracy: 0.98756 | valid_accuracy: 0.98761 |  0:00:21s
epoch 17 | loss: 0.19063 | train_balanced_accuracy: 0.99464 | train_accuracy: 0.99474 | valid_balanced_accuracy: 0.98805 | valid_accuracy: 0.98821 |  0:00:22s

Early stopping occurred at epoch 17 with best_epoch = 7 and best_valid_accuracy = 0.98901

Successfully saved training history and parameters
Successfully saved model at snRNAseq_human_retina/model_scAdam/model.zip
Successfully undersampled cell types: retinal rod cell, GABAergic amacrine cell, Mueller cell, OFF midget ganglion cell, ON midget ganglion cell, flat midget bipolar cell, glycinergic amacrine cell, retinal cone cell, invaginating midget bipolar cell

Successfully oversampled cell types: rod bipolar cell, H1 horizontal cell, diffuse bipolar 2 cell, amacrine cell, retinal bipolar neuron, diffuse bipolar 1 cell, diffuse bipolar 4 cell, diffuse bipolar 3b cell, retinal ganglion cell, giant bipolar cell, diffuse bipolar 3a cell, starburst amacrine cell, diffuse bipolar 6 cell, OFFx cell, astrocyte, H2 horizontal cell, OFF parasol ganglion cell, S cone cell, ON parasol ganglion cell, microglial cell, ON-blue cone bipolar cell, retinal pigment epithelial cell
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Train dataset contains: 22515 cells, it is 90.0 % of input dataset
Test dataset contains: 2502 cells, it is 10.0 % of input dataset

Successfully loaded parameters

Accelerator: cuda
Start training
epoch 0  | loss: 0.25691 | train_balanced_accuracy: 0.98874 | train_accuracy: 0.98885 | valid_balanced_accuracy: 0.98564 | valid_accuracy: 0.98581 |  0:00:01s
epoch 1  | loss: 0.22516 | train_balanced_accuracy: 0.99112 | train_accuracy: 0.99116 | valid_balanced_accuracy: 0.98905 | valid_accuracy: 0.98901 |  0:00:02s
epoch 2  | loss: 0.21182 | train_balanced_accuracy: 0.99143 | train_accuracy: 0.99145 | valid_balanced_accuracy: 0.98918 | valid_accuracy: 0.98921 |  0:00:03s
epoch 3  | loss: 0.21638 | train_balanced_accuracy: 0.99258 | train_accuracy: 0.99272 | valid_balanced_accuracy: 0.99106 | valid_accuracy: 0.99121 |  0:00:04s
epoch 4  | loss: 0.20085 | train_balanced_accuracy: 0.99276 | train_accuracy: 0.99285 | valid_balanced_accuracy: 0.98915 | valid_accuracy: 0.98941 |  0:00:06s
epoch 5  | loss: 0.21065 | train_balanced_accuracy: 0.99298 | train_accuracy: 0.99309 | valid_balanced_accuracy: 0.98885 | valid_accuracy: 0.98901 |  0:00:07s
epoch 6  | loss: 0.20995 | train_balanced_accuracy: 0.99409 | train_accuracy: 0.99412 | valid_balanced_accuracy: 0.99084 | valid_accuracy: 0.99081 |  0:00:08s
epoch 7  | loss: 0.20724 | train_balanced_accuracy: 0.99507 | train_accuracy: 0.99507 | valid_balanced_accuracy: 0.98924 | valid_accuracy: 0.98981 |  0:00:09s
epoch 8  | loss: 0.19699 | train_balanced_accuracy: 0.99554 | train_accuracy: 0.99556 | valid_balanced_accuracy: 0.99105 | valid_accuracy: 0.99121 |  0:00:10s
epoch 9  | loss: 0.19933 | train_balanced_accuracy: 0.99525 | train_accuracy: 0.99531 | valid_balanced_accuracy: 0.99116 | valid_accuracy: 0.99121 |  0:00:12s
epoch 10 | loss: 0.18906 | train_balanced_accuracy: 0.99584 | train_accuracy: 0.99587 | valid_balanced_accuracy: 0.99185 | valid_accuracy: 0.99201 |  0:00:13s
epoch 11 | loss: 0.18892 | train_balanced_accuracy: 0.9958  | train_accuracy: 0.99583 | valid_balanced_accuracy: 0.99345 | valid_accuracy: 0.99361 |  0:00:14s
epoch 12 | loss: 0.19035 | train_balanced_accuracy: 0.99567 | train_accuracy: 0.99567 | valid_balanced_accuracy: 0.99225 | valid_accuracy: 0.99261 |  0:00:15s
epoch 13 | loss: 0.18244 | train_balanced_accuracy: 0.996   | train_accuracy: 0.99607 | valid_balanced_accuracy: 0.99154 | valid_accuracy: 0.99201 |  0:00:16s
epoch 14 | loss: 0.18461 | train_balanced_accuracy: 0.99528 | train_accuracy: 0.99531 | valid_balanced_accuracy: 0.98973 | valid_accuracy: 0.99041 |  0:00:18s
epoch 15 | loss: 0.18248 | train_balanced_accuracy: 0.9962  | train_accuracy: 0.99622 | valid_balanced_accuracy: 0.99272 | valid_accuracy: 0.99341 |  0:00:19s
epoch 16 | loss: 0.17748 | train_balanced_accuracy: 0.99662 | train_accuracy: 0.9966  | valid_balanced_accuracy: 0.99143 | valid_accuracy: 0.99201 |  0:00:20s
epoch 17 | loss: 0.18126 | train_balanced_accuracy: 0.99735 | train_accuracy: 0.99736 | valid_balanced_accuracy: 0.99074 | valid_accuracy: 0.99121 |  0:00:21s
epoch 18 | loss: 0.18514 | train_balanced_accuracy: 0.99689 | train_accuracy: 0.99691 | valid_balanced_accuracy: 0.99216 | valid_accuracy: 0.99221 |  0:00:22s
epoch 19 | loss: 0.18461 | train_balanced_accuracy: 0.99648 | train_accuracy: 0.99649 | valid_balanced_accuracy: 0.99345 | valid_accuracy: 0.9938  |  0:00:24s
epoch 20 | loss: 0.17847 | train_balanced_accuracy: 0.99637 | train_accuracy: 0.99638 | valid_balanced_accuracy: 0.9924  | valid_accuracy: 0.99241 |  0:00:25s
epoch 21 | loss: 0.17701 | train_balanced_accuracy: 0.99728 | train_accuracy: 0.99729 | valid_balanced_accuracy: 0.99336 | valid_accuracy: 0.99341 |  0:00:26s
epoch 22 | loss: 0.17557 | train_balanced_accuracy: 0.99737 | train_accuracy: 0.99738 | valid_balanced_accuracy: 0.99236 | valid_accuracy: 0.99241 |  0:00:27s
epoch 23 | loss: 0.17332 | train_balanced_accuracy: 0.99723 | train_accuracy: 0.99727 | valid_balanced_accuracy: 0.99256 | valid_accuracy: 0.99261 |  0:00:28s
epoch 24 | loss: 0.18068 | train_balanced_accuracy: 0.9969  | train_accuracy: 0.99696 | valid_balanced_accuracy: 0.99246 | valid_accuracy: 0.99281 |  0:00:29s
epoch 25 | loss: 0.17254 | train_balanced_accuracy: 0.9975  | train_accuracy: 0.99756 | valid_balanced_accuracy: 0.99234 | valid_accuracy: 0.99281 |  0:00:31s
epoch 26 | loss: 0.16806 | train_balanced_accuracy: 0.99776 | train_accuracy: 0.99776 | valid_balanced_accuracy: 0.99225 | valid_accuracy: 0.99241 |  0:00:32s
epoch 27 | loss: 0.17744 | train_balanced_accuracy: 0.99814 | train_accuracy: 0.99813 | valid_balanced_accuracy: 0.99426 | valid_accuracy: 0.9942  |  0:00:33s
epoch 28 | loss: 0.17337 | train_balanced_accuracy: 0.99796 | train_accuracy: 0.99802 | valid_balanced_accuracy: 0.99416 | valid_accuracy: 0.9942  |  0:00:34s
epoch 29 | loss: 0.17079 | train_balanced_accuracy: 0.99802 | train_accuracy: 0.99809 | valid_balanced_accuracy: 0.99396 | valid_accuracy: 0.994   |  0:00:35s
epoch 30 | loss: 0.17553 | train_balanced_accuracy: 0.99818 | train_accuracy: 0.99818 | valid_balanced_accuracy: 0.99407 | valid_accuracy: 0.994   |  0:00:37s
epoch 31 | loss: 0.16556 | train_balanced_accuracy: 0.99836 | train_accuracy: 0.99833 | valid_balanced_accuracy: 0.99316 | valid_accuracy: 0.99321 |  0:00:38s
epoch 32 | loss: 0.16672 | train_balanced_accuracy: 0.99814 | train_accuracy: 0.99813 | valid_balanced_accuracy: 0.99247 | valid_accuracy: 0.99241 |  0:00:39s
epoch 33 | loss: 0.15766 | train_balanced_accuracy: 0.99798 | train_accuracy: 0.99798 | valid_balanced_accuracy: 0.99436 | valid_accuracy: 0.9944  |  0:00:40s
epoch 34 | loss: 0.16557 | train_balanced_accuracy: 0.99752 | train_accuracy: 0.99751 | valid_balanced_accuracy: 0.99455 | valid_accuracy: 0.9946  |  0:00:41s
epoch 35 | loss: 0.16067 | train_balanced_accuracy: 0.99746 | train_accuracy: 0.99742 | valid_balanced_accuracy: 0.99307 | valid_accuracy: 0.99301 |  0:00:42s
epoch 36 | loss: 0.16737 | train_balanced_accuracy: 0.99772 | train_accuracy: 0.99776 | valid_balanced_accuracy: 0.99256 | valid_accuracy: 0.99261 |  0:00:44s
epoch 37 | loss: 0.16903 | train_balanced_accuracy: 0.99801 | train_accuracy: 0.99811 | valid_balanced_accuracy: 0.99305 | valid_accuracy: 0.99321 |  0:00:45s
epoch 38 | loss: 0.166   | train_balanced_accuracy: 0.99818 | train_accuracy: 0.99818 | valid_balanced_accuracy: 0.99307 | valid_accuracy: 0.99301 |  0:00:46s
epoch 39 | loss: 0.16403 | train_balanced_accuracy: 0.99852 | train_accuracy: 0.99851 | valid_balanced_accuracy: 0.99426 | valid_accuracy: 0.9942  |  0:00:47s
epoch 40 | loss: 0.16122 | train_balanced_accuracy: 0.9984  | train_accuracy: 0.9984  | valid_balanced_accuracy: 0.99446 | valid_accuracy: 0.9944  |  0:00:48s
epoch 41 | loss: 0.16069 | train_balanced_accuracy: 0.99677 | train_accuracy: 0.99676 | valid_balanced_accuracy: 0.99356 | valid_accuracy: 0.99361 |  0:00:49s
epoch 42 | loss: 0.1639  | train_balanced_accuracy: 0.99764 | train_accuracy: 0.99765 | valid_balanced_accuracy: 0.99216 | valid_accuracy: 0.99221 |  0:00:51s
epoch 43 | loss: 0.15643 | train_balanced_accuracy: 0.99848 | train_accuracy: 0.99847 | valid_balanced_accuracy: 0.99407 | valid_accuracy: 0.994   |  0:00:52s
epoch 44 | loss: 0.16107 | train_balanced_accuracy: 0.99853 | train_accuracy: 0.99853 | valid_balanced_accuracy: 0.99427 | valid_accuracy: 0.9942  |  0:00:53s

Early stopping occurred at epoch 44 with best_epoch = 34 and best_valid_accuracy = 0.9946

Successfully saved training history and parameters
Successfully saved model at snRNAseq_human_retina/model_scAdam/model.zip
Successfully undersampled cell types: retinal rod cell, GABAergic amacrine cell, Mueller cell, OFF midget ganglion cell, ON midget ganglion cell, flat midget bipolar cell, glycinergic amacrine cell, retinal cone cell, invaginating midget bipolar cell

Successfully oversampled cell types: rod bipolar cell, H1 horizontal cell, diffuse bipolar 2 cell, amacrine cell, retinal bipolar neuron, diffuse bipolar 1 cell, diffuse bipolar 4 cell, diffuse bipolar 3b cell, retinal ganglion cell, giant bipolar cell, diffuse bipolar 3a cell, starburst amacrine cell, diffuse bipolar 6 cell, OFFx cell, astrocyte, H2 horizontal cell, OFF parasol ganglion cell, S cone cell, ON parasol ganglion cell, microglial cell, ON-blue cone bipolar cell, retinal pigment epithelial cell
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Train dataset contains: 22515 cells, it is 90.0 % of input dataset
Test dataset contains: 2502 cells, it is 10.0 % of input dataset

Successfully loaded parameters

Accelerator: cuda
Start training
epoch 0  | loss: 0.23292 | train_balanced_accuracy: 0.98984 | train_accuracy: 0.99007 | valid_balanced_accuracy: 0.98956 | valid_accuracy: 0.98981 |  0:00:01s
epoch 1  | loss: 0.20013 | train_balanced_accuracy: 0.99118 | train_accuracy: 0.99114 | valid_balanced_accuracy: 0.98888 | valid_accuracy: 0.98881 |  0:00:02s
epoch 2  | loss: 0.19169 | train_balanced_accuracy: 0.9925  | train_accuracy: 0.99258 | valid_balanced_accuracy: 0.99034 | valid_accuracy: 0.99101 |  0:00:03s
epoch 3  | loss: 0.19631 | train_balanced_accuracy: 0.99199 | train_accuracy: 0.99212 | valid_balanced_accuracy: 0.98835 | valid_accuracy: 0.98901 |  0:00:04s
epoch 4  | loss: 0.18296 | train_balanced_accuracy: 0.99278 | train_accuracy: 0.99274 | valid_balanced_accuracy: 0.99016 | valid_accuracy: 0.99041 |  0:00:06s
epoch 5  | loss: 0.18494 | train_balanced_accuracy: 0.99338 | train_accuracy: 0.99336 | valid_balanced_accuracy: 0.99138 | valid_accuracy: 0.99121 |  0:00:07s
epoch 6  | loss: 0.19315 | train_balanced_accuracy: 0.99372 | train_accuracy: 0.99369 | valid_balanced_accuracy: 0.99142 | valid_accuracy: 0.99121 |  0:00:08s
epoch 7  | loss: 0.18335 | train_balanced_accuracy: 0.99505 | train_accuracy: 0.99505 | valid_balanced_accuracy: 0.99206 | valid_accuracy: 0.99281 |  0:00:09s
epoch 8  | loss: 0.17319 | train_balanced_accuracy: 0.99473 | train_accuracy: 0.99469 | valid_balanced_accuracy: 0.99158 | valid_accuracy: 0.99141 |  0:00:11s
epoch 9  | loss: 0.18107 | train_balanced_accuracy: 0.99536 | train_accuracy: 0.99534 | valid_balanced_accuracy: 0.99324 | valid_accuracy: 0.99361 |  0:00:12s
epoch 10 | loss: 0.17673 | train_balanced_accuracy: 0.99518 | train_accuracy: 0.9952  | valid_balanced_accuracy: 0.99076 | valid_accuracy: 0.99101 |  0:00:13s
epoch 11 | loss: 0.17895 | train_balanced_accuracy: 0.99539 | train_accuracy: 0.99536 | valid_balanced_accuracy: 0.99329 | valid_accuracy: 0.99361 |  0:00:14s
epoch 12 | loss: 0.17441 | train_balanced_accuracy: 0.99641 | train_accuracy: 0.99638 | valid_balanced_accuracy: 0.99197 | valid_accuracy: 0.99181 |  0:00:16s
epoch 13 | loss: 0.16971 | train_balanced_accuracy: 0.99671 | train_accuracy: 0.99669 | valid_balanced_accuracy: 0.9918  | valid_accuracy: 0.99221 |  0:00:17s
epoch 14 | loss: 0.16967 | train_balanced_accuracy: 0.99609 | train_accuracy: 0.99609 | valid_balanced_accuracy: 0.99242 | valid_accuracy: 0.99241 |  0:00:18s
epoch 15 | loss: 0.17293 | train_balanced_accuracy: 0.99586 | train_accuracy: 0.99585 | valid_balanced_accuracy: 0.9936  | valid_accuracy: 0.994   |  0:00:19s
epoch 16 | loss: 0.16655 | train_balanced_accuracy: 0.99631 | train_accuracy: 0.99629 | valid_balanced_accuracy: 0.99203 | valid_accuracy: 0.99201 |  0:00:20s
epoch 17 | loss: 0.17146 | train_balanced_accuracy: 0.99697 | train_accuracy: 0.99694 | valid_balanced_accuracy: 0.99322 | valid_accuracy: 0.99321 |  0:00:22s
epoch 18 | loss: 0.16555 | train_balanced_accuracy: 0.99702 | train_accuracy: 0.997   | valid_balanced_accuracy: 0.99322 | valid_accuracy: 0.99321 |  0:00:23s
epoch 19 | loss: 0.17055 | train_balanced_accuracy: 0.99705 | train_accuracy: 0.99702 | valid_balanced_accuracy: 0.99422 | valid_accuracy: 0.9942  |  0:00:24s
epoch 20 | loss: 0.17196 | train_balanced_accuracy: 0.99715 | train_accuracy: 0.99714 | valid_balanced_accuracy: 0.99342 | valid_accuracy: 0.99341 |  0:00:25s
epoch 21 | loss: 0.17013 | train_balanced_accuracy: 0.99717 | train_accuracy: 0.99714 | valid_balanced_accuracy: 0.99221 | valid_accuracy: 0.99261 |  0:00:26s
epoch 22 | loss: 0.16264 | train_balanced_accuracy: 0.99747 | train_accuracy: 0.99745 | valid_balanced_accuracy: 0.99161 | valid_accuracy: 0.99201 |  0:00:28s
epoch 23 | loss: 0.16108 | train_balanced_accuracy: 0.99721 | train_accuracy: 0.99718 | valid_balanced_accuracy: 0.99337 | valid_accuracy: 0.99321 |  0:00:29s
epoch 24 | loss: 0.17589 | train_balanced_accuracy: 0.99789 | train_accuracy: 0.99787 | valid_balanced_accuracy: 0.99201 | valid_accuracy: 0.99241 |  0:00:30s
epoch 25 | loss: 0.16501 | train_balanced_accuracy: 0.99748 | train_accuracy: 0.99747 | valid_balanced_accuracy: 0.99196 | valid_accuracy: 0.99221 |  0:00:31s
epoch 26 | loss: 0.15835 | train_balanced_accuracy: 0.99729 | train_accuracy: 0.99725 | valid_balanced_accuracy: 0.99171 | valid_accuracy: 0.99181 |  0:00:33s
epoch 27 | loss: 0.16955 | train_balanced_accuracy: 0.99746 | train_accuracy: 0.99745 | valid_balanced_accuracy: 0.99203 | valid_accuracy: 0.99201 |  0:00:34s
epoch 28 | loss: 0.16966 | train_balanced_accuracy: 0.99749 | train_accuracy: 0.99747 | valid_balanced_accuracy: 0.99188 | valid_accuracy: 0.99221 |  0:00:35s
epoch 29 | loss: 0.16197 | train_balanced_accuracy: 0.99788 | train_accuracy: 0.99787 | valid_balanced_accuracy: 0.99305 | valid_accuracy: 0.99341 |  0:00:36s

Early stopping occurred at epoch 29 with best_epoch = 19 and best_valid_accuracy = 0.9942

Successfully saved training history and parameters
Successfully saved model at snRNAseq_human_retina/model_scAdam/model.zip
[18]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(adata_test,
                                       path_model = 'snRNAseq_human_retina/model_scAdam')
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
[19]:
## Check model quality
df_warm_start_l1 = scparadise.scnoah.report_classif_full(adata_test,
                                                         celltype='majorclass',
                                                         pred_celltype='pred_celltype_l1')
df_warm_start_l1
[19]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
AC 0.9993 0.9949 0.9999 0.9971 0.9974 0.9942 4496
Astrocyte 0.9909 0.982 1.0 0.9864 0.9909 0.9802 111
BC 0.9971 0.9996 0.9992 0.9983 0.9994 0.9989 5437
Cone 1.0000 1.0 1.0 1.0 1.0 1.0 1000
HC 0.9953 1.0 0.9999 0.9976 0.9999 0.9999 634
MG 1.0000 0.9989 1.0 0.9994 0.9994 0.9987 1744
Microglia 0.9500 0.9744 0.9999 0.962 0.9871 0.9718 39
RGC 0.9978 0.9997 0.9997 0.9987 0.9997 0.9994 3144
RPE 1.0000 1.0 1.0 1.0 1.0 1.0 7
Rod 0.9996 0.9995 0.9998 0.9996 0.9997 0.9993 8388
macro avg 0.9930 0.9949 0.9998 0.9939 0.9973 0.9942
weighted avg 0.9986 0.9986 0.9997 0.9986 0.9991 0.9982
Accuracy 0.9986
Balanced accuracy 0.9949
[20]:
## Check model quality
df_warm_start_l2 = scparadise.scnoah.report_classif_full(adata_test,
                                                         celltype='cell_type',
                                                         pred_celltype='pred_celltype_l2')
df_warm_start_l2
[20]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
GABAergic amacrine cell 0.9965 0.9832 0.9995 0.9898 0.9913 0.9811 2855
H1 horizontal cell 0.9981 0.9944 1.0 0.9963 0.9972 0.9939 540
H2 horizontal cell 0.9490 0.9894 0.9998 0.9688 0.9946 0.9881 94
Mueller cell 1.0000 0.9977 1.0 0.9989 0.9989 0.9975 1744
OFF midget ganglion cell 0.9604 0.877 0.9976 0.9168 0.9353 0.8643 1577
OFF parasol ganglion cell 0.9630 0.9873 0.9999 0.975 0.9936 0.986 79
OFFx cell 0.9918 0.9918 1.0 0.9918 0.9959 0.991 122
ON midget ganglion cell 0.9280 0.9405 0.9963 0.9342 0.968 0.9318 1193
ON parasol ganglion cell 0.9800 1.0 1.0 0.9899 1.0 1.0 49
ON-blue cone bipolar cell 0.8800 0.9565 0.9999 0.9167 0.978 0.9523 23
S cone cell 0.9437 1.0 0.9998 0.971 0.9999 0.9999 67
amacrine cell 0.9586 0.9778 0.9992 0.9681 0.9884 0.9749 450
astrocyte 1.0000 0.982 1.0 0.9909 0.991 0.9802 111
diffuse bipolar 1 cell 0.9900 1.0 0.9998 0.995 0.9999 0.9999 397
diffuse bipolar 2 cell 0.9926 0.9944 0.9998 0.9935 0.9971 0.9937 536
diffuse bipolar 3a cell 1.0000 0.9884 1.0 0.9942 0.9942 0.9872 172
diffuse bipolar 3b cell 0.9858 1.0 0.9998 0.9929 0.9999 0.9999 278
diffuse bipolar 4 cell 0.9948 0.9897 0.9999 0.9922 0.9948 0.9886 387
diffuse bipolar 6 cell 0.9796 0.9863 0.9999 0.9829 0.9931 0.9848 146
flat midget bipolar cell 0.9965 0.9947 0.9998 0.9956 0.9973 0.994 1134
giant bipolar cell 0.9850 0.9899 0.9999 0.9875 0.9949 0.9888 199
glycinergic amacrine cell 0.9715 0.9893 0.9987 0.9803 0.994 0.9872 1032
invaginating midget bipolar cell 0.9976 0.994 0.9999 0.9958 0.997 0.9933 835
microglial cell 1.0000 1.0 1.0 1.0 1.0 1.0 39
retinal bipolar neuron 0.9903 0.9927 0.9998 0.9915 0.9963 0.9919 412
retinal cone cell 1.0000 0.9957 1.0 0.9979 0.9979 0.9953 933
retinal ganglion cell 0.4879 0.7358 0.9923 0.5867 0.8545 0.7114 246
retinal pigment epithelial cell 1.0000 1.0 1.0 1.0 1.0 1.0 7
retinal rod cell 0.9995 0.9998 0.9998 0.9996 0.9998 0.9995 8388
rod bipolar cell 0.9962 0.995 0.9999 0.9956 0.9974 0.9944 796
starburst amacrine cell 0.9812 0.9874 0.9999 0.9843 0.9936 0.9861 159
macro avg 0.9644 0.9778 0.9994 0.9701 0.9882 0.9754
weighted avg 0.9844 0.982 0.9994 0.9828 0.9904 0.9798
Accuracy 0.9820
Balanced accuracy 0.9778
[22]:
pd.set_option('display.max_rows', 100)
df_l2.compare(df_warm_start_l2, keep_equal=True, align_axis = 0, result_names=('default', 'warm start'))
[22]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy
GABAergic amacrine cell default 0.9940 0.9881 0.9992 0.991 0.9936 0.9862
warm start 0.9965 0.9832 0.9995 0.9898 0.9913 0.9811
H1 horizontal cell default 0.9871 0.9907 0.9997 0.9889 0.9952 0.9896
warm start 0.9981 0.9944 1.0 0.9963 0.9972 0.9939
H2 horizontal cell default 0.9355 0.9255 0.9998 0.9305 0.9619 0.9184
warm start 0.9490 0.9894 0.9998 0.9688 0.9946 0.9881
Mueller cell default 0.9994 0.9977 1.0 0.9986 0.9988 0.9974
warm start 1.0000 0.9977 1.0 0.9989 0.9989 0.9975
OFF midget ganglion cell default 0.9146 0.896 0.9944 0.9052 0.9439 0.8822
warm start 0.9604 0.877 0.9976 0.9168 0.9353 0.8643
OFF parasol ganglion cell default 0.9157 0.962 0.9997 0.9383 0.9807 0.9581
warm start 0.9630 0.9873 0.9999 0.975 0.9936 0.986
OFFx cell default 0.9449 0.9836 0.9997 0.9639 0.9916 0.9817
warm start 0.9918 0.9918 1.0 0.9918 0.9959 0.991
ON midget ganglion cell default 0.9259 0.9003 0.9964 0.9129 0.9471 0.8884
warm start 0.9280 0.9405 0.9963 0.9342 0.968 0.9318
ON parasol ganglion cell default 0.8889 0.9796 0.9998 0.932 0.9896 0.9774
warm start 0.9800 1.0 1.0 0.9899 1.0 1.0
ON-blue cone bipolar cell default 0.8750 0.913 0.9999 0.8936 0.9555 0.905
warm start 0.8800 0.9565 0.9999 0.9167 0.978 0.9523
S cone cell default 0.9054 1.0 0.9997 0.9504 0.9999 0.9997
warm start 0.9437 1.0 0.9998 0.971 0.9999 0.9999
amacrine cell default 0.9731 0.9644 0.9995 0.9688 0.9818 0.9606
warm start 0.9586 0.9778 0.9992 0.9681 0.9884 0.9749
diffuse bipolar 1 cell default 0.9949 0.9874 0.9999 0.9912 0.9936 0.9861
warm start 0.9900 1.0 0.9998 0.995 0.9999 0.9999
diffuse bipolar 2 cell default 0.9944 0.9869 0.9999 0.9906 0.9934 0.9855
warm start 0.9926 0.9944 0.9998 0.9935 0.9971 0.9937
diffuse bipolar 3a cell default 0.9882 0.9767 0.9999 0.9825 0.9883 0.9744
warm start 1.0000 0.9884 1.0 0.9942 0.9942 0.9872
diffuse bipolar 3b cell default 0.9685 0.9964 0.9996 0.9823 0.998 0.9957
warm start 0.9858 1.0 0.9998 0.9929 0.9999 0.9999
diffuse bipolar 4 cell default 0.9871 0.9922 0.9998 0.9897 0.996 0.9913
warm start 0.9948 0.9897 0.9999 0.9922 0.9948 0.9886
diffuse bipolar 6 cell default 0.9474 0.9863 0.9997 0.9664 0.993 0.9847
warm start 0.9796 0.9863 0.9999 0.9829 0.9931 0.9848
flat midget bipolar cell default 0.9947 0.9929 0.9997 0.9938 0.9963 0.992
warm start 0.9965 0.9947 0.9998 0.9956 0.9973 0.994
giant bipolar cell default 0.9608 0.9849 0.9997 0.9727 0.9923 0.9832
warm start 0.9850 0.9899 0.9999 0.9875 0.9949 0.9888
glycinergic amacrine cell default 0.9769 0.9816 0.999 0.9792 0.9903 0.9789
warm start 0.9715 0.9893 0.9987 0.9803 0.994 0.9872
invaginating midget bipolar cell default 0.9940 0.9868 0.9998 0.9904 0.9933 0.9853
warm start 0.9976 0.994 0.9999 0.9958 0.997 0.9933
microglial cell default 0.9500 0.9744 0.9999 0.962 0.9871 0.9718
warm start 1.0000 1.0 1.0 1.0 1.0 1.0
retinal bipolar neuron default 0.9927 0.9879 0.9999 0.9903 0.9939 0.9866
warm start 0.9903 0.9927 0.9998 0.9915 0.9963 0.9919
retinal cone cell default 1.0000 0.9914 1.0 0.9957 0.9957 0.9906
warm start 1.0000 0.9957 1.0 0.9979 0.9979 0.9953
retinal ganglion cell default 0.4304 0.5407 0.9929 0.4793 0.7327 0.5125
warm start 0.4879 0.7358 0.9923 0.5867 0.8545 0.7114
retinal rod cell default 0.9996 0.9999 0.9998 0.9998 0.9999 0.9997
warm start 0.9995 0.9998 0.9998 0.9996 0.9998 0.9995
rod bipolar cell default 0.9962 0.9962 0.9999 0.9962 0.9981 0.9957
warm start 0.9962 0.995 0.9999 0.9956 0.9974 0.9944
starburst amacrine cell default 0.9691 0.9874 0.9998 0.9782 0.9936 0.986
warm start 0.9812 0.9874 0.9999 0.9843 0.9936 0.9861
macro avg default 0.9485 0.9624 0.9993 0.955 0.9795 0.9589
warm start 0.9644 0.9778 0.9994 0.9701 0.9882 0.9754
weighted avg default 0.9791 0.9778 0.9991 0.9784 0.988 0.9752
warm start 0.9844 0.982 0.9994 0.9828 0.9904 0.9798
Accuracy default 0.9778
warm start 0.9820
Balanced accuracy default 0.9624
warm start 0.9778

Iterative warm start training led to an increase in all model quality metrics (rows - macro average, weighted average, accuracy, and balanced accuracy). Additionally, the model’s sensitivity increased by 19.5% and precision by 5.5% for the retinal ganglion cell.

[21]:
import session_info
session_info.show()
[21]:
Click to view session information
-----
anndata             0.10.8
numpy               1.25.0
pandas              2.2.3
scanpy              1.10.3
scparadise          0.4.0_beta
session_info        1.0.0
-----
Click to view modules imported as dependencies
PIL                         10.4.0
anyio                       NA
arrow                       1.3.0
asciitree                   NA
asttokens                   NA
attr                        24.2.0
attrs                       24.2.0
awkward                     2.7.1
awkward_cpp                 NA
babel                       2.16.0
backports                   NA
certifi                     2024.08.30
cffi                        1.17.1
charset_normalizer          3.3.2
cloudpickle                 3.1.0
colorlog                    NA
comm                        0.2.2
cycler                      0.12.1
cython_runtime              NA
dask                        2024.8.0
dateutil                    2.9.0.post0
debugpy                     1.8.6
decorator                   5.1.1
defusedxml                  0.7.1
exceptiongroup              1.2.2
executing                   2.1.0
fastjsonschema              NA
fqdn                        NA
fsspec                      2023.6.0
h5py                        3.12.1
idna                        3.10
igraph                      0.11.6
imblearn                    0.12.3
importlib_metadata          NA
importlib_resources         NA
ipykernel                   6.29.5
isoduration                 NA
jaraco                      NA
jedi                        0.19.1
jinja2                      3.1.4
joblib                      1.4.2
json5                       0.9.25
jsonpointer                 3.0.0
jsonschema                  4.23.0
jsonschema_specifications   NA
jupyter_events              0.10.0
jupyter_server              2.14.2
jupyterlab_server           2.27.3
kiwisolver                  1.4.7
legacy_api_wrap             NA
leidenalg                   0.10.2
llvmlite                    0.43.0
markupsafe                  2.1.5
matplotlib                  3.9.2
matplotlib_inline           0.1.7
more_itertools              10.5.0
mpl_toolkits                NA
mpmath                      1.3.0
msgpack                     1.1.0
mudata                      0.2.4
muon                        0.1.6
natsort                     8.4.0
nbformat                    5.10.4
numba                       0.60.0
numcodecs                   0.12.1
optuna                      4.0.0
overrides                   NA
packaging                   24.1
parso                       0.8.4
patsy                       0.5.6
pexpect                     4.9.0
platformdirs                4.3.6
plotly                      5.24.1
prometheus_client           NA
prompt_toolkit              3.0.48
psutil                      6.0.0
ptyprocess                  0.7.0
pure_eval                   0.2.3
pyarrow                     18.1.0
pycparser                   2.22
pydev_ipython               NA
pydevconsole                NA
pydevd                      3.1.0
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pydot                       3.0.3
pygments                    2.18.0
pynndescent                 0.5.13
pyparsing                   3.1.4
pythonjsonlogger            NA
pytorch_tabnet              NA
pytz                        2024.2
referencing                 NA
requests                    2.32.3
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rich                        NA
rpds                        NA
scipy                       1.13.1
seaborn                     0.13.2
send2trash                  NA
setuptools                  75.1.0
setuptools_scm              NA
shap                        0.46.0
six                         1.16.0
sklearn                     1.5.2
slicer                      NA
sniffio                     1.3.1
stack_data                  0.6.3
statsmodels                 0.14.3
sympy                       1.13.3
tblib                       3.0.0
texttable                   1.7.0
threadpoolctl               3.5.0
tlz                         0.12.1
tomli                       2.0.1
toolz                       0.12.1
torch                       2.4.1+cu121
torchgen                    NA
tornado                     6.4.1
tqdm                        4.66.5
traitlets                   5.14.3
triton                      3.0.0
typing_extensions           NA
umap                        0.5.6
uri_template                NA
urllib3                     1.26.20
wcwidth                     0.2.13
webcolors                   24.8.0
websocket                   1.8.0
yaml                        6.0.2
zarr                        2.18.2
zipp                        NA
zmq                         26.2.0
zoneinfo                    NA
-----
IPython             8.18.1
jupyter_client      8.6.3
jupyter_core        5.7.2
jupyterlab          4.2.5
-----
Python 3.9.19 (main, May  6 2024, 19:43:03) [GCC 11.2.0]
Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
-----
Session information updated at 2025-02-08 17:59