Managing large and complex training datasets#

Large and complex datasets are difficult to use for training models due to computational limitations (enormous amounts of RAM are required just to open such datasets).

Additionally, the time needed to select genes for model training increases.

Meanwhile, increasing the amount of data used for training does not always significantly improve model quality.

Here, we present a method to overcome computational limits for training models on large, complex datasets with many donors.

[1]:
# Python packages
import warnings
warnings.simplefilter('ignore')

import scanpy as sc
import scparadise
import numpy as np
import pandas as pd
import os

sc.set_figure_params(dpi = 120)
[2]:
# Create folder to save files
# single nuclear RNA-seq of human retina
os.makedirs('snRNAseq_human_retina', exist_ok=True)
[3]:
# Download CELLxGENE dataset (snRNA-seq of human retina):
# https://cellxgene.cziscience.com/collections/4c6eaf5c-6d57-4c76-b1e9-60df8c655f1e
!wget https://datasets.cellxgene.cziscience.com/2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad
--2025-02-08 13:49:05--  https://datasets.cellxgene.cziscience.com/2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad
Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 52.85.49.24, 52.85.49.28, 52.85.49.17, ...
Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|52.85.49.24|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 37946797973 (35G) [binary/octet-stream]
Saving to: ‘2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad’

2e910e62-7eaf-4c06- 100%[===================>]  35.34G  25.1MB/s    in 27m 20s

2025-02-08 14:13:58 (22.1 MB/s) - ‘2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad’ saved [37946797973/37946797973]

[4]:
# Check metadata using backed mode
adata = sc.read_h5ad('2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad', backed = 'r')
adata
[4]:
AnnData object with n_obs × n_vars = 3177310 × 36406 backed at '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad'
    obs: 'reference_genome', 'gene_annotation_version', 'alignment_software', 'intronic_reads_counted', 'donor_id', 'donor_age', 'self_reported_ethnicity_ontology_term_id', 'donor_cause_of_death', 'donor_living_at_sample_collection', 'organism_ontology_term_id', 'sample_id', 'sample_preservation_method', 'tissue_ontology_term_id', 'development_stage_ontology_term_id', 'sample_collection_method', 'tissue_source', 'tissue_type', 'suspension_derivation_process', 'suspension_dissociation_reagent', 'suspension_enriched_cell_types', 'suspension_enrichment_factors', 'suspension_uuid', 'suspension_type', 'tissue_handling_interval', 'library_id', 'assay_ontology_term_id', 'sequenced_fragment', 'institute', 'library_id_repository', 'sequencing_platform', 'is_primary_data', 'cell_type_ontology_term_id', 'author_cell_type', 'disease_ontology_term_id', 'sex_ontology_term_id', 'majorclass', 'AC_subclass', 'AC_cluster', 'AC_celltype_number', 'BC_subclass', 'RGC_cluster', 'RGC_celltype_number', 'study_name', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'sampleid', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
    var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type'
    uns: 'citation', 'default_embedding', 'schema_reference', 'schema_version', 'title'
    obsm: 'X_scVI', 'X_umap'

Train a scAdam model using a fraction of the dataset#

The entire dataset contains 3,177,310 cells and 36,406 genes (35.34 GB), which is too large to open on a standard computer. In addition, selecting genes and training a model on a dataset of this size require substantial computational time and resources.

Therefore, the scParadise team recommends extracting a smaller subset of the dataset for scAdam model training.

[5]:
# Obtain 25000 cells from dataset
adata_fraction = scparadise.scnoah.get_frac(
    path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
    fraction = 25000,
    stratify = 'cell_type', # Use to get the same ratio as in the full dataset
    random_state = 0
)
[6]:
# Get raw counts from adata_fraction.raw
adata_fraction = adata_fraction.raw.to_adata()
# Replace variable names with gene names
adata_fraction.var.set_index('feature_name', inplace = True)
adata_fraction.var_names = adata_fraction.var_names.astype('str')
adata_fraction.var_names_make_unique()
# Normalize data
sc.pp.normalize_total(adata_fraction, target_sum = None)
sc.pp.log1p(adata_fraction)
[7]:
# Select genes for model training (marker genes of cell types)
lst_genes = []
lst_annotations = ['majorclass', 'cell_type'] # annotation levels
for annotation in lst_annotations:
    sc.tl.rank_genes_groups(adata_fraction,
                            groupby = annotation,
                            method = 't-test_overestim_var', pts = True)
    # Filter marker genes of cell types
    sc.tl.filter_rank_genes_groups(adata_fraction,
                                   min_fold_change = 1.0,
                                   min_in_group_fraction = 0.4,
                                   key_added = 'filtered_rank_genes_groups')
    # Create list of genes for model training

    for i in adata_fraction.obs[annotation].unique():
        df = sc.get.rank_genes_groups_df(adata_fraction, group = i, key = 'filtered_rank_genes_groups', pval_cutoff = 0.05)
        df['pts_comparizon'] = df['pct_nz_group']/df['pct_nz_reference']
        lst_genes.extend(df.sort_values(by = 'logfoldchanges', ascending = False).head(20)['names'].tolist())
        lst_genes.extend(df.sort_values(by = 'pts_comparizon', ascending = False).head(20)['names'].tolist())
# Remove duplicates
lst_genes = np.unique(lst_genes).tolist()
print(f'Number of selected features: {len(lst_genes)}')
Number of selected features: 595
[8]:
# Subset genes for model training
adata_fraction = adata_fraction[:, lst_genes]
[9]:
# Alternative way to select genes for model training
# sc.pp.highly_variable_genes(adata_fraction,
#                             n_top_genes = 1000,
#                             subset = True)
# lst_genes = adata_fraction.var_names.tolist()
[9]:
adata_balanced = scparadise.scnoah.balance(
    adata_fraction,
    celltype_keys = lst_annotations
)
[10]:
# Train scadam model using adata_fraction dataset
scparadise.scadam.train(
    adata_balanced,
    path = 'snRNAseq_human_retina', # path to save model
    model_name = 'Human_Retina_scAdam', # folder name with model
    celltype_keys = lst_annotations,
    eval_metric = ['accuracy', 'balanced_accuracy']
)
Device: cuda
Number of features: 595
Label hierarchy: majorclass → cell_type
Annotation levels weights using strategy 'linear_offset':
  majorclass: 10 cell types, 0.4 relative weight
  cell_type: 31 cell types, 0.6 relative weight

Dataset split:
Train dataset contains: 20716 cells, it is 80.0 % of input dataset
Validation dataset contains: 5180 cells, it is 20.0 % of input dataset
Unsupervised pretraining: 100%|███████████████████████████████████████████████████████████████████████████| 50/50 [02:07<00:00,  2.56s/it]
Training scAdam model:   9%|██████▉                                                                      | 18/200 [00:48<08:06,  2.67s/it]
Early stopping triggered! Best score: 0.9924
Training completed!
Fitting unknown cells detector

UnknownCellDetector fitted successfully!
Model saved to snRNAseq_human_retina/Human_Retina_scAdam

Evaluation of model quality#

For model evaluation, we use another subset of 25,000 cells generated using a different random state.

[11]:
# Get test dataset for model quality evaluation
adata_test = scparadise.scnoah.get_frac(
    path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
    fraction = 25000,
    stratify = 'cell_type',
    random_state = 42
)
[12]:
# Check common cells between test and training datasets
lst_train = adata_fraction.obs_names.tolist()
lst_test = adata_test.obs_names.tolist()
lst_train.extend(lst_test)
lst_train = np.unique(lst_train)
percent = round((2 * len(lst_test) - len(lst_train))/len(lst_test)*100, 5)
print(f"There are {percent} % common cells ({2 * len(lst_test) - len(lst_train)} cells) between the test and training datasets")
There are 0.836 % common cells (209 cells) between the test and training datasets

Less than 1% of cells are the same between the test dataset and the training dataset. This number of similar cells can be ignored, and we can proceed with testing the model’s quality.

[13]:
# Apply the same preprocessing steps to the test dataset as used for training
# Get raw counts from adata_fraction.raw
adata_test = adata_test.raw.to_adata()

# Replace variable names with gene names
adata_test.var.set_index('feature_name', inplace = True)
adata_test.var_names = adata_test.var_names.astype('str')
adata_test.var_names_make_unique()

# Normalize data
sc.pp.normalize_total(adata_test, target_sum = None)
sc.pp.log1p(adata_test)
adata_test.raw = adata_test
[14]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(
    adata_test,
    path_model = 'snRNAseq_human_retina/Human_Retina_scAdam'
)
scAdam model with unknown detector loaded from snRNAseq_human_retina/Human_Retina_scAdam
Gene alignment:
  Model features: 595
  Matched features: 595 (100.0%)
Predicting: 100%|████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:00<00:00, 151.49it/s]
Added cell type column: pred_celltype_l1
Added probabilities column: pred_celltype_l1_probability
Added cell type column: pred_celltype_l2
Added probabilities column: pred_celltype_l2_probability
[15]:
## Check model quality
df_l1 = scparadise.scnoah.report_classif_full(
    adata_test,
    celltype = 'majorclass',
    pred_celltype = 'pred_celltype_l1'
)
df_l1
[15]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
AC 0.9989 0.9978 0.9998 0.9983 0.9988 0.9973 4496
Astrocyte 1.0000 0.982 1.0 0.9909 0.991 0.9802 111
BC 0.9993 0.9998 0.9998 0.9995 0.9998 0.9996 5437
Cone 0.9990 1.0 1.0 0.9995 1.0 1.0 1000
HC 0.9953 0.9984 0.9999 0.9969 0.9991 0.9982 634
MG 0.9994 0.9994 1.0 0.9994 0.9997 0.9993 1744
Microglia 1.0000 0.9744 1.0 0.987 0.9871 0.9719 39
RGC 0.9978 0.999 0.9997 0.9984 0.9994 0.9987 3144
RPE 0.8750 1.0 1.0 0.9333 1.0 1.0 7
Rod 0.9999 0.9995 0.9999 0.9997 0.9997 0.9994 8388
macro avg 0.9865 0.995 0.9999 0.9903 0.9975 0.9945
weighted avg 0.9991 0.9991 0.9998 0.9991 0.9995 0.9988
Accuracy 0.9991
Balanced accuracy 0.9950
[16]:
## Check model quality
df_l2 = scparadise.scnoah.report_classif_full(
    adata_test,
    celltype = 'cell_type',
    pred_celltype = 'pred_celltype_l2'
)
df_l2
[16]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
GABAergic amacrine cell 0.9944 0.9968 0.9993 0.9956 0.9981 0.9959 2855
H1 horizontal cell 0.9963 0.987 0.9999 0.9916 0.9935 0.9857 540
H2 horizontal cell 0.9192 0.9681 0.9997 0.943 0.9838 0.9647 94
Mueller cell 0.9994 1.0 1.0 0.9997 1.0 1.0 1744
OFF midget ganglion cell 0.8699 0.9708 0.9902 0.9176 0.9805 0.9595 1577
OFF parasol ganglion cell 0.9740 0.9494 0.9999 0.9615 0.9743 0.9445 79
OFFx cell 0.9835 0.9754 0.9999 0.9794 0.9876 0.9729 122
ON midget ganglion cell 0.9652 0.8357 0.9985 0.8958 0.9135 0.8209 1193
ON parasol ganglion cell 1.0000 0.9796 1.0 0.9897 0.9897 0.9776 49
ON-blue cone bipolar cell 1.0000 0.7826 1.0 0.878 0.8847 0.7656 23
S cone cell 0.9571 1.0 0.9999 0.9781 0.9999 0.9999 67
amacrine cell 0.9864 0.9667 0.9998 0.9764 0.9831 0.9632 450
astrocyte 1.0000 0.982 1.0 0.9909 0.991 0.9802 111
diffuse bipolar 1 cell 0.9975 0.9874 1.0 0.9924 0.9937 0.9861 397
diffuse bipolar 2 cell 0.9963 0.9944 0.9999 0.9953 0.9972 0.9938 536
diffuse bipolar 3a cell 1.0000 0.9942 1.0 0.9971 0.9971 0.9936 172
diffuse bipolar 3b cell 0.9964 1.0 1.0 0.9982 1.0 1.0 278
diffuse bipolar 4 cell 0.9949 1.0 0.9999 0.9974 1.0 0.9999 387
diffuse bipolar 6 cell 0.9931 0.9863 1.0 0.9897 0.9931 0.9849 146
flat midget bipolar cell 0.9956 0.9974 0.9998 0.9965 0.9986 0.9969 1134
giant bipolar cell 0.9706 0.995 0.9998 0.9826 0.9974 0.9943 199
glycinergic amacrine cell 0.9818 0.9913 0.9992 0.9865 0.9952 0.9897 1032
invaginating midget bipolar cell 0.9964 0.9976 0.9999 0.997 0.9987 0.9973 835
microglial cell 1.0000 0.9487 1.0 0.9737 0.974 0.9439 39
retinal bipolar neuron 0.9976 0.9951 1.0 0.9964 0.9975 0.9946 412
retinal cone cell 0.9989 0.9968 1.0 0.9979 0.9984 0.9964 933
retinal ganglion cell 0.5152 0.4837 0.9955 0.499 0.6939 0.4569 246
retinal pigment epithelial cell 1.0000 1.0 1.0 1.0 1.0 1.0 7
retinal rod cell 0.9999 0.9993 0.9999 0.9996 0.9996 0.9992 8388
rod bipolar cell 0.9975 1.0 0.9999 0.9987 1.0 0.9999 796
starburst amacrine cell 1.0000 0.9686 1.0 0.984 0.9842 0.9655 159
macro avg 0.9702 0.959 0.9994 0.9639 0.9774 0.9556
weighted avg 0.9820 0.9815 0.9991 0.9813 0.9897 0.9793
Accuracy 0.9815
Balanced accuracy 0.9590

The model performs well except for the retinal ganglion cell.

You could try using a different random state to generate another test dataset.

Iterative warm start training (optional)#

You may use another subset of the whole dataset to increase model generalization.

[18]:
# Do not set the lower bound of the range to 0, which was used for the primary training of the model
for i in range(1, 4):
    # Obtain 25000 cells randomly
    adata_fraction = scparadise.scnoah.get_frac(
        path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
        fraction = 25000,
        stratify = 'cell_type',
        random_state = i
    )

    # Get raw counts from adata_fraction.raw
    adata_fraction = adata_fraction.raw.to_adata()

    # Replace variable names with gene names
    adata_fraction.var.set_index('feature_name', inplace=True)
    adata_fraction.var_names = adata_fraction.var_names.astype('str')
    adata_fraction.var_names_make_unique()

    # Normalize data
    sc.pp.normalize_total(adata_fraction, target_sum = None)
    sc.pp.log1p(adata_fraction)
    adata_fraction.raw = adata_fraction

    # Subset genes for model training
    adata_fraction = adata_fraction[:, lst_genes]
    # Balance dataset
    adata_balanced = scparadise.scnoah.balance(
        adata_fraction,
        celltype_keys = lst_annotations
    )
    adata_balanced.raw = adata_balanced
    # Warm start requires second training dataset and path to pretrained model
    scparadise.scadam.warm_start(
        adata_balanced,
        path = 'snRNAseq_human_retina', # path to save model
        path_model = 'snRNAseq_human_retina/Human_Retina_scAdam', # folder name with pretrained model
        model_name = 'Human_Retina_scAdam',
        celltype_keys = lst_annotations,
        eval_metric = ['accuracy', 'balanced_accuracy']
    )
scAdam model with unknown detector loaded from snRNAseq_human_retina/Human_Retina_scAdam
Device: cuda
Gene alignment:
 Number of features: 595
 Matched features: 595 (100.0%)
Label hierarchy: majorclass → cell_type
Annotation levels weights using strategy 'linear_offset':
  majorclass: 10 cell types, 0.4 relative weight
  cell_type: 31 cell types, 0.6 relative weight

Dataset split:
Train dataset contains: 20716 cells, it is 80.0 % of input dataset
Validation dataset contains: 5180 cells, it is 20.0 % of input dataset
Warm-start model fine-tuning:  42%|███████████████████▎                          | 42/100 [02:01<02:47,  2.89s/it]
Early stopping triggered! Best score: 0.9943
Fitting unknown cells detector

UnknownCellDetector fitted successfully!
Model saved to snRNAseq_human_retina/Human_Retina_scAdam
scAdam model with unknown detector loaded from snRNAseq_human_retina/Human_Retina_scAdam
Device: cuda
Gene alignment:
 Number of features: 595
 Matched features: 595 (100.0%)
Label hierarchy: majorclass → cell_type
Annotation levels weights using strategy 'linear_offset':
  majorclass: 10 cell types, 0.4 relative weight
  cell_type: 31 cell types, 0.6 relative weight

Dataset split:
Train dataset contains: 20716 cells, it is 80.0 % of input dataset
Validation dataset contains: 5180 cells, it is 20.0 % of input dataset
Warm-start model fine-tuning:  22%|██████████                                    | 22/100 [01:07<03:57,  3.05s/it]
Early stopping triggered! Best score: 0.9932
Fitting unknown cells detector

UnknownCellDetector fitted successfully!
Model saved to snRNAseq_human_retina/Human_Retina_scAdam
scAdam model with unknown detector loaded from snRNAseq_human_retina/Human_Retina_scAdam
Device: cuda
Gene alignment:
 Number of features: 595
 Matched features: 595 (100.0%)
Label hierarchy: majorclass → cell_type
Annotation levels weights using strategy 'linear_offset':
  majorclass: 10 cell types, 0.4 relative weight
  cell_type: 31 cell types, 0.6 relative weight

Dataset split:
Train dataset contains: 20716 cells, it is 80.0 % of input dataset
Validation dataset contains: 5180 cells, it is 20.0 % of input dataset
Warm-start model fine-tuning:  14%|██████▍                                       | 14/100 [00:34<03:34,  2.50s/it]
Early stopping triggered! Best score: 0.9953
Fitting unknown cells detector

UnknownCellDetector fitted successfully!
Model saved to snRNAseq_human_retina/Human_Retina_scAdam
[19]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(
    adata_test,
    path_model = 'snRNAseq_human_retina/Human_Retina_scAdam'
)
scAdam model with unknown detector loaded from snRNAseq_human_retina/Human_Retina_scAdam
Gene alignment:
  Model features: 595
  Matched features: 595 (100.0%)
Predicting: 100%|████████████████████████████████████████████████████████████████| 98/98 [00:00<00:00, 133.10it/s]
Added cell type column: pred_celltype_l1
Added probabilities column: pred_celltype_l1_probability
Added cell type column: pred_celltype_l2
Added probabilities column: pred_celltype_l2_probability
[20]:
## Check model quality
df_warm_start_l1 = scparadise.scnoah.report_classif_full(
    adata_test,
    celltype='majorclass',
    pred_celltype='pred_celltype_l1'
)
df_warm_start_l1
[20]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
AC 0.9987 0.9998 0.9997 0.9992 0.9997 0.9995 4496
Astrocyte 0.9909 0.982 1.0 0.9864 0.9909 0.9802 111
BC 0.9998 0.9996 0.9999 0.9997 0.9998 0.9995 5437
Cone 1.0000 0.999 1.0 0.9995 0.9995 0.9989 1000
HC 1.0000 1.0 1.0 1.0 1.0 1.0 634
MG 1.0000 1.0 1.0 1.0 1.0 1.0 1744
Microglia 1.0000 0.9744 1.0 0.987 0.9871 0.9719 39
RGC 0.9990 0.9994 0.9999 0.9992 0.9996 0.9992 3144
RPE 1.0000 1.0 1.0 1.0 1.0 1.0 7
Rod 0.9998 0.9995 0.9999 0.9996 0.9997 0.9994 8388
macro avg 0.9988 0.9954 0.9999 0.9971 0.9976 0.9949
weighted avg 0.9995 0.9995 0.9999 0.9995 0.9997 0.9993
Accuracy 0.9995
Balanced accuracy 0.9954
[21]:
## Check model quality
df_warm_start_l2 = scparadise.scnoah.report_classif_full(
    adata_test,
    celltype='cell_type',
    pred_celltype='pred_celltype_l2'
)
df_warm_start_l2
[21]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy number of cells
GABAergic amacrine cell 0.9965 0.9954 0.9995 0.996 0.9975 0.9946 2855
H1 horizontal cell 0.9981 0.9944 1.0 0.9963 0.9972 0.9939 540
H2 horizontal cell 0.9688 0.9894 0.9999 0.9789 0.9946 0.9882 94
Mueller cell 1.0000 1.0 1.0 1.0 1.0 1.0 1744
OFF midget ganglion cell 0.9290 0.9461 0.9951 0.9375 0.9703 0.9369 1577
OFF parasol ganglion cell 0.9630 0.9873 0.9999 0.975 0.9936 0.986 79
OFFx cell 0.9918 0.9918 1.0 0.9918 0.9959 0.991 122
ON midget ganglion cell 0.9587 0.9346 0.998 0.9465 0.9658 0.9268 1193
ON parasol ganglion cell 1.0000 0.9796 1.0 0.9897 0.9897 0.9776 49
ON-blue cone bipolar cell 1.0000 0.8696 1.0 0.9302 0.9325 0.8582 23
S cone cell 0.9571 1.0 0.9999 0.9781 0.9999 0.9999 67
amacrine cell 0.9822 0.98 0.9997 0.9811 0.9898 0.9778 450
astrocyte 0.9909 0.982 1.0 0.9864 0.9909 0.9802 111
diffuse bipolar 1 cell 0.9949 0.9924 0.9999 0.9937 0.9962 0.9916 397
diffuse bipolar 2 cell 0.9925 0.9925 0.9998 0.9925 0.9962 0.9917 536
diffuse bipolar 3a cell 0.9942 0.9942 1.0 0.9942 0.9971 0.9936 172
diffuse bipolar 3b cell 0.9858 0.9964 0.9998 0.9911 0.9981 0.9959 278
diffuse bipolar 4 cell 0.9948 0.9948 0.9999 0.9948 0.9974 0.9942 387
diffuse bipolar 6 cell 0.9863 0.9863 0.9999 0.9863 0.9931 0.9849 146
flat midget bipolar cell 0.9982 0.9956 0.9999 0.9969 0.9978 0.9951 1134
giant bipolar cell 0.9851 1.0 0.9999 0.9925 0.9999 0.9999 199
glycinergic amacrine cell 0.9836 0.9903 0.9993 0.987 0.9948 0.9887 1032
invaginating midget bipolar cell 0.9964 0.9988 0.9999 0.9976 0.9993 0.9986 835
microglial cell 1.0000 0.9744 1.0 0.987 0.9871 0.9719 39
retinal bipolar neuron 0.9976 0.9927 1.0 0.9951 0.9963 0.992 412
retinal cone cell 1.0000 0.9957 1.0 0.9979 0.9979 0.9953 933
retinal ganglion cell 0.5830 0.5854 0.9958 0.5842 0.7635 0.559 246
retinal pigment epithelial cell 1.0000 1.0 1.0 1.0 1.0 1.0 7
retinal rod cell 0.9998 0.9996 0.9999 0.9997 0.9998 0.9995 8388
rod bipolar cell 0.9987 0.9987 1.0 0.9987 0.9994 0.9986 796
starburst amacrine cell 0.9812 0.9874 0.9999 0.9843 0.9936 0.9861 159
macro avg 0.9745 0.9718 0.9995 0.9729 0.9847 0.9693
weighted avg 0.9864 0.9863 0.9994 0.9863 0.9925 0.9847
Accuracy 0.9863
Balanced accuracy 0.9718
[22]:
pd.set_option('display.max_rows', 100)
df_l2.compare(df_warm_start_l2, keep_equal=True, align_axis = 0, result_names=('default', 'warm start'))
[22]:
precision recall/sensitivity specificity f1-score geometric mean index balanced accuracy
GABAergic amacrine cell default 0.9934 0.9961 0.9991 0.9948 0.9976 0.995
warm start 0.9965 0.9954 0.9995 0.996 0.9975 0.9946
H1 horizontal cell default 0.9926 0.9944 0.9998 0.9935 0.9971 0.9937
warm start 0.9981 0.9944 1.0 0.9963 0.9972 0.9939
H2 horizontal cell default 0.9677 0.9574 0.9999 0.9626 0.9784 0.9533
warm start 0.9688 0.9894 0.9999 0.9789 0.9946 0.9882
Mueller cell default 0.9994 0.9994 1.0 0.9994 0.9997 0.9993
warm start 1.0000 1.0 1.0 1.0 1.0 1.0
OFF midget ganglion cell default 0.9079 0.9562 0.9935 0.9314 0.9747 0.9465
warm start 0.9290 0.9461 0.9951 0.9375 0.9703 0.9369
OFF parasol ganglion cell default 0.9487 0.9367 0.9998 0.9427 0.9678 0.9306
warm start 0.9630 0.9873 0.9999 0.975 0.9936 0.986
OFFx cell default 0.9835 0.9754 0.9999 0.9794 0.9876 0.9729
warm start 0.9918 0.9918 1.0 0.9918 0.9959 0.991
ON midget ganglion cell default 0.9521 0.9162 0.9977 0.9338 0.9561 0.9066
warm start 0.9587 0.9346 0.998 0.9465 0.9658 0.9268
ON parasol ganglion cell default 0.9796 0.9796 1.0 0.9796 0.9897 0.9776
warm start 1.0000 0.9796 1.0 0.9897 0.9897 0.9776
S cone cell default 0.9844 0.9403 1.0 0.9618 0.9697 0.9347
warm start 0.9571 1.0 0.9999 0.9781 0.9999 0.9999
amacrine cell default 0.9863 0.9622 0.9998 0.9741 0.9808 0.9584
warm start 0.9822 0.98 0.9997 0.9811 0.9898 0.9778
astrocyte default 1.0000 0.982 1.0 0.9909 0.991 0.9802
warm start 0.9909 0.982 1.0 0.9864 0.9909 0.9802
diffuse bipolar 1 cell default 0.9975 0.9874 1.0 0.9924 0.9937 0.9861
warm start 0.9949 0.9924 0.9999 0.9937 0.9962 0.9916
diffuse bipolar 2 cell default 0.9926 0.9963 0.9998 0.9944 0.9981 0.9958
warm start 0.9925 0.9925 0.9998 0.9925 0.9962 0.9917
diffuse bipolar 3a cell default 1.0000 0.9826 1.0 0.9912 0.9912 0.9808
warm start 0.9942 0.9942 1.0 0.9942 0.9971 0.9936
diffuse bipolar 3b cell default 0.9964 0.9928 1.0 0.9946 0.9964 0.9921
warm start 0.9858 0.9964 0.9998 0.9911 0.9981 0.9959
diffuse bipolar 4 cell default 0.9923 1.0 0.9999 0.9961 0.9999 0.9999
warm start 0.9948 0.9948 0.9999 0.9948 0.9974 0.9942
diffuse bipolar 6 cell default 0.9931 0.9863 1.0 0.9897 0.9931 0.9849
warm start 0.9863 0.9863 0.9999 0.9863 0.9931 0.9849
flat midget bipolar cell default 0.9956 0.9974 0.9998 0.9965 0.9986 0.9969
warm start 0.9982 0.9956 0.9999 0.9969 0.9978 0.9951
giant bipolar cell default 0.9802 0.995 0.9998 0.9875 0.9974 0.9943
warm start 0.9851 1.0 0.9999 0.9925 0.9999 0.9999
glycinergic amacrine cell default 0.9790 0.9922 0.9991 0.9856 0.9957 0.9907
warm start 0.9836 0.9903 0.9993 0.987 0.9948 0.9887
invaginating midget bipolar cell default 0.9964 0.9964 0.9999 0.9964 0.9981 0.9959
warm start 0.9964 0.9988 0.9999 0.9976 0.9993 0.9986
microglial cell default 1.0000 0.9487 1.0 0.9737 0.974 0.9439
warm start 1.0000 0.9744 1.0 0.987 0.9871 0.9719
retinal bipolar neuron default 0.9951 0.9951 0.9999 0.9951 0.9975 0.9946
warm start 0.9976 0.9927 1.0 0.9951 0.9963 0.992
retinal cone cell default 0.9947 0.9989 0.9998 0.9968 0.9994 0.9986
warm start 1.0000 0.9957 1.0 0.9979 0.9979 0.9953
retinal ganglion cell default 0.5749 0.4837 0.9964 0.5254 0.6943 0.4573
warm start 0.5830 0.5854 0.9958 0.5842 0.7635 0.559
retinal rod cell default 0.9999 0.9995 0.9999 0.9997 0.9997 0.9994
warm start 0.9998 0.9996 0.9999 0.9997 0.9998 0.9995
rod bipolar cell default 0.9962 0.9987 0.9999 0.9975 0.9993 0.9985
warm start 0.9987 0.9987 1.0 0.9987 0.9994 0.9986
starburst amacrine cell default 1.0000 0.9811 1.0 0.9905 0.9905 0.9793
warm start 0.9812 0.9874 0.9999 0.9843 0.9936 0.9861
macro avg default 0.9735 0.9612 0.9995 0.967 0.9787 0.9579
warm start 0.9745 0.9718 0.9995 0.9729 0.9847 0.9693
weighted avg default 0.9839 0.9843 0.9992 0.984 0.9913 0.9824
warm start 0.9864 0.9863 0.9994 0.9863 0.9925 0.9847
Accuracy default 0.9843
warm start 0.9863
Balanced accuracy default 0.9612
warm start 0.9718

Iterative warm start training led to an increase in all model quality metrics (rows - macro average, weighted average, accuracy, and balanced accuracy). Additionally, the model’s sensitivity increased by 10.17% and precision by 0.81% for the retinal ganglion cell.

[23]:
import session_info
session_info.show()
[23]:
Click to view session information
-----
anndata             0.11.4
numpy               1.26.4
pandas              2.3.3
scanpy              1.11.4
scparadise          1.0.0
session_info        v1.0.1
-----
Click to view modules imported as dependencies
PIL                         11.3.0
anyio                       NA
arrow                       1.3.0
asttokens                   NA
attr                        25.3.0
attrs                       25.3.0
babel                       2.17.0
backports                   NA
certifi                     2025.08.03
cffi                        1.17.1
charset_normalizer          3.4.3
cloudpickle                 3.1.1
colorlog                    NA
comm                        0.2.3
cupy                        13.6.0
cupy_backends               NA
cupyx                       NA
cycler                      0.12.1
cython_runtime              NA
dask                        2025.10.0
dateutil                    2.9.0.post0
debugpy                     1.8.16
decorator                   5.2.1
defusedxml                  0.7.1
exceptiongroup              1.3.0
executing                   2.2.1
fastjsonschema              NA
fastrlock                   0.8.3
fqdn                        NA
fsspec                      2025.7.0
h5py                        3.14.0
idna                        3.10
igraph                      0.11.9
imblearn                    0.14.0
importlib_metadata          NA
ipykernel                   6.30.1
ipywidgets                  8.1.8
isoduration                 NA
jaraco                      NA
jedi                        0.19.2
jinja2                      3.1.6
joblib                      1.5.2
json5                       0.12.1
jsonpointer                 3.0.0
jsonschema                  4.25.1
jsonschema_specifications   NA
jupyter_events              0.12.0
jupyter_server              2.17.0
jupyterlab_server           2.28.0
kiwisolver                  1.4.9
lark                        1.2.2
lazy_loader                 0.4
legacy_api_wrap             NA
leidenalg                   0.10.2
llvmlite                    0.44.0
louvain                     0.8.2
markupsafe                  3.0.2
matplotlib                  3.10.6
matplotlib_inline           0.1.7
more_itertools              10.3.0
mpl_toolkits                NA
mpmath                      1.3.0
mudata                      0.3.2
muon                        0.1.7
natsort                     8.4.0
nbformat                    5.10.4
numba                       0.61.2
numexpr                     2.14.1
optuna                      4.5.0
overrides                   NA
packaging                   25.0
parso                       0.8.5
patsy                       1.0.1
pkg_resources               NA
platformdirs                4.4.0
plottable                   0.1.5
prometheus_client           NA
prompt_toolkit              3.0.52
psutil                      7.0.0
pure_eval                   0.2.3
pyarrow                     22.0.0
pycparser                   2.22
pydev_ipython               NA
pydevconsole                NA
pydevd                      3.2.3
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pydot                       4.0.1
pygments                    2.19.2
pynndescent                 0.5.13
pyparsing                   3.2.3
pythonjsonlogger            NA
pytorch_tabnet              NA
pytz                        2025.2
referencing                 NA
requests                    2.32.5
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rfc3987_syntax              NA
rpds                        NA
scipy                       1.12.0
seaborn                     0.13.2
send2trash                  NA
shap                        0.48.0
simplejson                  3.20.1
six                         1.17.0
skimage                     0.25.2
sklearn                     1.7.1
slicer                      NA
sniffio                     1.3.1
sparse                      0.17.0
stack_data                  0.6.3
statsmodels                 0.14.5
sympy                       1.14.0
texttable                   1.7.0
threadpoolctl               3.6.0
tlz                         1.0.0
toolz                       1.0.0
torch                       2.8.0+cu128
torchgen                    NA
tornado                     6.5.2
tqdm                        4.64.1
traitlets                   5.14.3
triton                      3.4.0
typing_extensions           NA
umap                        0.5.9.post2
uri_template                NA
urllib3                     2.5.0
wcwidth                     0.2.13
webcolors                   NA
websocket                   1.8.0
xarray                      2025.6.1
yaml                        6.0.2
zipp                        NA
zmq                         27.0.2
zoneinfo                    NA
-----
IPython             8.37.0
jupyter_client      8.6.3
jupyter_core        5.8.1
jupyterlab          4.5.1
-----
Python 3.10.18 (main, Jun  5 2025, 13:14:17) [GCC 11.2.0]
Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.35
-----
Session information updated at 2026-02-03 19:56