scEve model training#

[1]:
# Python packages
import warnings
warnings.simplefilter('ignore')

import scanpy as sc
import scparadise
import muon as mu
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

sc.set_figure_params(dpi = 120)

Recommendations about training dataset#

We recommend shifted logarithm data normalization method for gene expression data: sc.pp.normalize_total(adata, target_sum=None) sc.pp.log1p(adata)

But you can use any other method of gene expression data normalzation (Use the same normalization method for test dataset). Training dataset sould contain all genes that you want to use for model training in mdata_train.mod[‘rna’].X We recommend removing all non-marker genes from mdata_train.mod[‘rna’].X (removing such useless genes may improve performance and model quality metrics).

To demonstrate the scEve model training and evaluation, we used eight samples from the CITE-seq 3’ PBMC dataset. The genes were pre-sorted, and the data were normalized. Details of the dataset processing are available in the PBMC_3p_dataset script.

[2]:
!wget https://figshare.com/ndownloader/files/51749906
--2025-01-17 16:32:29--  https://figshare.com/ndownloader/files/51749906
Resolving figshare.com (figshare.com)... 52.48.154.230, 99.80.157.3, 2a05:d018:1f4:d003:cf11:7831:3c7b:f409, ...
Connecting to figshare.com (figshare.com)|52.48.154.230|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/51749906/mdata_pbmc.h5mu?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20250117/eu-west-1/s3/aws4_request&X-Amz-Date=20250117T133229Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=b83c326d243589ceacade0a1bbc8ac1a143941f9ffa5e3cc9e17b43b054f7600 [following]
--2025-01-17 16:32:29--  https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/51749906/mdata_pbmc.h5mu?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIYCQYOYV5JSSROOA/20250117/eu-west-1/s3/aws4_request&X-Amz-Date=20250117T133229Z&X-Amz-Expires=10&X-Amz-SignedHeaders=host&X-Amz-Signature=b83c326d243589ceacade0a1bbc8ac1a143941f9ffa5e3cc9e17b43b054f7600
Resolving s3-eu-west-1.amazonaws.com (s3-eu-west-1.amazonaws.com)... 3.5.70.0, 52.218.36.178, 52.218.93.51, ...
Connecting to s3-eu-west-1.amazonaws.com (s3-eu-west-1.amazonaws.com)|3.5.70.0|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 318572835 (304M) [application/octet-stream]
Saving to: ‘51749906.1’

51749906.1          100%[===================>] 303.81M  1.79MB/s    in 3m 12s

2025-01-17 16:35:41 (1.59 MB/s) - ‘51749906.1’ saved [318572835/318572835]

[2]:
# Load prepared for training anndata object
mdata = mu.read_h5mu('51749906')
[3]:
# Create mdata_train dataset (we used 4 samples for training model)
mdata_train = mdata[mdata.obs['orig.ident'].isin(['P2_0', 'P3_0', 'P4_0', 'P5_0', 'P6_0', 'P7_0', 'P8_0'])].copy()
# Create mdata_test dataset (we used 1 sample for test model quality)
mdata_test = mdata[mdata.obs['orig.ident'].isin(['P1_0'])].copy()
[4]:
# Extract RNA expression data from mdata_test for future predictions
adata_test_rna = mdata_test.mod['rna']

Default scEve model training#

[5]:
# Train default scEve model using mdata_train dataset
scparadise.sceve.train(mdata_train,
                       path = '',
                       rna_modality_name = 'rna',
                       second_modality_name = 'adt',
                       detailed_annotation = 'celltype_l3',
                       model_name = 'model_scEve',
                       max_epochs= 200,
                       eval_metric=['rmse'])
Successfully saved genes names for training model

Successfully saved proteins names for training model

Train dataset contains: 42029 cells, it is 90.0 % of input dataset
Test dataset contains: 4670 cells, it is 10.0 % of input dataset

Accelerator: cuda
Start training
epoch 0  | loss: 0.53146 | train_rmse: 0.4486500024795532| valid_rmse: 0.4459399878978729|  0:00:02s
epoch 1  | loss: 0.13884 | train_rmse: 0.4063900113105774| valid_rmse: 0.4043999910354614|  0:00:04s
epoch 2  | loss: 0.11796 | train_rmse: 0.39267000555992126| valid_rmse: 0.39061999320983887|  0:00:07s
epoch 3  | loss: 0.11099 | train_rmse: 0.3783400058746338| valid_rmse: 0.3764500021934509|  0:00:09s
epoch 4  | loss: 0.10588 | train_rmse: 0.368369996547699| valid_rmse: 0.3665800094604492|  0:00:12s
epoch 5  | loss: 0.10326 | train_rmse: 0.3664500117301941| valid_rmse: 0.3647499978542328|  0:00:14s
epoch 6  | loss: 0.10206 | train_rmse: 0.3655799925327301| valid_rmse: 0.3637000024318695|  0:00:17s
epoch 7  | loss: 0.10091 | train_rmse: 0.36375999450683594| valid_rmse: 0.3625600039958954|  0:00:19s
epoch 8  | loss: 0.10006 | train_rmse: 0.3627200126647949| valid_rmse: 0.3616600036621094|  0:00:21s
epoch 9  | loss: 0.09919 | train_rmse: 0.35940998792648315| valid_rmse: 0.358599990606308|  0:00:24s
epoch 10 | loss: 0.09771 | train_rmse: 0.35673001408576965| valid_rmse: 0.35558000206947327|  0:00:26s
epoch 11 | loss: 0.09655 | train_rmse: 0.353410005569458| valid_rmse: 0.3524399995803833|  0:00:28s
epoch 12 | loss: 0.09425 | train_rmse: 0.35016998648643494| valid_rmse: 0.35001999139785767|  0:00:31s
epoch 13 | loss: 0.09285 | train_rmse: 0.3470500111579895| valid_rmse: 0.3462499976158142|  0:00:33s
epoch 14 | loss: 0.09187 | train_rmse: 0.3456900119781494| valid_rmse: 0.34477001428604126|  0:00:36s
epoch 15 | loss: 0.09123 | train_rmse: 0.34325000643730164| valid_rmse: 0.34237998723983765|  0:00:38s
epoch 16 | loss: 0.09047 | train_rmse: 0.3420799970626831| valid_rmse: 0.3412800133228302|  0:00:41s
epoch 17 | loss: 0.08992 | train_rmse: 0.3414599895477295| valid_rmse: 0.34081000089645386|  0:00:43s
epoch 18 | loss: 0.08885 | train_rmse: 0.3413900136947632| valid_rmse: 0.34060999751091003|  0:00:45s
epoch 19 | loss: 0.08886 | train_rmse: 0.3397200107574463| valid_rmse: 0.3389599919319153|  0:00:48s
epoch 20 | loss: 0.08825 | train_rmse: 0.3388800024986267| valid_rmse: 0.3383600115776062|  0:00:50s
epoch 21 | loss: 0.08755 | train_rmse: 0.3383300006389618| valid_rmse: 0.337660014629364|  0:00:52s
epoch 22 | loss: 0.08742 | train_rmse: 0.3378700017929077| valid_rmse: 0.3372400104999542|  0:00:54s
epoch 23 | loss: 0.08701 | train_rmse: 0.3373599946498871| valid_rmse: 0.3366599977016449|  0:00:57s
epoch 24 | loss: 0.08693 | train_rmse: 0.3369100093841553| valid_rmse: 0.33601999282836914|  0:00:59s
epoch 25 | loss: 0.08682 | train_rmse: 0.3359600007534027| valid_rmse: 0.335099995136261|  0:01:02s
epoch 26 | loss: 0.08619 | train_rmse: 0.3365600109100342| valid_rmse: 0.33574000000953674|  0:01:04s
epoch 27 | loss: 0.08569 | train_rmse: 0.3354099988937378| valid_rmse: 0.334989994764328|  0:01:07s
epoch 28 | loss: 0.08553 | train_rmse: 0.33535999059677124| valid_rmse: 0.33465999364852905|  0:01:09s
epoch 29 | loss: 0.08541 | train_rmse: 0.334879994392395| valid_rmse: 0.3341499865055084|  0:01:12s
epoch 30 | loss: 0.08545 | train_rmse: 0.33417999744415283| valid_rmse: 0.3336699903011322|  0:01:14s
epoch 31 | loss: 0.08499 | train_rmse: 0.3342899978160858| valid_rmse: 0.33375999331474304|  0:01:16s
epoch 32 | loss: 0.08511 | train_rmse: 0.3347800076007843| valid_rmse: 0.3347800076007843|  0:01:18s
epoch 33 | loss: 0.08488 | train_rmse: 0.3332099914550781| valid_rmse: 0.332940012216568|  0:01:21s
epoch 34 | loss: 0.0847  | train_rmse: 0.3334299921989441| valid_rmse: 0.3336299955844879|  0:01:23s
epoch 35 | loss: 0.08487 | train_rmse: 0.3341600000858307| valid_rmse: 0.334089994430542|  0:01:25s
epoch 36 | loss: 0.08426 | train_rmse: 0.3330099880695343| valid_rmse: 0.33292001485824585|  0:01:28s
epoch 37 | loss: 0.08418 | train_rmse: 0.33250999450683594| valid_rmse: 0.33278998732566833|  0:01:30s
epoch 38 | loss: 0.08384 | train_rmse: 0.33191999793052673| valid_rmse: 0.3322499990463257|  0:01:33s
epoch 39 | loss: 0.08425 | train_rmse: 0.3325299918651581| valid_rmse: 0.3325200080871582|  0:01:35s
epoch 40 | loss: 0.08367 | train_rmse: 0.33149001002311707| valid_rmse: 0.33208000659942627|  0:01:38s
epoch 41 | loss: 0.08357 | train_rmse: 0.33083000779151917| valid_rmse: 0.3317300081253052|  0:01:40s
epoch 42 | loss: 0.08335 | train_rmse: 0.33048999309539795| valid_rmse: 0.3307799994945526|  0:01:43s
epoch 43 | loss: 0.08321 | train_rmse: 0.3299199938774109| valid_rmse: 0.331169992685318|  0:01:45s
epoch 44 | loss: 0.08272 | train_rmse: 0.3295699954032898| valid_rmse: 0.3304300010204315|  0:01:47s
epoch 45 | loss: 0.08299 | train_rmse: 0.32962000370025635| valid_rmse: 0.33125999569892883|  0:01:50s
epoch 46 | loss: 0.08257 | train_rmse: 0.32951998710632324| valid_rmse: 0.3300800025463104|  0:01:52s
epoch 47 | loss: 0.0825  | train_rmse: 0.3283900022506714| valid_rmse: 0.3296999931335449|  0:01:54s
epoch 48 | loss: 0.08234 | train_rmse: 0.32785001397132874| valid_rmse: 0.3297500014305115|  0:01:57s
epoch 49 | loss: 0.08195 | train_rmse: 0.3274799883365631| valid_rmse: 0.3289400041103363|  0:01:59s
epoch 50 | loss: 0.08172 | train_rmse: 0.32708001136779785| valid_rmse: 0.3288300037384033|  0:02:02s
epoch 51 | loss: 0.08173 | train_rmse: 0.32719001173973083| valid_rmse: 0.329259991645813|  0:02:04s
epoch 52 | loss: 0.08145 | train_rmse: 0.32655999064445496| valid_rmse: 0.3282800018787384|  0:02:06s
epoch 53 | loss: 0.0816  | train_rmse: 0.32506999373435974| valid_rmse: 0.32747000455856323|  0:02:09s
epoch 54 | loss: 0.08108 | train_rmse: 0.325080007314682| valid_rmse: 0.32767000794410706|  0:02:11s
epoch 55 | loss: 0.08132 | train_rmse: 0.3252499997615814| valid_rmse: 0.32822999358177185|  0:02:13s
epoch 56 | loss: 0.08052 | train_rmse: 0.32460999488830566| valid_rmse: 0.327129989862442|  0:02:15s
epoch 57 | loss: 0.08054 | train_rmse: 0.32440000772476196| valid_rmse: 0.3276199996471405|  0:02:18s
epoch 58 | loss: 0.08016 | train_rmse: 0.3232499957084656| valid_rmse: 0.32635000348091125|  0:02:20s
epoch 59 | loss: 0.08072 | train_rmse: 0.323419988155365| valid_rmse: 0.32701000571250916|  0:02:23s
epoch 60 | loss: 0.07995 | train_rmse: 0.3239400088787079| valid_rmse: 0.3266499936580658|  0:02:25s
epoch 61 | loss: 0.08034 | train_rmse: 0.32328999042510986| valid_rmse: 0.32677000761032104|  0:02:28s
epoch 62 | loss: 0.07997 | train_rmse: 0.32343000173568726| valid_rmse: 0.32728999853134155|  0:02:30s
epoch 63 | loss: 0.08    | train_rmse: 0.32295000553131104| valid_rmse: 0.3264400064945221|  0:02:33s
epoch 64 | loss: 0.07988 | train_rmse: 0.32293999195098877| valid_rmse: 0.3261300027370453|  0:02:35s
epoch 65 | loss: 0.08    | train_rmse: 0.32322999835014343| valid_rmse: 0.32594001293182373|  0:02:38s
epoch 66 | loss: 0.07987 | train_rmse: 0.32232001423835754| valid_rmse: 0.32659000158309937|  0:02:40s
epoch 67 | loss: 0.07948 | train_rmse: 0.322409987449646| valid_rmse: 0.32701998949050903|  0:02:42s
epoch 68 | loss: 0.07947 | train_rmse: 0.32297998666763306| valid_rmse: 0.3279300034046173|  0:02:45s
epoch 69 | loss: 0.07921 | train_rmse: 0.3224399983882904| valid_rmse: 0.32642999291419983|  0:02:47s
epoch 70 | loss: 0.07906 | train_rmse: 0.32190001010894775| valid_rmse: 0.3264099955558777|  0:02:49s
epoch 71 | loss: 0.07919 | train_rmse: 0.32137998938560486| valid_rmse: 0.3259899914264679|  0:02:52s
epoch 72 | loss: 0.07903 | train_rmse: 0.3209800124168396| valid_rmse: 0.32552000880241394|  0:02:54s
epoch 73 | loss: 0.07915 | train_rmse: 0.321289986371994| valid_rmse: 0.32607999444007874|  0:02:57s
epoch 74 | loss: 0.07894 | train_rmse: 0.322050005197525| valid_rmse: 0.32642999291419983|  0:02:59s
epoch 75 | loss: 0.07894 | train_rmse: 0.3204300105571747| valid_rmse: 0.3256799876689911|  0:03:02s
epoch 76 | loss: 0.07842 | train_rmse: 0.3207100033760071| valid_rmse: 0.32541999220848083|  0:03:04s
epoch 77 | loss: 0.07883 | train_rmse: 0.32030001282691956| valid_rmse: 0.32618001103401184|  0:03:06s
epoch 78 | loss: 0.07852 | train_rmse: 0.31981998682022095| valid_rmse: 0.32521000504493713|  0:03:09s
epoch 79 | loss: 0.07869 | train_rmse: 0.31960999965667725| valid_rmse: 0.3255699872970581|  0:03:11s
epoch 80 | loss: 0.07821 | train_rmse: 0.3202899992465973| valid_rmse: 0.325080007314682|  0:03:13s
epoch 81 | loss: 0.07813 | train_rmse: 0.31940001249313354| valid_rmse: 0.3244599997997284|  0:03:16s
epoch 82 | loss: 0.0786  | train_rmse: 0.31970998644828796| valid_rmse: 0.3248800039291382|  0:03:18s
epoch 83 | loss: 0.07832 | train_rmse: 0.3192799985408783| valid_rmse: 0.3256700038909912|  0:03:21s
epoch 84 | loss: 0.07825 | train_rmse: 0.3197999894618988| valid_rmse: 0.3268899917602539|  0:03:23s
epoch 85 | loss: 0.07811 | train_rmse: 0.31944000720977783| valid_rmse: 0.32541999220848083|  0:03:25s
epoch 86 | loss: 0.07785 | train_rmse: 0.31863999366760254| valid_rmse: 0.3244599997997284|  0:03:28s
epoch 87 | loss: 0.07799 | train_rmse: 0.3187499940395355| valid_rmse: 0.32502999901771545|  0:03:30s
epoch 88 | loss: 0.07759 | train_rmse: 0.3181999921798706| valid_rmse: 0.32499998807907104|  0:03:32s
epoch 89 | loss: 0.07779 | train_rmse: 0.3186799883842468| valid_rmse: 0.3246699869632721|  0:03:35s
epoch 90 | loss: 0.07763 | train_rmse: 0.3192799985408783| valid_rmse: 0.3260299861431122|  0:03:37s
epoch 91 | loss: 0.0779  | train_rmse: 0.3180699944496155| valid_rmse: 0.32455000281333923|  0:03:39s
epoch 92 | loss: 0.07743 | train_rmse: 0.31769999861717224| valid_rmse: 0.3244900107383728|  0:03:42s
epoch 93 | loss: 0.07761 | train_rmse: 0.31821000576019287| valid_rmse: 0.3249100148677826|  0:03:44s
epoch 94 | loss: 0.07779 | train_rmse: 0.3179599940776825| valid_rmse: 0.32592999935150146|  0:03:46s
epoch 95 | loss: 0.07747 | train_rmse: 0.31769001483917236| valid_rmse: 0.324070006608963|  0:03:49s
epoch 96 | loss: 0.07725 | train_rmse: 0.3177199959754944| valid_rmse: 0.3257000148296356|  0:03:51s
epoch 97 | loss: 0.07739 | train_rmse: 0.3174799978733063| valid_rmse: 0.324970006942749|  0:03:53s
epoch 98 | loss: 0.07707 | train_rmse: 0.3181999921798706| valid_rmse: 0.32499998807907104|  0:03:56s
epoch 99 | loss: 0.07693 | train_rmse: 0.3183000087738037| valid_rmse: 0.32572999596595764|  0:03:58s
epoch 100| loss: 0.07726 | train_rmse: 0.31727999448776245| valid_rmse: 0.32603999972343445|  0:04:01s
epoch 101| loss: 0.07723 | train_rmse: 0.31654998660087585| valid_rmse: 0.3251500129699707|  0:04:03s
epoch 102| loss: 0.0772  | train_rmse: 0.3163500130176544| valid_rmse: 0.32433000206947327|  0:04:05s
epoch 103| loss: 0.07685 | train_rmse: 0.3165000081062317| valid_rmse: 0.32429999113082886|  0:04:07s
epoch 104| loss: 0.07662 | train_rmse: 0.3163900077342987| valid_rmse: 0.32471999526023865|  0:04:10s
epoch 105| loss: 0.07685 | train_rmse: 0.316540002822876| valid_rmse: 0.3248800039291382|  0:04:12s

Early stopping occurred at epoch 105 with best_epoch = 95 and best_valid_rmse = 0.324070006608963
Successfully saved model at model_scEve/model.zip

Check model quality#

[6]:
# Predict surface proteins using scEve model
mdata_pred = scparadise.sceve.predict(adata_test_rna,
                                      path_model = 'model_scEve',
                                      rna_modality_name='rna', # default value
                                      second_modality_name='adt', # default value
                                      return_mdata = True # default value
                                      )
Successfully loaded list of genes used for training model

Successfully loaded list of features used for training model

Successfully loaded model

[7]:
# Add umap coordinates
mdata_pred.obsm['X_umap'] = adata_test_rna.obsm['X_umap']
[8]:
# List of proteins for visualization
proteins = ['adt_CLEC12A', 'adt_CD4-1', 'adt_CD8', 'adt_IgM', 'adt_CD133-2']

# Create list of predicted proteins
proteins_pred = []
for i in proteins:
    i = i + '_pred'
    proteins_pred.append(i)
[9]:
# Proteins visualization
proteins_dict = {
    "Observed": proteins,
    "Predicted": proteins_pred
}
# Choose a color map for protein abundance visualization
cmap = 'bwr'
# Make Axes
# Number of needed rows and columns (based on the row with the most columns)
nrow = len(proteins_dict)
ncol = max([len(vs) for vs in proteins_dict.values()])
fig, axs = plt.subplots(nrow, ncol, figsize=(3 * ncol, 3 * nrow))
# Plot expression for every marker on the corresponding Axes object
for row_idx, (prots, markers) in enumerate(proteins_dict.items()):
    col_idx = 0
    for marker in markers:
        ax = axs[row_idx, col_idx]
        if prots == 'Observed':
            mu.pl.embedding(mdata_test, basis = 'X_umap', color=marker, cmap = cmap, ax=ax, add_outline = True, show=False, frameon=False, s=10)
        else:
            mu.pl.embedding(mdata_pred, basis = 'X_umap', color=marker, cmap = cmap, ax=ax, add_outline = True, show=False, frameon=False, s=10)
        # Add cell type as row label - here we simply add it as ylabel of
        # the first Axes object in the row
        if col_idx == 0:
            # We disabled axis drawing in UMAP to have plots without background and border
            # so we need to re-enable axis to plot the ylabel
            ax.axis("on")
            ax.tick_params(
                top="off",
                bottom="off",
                left="off",
                right="off",
                labelleft="on",
                labelbottom="off",
            )
            ax.set_ylabel(prots + "\n", rotation=90, fontsize=14)
            ax.set(frame_on=False)
        col_idx += 1
    # Remove unused column Axes in the current row
    while col_idx < ncol:
        axs[row_idx, col_idx].remove()
        col_idx += 1
# Alignment within the Figure
fig.tight_layout()
../../../_images/tutorials_notebooks_scEve_scEve_train_14_0.png

The obtained model reliably predicts the main protein markers. However, there are proteins such as CD133, which the default model cannot predict.

[10]:
# Check prediction results
df = scparadise.scnoah.report_reg(adata_prot = mdata_test.mod['adt'],
                                  adata_pred_prot = mdata_pred.mod['adt'])
df
[10]:
EVS r2_score RMSE MedianAE MeanAE
score 0.267 0.227 0.333 0.202 0.251
EVS/r2_score higher value - better prediction
RMSE/MedianAE/MeanAE lower value - better prediction
[11]:
scparadise.scnoah.regres_status(adata_prot = mdata_test.mod['adt'],
                                adata_pred_prot = mdata_pred.mod['adt'],
                                metric='RMSE')
scparadise.scnoah.regres_status(adata_prot = mdata_test.mod['adt'],
                                adata_pred_prot = mdata_pred.mod['adt'],
                                metric='MeanAE')
scparadise.scnoah.regres_status(adata_prot = mdata_test.mod['adt'],
                                adata_pred_prot = mdata_pred.mod['adt'],
                                metric='MedianAE')
scparadise.scnoah.regres_status(adata_prot = mdata_test.mod['adt'],
                                adata_pred_prot = mdata_pred.mod['adt'],
                                metric='EVS')
scparadise.scnoah.regres_status(adata_prot = mdata_test.mod['adt'],
                                adata_pred_prot = mdata_pred.mod['adt'],
                                metric='r2_score')
[12]:
for i in ['regres_status_RMSE', 'regres_status_MeanAE', 'regres_status_MedianAE', 'regres_status_EVS', 'regres_status_r2_score']:
    mdata_pred.obs[i] = mdata_pred.mod['adt'].obs[i]
[13]:
# Visualize metrics
mu.pl.embedding(mdata_pred,
                color=[
                    'regres_status_RMSE',
                    'regres_status_MeanAE',
                    'regres_status_MedianAE',
                    'regres_status_EVS',
                    'regres_status_r2_score'
                ],
                basis = 'X_umap',
                frameon = False,
                add_outline = True,
                cmap = 'bwr',
                ncols = 5,
                wspace = 0.15,
                hspace = 0.1)
../../../_images/tutorials_notebooks_scEve_scEve_train_19_0.png

Based on a per cell type regression metrics score, it is possible to draw conclusions about the uniformity of protein prediction across different cell types in the dataset.

[14]:
df_corr = pd.DataFrame(columns=['Pearson coef', 'p-value'])
for i in mdata_test.mod['adt'].var_names.tolist():
    person_coef = scparadise.scnoah.pearson_coef_prot(adata_prot = mdata_test.mod['adt'],
                                                      adata_pred_prot = mdata_pred.mod['adt'],
                                                      protein = i,
                                                      protein_pred = i + '_pred',
                                                      print_res = False)
    df_corr.loc[i] = [person_coef['Pearson coefficient'], person_coef['p-value']]
df_corr['p-value'] = df_corr['p-value'].astype('float64')
[15]:
df_corr.head(5)
[15]:
Pearson coef p-value
adt_CD39 0.859 0.000000e+00
adt_CD107a 0.644 0.000000e+00
adt_CD62P 0.800 0.000000e+00
adt_TCR-2 0.502 0.000000e+00
adt_CD30 0.257 5.423000e-97
[16]:
df_high_corr = df_corr[df_corr['Pearson coef'] >= 0.5].copy()
print('Number of proteins with Pearson coefficient >= 0.5:', len(df_high_corr))
df_high_corr.sort_values(by = 'Pearson coef', ascending=False).head(5)
Number of proteins with Pearson coefficient >= 0.5: 94
[16]:
Pearson coef p-value
adt_CLEC12A 0.955 0.0
adt_CD64 0.933 0.0
adt_CD11b-2 0.929 0.0
adt_CD93 0.913 0.0
adt_CD8a 0.907 0.0
[17]:
# Save mdata with imputed proteins
mdata_pred.write_h5mu('mdata_pred_untuned.h5mu')

Hyperparameters tuning#

[18]:
scparadise.sceve.hyperparameter_tuning(mdata_train,
                                       path = '',
                                       rna_modality_name = 'rna',
                                       second_modality_name = 'adt',
                                       detailed_annotation = 'celltype_l3',
                                       random_state = 42,
                                       model_name = 'model_scEve_tuning',
                                       eval_metric = ['rmse'])
Successfully saved genes names for training model

Successfully saved proteins names for training model

Accelerator: cuda

[I 2025-01-17 16:52:02,880] A new study created in RDB with name: model_scEve_tuning
Fold 1:
Stop training because you reached max_epochs = 100 with best_epoch = 98 and best_valid_rmse = 0.3299599885940552

Fold 2:
Stop training because you reached max_epochs = 100 with best_epoch = 98 and best_valid_rmse = 0.33000999689102173

Fold 3:
Stop training because you reached max_epochs = 100 with best_epoch = 99 and best_valid_rmse = 0.3275200128555298

Fold 4:
Stop training because you reached max_epochs = 100 with best_epoch = 96 and best_valid_rmse = 0.3295400142669678
[I 2025-01-17 17:05:42,375] Trial 0 finished with value: 0.32925945520401 and parameters: {'n_d': 8, 'n_a': 8, 'n_steps': 3, 'n_shared': 2, 'cat_emb_dim': 1, 'n_independent': 1, 'gamma': 1.3, 'momentum': 0.02, 'lr': 0.01, 'mask_type': 'entmax', 'lambda_sparse': 0.001, 'patience': 10, 'max_epochs': 100, 'virtual_batch_size': 128, 'batch_size': 1024}. Best is trial 0 with value: 0.32925945520401.

Fold 1:
Stop training because you reached max_epochs = 45 with best_epoch = 31 and best_valid_rmse = 0.37463998794555664

Fold 2:
Stop training because you reached max_epochs = 45 with best_epoch = 44 and best_valid_rmse = 0.3817799985408783

Fold 3:

Early stopping occurred at epoch 25 with best_epoch = 5 and best_valid_rmse = 0.4211199879646301

Fold 4:

Early stopping occurred at epoch 35 with best_epoch = 15 and best_valid_rmse = 0.3957499861717224
[I 2025-01-17 17:11:42,480] Trial 1 finished with value: 0.3933216407895088 and parameters: {'n_d': 52, 'n_a': 124, 'n_steps': 8, 'n_shared': 6, 'cat_emb_dim': 2, 'n_independent': 2, 'gamma': 1.0580836121681996, 'momentum': 0.34780869685222476, 'lr': 0.3005973943704301, 'mask_type': 'entmax', 'lambda_sparse': 0.04147225000481637, 'patience': 20, 'max_epochs': 45, 'virtual_batch_size': 1024, 'batch_size': 3072}. Best is trial 0 with value: 0.32925945520401.

Fold 1:
Stop training because you reached max_epochs = 15 with best_epoch = 11 and best_valid_rmse = 0.3627200126647949

Fold 2:

Early stopping occurred at epoch 6 with best_epoch = 1 and best_valid_rmse = 0.3785000145435333

Fold 3:
Stop training because you reached max_epochs = 15 with best_epoch = 13 and best_valid_rmse = 0.3506999909877777

Fold 4:
Stop training because you reached max_epochs = 15 with best_epoch = 14 and best_valid_rmse = 0.3674600124359131
[I 2025-01-17 17:14:32,983] Trial 2 finished with value: 0.36484409123659134 and parameters: {'n_d': 44, 'n_a': 80, 'n_steps': 2, 'n_shared': 3, 'cat_emb_dim': 4, 'n_independent': 5, 'gamma': 1.7851759613930136, 'momentum': 0.08787277504176029, 'lr': 0.25716579576296444, 'mask_type': 'entmax', 'lambda_sparse': 0.004362599362560562, 'patience': 5, 'max_epochs': 15, 'virtual_batch_size': 256, 'batch_size': 512}. Best is trial 0 with value: 0.32925945520401.

Fold 1:

Early stopping occurred at epoch 23 with best_epoch = 3 and best_valid_rmse = 0.36553001403808594
[I 2025-01-17 17:17:58,554] Trial 3 pruned.
Fold 1:
Stop training because you reached max_epochs = 155 with best_epoch = 151 and best_valid_rmse = 0.36671000719070435
[I 2025-01-17 17:22:55,792] Trial 4 pruned.
Fold 1:
Stop training because you reached max_epochs = 95 with best_epoch = 88 and best_valid_rmse = 0.3305099904537201

Fold 2:
Stop training because you reached max_epochs = 95 with best_epoch = 94 and best_valid_rmse = 0.3339200019836426

Fold 3:
Stop training because you reached max_epochs = 95 with best_epoch = 93 and best_valid_rmse = 0.3348200023174286

Fold 4:
Stop training because you reached max_epochs = 95 with best_epoch = 87 and best_valid_rmse = 0.33608001470565796
[I 2025-01-17 17:34:03,969] Trial 5 finished with value: 0.3338331952691078 and parameters: {'n_d': 100, 'n_a': 16, 'n_steps': 4, 'n_shared': 2, 'cat_emb_dim': 9, 'n_independent': 7, 'gamma': 1.3308980248526492, 'momentum': 0.03478775661154922, 'lr': 0.15556006262565952, 'mask_type': 'sparsemax', 'lambda_sparse': 0.005257127272786869, 'patience': 20, 'max_epochs': 95, 'virtual_batch_size': 512, 'batch_size': 2560}. Best is trial 0 with value: 0.32925945520401.

Fold 1:
Stop training because you reached max_epochs = 50 with best_epoch = 49 and best_valid_rmse = 0.36195001006126404

Fold 2:
Stop training because you reached max_epochs = 50 with best_epoch = 48 and best_valid_rmse = 0.3514400124549866

Fold 3:
Stop training because you reached max_epochs = 50 with best_epoch = 49 and best_valid_rmse = 0.3546600043773651
[I 2025-01-17 17:37:13,539] Trial 6 pruned.
Fold 1:

Early stopping occurred at epoch 29 with best_epoch = 9 and best_valid_rmse = 0.40435999631881714
[I 2025-01-17 17:40:02,083] Trial 7 pruned.
Fold 1:

Early stopping occurred at epoch 29 with best_epoch = 24 and best_valid_rmse = 0.3323099911212921

Fold 2:

Early stopping occurred at epoch 46 with best_epoch = 41 and best_valid_rmse = 0.3248000144958496

Fold 3:

Early stopping occurred at epoch 35 with best_epoch = 30 and best_valid_rmse = 0.3292500078678131

Fold 4:

Early stopping occurred at epoch 29 with best_epoch = 24 and best_valid_rmse = 0.3367899954319
[I 2025-01-17 17:44:16,057] Trial 8 finished with value: 0.3307855799794197 and parameters: {'n_d': 48, 'n_a': 124, 'n_steps': 4, 'n_shared': 6, 'cat_emb_dim': 8, 'n_independent': 4, 'gamma': 1.9717820827209607, 'momentum': 0.3853544450274234, 'lr': 0.12596596968309953, 'mask_type': 'entmax', 'lambda_sparse': 0.000587186348818331, 'patience': 5, 'max_epochs': 125, 'virtual_batch_size': 1024, 'batch_size': 2048}. Best is trial 0 with value: 0.32925945520401.

Fold 1:

Early stopping occurred at epoch 32 with best_epoch = 27 and best_valid_rmse = 0.39447999000549316

Fold 2:

Early stopping occurred at epoch 11 with best_epoch = 6 and best_valid_rmse = 0.4113999903202057

Fold 3:

Early stopping occurred at epoch 19 with best_epoch = 14 and best_valid_rmse = 0.40380001068115234
[I 2025-01-17 17:46:39,620] Trial 9 pruned.
Fold 1:
Stop training because you reached max_epochs = 95 with best_epoch = 86 and best_valid_rmse = 0.3455600142478943

Fold 2:
Stop training because you reached max_epochs = 95 with best_epoch = 93 and best_valid_rmse = 0.33882999420166016

Fold 3:

Early stopping occurred at epoch 28 with best_epoch = 18 and best_valid_rmse = 0.35934001207351685
[I 2025-01-17 18:12:11,223] Trial 10 pruned.
Fold 1:

Early stopping occurred at epoch 47 with best_epoch = 37 and best_valid_rmse = 0.31428998708724976

Fold 2:

Early stopping occurred at epoch 46 with best_epoch = 36 and best_valid_rmse = 0.31384000182151794

Fold 3:

Early stopping occurred at epoch 52 with best_epoch = 42 and best_valid_rmse = 0.3099200129508972

Fold 4:

Early stopping occurred at epoch 51 with best_epoch = 41 and best_valid_rmse = 0.3115200102329254
[I 2025-01-17 18:19:44,308] Trial 11 finished with value: 0.31239335983991623 and parameters: {'n_d': 124, 'n_a': 100, 'n_steps': 1, 'n_shared': 7, 'cat_emb_dim': 7, 'n_independent': 4, 'gamma': 1.9553588069720769, 'momentum': 0.19279503536625844, 'lr': 0.012787030674309235, 'mask_type': 'entmax', 'lambda_sparse': 0.0004894266844926874, 'patience': 10, 'max_epochs': 130, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 11 with value: 0.31239335983991623.

Fold 1:

Early stopping occurred at epoch 37 with best_epoch = 27 and best_valid_rmse = 0.39107999205589294
[I 2025-01-17 18:21:05,889] Trial 12 pruned.
Fold 1:

Early stopping occurred at epoch 65 with best_epoch = 50 and best_valid_rmse = 0.3126699924468994

Fold 2:

Early stopping occurred at epoch 65 with best_epoch = 50 and best_valid_rmse = 0.3119400143623352

Fold 3:

Early stopping occurred at epoch 55 with best_epoch = 40 and best_valid_rmse = 0.3091199994087219

Fold 4:

Early stopping occurred at epoch 72 with best_epoch = 57 and best_valid_rmse = 0.31066998839378357
[I 2025-01-17 18:30:14,392] Trial 13 finished with value: 0.31110139936208725 and parameters: {'n_d': 124, 'n_a': 96, 'n_steps': 1, 'n_shared': 8, 'cat_emb_dim': 6, 'n_independent': 3, 'gamma': 1.5056061581384277, 'momentum': 0.17311165394985586, 'lr': 0.012061738656777975, 'mask_type': 'entmax', 'lambda_sparse': 0.00011387937858278137, 'patience': 15, 'max_epochs': 200, 'virtual_batch_size': 128, 'batch_size': 1024}. Best is trial 13 with value: 0.31110139936208725.

Fold 1:

Early stopping occurred at epoch 70 with best_epoch = 55 and best_valid_rmse = 0.31435999274253845

Fold 2:

Early stopping occurred at epoch 52 with best_epoch = 37 and best_valid_rmse = 0.3133600056171417

Fold 3:

Early stopping occurred at epoch 48 with best_epoch = 33 and best_valid_rmse = 0.30935001373291016

Fold 4:

Early stopping occurred at epoch 48 with best_epoch = 33 and best_valid_rmse = 0.3156299889087677
[I 2025-01-17 18:39:55,944] Trial 14 finished with value: 0.31317276507616043 and parameters: {'n_d': 128, 'n_a': 96, 'n_steps': 1, 'n_shared': 8, 'cat_emb_dim': 6, 'n_independent': 4, 'gamma': 1.9652313348965311, 'momentum': 0.1944860626629508, 'lr': 0.06752504258116275, 'mask_type': 'entmax', 'lambda_sparse': 0.00012530996305046304, 'patience': 15, 'max_epochs': 200, 'virtual_batch_size': 128, 'batch_size': 640}. Best is trial 13 with value: 0.31110139936208725.

Fold 1:

Early stopping occurred at epoch 39 with best_epoch = 24 and best_valid_rmse = 0.43498000502586365

Fold 2:

Early stopping occurred at epoch 37 with best_epoch = 22 and best_valid_rmse = 0.4336000084877014

Fold 3:

Early stopping occurred at epoch 36 with best_epoch = 21 and best_valid_rmse = 0.4572800099849701
[I 2025-01-17 18:50:45,344] Trial 15 pruned.
Fold 1:

Early stopping occurred at epoch 68 with best_epoch = 53 and best_valid_rmse = 0.3137100040912628

Fold 2:

Early stopping occurred at epoch 47 with best_epoch = 32 and best_valid_rmse = 0.31321999430656433

Fold 3:

Early stopping occurred at epoch 57 with best_epoch = 42 and best_valid_rmse = 0.3121199905872345
[I 2025-01-17 18:57:16,413] Trial 16 pruned.
Fold 1:
Stop training because you reached max_epochs = 60 with best_epoch = 52 and best_valid_rmse = 0.31352999806404114

Fold 2:
Stop training because you reached max_epochs = 60 with best_epoch = 58 and best_valid_rmse = 0.3145599961280823

Fold 3:

Early stopping occurred at epoch 44 with best_epoch = 34 and best_valid_rmse = 0.3134700059890747
[I 2025-01-17 19:06:37,371] Trial 17 pruned.
Fold 1:
Stop training because you reached max_epochs = 70 with best_epoch = 69 and best_valid_rmse = 0.31648001074790955

Fold 2:
Stop training because you reached max_epochs = 70 with best_epoch = 69 and best_valid_rmse = 0.3125300109386444

Fold 3:
Stop training because you reached max_epochs = 70 with best_epoch = 68 and best_valid_rmse = 0.3110800087451935

Fold 4:
Stop training because you reached max_epochs = 70 with best_epoch = 69 and best_valid_rmse = 0.31310001015663147
[I 2025-01-17 19:16:01,806] Trial 18 finished with value: 0.3132946342229843 and parameters: {'n_d': 116, 'n_a': 108, 'n_steps': 2, 'n_shared': 7, 'cat_emb_dim': 7, 'n_independent': 6, 'gamma': 1.6605885540358076, 'momentum': 0.08465761894426632, 'lr': 0.006659096439311225, 'mask_type': 'entmax', 'lambda_sparse': 0.000245762009210572, 'patience': 15, 'max_epochs': 70, 'virtual_batch_size': 256, 'batch_size': 1536}. Best is trial 13 with value: 0.31110139936208725.

Fold 1:

Early stopping occurred at epoch 28 with best_epoch = 18 and best_valid_rmse = 0.4043099880218506

Fold 2:

Early stopping occurred at epoch 32 with best_epoch = 22 and best_valid_rmse = 0.399619996547699

Fold 3:

Early stopping occurred at epoch 13 with best_epoch = 3 and best_valid_rmse = 0.4361000061035156
[I 2025-01-17 19:19:43,305] Trial 19 pruned.
Fold 1:

Early stopping occurred at epoch 25 with best_epoch = 10 and best_valid_rmse = 0.34808000922203064

Fold 2:

Early stopping occurred at epoch 26 with best_epoch = 11 and best_valid_rmse = 0.33204999566078186

Fold 3:

Early stopping occurred at epoch 17 with best_epoch = 2 and best_valid_rmse = 0.37351998686790466
[I 2025-01-17 19:24:12,753] Trial 20 pruned.
Fold 1:

Early stopping occurred at epoch 51 with best_epoch = 36 and best_valid_rmse = 0.3128400146961212

Fold 2:

Early stopping occurred at epoch 47 with best_epoch = 32 and best_valid_rmse = 0.3145500123500824

Fold 3:

Early stopping occurred at epoch 55 with best_epoch = 40 and best_valid_rmse = 0.3109799921512604
[I 2025-01-17 19:31:06,971] Trial 21 pruned.
Fold 1:

Early stopping occurred at epoch 40 with best_epoch = 25 and best_valid_rmse = 0.3174099922180176
[I 2025-01-17 19:32:54,103] Trial 22 pruned.
Fold 1:

Early stopping occurred at epoch 56 with best_epoch = 41 and best_valid_rmse = 0.3144499957561493
[I 2025-01-17 19:36:11,228] Trial 23 pruned.
Fold 1:

Early stopping occurred at epoch 43 with best_epoch = 33 and best_valid_rmse = 0.32545000314712524

Fold 2:

Early stopping occurred at epoch 28 with best_epoch = 18 and best_valid_rmse = 0.32385000586509705

Fold 3:

Early stopping occurred at epoch 42 with best_epoch = 32 and best_valid_rmse = 0.32506000995635986

Fold 4:

Early stopping occurred at epoch 33 with best_epoch = 23 and best_valid_rmse = 0.31951001286506653
[I 2025-01-17 19:44:44,776] Trial 24 finished with value: 0.32346611469984055 and parameters: {'n_d': 108, 'n_a': 76, 'n_steps': 1, 'n_shared': 9, 'cat_emb_dim': 5, 'n_independent': 2, 'gamma': 1.8976976418453342, 'momentum': 0.2476008369087243, 'lr': 0.11304921388017593, 'mask_type': 'entmax', 'lambda_sparse': 0.00020158675139925135, 'patience': 10, 'max_epochs': 115, 'virtual_batch_size': 128, 'batch_size': 384}. Best is trial 13 with value: 0.31110139936208725.

Fold 1:

Early stopping occurred at epoch 52 with best_epoch = 37 and best_valid_rmse = 0.31279000639915466

Fold 2:

Early stopping occurred at epoch 50 with best_epoch = 35 and best_valid_rmse = 0.31033000349998474

Fold 3:

Early stopping occurred at epoch 61 with best_epoch = 46 and best_valid_rmse = 0.30956000089645386

Fold 4:

Early stopping occurred at epoch 53 with best_epoch = 38 and best_valid_rmse = 0.3111400008201599
[I 2025-01-17 19:57:34,057] Trial 25 finished with value: 0.3109540268778801 and parameters: {'n_d': 120, 'n_a': 116, 'n_steps': 3, 'n_shared': 7, 'cat_emb_dim': 4, 'n_independent': 3, 'gamma': 1.9271652865979079, 'momentum': 0.09451106748439808, 'lr': 0.035189636718348266, 'mask_type': 'entmax', 'lambda_sparse': 0.00010399249986642899, 'patience': 15, 'max_epochs': 160, 'virtual_batch_size': 128, 'batch_size': 768}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 45 with best_epoch = 35 and best_valid_rmse = 0.3140299916267395

Fold 2:

Early stopping occurred at epoch 47 with best_epoch = 37 and best_valid_rmse = 0.3128800094127655

Fold 3:

Early stopping occurred at epoch 55 with best_epoch = 45 and best_valid_rmse = 0.3080100119113922

Fold 4:

Early stopping occurred at epoch 50 with best_epoch = 40 and best_valid_rmse = 0.3113499879837036
[I 2025-01-17 20:08:08,995] Trial 26 finished with value: 0.31156813353300095 and parameters: {'n_d': 104, 'n_a': 128, 'n_steps': 3, 'n_shared': 6, 'cat_emb_dim': 4, 'n_independent': 3, 'gamma': 1.409419715595113, 'momentum': 0.07434868324939609, 'lr': 0.033334087780441383, 'mask_type': 'entmax', 'lambda_sparse': 0.0008509087185274581, 'patience': 10, 'max_epochs': 155, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 46 with best_epoch = 31 and best_valid_rmse = 0.33208000659942627
[I 2025-01-17 20:10:16,000] Trial 27 pruned.
Fold 1:

Early stopping occurred at epoch 50 with best_epoch = 40 and best_valid_rmse = 0.31720998883247375
[I 2025-01-17 20:11:49,697] Trial 28 pruned.
Fold 1:

Early stopping occurred at epoch 22 with best_epoch = 17 and best_valid_rmse = 0.35521000623703003
[I 2025-01-17 20:13:05,813] Trial 29 pruned.
Fold 1:

Early stopping occurred at epoch 11 with best_epoch = 1 and best_valid_rmse = 0.8418700098991394

Fold 2:

Early stopping occurred at epoch 11 with best_epoch = 1 and best_valid_rmse = 0.8410199880599976

Fold 3:

Early stopping occurred at epoch 14 with best_epoch = 4 and best_valid_rmse = 0.893310010433197
[I 2025-01-17 20:15:05,781] Trial 30 pruned.
Fold 1:

Early stopping occurred at epoch 44 with best_epoch = 34 and best_valid_rmse = 0.31251999735832214

Fold 2:

Early stopping occurred at epoch 44 with best_epoch = 34 and best_valid_rmse = 0.31273001432418823

Fold 3:

Early stopping occurred at epoch 48 with best_epoch = 38 and best_valid_rmse = 0.30935999751091003

Fold 4:

Early stopping occurred at epoch 41 with best_epoch = 31 and best_valid_rmse = 0.3119100034236908
[I 2025-01-17 20:23:02,757] Trial 31 finished with value: 0.31163232028484344 and parameters: {'n_d': 116, 'n_a': 104, 'n_steps': 2, 'n_shared': 7, 'cat_emb_dim': 5, 'n_independent': 2, 'gamma': 1.5915815793023236, 'momentum': 0.15047253324296359, 'lr': 0.035342220128138174, 'mask_type': 'entmax', 'lambda_sparse': 0.0007783724519171544, 'patience': 10, 'max_epochs': 140, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 41 with best_epoch = 31 and best_valid_rmse = 0.31516000628471375
[I 2025-01-17 20:24:54,859] Trial 32 pruned.
Fold 1:

Early stopping occurred at epoch 42 with best_epoch = 27 and best_valid_rmse = 0.31665998697280884
[I 2025-01-17 20:27:14,675] Trial 33 pruned.
Fold 1:

Early stopping occurred at epoch 23 with best_epoch = 13 and best_valid_rmse = 0.355569988489151
[I 2025-01-17 20:28:18,666] Trial 34 pruned.
Fold 1:

Early stopping occurred at epoch 61 with best_epoch = 56 and best_valid_rmse = 0.3236300051212311

Fold 2:

Early stopping occurred at epoch 56 with best_epoch = 51 and best_valid_rmse = 0.3284299969673157

Fold 3:

Early stopping occurred at epoch 50 with best_epoch = 45 and best_valid_rmse = 0.32576999068260193

Fold 4:

Early stopping occurred at epoch 54 with best_epoch = 49 and best_valid_rmse = 0.3226499855518341
[I 2025-01-17 20:33:30,954] Trial 35 finished with value: 0.32511890679597855 and parameters: {'n_d': 104, 'n_a': 108, 'n_steps': 3, 'n_shared': 7, 'cat_emb_dim': 2, 'n_independent': 3, 'gamma': 1.2780378599728572, 'momentum': 0.08156260546491245, 'lr': 0.13521810025705489, 'mask_type': 'entmax', 'lambda_sparse': 0.00036461475325515636, 'patience': 5, 'max_epochs': 160, 'virtual_batch_size': 1024, 'batch_size': 3072}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:
Stop training because you reached max_epochs = 5 with best_epoch = 3 and best_valid_rmse = 0.38019999861717224
[I 2025-01-17 20:33:50,728] Trial 36 pruned.
Fold 1:

Early stopping occurred at epoch 69 with best_epoch = 59 and best_valid_rmse = 0.32343000173568726
[I 2025-01-17 20:36:20,227] Trial 37 pruned.
Fold 1:

Early stopping occurred at epoch 42 with best_epoch = 22 and best_valid_rmse = 0.38291001319885254
[I 2025-01-17 20:38:14,775] Trial 38 pruned.
Fold 1:

Early stopping occurred at epoch 48 with best_epoch = 33 and best_valid_rmse = 0.32795000076293945

Fold 2:

Early stopping occurred at epoch 43 with best_epoch = 28 and best_valid_rmse = 0.3304400146007538

Fold 3:

Early stopping occurred at epoch 28 with best_epoch = 13 and best_valid_rmse = 0.332720011472702
[I 2025-01-17 20:45:34,915] Trial 39 pruned.
Fold 1:
Stop training because you reached max_epochs = 30 with best_epoch = 29 and best_valid_rmse = 0.3562000095844269

Fold 2:
Stop training because you reached max_epochs = 30 with best_epoch = 29 and best_valid_rmse = 0.3526400029659271

Fold 3:
Stop training because you reached max_epochs = 30 with best_epoch = 29 and best_valid_rmse = 0.3479900062084198
[I 2025-01-17 20:48:49,130] Trial 40 pruned.
Fold 1:

Early stopping occurred at epoch 43 with best_epoch = 33 and best_valid_rmse = 0.3130500018596649

Fold 2:

Early stopping occurred at epoch 40 with best_epoch = 30 and best_valid_rmse = 0.3118399977684021

Fold 3:

Early stopping occurred at epoch 40 with best_epoch = 30 and best_valid_rmse = 0.30928999185562134

Fold 4:

Early stopping occurred at epoch 43 with best_epoch = 33 and best_valid_rmse = 0.311599999666214
[I 2025-01-17 20:55:15,040] Trial 41 finished with value: 0.3114461377263069 and parameters: {'n_d': 120, 'n_a': 100, 'n_steps': 1, 'n_shared': 7, 'cat_emb_dim': 7, 'n_independent': 4, 'gamma': 1.9240960238971727, 'momentum': 0.16946023978478975, 'lr': 0.020577101090279867, 'mask_type': 'entmax', 'lambda_sparse': 0.00048577257010938574, 'patience': 10, 'max_epochs': 135, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:
Stop training because you reached max_epochs = 140 with best_epoch = 138 and best_valid_rmse = 0.3411499857902527
[I 2025-01-17 21:01:39,757] Trial 42 pruned.
Fold 1:

Early stopping occurred at epoch 50 with best_epoch = 40 and best_valid_rmse = 0.31332001090049744

Fold 2:

Early stopping occurred at epoch 45 with best_epoch = 35 and best_valid_rmse = 0.31498000025749207

Fold 3:

Early stopping occurred at epoch 47 with best_epoch = 37 and best_valid_rmse = 0.3113200068473816
[I 2025-01-17 21:07:11,123] Trial 43 pruned.
Fold 1:

Early stopping occurred at epoch 64 with best_epoch = 54 and best_valid_rmse = 0.3145799934864044

Fold 2:

Early stopping occurred at epoch 52 with best_epoch = 42 and best_valid_rmse = 0.31398001313209534

Fold 3:

Early stopping occurred at epoch 57 with best_epoch = 47 and best_valid_rmse = 0.3110400140285492

Fold 4:

Early stopping occurred at epoch 52 with best_epoch = 42 and best_valid_rmse = 0.31259000301361084
[I 2025-01-17 21:18:17,807] Trial 44 finished with value: 0.31304652988910675 and parameters: {'n_d': 36, 'n_a': 112, 'n_steps': 2, 'n_shared': 7, 'cat_emb_dim': 4, 'n_independent': 3, 'gamma': 1.7679192099229608, 'momentum': 0.06963568174140487, 'lr': 0.024193295714677238, 'mask_type': 'entmax', 'lambda_sparse': 0.000430858318892026, 'patience': 10, 'max_epochs': 85, 'virtual_batch_size': 128, 'batch_size': 768}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 68 with best_epoch = 58 and best_valid_rmse = 0.3139899969100952

Fold 2:

Early stopping occurred at epoch 54 with best_epoch = 44 and best_valid_rmse = 0.31345999240875244

Fold 3:

Early stopping occurred at epoch 60 with best_epoch = 50 and best_valid_rmse = 0.31108999252319336

Fold 4:

Early stopping occurred at epoch 59 with best_epoch = 49 and best_valid_rmse = 0.31213998794555664
[I 2025-01-17 21:29:47,389] Trial 45 finished with value: 0.31266724318265915 and parameters: {'n_d': 124, 'n_a': 68, 'n_steps': 3, 'n_shared': 5, 'cat_emb_dim': 6, 'n_independent': 2, 'gamma': 1.5402837588519407, 'momentum': 0.03730794393685083, 'lr': 0.019737519282747036, 'mask_type': 'sparsemax', 'lambda_sparse': 0.0005779526562777366, 'patience': 10, 'max_epochs': 120, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 67 with best_epoch = 47 and best_valid_rmse = 0.3152799904346466
[I 2025-01-17 21:32:14,222] Trial 46 pruned.
Fold 1:

Early stopping occurred at epoch 71 with best_epoch = 66 and best_valid_rmse = 0.3212999999523163

Fold 2:

Early stopping occurred at epoch 74 with best_epoch = 69 and best_valid_rmse = 0.31999000906944275

Fold 3:

Early stopping occurred at epoch 85 with best_epoch = 80 and best_valid_rmse = 0.31657999753952026

Fold 4:

Early stopping occurred at epoch 84 with best_epoch = 79 and best_valid_rmse = 0.3186199963092804
[I 2025-01-17 21:39:23,031] Trial 47 finished with value: 0.31911955028772354 and parameters: {'n_d': 128, 'n_a': 24, 'n_steps': 4, 'n_shared': 6, 'cat_emb_dim': 5, 'n_independent': 4, 'gamma': 1.941036977536374, 'momentum': 0.230892479281657, 'lr': 0.08879287406163684, 'mask_type': 'entmax', 'lambda_sparse': 0.0016479132421608094, 'patience': 5, 'max_epochs': 170, 'virtual_batch_size': 1024, 'batch_size': 4096}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 101 with best_epoch = 86 and best_valid_rmse = 0.33847999572753906
[I 2025-01-17 21:42:04,143] Trial 48 pruned.
Fold 1:

Early stopping occurred at epoch 42 with best_epoch = 32 and best_valid_rmse = 0.3332099914550781

Fold 2:

Early stopping occurred at epoch 42 with best_epoch = 32 and best_valid_rmse = 0.3266899883747101

Fold 3:

Early stopping occurred at epoch 56 with best_epoch = 46 and best_valid_rmse = 0.3203299939632416

Fold 4:

Early stopping occurred at epoch 47 with best_epoch = 37 and best_valid_rmse = 0.3299599885940552
[I 2025-01-17 21:48:32,764] Trial 49 finished with value: 0.32754722982645035 and parameters: {'n_d': 120, 'n_a': 100, 'n_steps': 1, 'n_shared': 9, 'cat_emb_dim': 6, 'n_independent': 9, 'gamma': 1.309301195203155, 'momentum': 0.13294222151129065, 'lr': 0.13196293767589895, 'mask_type': 'sparsemax', 'lambda_sparse': 0.0003361826565068727, 'patience': 10, 'max_epochs': 135, 'virtual_batch_size': 256, 'batch_size': 1280}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 62 with best_epoch = 47 and best_valid_rmse = 0.3986699879169464
[I 2025-01-17 21:52:59,741] Trial 50 pruned.
Fold 1:

Early stopping occurred at epoch 47 with best_epoch = 37 and best_valid_rmse = 0.3134700059890747

Fold 2:

Early stopping occurred at epoch 35 with best_epoch = 25 and best_valid_rmse = 0.31376999616622925

Fold 3:

Early stopping occurred at epoch 53 with best_epoch = 43 and best_valid_rmse = 0.3082500100135803

Fold 4:

Early stopping occurred at epoch 45 with best_epoch = 35 and best_valid_rmse = 0.3119800090789795
[I 2025-01-17 21:59:53,471] Trial 51 finished with value: 0.3118675649166107 and parameters: {'n_d': 124, 'n_a': 100, 'n_steps': 1, 'n_shared': 7, 'cat_emb_dim': 7, 'n_independent': 4, 'gamma': 1.8525307743638406, 'momentum': 0.21821754110968022, 'lr': 0.014746162800902695, 'mask_type': 'entmax', 'lambda_sparse': 0.00046690348627289414, 'patience': 10, 'max_epochs': 125, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 53 with best_epoch = 43 and best_valid_rmse = 0.3140000104904175

Fold 2:

Early stopping occurred at epoch 55 with best_epoch = 45 and best_valid_rmse = 0.3132599890232086

Fold 3:

Early stopping occurred at epoch 39 with best_epoch = 29 and best_valid_rmse = 0.3112899959087372
[I 2025-01-17 22:05:21,966] Trial 52 pruned.
Fold 1:

Early stopping occurred at epoch 53 with best_epoch = 43 and best_valid_rmse = 0.31529000401496887
[I 2025-01-17 22:07:53,084] Trial 53 pruned.
Fold 1:

Early stopping occurred at epoch 46 with best_epoch = 36 and best_valid_rmse = 0.3131200075149536

Fold 2:

Early stopping occurred at epoch 40 with best_epoch = 30 and best_valid_rmse = 0.3121100068092346

Fold 3:

Early stopping occurred at epoch 52 with best_epoch = 42 and best_valid_rmse = 0.3100399971008301
[I 2025-01-17 22:13:16,823] Trial 54 pruned.
Fold 1:

Early stopping occurred at epoch 50 with best_epoch = 40 and best_valid_rmse = 0.3143500089645386
[I 2025-01-17 22:15:58,843] Trial 55 pruned.
Fold 1:

Early stopping occurred at epoch 58 with best_epoch = 43 and best_valid_rmse = 0.330269992351532
[I 2025-01-17 22:18:15,848] Trial 56 pruned.
Fold 1:
Stop training because you reached max_epochs = 145 with best_epoch = 141 and best_valid_rmse = 0.31850001215934753
[I 2025-01-17 22:27:05,014] Trial 57 pruned.
Fold 1:

Early stopping occurred at epoch 62 with best_epoch = 47 and best_valid_rmse = 0.31398001313209534

Fold 2:

Early stopping occurred at epoch 62 with best_epoch = 47 and best_valid_rmse = 0.3131900131702423

Fold 3:

Early stopping occurred at epoch 66 with best_epoch = 51 and best_valid_rmse = 0.311710000038147

Fold 4:

Early stopping occurred at epoch 47 with best_epoch = 32 and best_valid_rmse = 0.31244000792503357
[I 2025-01-17 22:35:11,627] Trial 58 finished with value: 0.31283043324947357 and parameters: {'n_d': 116, 'n_a': 40, 'n_steps': 1, 'n_shared': 7, 'cat_emb_dim': 9, 'n_independent': 4, 'gamma': 1.935139994638237, 'momentum': 0.10067692553062513, 'lr': 0.06557051712247991, 'mask_type': 'entmax', 'lambda_sparse': 0.0008801655911436464, 'patience': 15, 'max_epochs': 190, 'virtual_batch_size': 128, 'batch_size': 1152}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 25 with best_epoch = 20 and best_valid_rmse = 0.32172998785972595
[I 2025-01-17 22:36:47,692] Trial 59 pruned.
Fold 1:

Early stopping occurred at epoch 41 with best_epoch = 31 and best_valid_rmse = 0.3716300129890442

Fold 2:

Early stopping occurred at epoch 16 with best_epoch = 6 and best_valid_rmse = 0.37599000334739685

Fold 3:

Early stopping occurred at epoch 28 with best_epoch = 18 and best_valid_rmse = 0.3747499883174896
[I 2025-01-17 22:40:56,769] Trial 60 pruned.
Fold 1:

Early stopping occurred at epoch 36 with best_epoch = 26 and best_valid_rmse = 0.3140699863433838

Fold 2:

Early stopping occurred at epoch 50 with best_epoch = 40 and best_valid_rmse = 0.31248998641967773

Fold 3:

Early stopping occurred at epoch 37 with best_epoch = 27 and best_valid_rmse = 0.3093299865722656

Fold 4:

Early stopping occurred at epoch 39 with best_epoch = 29 and best_valid_rmse = 0.31025999784469604
[I 2025-01-17 22:47:13,185] Trial 61 finished with value: 0.3115405961871147 and parameters: {'n_d': 128, 'n_a': 100, 'n_steps': 1, 'n_shared': 7, 'cat_emb_dim': 7, 'n_independent': 3, 'gamma': 1.9636363028093444, 'momentum': 0.19045317989556418, 'lr': 0.048233376920610235, 'mask_type': 'entmax', 'lambda_sparse': 0.0005326671020374001, 'patience': 10, 'max_epochs': 135, 'virtual_batch_size': 128, 'batch_size': 768}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 32 with best_epoch = 22 and best_valid_rmse = 0.3143500089645386

Fold 2:

Early stopping occurred at epoch 37 with best_epoch = 27 and best_valid_rmse = 0.3116399943828583

Fold 3:

Early stopping occurred at epoch 41 with best_epoch = 31 and best_valid_rmse = 0.31025999784469604
[I 2025-01-17 22:51:56,778] Trial 62 pruned.
Fold 1:

Early stopping occurred at epoch 54 with best_epoch = 44 and best_valid_rmse = 0.31527000665664673
[I 2025-01-17 22:54:06,183] Trial 63 pruned.
Fold 1:

Early stopping occurred at epoch 56 with best_epoch = 46 and best_valid_rmse = 0.32611000537872314

Fold 2:

Early stopping occurred at epoch 67 with best_epoch = 57 and best_valid_rmse = 0.3273000121116638

Fold 3:

Early stopping occurred at epoch 40 with best_epoch = 30 and best_valid_rmse = 0.3512200117111206
[I 2025-01-17 23:16:45,852] Trial 64 pruned.
Fold 1:

Early stopping occurred at epoch 93 with best_epoch = 78 and best_valid_rmse = 0.31240999698638916

Fold 2:

Early stopping occurred at epoch 95 with best_epoch = 80 and best_valid_rmse = 0.3112100064754486

Fold 3:

Early stopping occurred at epoch 95 with best_epoch = 80 and best_valid_rmse = 0.3099299967288971

Fold 4:

Early stopping occurred at epoch 115 with best_epoch = 100 and best_valid_rmse = 0.31108999252319336
[I 2025-01-17 23:25:43,740] Trial 65 finished with value: 0.31116003543138504 and parameters: {'n_d': 112, 'n_a': 8, 'n_steps': 2, 'n_shared': 7, 'cat_emb_dim': 8, 'n_independent': 1, 'gamma': 1.9648114107373051, 'momentum': 0.07679069890651134, 'lr': 0.03336528165981404, 'mask_type': 'entmax', 'lambda_sparse': 0.00026801848779804216, 'patience': 15, 'max_epochs': 145, 'virtual_batch_size': 512, 'batch_size': 2560}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 78 with best_epoch = 63 and best_valid_rmse = 0.3145900070667267

Fold 2:

Early stopping occurred at epoch 90 with best_epoch = 75 and best_valid_rmse = 0.3114500045776367

Fold 3:

Early stopping occurred at epoch 83 with best_epoch = 68 and best_valid_rmse = 0.3105100095272064

Fold 4:

Early stopping occurred at epoch 89 with best_epoch = 74 and best_valid_rmse = 0.31084001064300537
[I 2025-01-17 23:34:24,346] Trial 66 finished with value: 0.31184808909893036 and parameters: {'n_d': 104, 'n_a': 32, 'n_steps': 3, 'n_shared': 9, 'cat_emb_dim': 8, 'n_independent': 1, 'gamma': 1.8034330812334225, 'momentum': 0.0776020301563344, 'lr': 0.048180525050358886, 'mask_type': 'entmax', 'lambda_sparse': 0.00011236069749271203, 'patience': 15, 'max_epochs': 145, 'virtual_batch_size': 512, 'batch_size': 2560}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 79 with best_epoch = 64 and best_valid_rmse = 0.3165999948978424

Fold 2:

Early stopping occurred at epoch 97 with best_epoch = 82 and best_valid_rmse = 0.31536000967025757

Fold 3:

Early stopping occurred at epoch 87 with best_epoch = 72 and best_valid_rmse = 0.31066998839378357

Fold 4:

Early stopping occurred at epoch 83 with best_epoch = 68 and best_valid_rmse = 0.31363001465797424
[I 2025-01-17 23:42:03,676] Trial 67 finished with value: 0.3140674829483032 and parameters: {'n_d': 112, 'n_a': 56, 'n_steps': 2, 'n_shared': 6, 'cat_emb_dim': 9, 'n_independent': 1, 'gamma': 1.9608789381222895, 'momentum': 0.042621156143553364, 'lr': 0.03293307513000275, 'mask_type': 'sparsemax', 'lambda_sparse': 0.00013727645308372836, 'patience': 15, 'max_epochs': 160, 'virtual_batch_size': 512, 'batch_size': 2560}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 76 with best_epoch = 61 and best_valid_rmse = 0.40838998556137085
[I 2025-01-17 23:44:18,296] Trial 68 pruned.
Fold 1:

Early stopping occurred at epoch 61 with best_epoch = 46 and best_valid_rmse = 0.3183000087738037

Fold 2:

Early stopping occurred at epoch 75 with best_epoch = 60 and best_valid_rmse = 0.31606000661849976

Fold 3:

Early stopping occurred at epoch 99 with best_epoch = 84 and best_valid_rmse = 0.3153899908065796
[I 2025-01-17 23:49:56,274] Trial 69 pruned.
Fold 1:

Early stopping occurred at epoch 84 with best_epoch = 69 and best_valid_rmse = 0.31731998920440674
[I 2025-01-17 23:51:32,903] Trial 70 pruned.
Fold 1:

Early stopping occurred at epoch 119 with best_epoch = 104 and best_valid_rmse = 0.31485000252723694

Fold 2:

Early stopping occurred at epoch 89 with best_epoch = 74 and best_valid_rmse = 0.3123300075531006

Fold 3:

Early stopping occurred at epoch 98 with best_epoch = 83 and best_valid_rmse = 0.31226998567581177
[I 2025-01-17 23:59:28,461] Trial 71 pruned.
Fold 1:

Early stopping occurred at epoch 119 with best_epoch = 99 and best_valid_rmse = 0.31283000111579895

Fold 2:

Early stopping occurred at epoch 110 with best_epoch = 90 and best_valid_rmse = 0.31185999512672424

Fold 3:

Early stopping occurred at epoch 116 with best_epoch = 96 and best_valid_rmse = 0.30906999111175537

Fold 4:

Early stopping occurred at epoch 125 with best_epoch = 105 and best_valid_rmse = 0.31220000982284546
[I 2025-01-18 00:11:48,368] Trial 72 finished with value: 0.3114901930093765 and parameters: {'n_d': 96, 'n_a': 28, 'n_steps': 3, 'n_shared': 10, 'cat_emb_dim': 8, 'n_independent': 1, 'gamma': 1.8796409103591705, 'momentum': 0.09222308586682235, 'lr': 0.03228376745929876, 'mask_type': 'entmax', 'lambda_sparse': 0.00011687103826460015, 'patience': 20, 'max_epochs': 140, 'virtual_batch_size': 512, 'batch_size': 2560}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 83 with best_epoch = 63 and best_valid_rmse = 0.3130599856376648

Fold 2:

Early stopping occurred at epoch 104 with best_epoch = 84 and best_valid_rmse = 0.3136099874973297

Fold 3:

Early stopping occurred at epoch 93 with best_epoch = 73 and best_valid_rmse = 0.3093700110912323
[I 2025-01-18 00:18:43,387] Trial 73 pruned.
Fold 1:

Early stopping occurred at epoch 115 with best_epoch = 95 and best_valid_rmse = 0.3373199999332428
[I 2025-01-18 00:21:57,639] Trial 74 pruned.
Fold 1:
Stop training because you reached max_epochs = 110 with best_epoch = 108 and best_valid_rmse = 0.31345000863075256

Fold 2:
Stop training because you reached max_epochs = 110 with best_epoch = 105 and best_valid_rmse = 0.31240999698638916

Fold 3:
Stop training because you reached max_epochs = 110 with best_epoch = 107 and best_valid_rmse = 0.30952998995780945

Fold 4:
Stop training because you reached max_epochs = 110 with best_epoch = 97 and best_valid_rmse = 0.31174999475479126
[I 2025-01-18 00:34:08,765] Trial 75 finished with value: 0.31178610771894455 and parameters: {'n_d': 116, 'n_a': 32, 'n_steps': 3, 'n_shared': 7, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.4411982090558846, 'momentum': 0.0717638209077907, 'lr': 0.011444198315038547, 'mask_type': 'entmax', 'lambda_sparse': 0.00016145142343868205, 'patience': 20, 'max_epochs': 110, 'virtual_batch_size': 512, 'batch_size': 2048}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 62 with best_epoch = 42 and best_valid_rmse = 0.3137899935245514

Fold 2:

Early stopping occurred at epoch 59 with best_epoch = 39 and best_valid_rmse = 0.31453999876976013

Fold 3:

Early stopping occurred at epoch 70 with best_epoch = 50 and best_valid_rmse = 0.31141000986099243
[I 2025-01-18 00:39:54,045] Trial 76 pruned.
Fold 1:

Early stopping occurred at epoch 125 with best_epoch = 115 and best_valid_rmse = 0.32166001200675964

Fold 2:

Early stopping occurred at epoch 136 with best_epoch = 126 and best_valid_rmse = 0.3203200101852417

Fold 3:

Early stopping occurred at epoch 131 with best_epoch = 121 and best_valid_rmse = 0.3172599971294403
[I 2025-01-18 01:07:08,330] Trial 77 pruned.
Fold 1:

Early stopping occurred at epoch 92 with best_epoch = 82 and best_valid_rmse = 0.3118399977684021

Fold 2:

Early stopping occurred at epoch 71 with best_epoch = 61 and best_valid_rmse = 0.31338000297546387

Fold 3:

Early stopping occurred at epoch 70 with best_epoch = 60 and best_valid_rmse = 0.31007999181747437
[I 2025-01-18 01:13:25,218] Trial 78 pruned.
Fold 1:

Early stopping occurred at epoch 39 with best_epoch = 19 and best_valid_rmse = 0.3304100036621094
[I 2025-01-18 01:15:58,725] Trial 79 pruned.
Fold 1:

Early stopping occurred at epoch 88 with best_epoch = 73 and best_valid_rmse = 0.3384000062942505

Fold 2:
Stop training because you reached max_epochs = 130 with best_epoch = 127 and best_valid_rmse = 0.33480000495910645

Fold 3:
Stop training because you reached max_epochs = 130 with best_epoch = 124 and best_valid_rmse = 0.3203200101852417
[I 2025-01-18 01:32:29,997] Trial 80 pruned.
Fold 1:
Stop training because you reached max_epochs = 110 with best_epoch = 95 and best_valid_rmse = 0.31310999393463135

Fold 2:

Early stopping occurred at epoch 100 with best_epoch = 80 and best_valid_rmse = 0.3137199878692627

Fold 3:

Early stopping occurred at epoch 100 with best_epoch = 80 and best_valid_rmse = 0.3093700110912323

Fold 4:
Stop training because you reached max_epochs = 110 with best_epoch = 101 and best_valid_rmse = 0.31049999594688416
[I 2025-01-18 01:44:11,398] Trial 81 finished with value: 0.31167732179164886 and parameters: {'n_d': 116, 'n_a': 40, 'n_steps': 3, 'n_shared': 7, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.3567480945652926, 'momentum': 0.0678070173061848, 'lr': 0.013494702185876478, 'mask_type': 'entmax', 'lambda_sparse': 0.00016488909321577347, 'patience': 20, 'max_epochs': 110, 'virtual_batch_size': 512, 'batch_size': 2048}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:
Stop training because you reached max_epochs = 110 with best_epoch = 94 and best_valid_rmse = 0.3126800060272217

Fold 2:
Stop training because you reached max_epochs = 110 with best_epoch = 91 and best_valid_rmse = 0.313289999961853

Fold 3:
Stop training because you reached max_epochs = 110 with best_epoch = 95 and best_valid_rmse = 0.3089999854564667

Fold 4:
Stop training because you reached max_epochs = 110 with best_epoch = 107 and best_valid_rmse = 0.3100599944591522
[I 2025-01-18 01:57:12,462] Trial 82 finished with value: 0.31126031279563904 and parameters: {'n_d': 128, 'n_a': 20, 'n_steps': 4, 'n_shared': 6, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.3550977524646246, 'momentum': 0.066936665517918, 'lr': 0.01435850123900418, 'mask_type': 'entmax', 'lambda_sparse': 0.00015977458694184366, 'patience': 20, 'max_epochs': 110, 'virtual_batch_size': 512, 'batch_size': 2048}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 74 with best_epoch = 54 and best_valid_rmse = 0.31373998522758484

Fold 2:

Early stopping occurred at epoch 74 with best_epoch = 54 and best_valid_rmse = 0.3122499883174896

Fold 3:

Early stopping occurred at epoch 100 with best_epoch = 80 and best_valid_rmse = 0.31220000982284546
[I 2025-01-18 02:05:43,643] Trial 83 pruned.
Fold 1:

Early stopping occurred at epoch 116 with best_epoch = 96 and best_valid_rmse = 0.3146499991416931
[I 2025-01-18 02:08:41,125] Trial 84 pruned.
Fold 1:

Early stopping occurred at epoch 78 with best_epoch = 58 and best_valid_rmse = 0.31422001123428345

Fold 2:

Early stopping occurred at epoch 90 with best_epoch = 70 and best_valid_rmse = 0.3112500011920929

Fold 3:

Early stopping occurred at epoch 97 with best_epoch = 77 and best_valid_rmse = 0.31029999256134033

Fold 4:

Early stopping occurred at epoch 79 with best_epoch = 59 and best_valid_rmse = 0.3112199902534485
[I 2025-01-18 02:21:11,169] Trial 85 finished with value: 0.3117472901940346 and parameters: {'n_d': 128, 'n_a': 16, 'n_steps': 5, 'n_shared': 6, 'cat_emb_dim': 8, 'n_independent': 2, 'gamma': 1.2761938652753675, 'momentum': 0.19504733573456667, 'lr': 0.035376434599307574, 'mask_type': 'entmax', 'lambda_sparse': 0.00018682731474363494, 'patience': 20, 'max_epochs': 105, 'virtual_batch_size': 512, 'batch_size': 1536}. Best is trial 25 with value: 0.3109540268778801.

Fold 1:

Early stopping occurred at epoch 69 with best_epoch = 49 and best_valid_rmse = 0.31154999136924744

Fold 2:

Early stopping occurred at epoch 64 with best_epoch = 44 and best_valid_rmse = 0.3122200071811676

Fold 3:

Early stopping occurred at epoch 66 with best_epoch = 46 and best_valid_rmse = 0.30948999524116516
[I 2025-01-18 02:31:48,521] Trial 86 pruned.
Fold 1:

Early stopping occurred at epoch 92 with best_epoch = 82 and best_valid_rmse = 0.3155600130558014

Fold 2:

Early stopping occurred at epoch 122 with best_epoch = 112 and best_valid_rmse = 0.31262001395225525

Fold 3:

Early stopping occurred at epoch 132 with best_epoch = 122 and best_valid_rmse = 0.3112500011920929
[I 2025-01-18 02:40:09,390] Trial 87 pruned.
Fold 1:

Early stopping occurred at epoch 58 with best_epoch = 43 and best_valid_rmse = 0.31720998883247375

Fold 2:

Early stopping occurred at epoch 46 with best_epoch = 31 and best_valid_rmse = 0.31317999958992004

Fold 3:

Early stopping occurred at epoch 41 with best_epoch = 26 and best_valid_rmse = 0.31112000346183777
[I 2025-01-18 02:53:31,562] Trial 88 pruned.
Fold 1:

Early stopping occurred at epoch 19 with best_epoch = 9 and best_valid_rmse = 0.40171000361442566
[I 2025-01-18 02:54:20,846] Trial 89 pruned.
Fold 1:

Early stopping occurred at epoch 52 with best_epoch = 37 and best_valid_rmse = 0.3162800073623657
[I 2025-01-18 02:57:23,520] Trial 90 pruned.
Fold 1:

Early stopping occurred at epoch 103 with best_epoch = 83 and best_valid_rmse = 0.3136399984359741

Fold 2:
Stop training because you reached max_epochs = 130 with best_epoch = 120 and best_valid_rmse = 0.313289999961853

Fold 3:

Early stopping occurred at epoch 118 with best_epoch = 98 and best_valid_rmse = 0.30979999899864197
[I 2025-01-18 03:07:02,247] Trial 91 pruned.
Fold 1:
Stop training because you reached max_epochs = 110 with best_epoch = 104 and best_valid_rmse = 0.32089000940322876

Fold 2:
Stop training because you reached max_epochs = 110 with best_epoch = 105 and best_valid_rmse = 0.31782999634742737

Fold 3:

Early stopping occurred at epoch 103 with best_epoch = 83 and best_valid_rmse = 0.31937000155448914
[I 2025-01-18 03:17:38,157] Trial 92 pruned.
Fold 1:
Stop training because you reached max_epochs = 80 with best_epoch = 75 and best_valid_rmse = 0.3175800144672394
[I 2025-01-18 03:19:51,238] Trial 93 pruned.
Fold 1:
Stop training because you reached max_epochs = 60 with best_epoch = 56 and best_valid_rmse = 0.31459999084472656
[I 2025-01-18 03:21:22,757] Trial 94 pruned.
Fold 1:

Early stopping occurred at epoch 62 with best_epoch = 42 and best_valid_rmse = 0.34485000371932983
[I 2025-01-18 03:23:13,679] Trial 95 pruned.
Fold 1:

Early stopping occurred at epoch 58 with best_epoch = 48 and best_valid_rmse = 0.3108699917793274

Fold 2:

Early stopping occurred at epoch 53 with best_epoch = 43 and best_valid_rmse = 0.31205999851226807

Fold 3:

Early stopping occurred at epoch 42 with best_epoch = 32 and best_valid_rmse = 0.30877000093460083

Fold 4:

Early stopping occurred at epoch 55 with best_epoch = 45 and best_valid_rmse = 0.3098300099372864
[I 2025-01-18 03:32:07,879] Trial 96 finished with value: 0.3103809058666229 and parameters: {'n_d': 112, 'n_a': 48, 'n_steps': 2, 'n_shared': 7, 'cat_emb_dim': 8, 'n_independent': 1, 'gamma': 1.4185666171617515, 'momentum': 0.11032455256395587, 'lr': 0.026111585839182588, 'mask_type': 'entmax', 'lambda_sparse': 0.00015379040651770232, 'patience': 10, 'max_epochs': 100, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 96 with value: 0.3103809058666229.

Fold 1:
Stop training because you reached max_epochs = 100 with best_epoch = 98 and best_valid_rmse = 0.3487800061702728

Fold 2:

Early stopping occurred at epoch 53 with best_epoch = 43 and best_valid_rmse = 0.3835799992084503

Fold 3:
Stop training because you reached max_epochs = 100 with best_epoch = 98 and best_valid_rmse = 0.34738001227378845
[I 2025-01-18 03:40:22,976] Trial 97 pruned.
Fold 1:

Early stopping occurred at epoch 43 with best_epoch = 33 and best_valid_rmse = 0.31358999013900757

Fold 2:

Early stopping occurred at epoch 43 with best_epoch = 33 and best_valid_rmse = 0.312389999628067

Fold 3:

Early stopping occurred at epoch 43 with best_epoch = 33 and best_valid_rmse = 0.30889999866485596

Fold 4:

Early stopping occurred at epoch 50 with best_epoch = 40 and best_valid_rmse = 0.3086499869823456
[I 2025-01-18 03:47:58,961] Trial 98 finished with value: 0.31088288873434067 and parameters: {'n_d': 108, 'n_a': 124, 'n_steps': 2, 'n_shared': 7, 'cat_emb_dim': 6, 'n_independent': 1, 'gamma': 1.4547467377738115, 'momentum': 0.10349272108371528, 'lr': 0.035609134197160254, 'mask_type': 'entmax', 'lambda_sparse': 0.00039352125290443007, 'patience': 10, 'max_epochs': 90, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 96 with value: 0.3103809058666229.

Fold 1:

Early stopping occurred at epoch 41 with best_epoch = 31 and best_valid_rmse = 0.3131900131702423

Fold 2:

Early stopping occurred at epoch 41 with best_epoch = 31 and best_valid_rmse = 0.31205999851226807

Fold 3:

Early stopping occurred at epoch 46 with best_epoch = 36 and best_valid_rmse = 0.30917999148368835

Fold 4:

Early stopping occurred at epoch 42 with best_epoch = 32 and best_valid_rmse = 0.31165000796318054
[I 2025-01-18 03:54:19,944] Trial 99 finished with value: 0.3115203082561493 and parameters: {'n_d': 108, 'n_a': 124, 'n_steps': 1, 'n_shared': 9, 'cat_emb_dim': 6, 'n_independent': 1, 'gamma': 1.4442864287877855, 'momentum': 0.10287916032269656, 'lr': 0.030974971700964398, 'mask_type': 'entmax', 'lambda_sparse': 0.0003862573441760338, 'patience': 10, 'max_epochs': 190, 'virtual_batch_size': 128, 'batch_size': 896}. Best is trial 96 with value: 0.3103809058666229.


Successfully saved best hyperparameters

Best hyperparameters: {'n_d': 112, 'n_a': 48, 'n_steps': 2, 'n_shared': 7, 'cat_emb_dim': 8, 'n_independent': 1, 'gamma': 1.4185666171617515, 'momentum': 0.11032455256395587, 'mask_type': 'entmax', 'lambda_sparse': 0.00015379040651770232, 'patience': 10, 'max_epochs': 100, 'virtual_batch_size': 128, 'batch_size': 896, 'optimizer_params': {'lr': 0.026111585839182588}}
[18]:
{'n_d': 112,
 'n_a': 48,
 'n_steps': 2,
 'n_shared': 7,
 'cat_emb_dim': 8,
 'n_independent': 1,
 'gamma': 1.4185666171617515,
 'momentum': 0.11032455256395587,
 'mask_type': 'entmax',
 'lambda_sparse': 0.00015379040651770232,
 'patience': 10,
 'max_epochs': 100,
 'virtual_batch_size': 128,
 'batch_size': 896,
 'optimizer_params': {'lr': 0.026111585839182588}}
[19]:
# Train sceve model using mdata_train dataset
scparadise.sceve.train_tuned(mdata_train,
                             path = '',
                             path_tuned='model_scEve_tuning',
                             model_name = 'model_scEve_hp_tuned',
                             rna_modality_name = 'rna',
                             second_modality_name = 'adt',
                             detailed_annotation = 'celltype_l3',
                             eval_metric=['rmse'])
Successfully saved genes names for training model

Successfully saved proteins names for training model

Train dataset contains: 42029 cells, it is 90.0 % of input dataset
Test dataset contains: 4670 cells, it is 10.0 % of input dataset

Accelerator: cuda
Start training with following hyperparameters: {'n_d': 112, 'n_a': 48, 'n_steps': 2, 'n_shared': 7, 'cat_emb_dim': 8, 'n_independent': 1, 'gamma': 1.4185666171617515, 'momentum': 0.11032455256395587, 'mask_type': 'entmax', 'lambda_sparse': 0.00015379040651770232, 'patience': 10, 'max_epochs': 100, 'virtual_batch_size': 128, 'batch_size': 896, 'optimizer_params': {'lr': 0.026111585839182588}, 'device_name': 'cuda'}

epoch 0  | loss: 0.57711 | train_rmse: 0.4005599915981293| valid_rmse: 0.39866000413894653|  0:00:03s
epoch 1  | loss: 0.10358 | train_rmse: 0.38192999362945557| valid_rmse: 0.37996000051498413|  0:00:05s
epoch 2  | loss: 0.0989  | train_rmse: 0.36601001024246216| valid_rmse: 0.3639099895954132|  0:00:08s
epoch 3  | loss: 0.0948  | train_rmse: 0.35738998651504517| valid_rmse: 0.3553900122642517|  0:00:11s
epoch 4  | loss: 0.08984 | train_rmse: 0.34676000475883484| valid_rmse: 0.34492000937461853|  0:00:14s
epoch 5  | loss: 0.0879  | train_rmse: 0.33757999539375305| valid_rmse: 0.33594000339508057|  0:00:17s
epoch 6  | loss: 0.08556 | train_rmse: 0.3377099931240082| valid_rmse: 0.3356899917125702|  0:00:20s
epoch 7  | loss: 0.08369 | train_rmse: 0.33149999380111694| valid_rmse: 0.3300899863243103|  0:00:23s
epoch 8  | loss: 0.08271 | train_rmse: 0.3266499936580658| valid_rmse: 0.3255400061607361|  0:00:26s
epoch 9  | loss: 0.08136 | train_rmse: 0.32592999935150146| valid_rmse: 0.32488998770713806|  0:00:29s
epoch 10 | loss: 0.08008 | train_rmse: 0.32374000549316406| valid_rmse: 0.32346999645233154|  0:00:31s
epoch 11 | loss: 0.07954 | train_rmse: 0.31953999400138855| valid_rmse: 0.3199799954891205|  0:00:34s
epoch 12 | loss: 0.07823 | train_rmse: 0.3188000023365021| valid_rmse: 0.319130003452301|  0:00:37s
epoch 13 | loss: 0.07775 | train_rmse: 0.3175100088119507| valid_rmse: 0.31812000274658203|  0:00:40s
epoch 14 | loss: 0.07734 | train_rmse: 0.31874001026153564| valid_rmse: 0.3187899887561798|  0:00:43s
epoch 15 | loss: 0.07654 | train_rmse: 0.31450000405311584| valid_rmse: 0.3154599964618683|  0:00:46s
epoch 16 | loss: 0.07589 | train_rmse: 0.3181000053882599| valid_rmse: 0.31672000885009766|  0:00:49s
epoch 17 | loss: 0.07578 | train_rmse: 0.313400000333786| valid_rmse: 0.3155499994754791|  0:00:52s
epoch 18 | loss: 0.07528 | train_rmse: 0.3125300109386444| valid_rmse: 0.31463998556137085|  0:00:55s
epoch 19 | loss: 0.0745  | train_rmse: 0.3112500011920929| valid_rmse: 0.3137100040912628|  0:00:58s
epoch 20 | loss: 0.07434 | train_rmse: 0.3099699914455414| valid_rmse: 0.3125799894332886|  0:01:00s
epoch 21 | loss: 0.07379 | train_rmse: 0.30897998809814453| valid_rmse: 0.3125399947166443|  0:01:03s
epoch 22 | loss: 0.07351 | train_rmse: 0.30730998516082764| valid_rmse: 0.31141000986099243|  0:01:06s
epoch 23 | loss: 0.07305 | train_rmse: 0.3070800006389618| valid_rmse: 0.3109799921512604|  0:01:09s
epoch 24 | loss: 0.07325 | train_rmse: 0.3064500093460083| valid_rmse: 0.3099600076675415|  0:01:12s
epoch 25 | loss: 0.07257 | train_rmse: 0.305400013923645| valid_rmse: 0.31057000160217285|  0:01:15s
epoch 26 | loss: 0.07221 | train_rmse: 0.30487000942230225| valid_rmse: 0.30976998805999756|  0:01:18s
epoch 27 | loss: 0.07184 | train_rmse: 0.3073999881744385| valid_rmse: 0.3125799894332886|  0:01:21s
epoch 28 | loss: 0.07193 | train_rmse: 0.3050200045108795| valid_rmse: 0.31119000911712646|  0:01:24s
epoch 29 | loss: 0.07127 | train_rmse: 0.30250999331474304| valid_rmse: 0.3091999888420105|  0:01:27s
epoch 30 | loss: 0.07114 | train_rmse: 0.3022899925708771| valid_rmse: 0.3087500035762787|  0:01:30s
epoch 31 | loss: 0.07045 | train_rmse: 0.30173999071121216| valid_rmse: 0.3088400065898895|  0:01:33s
epoch 32 | loss: 0.06997 | train_rmse: 0.30230000615119934| valid_rmse: 0.3093000054359436|  0:01:36s
epoch 33 | loss: 0.07006 | train_rmse: 0.29941999912261963| valid_rmse: 0.30810999870300293|  0:01:39s
epoch 34 | loss: 0.06974 | train_rmse: 0.29930999875068665| valid_rmse: 0.30761000514030457|  0:01:42s
epoch 35 | loss: 0.06924 | train_rmse: 0.29993999004364014| valid_rmse: 0.3101100027561188|  0:01:45s
epoch 36 | loss: 0.06935 | train_rmse: 0.2985000014305115| valid_rmse: 0.3088800013065338|  0:01:48s
epoch 37 | loss: 0.06892 | train_rmse: 0.298909991979599| valid_rmse: 0.3091300129890442|  0:01:51s
epoch 38 | loss: 0.06896 | train_rmse: 0.29715999960899353| valid_rmse: 0.3095499873161316|  0:01:53s
epoch 39 | loss: 0.06839 | train_rmse: 0.29642000794410706| valid_rmse: 0.3102099895477295|  0:01:56s
epoch 40 | loss: 0.06779 | train_rmse: 0.29517000913619995| valid_rmse: 0.3082199990749359|  0:01:59s
epoch 41 | loss: 0.06792 | train_rmse: 0.2943899929523468| valid_rmse: 0.30856001377105713|  0:02:02s
epoch 42 | loss: 0.06761 | train_rmse: 0.29462000727653503| valid_rmse: 0.309469997882843|  0:02:05s
epoch 43 | loss: 0.06748 | train_rmse: 0.29478999972343445| valid_rmse: 0.3099299967288971|  0:02:08s
epoch 44 | loss: 0.06729 | train_rmse: 0.2925499975681305| valid_rmse: 0.3080500066280365|  0:02:11s

Early stopping occurred at epoch 44 with best_epoch = 34 and best_valid_rmse = 0.30761000514030457
Successfully saved model at model_scEve_hp_tuned/model.zip

Check tuned model quality#

[20]:
# Predict surface proteins using tuned scEve model
mdata_pred = scparadise.sceve.predict(adata_test_rna,
                                      path_model = 'model_scEve_hp_tuned',
                                      rna_modality_name='rna', # default value
                                      second_modality_name='adt', # default value
                                      return_mdata = True # default value
                                      )
Successfully loaded list of genes used for training model

Successfully loaded list of features used for training model

Successfully loaded model

[21]:
# Add umap coordinates
mdata_pred.obsm['X_umap'] = adata_test_rna.obsm['X_umap']
[22]:
# List of proteins for visualization
proteins = ['adt_CLEC12A', 'adt_CD4-1', 'adt_CD8', 'adt_IgM', 'adt_CD133-2']

# Create list of predicted proteins
proteins_pred = []
for i in proteins:
    i = i + '_pred'
    proteins_pred.append(i)
[23]:
# Proteins visualization
proteins_dict = {
    "Observed": proteins,
    "Predicted": proteins_pred
}
# Choose a color map for protein abundance visualization
cmap = 'bwr'
# Make Axes
# Number of needed rows and columns (based on the row with the most columns)
nrow = len(proteins_dict)
ncol = max([len(vs) for vs in proteins_dict.values()])
fig, axs = plt.subplots(nrow, ncol, figsize=(3 * ncol, 3 * nrow))
# Plot expression for every marker on the corresponding Axes object
for row_idx, (prots, markers) in enumerate(proteins_dict.items()):
    col_idx = 0
    for marker in markers:
        ax = axs[row_idx, col_idx]
        if prots == 'Observed':
            mu.pl.embedding(mdata_test, basis = 'X_umap', color=marker, cmap = cmap, ax=ax, add_outline = True, show=False, frameon=False, s=10)
        else:
            mu.pl.embedding(mdata_pred, basis = 'X_umap', color=marker, cmap = cmap, ax=ax, add_outline = True, show=False, frameon=False, s=10)
        # Add cell type as row label - here we simply add it as ylabel of
        # the first Axes object in the row
        if col_idx == 0:
            # We disabled axis drawing in UMAP to have plots without background and border
            # so we need to re-enable axis to plot the ylabel
            ax.axis("on")
            ax.tick_params(
                top="off",
                bottom="off",
                left="off",
                right="off",
                labelleft="on",
                labelbottom="off",
            )
            ax.set_ylabel(prots + "\n", rotation=90, fontsize=14)
            ax.set(frame_on=False)
        col_idx += 1
    # Remove unused column Axes in the current row
    while col_idx < ncol:
        axs[row_idx, col_idx].remove()
        col_idx += 1
# Alignment within the Figure
fig.tight_layout()
../../../_images/tutorials_notebooks_scEve_scEve_train_32_0.png

The tuned model outperforms the default model in predicting protein markers of small cell populations, such as CD133 of HSPCs (Hematopoietic Stem/Progenitor Cells ).

[24]:
# Check prediction results
df_hp_tuned = scparadise.scnoah.report_reg(adata_prot = mdata_test.mod['adt'],
                                           adata_pred_prot = mdata_pred.mod['adt'])
df_hp_tuned
[24]:
EVS r2_score RMSE MedianAE MeanAE
score 0.311 0.277 0.321 0.197 0.243
EVS/r2_score higher value - better prediction
RMSE/MedianAE/MeanAE lower value - better prediction
[25]:
# Compare prediction results
# 'untuned' row represents untuned model
# 'hp tuned' row represents tuned model
df.compare(df_hp_tuned, keep_equal=True, align_axis = 0, result_names=('untuned', 'hp tuned'))
# EVS/r2_score:     higher value - better prediction
# RMSE/MedianAE/MeanAE:     lower value - better prediction
[25]:
EVS r2_score RMSE MedianAE MeanAE
score untuned 0.267 0.227 0.333 0.202 0.251
hp tuned 0.311 0.277 0.321 0.197 0.243
[26]:
scparadise.scnoah.regres_status(adata_prot = mdata_test.mod['adt'],
                                adata_pred_prot = mdata_pred.mod['adt'],
                                metric='RMSE')
scparadise.scnoah.regres_status(adata_prot = mdata_test.mod['adt'],
                                adata_pred_prot = mdata_pred.mod['adt'],
                                metric='MeanAE')
scparadise.scnoah.regres_status(adata_prot = mdata_test.mod['adt'],
                                adata_pred_prot = mdata_pred.mod['adt'],
                                metric='MedianAE')
scparadise.scnoah.regres_status(adata_prot = mdata_test.mod['adt'],
                                adata_pred_prot = mdata_pred.mod['adt'],
                                metric='EVS')
scparadise.scnoah.regres_status(adata_prot = mdata_test.mod['adt'],
                                adata_pred_prot = mdata_pred.mod['adt'],
                                metric='r2_score')
[27]:
for i in ['regres_status_RMSE', 'regres_status_MeanAE', 'regres_status_MedianAE', 'regres_status_EVS', 'regres_status_r2_score']:
    mdata_pred.obs[i] = mdata_pred.mod['adt'].obs[i]
[28]:
# Visualize observed proteins
mu.pl.embedding(mdata_pred,
                color=[
                    'regres_status_RMSE',
                    'regres_status_MeanAE',
                    'regres_status_MedianAE',
                    'regres_status_EVS',
                    'regres_status_r2_score'
                ],
                basis = 'X_umap',
                frameon = False,
                add_outline = True,
                cmap = 'bwr',
                ncols = 5,
                wspace = 0.15,
                hspace = 0.1)
../../../_images/tutorials_notebooks_scEve_scEve_train_38_0.png

Based on a per cell type regression metrics score, it is possible to draw conclusions about the uniformity of protein prediction across different cell types in the dataset.

[29]:
df_corr_hp_tuned = pd.DataFrame(columns=['Pearson coef', 'p-value'])
for i in mdata_test.mod['adt'].var_names.tolist():
    person_coef = scparadise.scnoah.pearson_coef_prot(adata_prot = mdata_test.mod['adt'],
                                                      adata_pred_prot = mdata_pred.mod['adt'],
                                                      protein = i,
                                                      protein_pred = i + '_pred',
                                                      print_res = False)
    df_corr_hp_tuned.loc[i] = [person_coef['Pearson coefficient'], person_coef['p-value']]
df_corr_hp_tuned['p-value'] = df_corr_hp_tuned['p-value'].astype('float64')
[30]:
df_corr_hp_tuned.head(5)
[30]:
Pearson coef p-value
adt_CD39 0.865 0.000000e+00
adt_CD107a 0.640 0.000000e+00
adt_CD62P 0.815 0.000000e+00
adt_TCR-2 0.550 0.000000e+00
adt_CD30 0.244 6.087000e-87
[31]:
df_high_corr_hp_tuned = df_corr_hp_tuned[df_corr_hp_tuned['Pearson coef'] >= 0.5].copy()
print('Number of proteins with Pearson coefficient >= 0.5:', len(df_high_corr_hp_tuned))
df_high_corr_hp_tuned.sort_values(by = 'Pearson coef', ascending=False).head(5)
Number of proteins with Pearson coefficient >= 0.5: 110
[31]:
Pearson coef p-value
adt_CLEC12A 0.962 0.0
adt_CD64 0.940 0.0
adt_CD11b-2 0.929 0.0
adt_CD8a 0.918 0.0
adt_CD93 0.915 0.0
[32]:
# Comparison of Pearson coefficients from the tuned and default models (tuned minus default).
df_compare = df_corr_hp_tuned.subtract(df_corr)
df_compare.rename(columns={'Pearson coef': 'Pearson coef comparison'}, inplace=True)
df_compare.sort_values(by = 'Pearson coef comparison', ascending=False).head(5)
[32]:
Pearson coef comparison p-value
adt_CD117 0.459 -4.147000e-28
adt_CD133-2 0.330 -2.926000e-82
adt_CD34 0.293 -1.021000e-283
adt_CD161 0.289 -2.854000e-146
adt_CD99 0.262 -1.179000e-77

The Pearson correlation coefficient for adt_CD117 increased by 0.459 as a result of model hyperparameters tuning.

Additionally, there is an increase in the number of proteins with a Pearson coefficient of 0.5 or greater from 94 to 110 as a result of model hyperparameter tuning.

[33]:
# Save mdata with imputed proteins
mdata_pred.write_h5mu('mdata_pred_tuned.h5mu')

Recommendation#

We strongly recommend hyperparameters tuning for scEve model training if possible. We also recommend testing the prediction results on a separate dataset (distinct from the training dataset).

[34]:
import session_info
session_info.show()
[34]:
Click to view session information
-----
anndata             0.10.8
matplotlib          3.9.2
mudata              0.2.4
muon                0.1.6
numpy               1.25.0
pandas              2.2.3
scanpy              1.10.3
scparadise          0.3.2_beta
session_info        1.0.0
-----
Click to view modules imported as dependencies
PIL                         10.4.0
alembic                     1.13.3
anyio                       NA
arrow                       1.3.0
asciitree                   NA
asttokens                   NA
attr                        24.2.0
attrs                       24.2.0
awkward                     2.7.1
awkward_cpp                 NA
babel                       2.16.0
backports                   NA
certifi                     2024.08.30
cffi                        1.17.1
charset_normalizer          3.3.2
cloudpickle                 3.1.0
colorlog                    NA
comm                        0.2.2
cycler                      0.12.1
cython_runtime              NA
dask                        2024.8.0
dateutil                    2.9.0.post0
debugpy                     1.8.6
decorator                   5.1.1
defusedxml                  0.7.1
exceptiongroup              1.2.2
executing                   2.1.0
fastjsonschema              NA
fqdn                        NA
fsspec                      2023.6.0
greenlet                    3.1.1
h5py                        3.12.1
idna                        3.10
igraph                      0.11.6
imblearn                    0.12.3
importlib_metadata          NA
importlib_resources         NA
ipykernel                   6.29.5
isoduration                 NA
jaraco                      NA
jedi                        0.19.1
jinja2                      3.1.4
joblib                      1.4.2
json5                       0.9.25
jsonpointer                 3.0.0
jsonschema                  4.23.0
jsonschema_specifications   NA
jupyter_events              0.10.0
jupyter_server              2.14.2
jupyterlab_server           2.27.3
kiwisolver                  1.4.7
legacy_api_wrap             NA
leidenalg                   0.10.2
llvmlite                    0.43.0
mako                        1.3.5
markupsafe                  2.1.5
matplotlib_inline           0.1.7
more_itertools              10.5.0
mpl_toolkits                NA
mpmath                      1.3.0
msgpack                     1.1.0
natsort                     8.4.0
nbformat                    5.10.4
numba                       0.60.0
numcodecs                   0.12.1
optuna                      4.0.0
overrides                   NA
packaging                   24.1
parso                       0.8.4
patsy                       0.5.6
pexpect                     4.9.0
platformdirs                4.3.6
plotly                      5.24.1
prometheus_client           NA
prompt_toolkit              3.0.48
psutil                      6.0.0
ptyprocess                  0.7.0
pure_eval                   0.2.3
pyarrow                     18.1.0
pycparser                   2.22
pydev_ipython               NA
pydevconsole                NA
pydevd                      3.1.0
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pydot                       3.0.3
pygments                    2.18.0
pynndescent                 0.5.13
pyparsing                   3.1.4
pythonjsonlogger            NA
pytorch_tabnet              NA
pytz                        2024.2
referencing                 NA
requests                    2.32.3
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rich                        NA
rpds                        NA
scipy                       1.13.1
seaborn                     0.13.2
send2trash                  NA
setuptools                  75.1.0
setuptools_scm              NA
shap                        0.46.0
six                         1.16.0
sklearn                     1.5.2
slicer                      NA
sniffio                     1.3.1
sqlalchemy                  2.0.35
stack_data                  0.6.3
statsmodels                 0.14.3
sympy                       1.13.3
tblib                       3.0.0
texttable                   1.7.0
threadpoolctl               3.5.0
tlz                         0.12.1
tomli                       2.0.1
toolz                       0.12.1
torch                       2.4.1+cu121
torchgen                    NA
tornado                     6.4.1
tqdm                        4.66.5
traitlets                   5.14.3
triton                      3.0.0
typing_extensions           NA
umap                        0.5.6
uri_template                NA
urllib3                     1.26.20
wcwidth                     0.2.13
webcolors                   24.8.0
websocket                   1.8.0
yaml                        6.0.2
zarr                        2.18.2
zipp                        NA
zmq                         26.2.0
zoneinfo                    NA
-----
IPython             8.18.1
jupyter_client      8.6.3
jupyter_core        5.7.2
jupyterlab          4.2.5
-----
Python 3.9.19 (main, May  6 2024, 19:43:03) [GCC 11.2.0]
Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
-----
Session information updated at 2025-01-18 10:17