Managing large and complex training datasets#
Large and complex datasets are difficult to use for training models due to computational limitations (enormous amounts of RAM are required just to open such datasets).
Additionally, the time needed to select genes for model training increases.
Meanwhile, increasing the amount of data used for training does not always significantly improve model quality.
Here, we present a method to overcome computational limits for training models on large, complex datasets with many donors.
[1]:
# Python packages
import warnings
warnings.simplefilter('ignore')
import scanpy as sc
import scparadise
import numpy as np
import pandas as pd
import os
sc.set_figure_params(dpi = 120)
[2]:
# Create folder to save files
# single nuclear RNA-seq of human retina
os.makedirs('snRNAseq_human_retina', exist_ok=True)
[3]:
# Download CELLxGENE dataset (snRNA-seq of human retina):
# https://cellxgene.cziscience.com/collections/4c6eaf5c-6d57-4c76-b1e9-60df8c655f1e
!wget https://datasets.cellxgene.cziscience.com/2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad
--2025-02-08 13:49:05-- https://datasets.cellxgene.cziscience.com/2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad
Resolving datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)... 52.85.49.24, 52.85.49.28, 52.85.49.17, ...
Connecting to datasets.cellxgene.cziscience.com (datasets.cellxgene.cziscience.com)|52.85.49.24|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 37946797973 (35G) [binary/octet-stream]
Saving to: ‘2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad’
2e910e62-7eaf-4c06- 100%[===================>] 35.34G 25.1MB/s in 27m 20s
2025-02-08 14:13:58 (22.1 MB/s) - ‘2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad’ saved [37946797973/37946797973]
[4]:
# Check metadata using backed mode
adata = sc.read_h5ad('2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad', backed = 'r')
adata
[4]:
AnnData object with n_obs × n_vars = 3177310 × 36406 backed at '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad'
obs: 'reference_genome', 'gene_annotation_version', 'alignment_software', 'intronic_reads_counted', 'donor_id', 'donor_age', 'self_reported_ethnicity_ontology_term_id', 'donor_cause_of_death', 'donor_living_at_sample_collection', 'organism_ontology_term_id', 'sample_id', 'sample_preservation_method', 'tissue_ontology_term_id', 'development_stage_ontology_term_id', 'sample_collection_method', 'tissue_source', 'tissue_type', 'suspension_derivation_process', 'suspension_dissociation_reagent', 'suspension_enriched_cell_types', 'suspension_enrichment_factors', 'suspension_uuid', 'suspension_type', 'tissue_handling_interval', 'library_id', 'assay_ontology_term_id', 'sequenced_fragment', 'institute', 'library_id_repository', 'sequencing_platform', 'is_primary_data', 'cell_type_ontology_term_id', 'author_cell_type', 'disease_ontology_term_id', 'sex_ontology_term_id', 'majorclass', 'AC_subclass', 'AC_cluster', 'AC_celltype_number', 'BC_subclass', 'RGC_cluster', 'RGC_celltype_number', 'study_name', 'nCount_RNA', 'nFeature_RNA', 'percent.mt', 'sampleid', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
var: 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type'
uns: 'citation', 'default_embedding', 'schema_reference', 'schema_version', 'title'
obsm: 'X_scVI', 'X_umap'
Train a scAdam model using a fraction of the dataset#
The entire dataset contains 3,177,310 cells and 36,406 genes (35.34 GB), which is too large to open on a standard computer. In addition, selecting genes and training a model on a dataset of this size require substantial computational time and resources.
Therefore, the scParadise team recommends extracting a smaller subset of the dataset for scAdam model training.
[5]:
# Obtain 25000 cells from dataset
adata_fraction = scparadise.scnoah.get_frac(
path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
fraction = 25000,
stratify = 'cell_type', # Use to get the same ratio as in the full dataset
random_state = 0
)
[6]:
# Get raw counts from adata_fraction.raw
adata_fraction = adata_fraction.raw.to_adata()
# Replace variable names with gene names
adata_fraction.var.set_index('feature_name', inplace = True)
adata_fraction.var_names = adata_fraction.var_names.astype('str')
adata_fraction.var_names_make_unique()
# Normalize data
sc.pp.normalize_total(adata_fraction, target_sum = None)
sc.pp.log1p(adata_fraction)
[7]:
# Select genes for model training (marker genes of cell types)
lst_genes = []
lst_annotations = ['majorclass', 'cell_type'] # annotation levels
for annotation in lst_annotations:
sc.tl.rank_genes_groups(adata_fraction,
groupby = annotation,
method = 't-test_overestim_var', pts = True)
# Filter marker genes of cell types
sc.tl.filter_rank_genes_groups(adata_fraction,
min_fold_change = 1.0,
min_in_group_fraction = 0.4,
key_added = 'filtered_rank_genes_groups')
# Create list of genes for model training
for i in adata_fraction.obs[annotation].unique():
df = sc.get.rank_genes_groups_df(adata_fraction, group = i, key = 'filtered_rank_genes_groups', pval_cutoff = 0.05)
df['pts_comparizon'] = df['pct_nz_group']/df['pct_nz_reference']
lst_genes.extend(df.sort_values(by = 'logfoldchanges', ascending = False).head(20)['names'].tolist())
lst_genes.extend(df.sort_values(by = 'pts_comparizon', ascending = False).head(20)['names'].tolist())
# Remove duplicates
lst_genes = np.unique(lst_genes).tolist()
print(f'Number of selected features: {len(lst_genes)}')
Number of selected features: 595
[8]:
# Subset genes for model training
adata_fraction = adata_fraction[:, lst_genes]
[9]:
# Alternative way to select genes for model training
# sc.pp.highly_variable_genes(adata_fraction,
# n_top_genes = 1000,
# subset = True)
# lst_genes = adata_fraction.var_names.tolist()
[9]:
adata_balanced = scparadise.scnoah.balance(
adata_fraction,
celltype_keys = lst_annotations
)
[10]:
# Train scadam model using adata_fraction dataset
scparadise.scadam.train(
adata_balanced,
path = 'snRNAseq_human_retina', # path to save model
model_name = 'Human_Retina_scAdam', # folder name with model
celltype_keys = lst_annotations,
eval_metric = ['accuracy', 'balanced_accuracy']
)
Device: cuda
Number of features: 595
Label hierarchy: majorclass → cell_type
Annotation levels weights using strategy 'linear_offset':
majorclass: 10 cell types, 0.4 relative weight
cell_type: 31 cell types, 0.6 relative weight
Dataset split:
Train dataset contains: 20716 cells, it is 80.0 % of input dataset
Validation dataset contains: 5180 cells, it is 20.0 % of input dataset
Unsupervised pretraining: 100%|███████████████████████████████████████████████████████████████████████████| 50/50 [02:07<00:00, 2.56s/it]
Training scAdam model: 9%|██████▉ | 18/200 [00:48<08:06, 2.67s/it]
Early stopping triggered! Best score: 0.9924
Training completed!
Fitting unknown cells detector
UnknownCellDetector fitted successfully!
Model saved to snRNAseq_human_retina/Human_Retina_scAdam
Evaluation of model quality#
For model evaluation, we use another subset of 25,000 cells generated using a different random state.
[11]:
# Get test dataset for model quality evaluation
adata_test = scparadise.scnoah.get_frac(
path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
fraction = 25000,
stratify = 'cell_type',
random_state = 42
)
[12]:
# Check common cells between test and training datasets
lst_train = adata_fraction.obs_names.tolist()
lst_test = adata_test.obs_names.tolist()
lst_train.extend(lst_test)
lst_train = np.unique(lst_train)
percent = round((2 * len(lst_test) - len(lst_train))/len(lst_test)*100, 5)
print(f"There are {percent} % common cells ({2 * len(lst_test) - len(lst_train)} cells) between the test and training datasets")
There are 0.836 % common cells (209 cells) between the test and training datasets
Less than 1% of cells are the same between the test dataset and the training dataset. This number of similar cells can be ignored, and we can proceed with testing the model’s quality.
[13]:
# Apply the same preprocessing steps to the test dataset as used for training
# Get raw counts from adata_fraction.raw
adata_test = adata_test.raw.to_adata()
# Replace variable names with gene names
adata_test.var.set_index('feature_name', inplace = True)
adata_test.var_names = adata_test.var_names.astype('str')
adata_test.var_names_make_unique()
# Normalize data
sc.pp.normalize_total(adata_test, target_sum = None)
sc.pp.log1p(adata_test)
adata_test.raw = adata_test
[14]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(
adata_test,
path_model = 'snRNAseq_human_retina/Human_Retina_scAdam'
)
scAdam model with unknown detector loaded from snRNAseq_human_retina/Human_Retina_scAdam
Gene alignment:
Model features: 595
Matched features: 595 (100.0%)
Predicting: 100%|████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:00<00:00, 151.49it/s]
Added cell type column: pred_celltype_l1
Added probabilities column: pred_celltype_l1_probability
Added cell type column: pred_celltype_l2
Added probabilities column: pred_celltype_l2_probability
[15]:
## Check model quality
df_l1 = scparadise.scnoah.report_classif_full(
adata_test,
celltype = 'majorclass',
pred_celltype = 'pred_celltype_l1'
)
df_l1
[15]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| AC | 0.9989 | 0.9978 | 0.9998 | 0.9983 | 0.9988 | 0.9973 | 4496 |
| Astrocyte | 1.0000 | 0.982 | 1.0 | 0.9909 | 0.991 | 0.9802 | 111 |
| BC | 0.9993 | 0.9998 | 0.9998 | 0.9995 | 0.9998 | 0.9996 | 5437 |
| Cone | 0.9990 | 1.0 | 1.0 | 0.9995 | 1.0 | 1.0 | 1000 |
| HC | 0.9953 | 0.9984 | 0.9999 | 0.9969 | 0.9991 | 0.9982 | 634 |
| MG | 0.9994 | 0.9994 | 1.0 | 0.9994 | 0.9997 | 0.9993 | 1744 |
| Microglia | 1.0000 | 0.9744 | 1.0 | 0.987 | 0.9871 | 0.9719 | 39 |
| RGC | 0.9978 | 0.999 | 0.9997 | 0.9984 | 0.9994 | 0.9987 | 3144 |
| RPE | 0.8750 | 1.0 | 1.0 | 0.9333 | 1.0 | 1.0 | 7 |
| Rod | 0.9999 | 0.9995 | 0.9999 | 0.9997 | 0.9997 | 0.9994 | 8388 |
| macro avg | 0.9865 | 0.995 | 0.9999 | 0.9903 | 0.9975 | 0.9945 | |
| weighted avg | 0.9991 | 0.9991 | 0.9998 | 0.9991 | 0.9995 | 0.9988 | |
| Accuracy | 0.9991 | ||||||
| Balanced accuracy | 0.9950 |
[16]:
## Check model quality
df_l2 = scparadise.scnoah.report_classif_full(
adata_test,
celltype = 'cell_type',
pred_celltype = 'pred_celltype_l2'
)
df_l2
[16]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| GABAergic amacrine cell | 0.9944 | 0.9968 | 0.9993 | 0.9956 | 0.9981 | 0.9959 | 2855 |
| H1 horizontal cell | 0.9963 | 0.987 | 0.9999 | 0.9916 | 0.9935 | 0.9857 | 540 |
| H2 horizontal cell | 0.9192 | 0.9681 | 0.9997 | 0.943 | 0.9838 | 0.9647 | 94 |
| Mueller cell | 0.9994 | 1.0 | 1.0 | 0.9997 | 1.0 | 1.0 | 1744 |
| OFF midget ganglion cell | 0.8699 | 0.9708 | 0.9902 | 0.9176 | 0.9805 | 0.9595 | 1577 |
| OFF parasol ganglion cell | 0.9740 | 0.9494 | 0.9999 | 0.9615 | 0.9743 | 0.9445 | 79 |
| OFFx cell | 0.9835 | 0.9754 | 0.9999 | 0.9794 | 0.9876 | 0.9729 | 122 |
| ON midget ganglion cell | 0.9652 | 0.8357 | 0.9985 | 0.8958 | 0.9135 | 0.8209 | 1193 |
| ON parasol ganglion cell | 1.0000 | 0.9796 | 1.0 | 0.9897 | 0.9897 | 0.9776 | 49 |
| ON-blue cone bipolar cell | 1.0000 | 0.7826 | 1.0 | 0.878 | 0.8847 | 0.7656 | 23 |
| S cone cell | 0.9571 | 1.0 | 0.9999 | 0.9781 | 0.9999 | 0.9999 | 67 |
| amacrine cell | 0.9864 | 0.9667 | 0.9998 | 0.9764 | 0.9831 | 0.9632 | 450 |
| astrocyte | 1.0000 | 0.982 | 1.0 | 0.9909 | 0.991 | 0.9802 | 111 |
| diffuse bipolar 1 cell | 0.9975 | 0.9874 | 1.0 | 0.9924 | 0.9937 | 0.9861 | 397 |
| diffuse bipolar 2 cell | 0.9963 | 0.9944 | 0.9999 | 0.9953 | 0.9972 | 0.9938 | 536 |
| diffuse bipolar 3a cell | 1.0000 | 0.9942 | 1.0 | 0.9971 | 0.9971 | 0.9936 | 172 |
| diffuse bipolar 3b cell | 0.9964 | 1.0 | 1.0 | 0.9982 | 1.0 | 1.0 | 278 |
| diffuse bipolar 4 cell | 0.9949 | 1.0 | 0.9999 | 0.9974 | 1.0 | 0.9999 | 387 |
| diffuse bipolar 6 cell | 0.9931 | 0.9863 | 1.0 | 0.9897 | 0.9931 | 0.9849 | 146 |
| flat midget bipolar cell | 0.9956 | 0.9974 | 0.9998 | 0.9965 | 0.9986 | 0.9969 | 1134 |
| giant bipolar cell | 0.9706 | 0.995 | 0.9998 | 0.9826 | 0.9974 | 0.9943 | 199 |
| glycinergic amacrine cell | 0.9818 | 0.9913 | 0.9992 | 0.9865 | 0.9952 | 0.9897 | 1032 |
| invaginating midget bipolar cell | 0.9964 | 0.9976 | 0.9999 | 0.997 | 0.9987 | 0.9973 | 835 |
| microglial cell | 1.0000 | 0.9487 | 1.0 | 0.9737 | 0.974 | 0.9439 | 39 |
| retinal bipolar neuron | 0.9976 | 0.9951 | 1.0 | 0.9964 | 0.9975 | 0.9946 | 412 |
| retinal cone cell | 0.9989 | 0.9968 | 1.0 | 0.9979 | 0.9984 | 0.9964 | 933 |
| retinal ganglion cell | 0.5152 | 0.4837 | 0.9955 | 0.499 | 0.6939 | 0.4569 | 246 |
| retinal pigment epithelial cell | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 7 |
| retinal rod cell | 0.9999 | 0.9993 | 0.9999 | 0.9996 | 0.9996 | 0.9992 | 8388 |
| rod bipolar cell | 0.9975 | 1.0 | 0.9999 | 0.9987 | 1.0 | 0.9999 | 796 |
| starburst amacrine cell | 1.0000 | 0.9686 | 1.0 | 0.984 | 0.9842 | 0.9655 | 159 |
| macro avg | 0.9702 | 0.959 | 0.9994 | 0.9639 | 0.9774 | 0.9556 | |
| weighted avg | 0.9820 | 0.9815 | 0.9991 | 0.9813 | 0.9897 | 0.9793 | |
| Accuracy | 0.9815 | ||||||
| Balanced accuracy | 0.9590 |
The model performs well except for the retinal ganglion cell.
You could try using a different random state to generate another test dataset.
Iterative warm start training (optional)#
You may use another subset of the whole dataset to increase model generalization.
[18]:
# Do not set the lower bound of the range to 0, which was used for the primary training of the model
for i in range(1, 4):
# Obtain 25000 cells randomly
adata_fraction = scparadise.scnoah.get_frac(
path = '2e910e62-7eaf-4c06-80cb-8918e3eea16e.h5ad',
fraction = 25000,
stratify = 'cell_type',
random_state = i
)
# Get raw counts from adata_fraction.raw
adata_fraction = adata_fraction.raw.to_adata()
# Replace variable names with gene names
adata_fraction.var.set_index('feature_name', inplace=True)
adata_fraction.var_names = adata_fraction.var_names.astype('str')
adata_fraction.var_names_make_unique()
# Normalize data
sc.pp.normalize_total(adata_fraction, target_sum = None)
sc.pp.log1p(adata_fraction)
adata_fraction.raw = adata_fraction
# Subset genes for model training
adata_fraction = adata_fraction[:, lst_genes]
# Balance dataset
adata_balanced = scparadise.scnoah.balance(
adata_fraction,
celltype_keys = lst_annotations
)
adata_balanced.raw = adata_balanced
# Warm start requires second training dataset and path to pretrained model
scparadise.scadam.warm_start(
adata_balanced,
path = 'snRNAseq_human_retina', # path to save model
path_model = 'snRNAseq_human_retina/Human_Retina_scAdam', # folder name with pretrained model
model_name = 'Human_Retina_scAdam',
celltype_keys = lst_annotations,
eval_metric = ['accuracy', 'balanced_accuracy']
)
scAdam model with unknown detector loaded from snRNAseq_human_retina/Human_Retina_scAdam
Device: cuda
Gene alignment:
Number of features: 595
Matched features: 595 (100.0%)
Label hierarchy: majorclass → cell_type
Annotation levels weights using strategy 'linear_offset':
majorclass: 10 cell types, 0.4 relative weight
cell_type: 31 cell types, 0.6 relative weight
Dataset split:
Train dataset contains: 20716 cells, it is 80.0 % of input dataset
Validation dataset contains: 5180 cells, it is 20.0 % of input dataset
Warm-start model fine-tuning: 42%|███████████████████▎ | 42/100 [02:01<02:47, 2.89s/it]
Early stopping triggered! Best score: 0.9943
Fitting unknown cells detector
UnknownCellDetector fitted successfully!
Model saved to snRNAseq_human_retina/Human_Retina_scAdam
scAdam model with unknown detector loaded from snRNAseq_human_retina/Human_Retina_scAdam
Device: cuda
Gene alignment:
Number of features: 595
Matched features: 595 (100.0%)
Label hierarchy: majorclass → cell_type
Annotation levels weights using strategy 'linear_offset':
majorclass: 10 cell types, 0.4 relative weight
cell_type: 31 cell types, 0.6 relative weight
Dataset split:
Train dataset contains: 20716 cells, it is 80.0 % of input dataset
Validation dataset contains: 5180 cells, it is 20.0 % of input dataset
Warm-start model fine-tuning: 22%|██████████ | 22/100 [01:07<03:57, 3.05s/it]
Early stopping triggered! Best score: 0.9932
Fitting unknown cells detector
UnknownCellDetector fitted successfully!
Model saved to snRNAseq_human_retina/Human_Retina_scAdam
scAdam model with unknown detector loaded from snRNAseq_human_retina/Human_Retina_scAdam
Device: cuda
Gene alignment:
Number of features: 595
Matched features: 595 (100.0%)
Label hierarchy: majorclass → cell_type
Annotation levels weights using strategy 'linear_offset':
majorclass: 10 cell types, 0.4 relative weight
cell_type: 31 cell types, 0.6 relative weight
Dataset split:
Train dataset contains: 20716 cells, it is 80.0 % of input dataset
Validation dataset contains: 5180 cells, it is 20.0 % of input dataset
Warm-start model fine-tuning: 14%|██████▍ | 14/100 [00:34<03:34, 2.50s/it]
Early stopping triggered! Best score: 0.9953
Fitting unknown cells detector
UnknownCellDetector fitted successfully!
Model saved to snRNAseq_human_retina/Human_Retina_scAdam
[19]:
# Predict cell types using trained model
adata_test = scparadise.scadam.predict(
adata_test,
path_model = 'snRNAseq_human_retina/Human_Retina_scAdam'
)
scAdam model with unknown detector loaded from snRNAseq_human_retina/Human_Retina_scAdam
Gene alignment:
Model features: 595
Matched features: 595 (100.0%)
Predicting: 100%|████████████████████████████████████████████████████████████████| 98/98 [00:00<00:00, 133.10it/s]
Added cell type column: pred_celltype_l1
Added probabilities column: pred_celltype_l1_probability
Added cell type column: pred_celltype_l2
Added probabilities column: pred_celltype_l2_probability
[20]:
## Check model quality
df_warm_start_l1 = scparadise.scnoah.report_classif_full(
adata_test,
celltype='majorclass',
pred_celltype='pred_celltype_l1'
)
df_warm_start_l1
[20]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| AC | 0.9987 | 0.9998 | 0.9997 | 0.9992 | 0.9997 | 0.9995 | 4496 |
| Astrocyte | 0.9909 | 0.982 | 1.0 | 0.9864 | 0.9909 | 0.9802 | 111 |
| BC | 0.9998 | 0.9996 | 0.9999 | 0.9997 | 0.9998 | 0.9995 | 5437 |
| Cone | 1.0000 | 0.999 | 1.0 | 0.9995 | 0.9995 | 0.9989 | 1000 |
| HC | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 634 |
| MG | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1744 |
| Microglia | 1.0000 | 0.9744 | 1.0 | 0.987 | 0.9871 | 0.9719 | 39 |
| RGC | 0.9990 | 0.9994 | 0.9999 | 0.9992 | 0.9996 | 0.9992 | 3144 |
| RPE | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 7 |
| Rod | 0.9998 | 0.9995 | 0.9999 | 0.9996 | 0.9997 | 0.9994 | 8388 |
| macro avg | 0.9988 | 0.9954 | 0.9999 | 0.9971 | 0.9976 | 0.9949 | |
| weighted avg | 0.9995 | 0.9995 | 0.9999 | 0.9995 | 0.9997 | 0.9993 | |
| Accuracy | 0.9995 | ||||||
| Balanced accuracy | 0.9954 |
[21]:
## Check model quality
df_warm_start_l2 = scparadise.scnoah.report_classif_full(
adata_test,
celltype='cell_type',
pred_celltype='pred_celltype_l2'
)
df_warm_start_l2
[21]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | number of cells | |
|---|---|---|---|---|---|---|---|
| GABAergic amacrine cell | 0.9965 | 0.9954 | 0.9995 | 0.996 | 0.9975 | 0.9946 | 2855 |
| H1 horizontal cell | 0.9981 | 0.9944 | 1.0 | 0.9963 | 0.9972 | 0.9939 | 540 |
| H2 horizontal cell | 0.9688 | 0.9894 | 0.9999 | 0.9789 | 0.9946 | 0.9882 | 94 |
| Mueller cell | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1744 |
| OFF midget ganglion cell | 0.9290 | 0.9461 | 0.9951 | 0.9375 | 0.9703 | 0.9369 | 1577 |
| OFF parasol ganglion cell | 0.9630 | 0.9873 | 0.9999 | 0.975 | 0.9936 | 0.986 | 79 |
| OFFx cell | 0.9918 | 0.9918 | 1.0 | 0.9918 | 0.9959 | 0.991 | 122 |
| ON midget ganglion cell | 0.9587 | 0.9346 | 0.998 | 0.9465 | 0.9658 | 0.9268 | 1193 |
| ON parasol ganglion cell | 1.0000 | 0.9796 | 1.0 | 0.9897 | 0.9897 | 0.9776 | 49 |
| ON-blue cone bipolar cell | 1.0000 | 0.8696 | 1.0 | 0.9302 | 0.9325 | 0.8582 | 23 |
| S cone cell | 0.9571 | 1.0 | 0.9999 | 0.9781 | 0.9999 | 0.9999 | 67 |
| amacrine cell | 0.9822 | 0.98 | 0.9997 | 0.9811 | 0.9898 | 0.9778 | 450 |
| astrocyte | 0.9909 | 0.982 | 1.0 | 0.9864 | 0.9909 | 0.9802 | 111 |
| diffuse bipolar 1 cell | 0.9949 | 0.9924 | 0.9999 | 0.9937 | 0.9962 | 0.9916 | 397 |
| diffuse bipolar 2 cell | 0.9925 | 0.9925 | 0.9998 | 0.9925 | 0.9962 | 0.9917 | 536 |
| diffuse bipolar 3a cell | 0.9942 | 0.9942 | 1.0 | 0.9942 | 0.9971 | 0.9936 | 172 |
| diffuse bipolar 3b cell | 0.9858 | 0.9964 | 0.9998 | 0.9911 | 0.9981 | 0.9959 | 278 |
| diffuse bipolar 4 cell | 0.9948 | 0.9948 | 0.9999 | 0.9948 | 0.9974 | 0.9942 | 387 |
| diffuse bipolar 6 cell | 0.9863 | 0.9863 | 0.9999 | 0.9863 | 0.9931 | 0.9849 | 146 |
| flat midget bipolar cell | 0.9982 | 0.9956 | 0.9999 | 0.9969 | 0.9978 | 0.9951 | 1134 |
| giant bipolar cell | 0.9851 | 1.0 | 0.9999 | 0.9925 | 0.9999 | 0.9999 | 199 |
| glycinergic amacrine cell | 0.9836 | 0.9903 | 0.9993 | 0.987 | 0.9948 | 0.9887 | 1032 |
| invaginating midget bipolar cell | 0.9964 | 0.9988 | 0.9999 | 0.9976 | 0.9993 | 0.9986 | 835 |
| microglial cell | 1.0000 | 0.9744 | 1.0 | 0.987 | 0.9871 | 0.9719 | 39 |
| retinal bipolar neuron | 0.9976 | 0.9927 | 1.0 | 0.9951 | 0.9963 | 0.992 | 412 |
| retinal cone cell | 1.0000 | 0.9957 | 1.0 | 0.9979 | 0.9979 | 0.9953 | 933 |
| retinal ganglion cell | 0.5830 | 0.5854 | 0.9958 | 0.5842 | 0.7635 | 0.559 | 246 |
| retinal pigment epithelial cell | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 7 |
| retinal rod cell | 0.9998 | 0.9996 | 0.9999 | 0.9997 | 0.9998 | 0.9995 | 8388 |
| rod bipolar cell | 0.9987 | 0.9987 | 1.0 | 0.9987 | 0.9994 | 0.9986 | 796 |
| starburst amacrine cell | 0.9812 | 0.9874 | 0.9999 | 0.9843 | 0.9936 | 0.9861 | 159 |
| macro avg | 0.9745 | 0.9718 | 0.9995 | 0.9729 | 0.9847 | 0.9693 | |
| weighted avg | 0.9864 | 0.9863 | 0.9994 | 0.9863 | 0.9925 | 0.9847 | |
| Accuracy | 0.9863 | ||||||
| Balanced accuracy | 0.9718 |
[22]:
pd.set_option('display.max_rows', 100)
df_l2.compare(df_warm_start_l2, keep_equal=True, align_axis = 0, result_names=('default', 'warm start'))
[22]:
| precision | recall/sensitivity | specificity | f1-score | geometric mean | index balanced accuracy | ||
|---|---|---|---|---|---|---|---|
| GABAergic amacrine cell | default | 0.9934 | 0.9961 | 0.9991 | 0.9948 | 0.9976 | 0.995 |
| warm start | 0.9965 | 0.9954 | 0.9995 | 0.996 | 0.9975 | 0.9946 | |
| H1 horizontal cell | default | 0.9926 | 0.9944 | 0.9998 | 0.9935 | 0.9971 | 0.9937 |
| warm start | 0.9981 | 0.9944 | 1.0 | 0.9963 | 0.9972 | 0.9939 | |
| H2 horizontal cell | default | 0.9677 | 0.9574 | 0.9999 | 0.9626 | 0.9784 | 0.9533 |
| warm start | 0.9688 | 0.9894 | 0.9999 | 0.9789 | 0.9946 | 0.9882 | |
| Mueller cell | default | 0.9994 | 0.9994 | 1.0 | 0.9994 | 0.9997 | 0.9993 |
| warm start | 1.0000 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | |
| OFF midget ganglion cell | default | 0.9079 | 0.9562 | 0.9935 | 0.9314 | 0.9747 | 0.9465 |
| warm start | 0.9290 | 0.9461 | 0.9951 | 0.9375 | 0.9703 | 0.9369 | |
| OFF parasol ganglion cell | default | 0.9487 | 0.9367 | 0.9998 | 0.9427 | 0.9678 | 0.9306 |
| warm start | 0.9630 | 0.9873 | 0.9999 | 0.975 | 0.9936 | 0.986 | |
| OFFx cell | default | 0.9835 | 0.9754 | 0.9999 | 0.9794 | 0.9876 | 0.9729 |
| warm start | 0.9918 | 0.9918 | 1.0 | 0.9918 | 0.9959 | 0.991 | |
| ON midget ganglion cell | default | 0.9521 | 0.9162 | 0.9977 | 0.9338 | 0.9561 | 0.9066 |
| warm start | 0.9587 | 0.9346 | 0.998 | 0.9465 | 0.9658 | 0.9268 | |
| ON parasol ganglion cell | default | 0.9796 | 0.9796 | 1.0 | 0.9796 | 0.9897 | 0.9776 |
| warm start | 1.0000 | 0.9796 | 1.0 | 0.9897 | 0.9897 | 0.9776 | |
| S cone cell | default | 0.9844 | 0.9403 | 1.0 | 0.9618 | 0.9697 | 0.9347 |
| warm start | 0.9571 | 1.0 | 0.9999 | 0.9781 | 0.9999 | 0.9999 | |
| amacrine cell | default | 0.9863 | 0.9622 | 0.9998 | 0.9741 | 0.9808 | 0.9584 |
| warm start | 0.9822 | 0.98 | 0.9997 | 0.9811 | 0.9898 | 0.9778 | |
| astrocyte | default | 1.0000 | 0.982 | 1.0 | 0.9909 | 0.991 | 0.9802 |
| warm start | 0.9909 | 0.982 | 1.0 | 0.9864 | 0.9909 | 0.9802 | |
| diffuse bipolar 1 cell | default | 0.9975 | 0.9874 | 1.0 | 0.9924 | 0.9937 | 0.9861 |
| warm start | 0.9949 | 0.9924 | 0.9999 | 0.9937 | 0.9962 | 0.9916 | |
| diffuse bipolar 2 cell | default | 0.9926 | 0.9963 | 0.9998 | 0.9944 | 0.9981 | 0.9958 |
| warm start | 0.9925 | 0.9925 | 0.9998 | 0.9925 | 0.9962 | 0.9917 | |
| diffuse bipolar 3a cell | default | 1.0000 | 0.9826 | 1.0 | 0.9912 | 0.9912 | 0.9808 |
| warm start | 0.9942 | 0.9942 | 1.0 | 0.9942 | 0.9971 | 0.9936 | |
| diffuse bipolar 3b cell | default | 0.9964 | 0.9928 | 1.0 | 0.9946 | 0.9964 | 0.9921 |
| warm start | 0.9858 | 0.9964 | 0.9998 | 0.9911 | 0.9981 | 0.9959 | |
| diffuse bipolar 4 cell | default | 0.9923 | 1.0 | 0.9999 | 0.9961 | 0.9999 | 0.9999 |
| warm start | 0.9948 | 0.9948 | 0.9999 | 0.9948 | 0.9974 | 0.9942 | |
| diffuse bipolar 6 cell | default | 0.9931 | 0.9863 | 1.0 | 0.9897 | 0.9931 | 0.9849 |
| warm start | 0.9863 | 0.9863 | 0.9999 | 0.9863 | 0.9931 | 0.9849 | |
| flat midget bipolar cell | default | 0.9956 | 0.9974 | 0.9998 | 0.9965 | 0.9986 | 0.9969 |
| warm start | 0.9982 | 0.9956 | 0.9999 | 0.9969 | 0.9978 | 0.9951 | |
| giant bipolar cell | default | 0.9802 | 0.995 | 0.9998 | 0.9875 | 0.9974 | 0.9943 |
| warm start | 0.9851 | 1.0 | 0.9999 | 0.9925 | 0.9999 | 0.9999 | |
| glycinergic amacrine cell | default | 0.9790 | 0.9922 | 0.9991 | 0.9856 | 0.9957 | 0.9907 |
| warm start | 0.9836 | 0.9903 | 0.9993 | 0.987 | 0.9948 | 0.9887 | |
| invaginating midget bipolar cell | default | 0.9964 | 0.9964 | 0.9999 | 0.9964 | 0.9981 | 0.9959 |
| warm start | 0.9964 | 0.9988 | 0.9999 | 0.9976 | 0.9993 | 0.9986 | |
| microglial cell | default | 1.0000 | 0.9487 | 1.0 | 0.9737 | 0.974 | 0.9439 |
| warm start | 1.0000 | 0.9744 | 1.0 | 0.987 | 0.9871 | 0.9719 | |
| retinal bipolar neuron | default | 0.9951 | 0.9951 | 0.9999 | 0.9951 | 0.9975 | 0.9946 |
| warm start | 0.9976 | 0.9927 | 1.0 | 0.9951 | 0.9963 | 0.992 | |
| retinal cone cell | default | 0.9947 | 0.9989 | 0.9998 | 0.9968 | 0.9994 | 0.9986 |
| warm start | 1.0000 | 0.9957 | 1.0 | 0.9979 | 0.9979 | 0.9953 | |
| retinal ganglion cell | default | 0.5749 | 0.4837 | 0.9964 | 0.5254 | 0.6943 | 0.4573 |
| warm start | 0.5830 | 0.5854 | 0.9958 | 0.5842 | 0.7635 | 0.559 | |
| retinal rod cell | default | 0.9999 | 0.9995 | 0.9999 | 0.9997 | 0.9997 | 0.9994 |
| warm start | 0.9998 | 0.9996 | 0.9999 | 0.9997 | 0.9998 | 0.9995 | |
| rod bipolar cell | default | 0.9962 | 0.9987 | 0.9999 | 0.9975 | 0.9993 | 0.9985 |
| warm start | 0.9987 | 0.9987 | 1.0 | 0.9987 | 0.9994 | 0.9986 | |
| starburst amacrine cell | default | 1.0000 | 0.9811 | 1.0 | 0.9905 | 0.9905 | 0.9793 |
| warm start | 0.9812 | 0.9874 | 0.9999 | 0.9843 | 0.9936 | 0.9861 | |
| macro avg | default | 0.9735 | 0.9612 | 0.9995 | 0.967 | 0.9787 | 0.9579 |
| warm start | 0.9745 | 0.9718 | 0.9995 | 0.9729 | 0.9847 | 0.9693 | |
| weighted avg | default | 0.9839 | 0.9843 | 0.9992 | 0.984 | 0.9913 | 0.9824 |
| warm start | 0.9864 | 0.9863 | 0.9994 | 0.9863 | 0.9925 | 0.9847 | |
| Accuracy | default | 0.9843 | |||||
| warm start | 0.9863 | ||||||
| Balanced accuracy | default | 0.9612 | |||||
| warm start | 0.9718 |
Iterative warm start training led to an increase in all model quality metrics (rows - macro average, weighted average, accuracy, and balanced accuracy). Additionally, the model’s sensitivity increased by 10.17% and precision by 0.81% for the retinal ganglion cell.
[23]:
import session_info
session_info.show()
[23]:
Click to view session information
----- anndata 0.11.4 numpy 1.26.4 pandas 2.3.3 scanpy 1.11.4 scparadise 1.0.0 session_info v1.0.1 -----
Click to view modules imported as dependencies
PIL 11.3.0 anyio NA arrow 1.3.0 asttokens NA attr 25.3.0 attrs 25.3.0 babel 2.17.0 backports NA certifi 2025.08.03 cffi 1.17.1 charset_normalizer 3.4.3 cloudpickle 3.1.1 colorlog NA comm 0.2.3 cupy 13.6.0 cupy_backends NA cupyx NA cycler 0.12.1 cython_runtime NA dask 2025.10.0 dateutil 2.9.0.post0 debugpy 1.8.16 decorator 5.2.1 defusedxml 0.7.1 exceptiongroup 1.3.0 executing 2.2.1 fastjsonschema NA fastrlock 0.8.3 fqdn NA fsspec 2025.7.0 h5py 3.14.0 idna 3.10 igraph 0.11.9 imblearn 0.14.0 importlib_metadata NA ipykernel 6.30.1 ipywidgets 8.1.8 isoduration NA jaraco NA jedi 0.19.2 jinja2 3.1.6 joblib 1.5.2 json5 0.12.1 jsonpointer 3.0.0 jsonschema 4.25.1 jsonschema_specifications NA jupyter_events 0.12.0 jupyter_server 2.17.0 jupyterlab_server 2.28.0 kiwisolver 1.4.9 lark 1.2.2 lazy_loader 0.4 legacy_api_wrap NA leidenalg 0.10.2 llvmlite 0.44.0 louvain 0.8.2 markupsafe 3.0.2 matplotlib 3.10.6 matplotlib_inline 0.1.7 more_itertools 10.3.0 mpl_toolkits NA mpmath 1.3.0 mudata 0.3.2 muon 0.1.7 natsort 8.4.0 nbformat 5.10.4 numba 0.61.2 numexpr 2.14.1 optuna 4.5.0 overrides NA packaging 25.0 parso 0.8.5 patsy 1.0.1 pkg_resources NA platformdirs 4.4.0 plottable 0.1.5 prometheus_client NA prompt_toolkit 3.0.52 psutil 7.0.0 pure_eval 0.2.3 pyarrow 22.0.0 pycparser 2.22 pydev_ipython NA pydevconsole NA pydevd 3.2.3 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pydot 4.0.1 pygments 2.19.2 pynndescent 0.5.13 pyparsing 3.2.3 pythonjsonlogger NA pytorch_tabnet NA pytz 2025.2 referencing NA requests 2.32.5 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rfc3987_syntax NA rpds NA scipy 1.12.0 seaborn 0.13.2 send2trash NA shap 0.48.0 simplejson 3.20.1 six 1.17.0 skimage 0.25.2 sklearn 1.7.1 slicer NA sniffio 1.3.1 sparse 0.17.0 stack_data 0.6.3 statsmodels 0.14.5 sympy 1.14.0 texttable 1.7.0 threadpoolctl 3.6.0 tlz 1.0.0 toolz 1.0.0 torch 2.8.0+cu128 torchgen NA tornado 6.5.2 tqdm 4.64.1 traitlets 5.14.3 triton 3.4.0 typing_extensions NA umap 0.5.9.post2 uri_template NA urllib3 2.5.0 wcwidth 0.2.13 webcolors NA websocket 1.8.0 xarray 2025.6.1 yaml 6.0.2 zipp NA zmq 27.0.2 zoneinfo NA
----- IPython 8.37.0 jupyter_client 8.6.3 jupyter_core 5.8.1 jupyterlab 4.5.1 ----- Python 3.10.18 (main, Jun 5 2025, 13:14:17) [GCC 11.2.0] Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.35 ----- Session information updated at 2026-02-03 19:56