scAdam model explainability#

Model explainability refers to the ability to understand and articulate how a machine learning model makes its predictions or decisions.

scNoah uses SHAP (SHapley Additive exPlanations) for model explainability analysis.

SHAP values are derived from cooperative game theory, specifically using Shapley values, which quantify the contribution of each player (or feature/gene) to the total payoff (or prediction). This approach allows for a fair distribution of contributions among features, similar to how players in a coalition share the rewards based on their input.

The SHAP framework operates under the principle that the prediction for an instance can be expressed as the sum of the average prediction and the contributions of each feature. Mathematically, this can be represented as:

\[prediction = f_0 + \sum_{i=1}^M f_i\]

where \(prediction\) is the model’s prediction, \(f_0\) is the expected value of the output, \(f_i\) are the SHAP values for each feature.

SHAP provides a framework for both local explainability and global explainability in machine learning models.

  1. Local Explainability: Focuses on explaining individual predictions (cells) made by the model. It seeks to clarify why a specific output was generated for a given input.

  2. Global Explainability: Aims to provide an overall understanding of the model’s behavior across all predictions (cells), identifying the main factors that influence outputs.

[1]:
# Python packages
import warnings
warnings.simplefilter('ignore')

import scanpy as sc
import scparadise
import shap
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

sc.set_figure_params(dpi = 120)

Recommendations about dataset#

Our models trained using shifted logarithm normalized data. We recommend shifted logarithm data normalization method to proper usage of our models: sc.pp.normalize_total(adata, target_sum=None) sc.pp.log1p(adata) adata.raw = adata

[2]:
# Load dataset from 10x Genomics
url = "https://cf.10xgenomics.com/samples/cell-exp/6.1.0/10k_PBMC_3p_nextgem_Chromium_X/10k_PBMC_3p_nextgem_Chromium_X_filtered_feature_bc_matrix.h5"
adata = sc.read_10x_h5("dataset.h5", backup_url = url)
adata.var_names_make_unique()
adata
[2]:
AnnData object with n_obs × n_vars = 11996 × 36601
    var: 'gene_ids', 'feature_types', 'genome'

QC#

[3]:
# mitochondrial genes, "MT-" for human, "Mt-" for mouse
adata.var["mt"] = adata.var_names.str.startswith("MT-")
# ribosomal genes
adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL"))
# hemoglobin genes
adata.var["hb"] = adata.var_names.str.contains("^HB[^(P)]")
[4]:
sc.pp.calculate_qc_metrics(
    adata, qc_vars=["mt", "ribo", "hb"], inplace=True, log1p=True
)
[5]:
sc.pl.violin(
    adata,
    ["n_genes_by_counts", "total_counts", "pct_counts_mt", 'pct_counts_ribo', 'pct_counts_hb'],
    jitter=0.4,
    multi_panel=True,
)
../../../_images/tutorials_notebooks_scNoah_scAdam_model_explainability_8_0.png
[6]:
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)
[7]:
# Detect doublets
sc.pp.scrublet(adata)
[8]:
# Remove doublets + other QC metrics
adata = adata[adata.obs['predicted_doublet'] == False]
sc.pp.filter_cells(adata, max_genes = 5000)
sc.pp.filter_cells(adata, max_counts = 20000)
adata = adata[adata.obs['pct_counts_mt'] < 15]

Normalization, HVG, neighbors, PCA, UMAP#

We recommend using shifted logarithm data normalization as described here.

[9]:
# Saving count data
adata.layers["counts"] = adata.X.copy()
# Normalization (shifted logarithm)
sc.pp.normalize_total(adata, target_sum=None)
sc.pp.log1p(adata)
# scParadise use normalized data in adata.raw!!!
adata.raw = adata
# HVG
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
# PCA
sc.tl.pca(adata)
# Nearest neighbors analysis
sc.pp.neighbors(adata, n_neighbors=10, n_pcs=20)
# UMAP
sc.tl.umap(adata)

Cell type prediction using scAdam#

[10]:
# Available models for cell type annotation
df = scparadise.scadam.available_models()
# Show models related to humans
df_human = df[df['Tissue/Model name'].str.startswith('Human_')]
df_human
[10]:
Tissue/Model name Description Suspension Accuracy Balanced Accuracy Number of Levels
0 Human_PBMC Peripheral blood mononuclear cells of healthy ... cells 0.979 0.979 3
1 Human_BMMC Bone marrow mononuclear cell of healthy adults cells 0.947 0.942 3
2 Human_Heart Human heart CITE-seq analysis of healthy and d... cells 0.957 0.956 2
3 Human_Lung Core Human Lung Cell Atlas cells 0.965 0.964 5
4 Human_Lung_Cancer Extended single-cell lung cancer atlas (LuCA) cells 0.937 0.936 3
5 Human_oropharyngeal_SCC Oropharyngeal HPV+/HPV- squamous cell carcinom... cells 0.972 0.968 2
6 Human_Retina Single cell atlas of the human retina cells 0.984 0.979 4
[11]:
# Download model for cell type prediction
scparadise.scadam.download_model('Human_PBMC', save_path='')
[12]:
# Predict cell types using trained model
adata = scparadise.scadam.predict(adata,
                                  path_model = 'Human_PBMC_scAdam/')
Successfully loaded list of genes used for training model

Successfully loaded dictionary of dataset annotations

Successfully loaded model

Successfully added predicted celltype_l1 and cell type probabilities
Successfully added predicted celltype_l2 and cell type probabilities
Successfully added predicted celltype_l3 and cell type probabilities
[13]:
# Visualise predicted cell types levels and prediction probabilities
sc.set_figure_params(figsize=[8, 5])
sc.pl.embedding(adata,
                color = [
                    'pred_celltype_l1',
                    'prob_celltype_l1',
                    'pred_celltype_l2',
                    'prob_celltype_l2',
                    'pred_celltype_l3',
                    'prob_celltype_l3'
                ],
                basis = 'X_umap',
                frameon = False,
                add_outline = True,
                legend_loc = 'on data',
                legend_fontsize = 10,
                legend_fontoutline = 2,
                ncols = 2,
                size = 30,
                wspace = 0,
                hspace = 0.1)
../../../_images/tutorials_notebooks_scNoah_scAdam_model_explainability_18_0.png

Create explanations matrix#

Compute SHAP values for a specific cell type predicted using scAdam model.

[14]:
# Compute explanations for CD8+ T cells
random_state = 0
num_cells = 100
explanations  = scparadise.scnoah.explain(adata,
                                          celltype = 'CD8 T',
                                          path_model = 'Human_PBMC_scAdam/',
                                          num_cells = num_cells,
                                          random_state = random_state)
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Successfully loaded model
Cell type "CD8 T" was successfully selected from pred_celltype_l2
PermutationExplainer explainer: 101it [02:40,  1.69s/it]
The explanations for "CD8 T" have been completed

[15]:
# General explainability
# Calculate gene impoortances
gene_importances = scparadise.scnoah.feature_importance(explanations,
                                                         path_model = 'Human_PBMC_scAdam/')
# Show 15 most important genes for CD8+ T cells annotation
gene_importances.head(15)
[15]:
gene_name gene_importance
76 GZMK 0.284200
593 CCL5 0.221275
108 GZMA 0.082325
539 CD8A 0.078000
425 IL7R 0.051650
40 CCR7 0.048450
72 NKG7 0.041350
403 CD8B 0.040450
116 TIGIT 0.033150
220 ITGB1 0.025675
386 LYAR 0.025250
506 FCMR 0.024350
305 KLRD1 0.024000
17 NUCB2 0.023825
604 MAL 0.023225
[16]:
# Visualization of gene importancies on cell embeddings
sc.pl.embedding(adata,
                color = gene_importances.head(16)['gene_name'],
                basis = 'X_umap',
                frameon = False,
                add_outline = True,
                size = 30,
                ncols=4)
../../../_images/tutorials_notebooks_scNoah_scAdam_model_explainability_22_0.png

gene_importances indicate the significance of the genes CD8A and CD8B for the annotation of CD8+ T cells, which makes sense since these genes are markers of these cells. Additionally, among the genes with high gene importance values are markers for various subtypes of CD8+ T cells. GZMK, GZMA, CCL5, NKG7, and KLRD1 are markers of memory CD8+ T cells, while IL7R and CCR7 are markers of naive CD8+ T cells.

Analysis of explanations#

  1. Positive SHAP value indicates that a particular feature (gene) has a significant impact on the model’s prediction for a specific instance (cell type).

  2. Negative SHAP value indicates that a specific feature (gene) has a detrimental effect on the model’s prediction for a given instance (cell type).

[17]:
# Visualization of explanations
# For the annotation of CD8+ T cells, the following genes are of the greatest importance:
plt.grid(visible=False, axis='both')
shap.plots.bar(explanations, max_display=15)
../../../_images/tutorials_notebooks_scNoah_scAdam_model_explainability_25_0.png
[18]:
# It is also possible to display SHAP values for individual cells
shap.plots.heatmap(explanations, max_display=15, show=False)
plt.grid(visible=False, axis='both')
../../../_images/tutorials_notebooks_scNoah_scAdam_model_explainability_26_0.png

Difference between CD8+ T cells#

The difference in SHAP values indicates the presence of at least two different groups of CD8+ T cells.

  1. Positive SHAP values of CCL5 and Negative SHAP values of GZMK

  2. Positive or 0 SHAP values of GZMK and Negative or 0 SHAP values of CCL5 For individual cells, SHAP values can be displayed along with the levels of gene expression.

Warning! shap.plots.heatmap makes clusterization automatically. To find the cell number in explanations, it is necessary to perform a similar clustering using hclust. explanations.hclust() Then you can find the cell using the instance number from a shap.plots.heatmap.

[19]:
# Memory CD8+ T cell
# Left column of numbers (near gene names) represents gene expression
shap.plots.bar(explanations[explanations.hclust()[1]], max_display=15, show=False)
plt.grid(visible=False, axis='both')
explanations.hclust()[1] # Actual cell from explanations
[19]:
5
../../../_images/tutorials_notebooks_scNoah_scAdam_model_explainability_29_1.png
[20]:
# Naive CD8+ T cell
# Left column of numbers (near gene names) represents gene expression
shap.plots.bar(explanations[explanations.hclust()[70]], max_display=15, show=False)
plt.grid(visible=False, axis='both')
explanations.hclust()[70] # Actual cell from explanations
[20]:
91
../../../_images/tutorials_notebooks_scNoah_scAdam_model_explainability_30_1.png
  1. For the annotation of cell 1, the most significant factors are the presence of GZMK and GZMA expression.

  2. For the annotation of cell 70, the most significant factors are the absence of CCL5 and NKG7 expression.

[21]:
# Extract the same 100 cells from a dataset
adata_subset = adata[adata.obs['pred_celltype_l2'] == 'CD8 T'].copy()
indexes = pd.DataFrame(index = adata_subset.obs_names).sample(n = num_cells, axis = 0, random_state = random_state).index.tolist()
adata_subset = adata_subset[indexes]
[22]:
# Visualization of gene importances on cell embeddings
# CD8 TCM - central memory
# CD8 TEM - effector memory
sc.pl.embedding(adata_subset,
                color = ['pred_celltype_l3',
                         'GZMK',
                         'GZMA',
                         'CCL5',
                         'NKG7'],
                basis = 'X_umap',
                frameon = False,
                add_outline = True,
                legend_loc = 'right margin',
                palette = 'tab20',
                ncols = 3,
                size = 300)
../../../_images/tutorials_notebooks_scNoah_scAdam_model_explainability_33_0.png

Marker gene search based on gene importances#

[23]:
# Compute explanations for plasmacytoid dendritic cells (pDC)
explanations  = scparadise.scnoah.explain(adata,
                                          celltype = 'pDC',
                                          num_cells = 75,
                                          path_model = 'Human_PBMC_scAdam/')
Successfully loaded list of genes used for training model
Successfully loaded dictionary of dataset annotations
Successfully loaded model
Cell type "pDC" was successfully selected from pred_celltype_l2
PermutationExplainer explainer: 76it [01:49,  1.61s/it]
The explanations for "pDC" have been completed

[24]:
# Calculate gene impoortances
gene_importances = scparadise.scnoah.feature_importance(explanations,
                                                         path_model = 'Human_PBMC_scAdam/')
# Show 10 most important genes for pDC cell annotation
gene_importances.head(10)
[24]:
gene_name gene_importance
507 LINC00996 0.003733
449 SMPD3 0.001244
258 SERPINF1 0.001244
372 GZMB 0.001244
283 GAS6 0.001244
562 HLA-DQA1 0.001244
344 PTCRA 0.001244
125 TSPAN13 0.001244
91 APP 0.001244
318 AXL 0.001244
[25]:
# Visualization of gene importances on cell embeddings
sc.pl.embedding(adata,
                color = gene_importances.head(16)['gene_name'],
                basis = 'X_umap',
                frameon = False,
                add_outline = True,
                size = 30,
                ncols = 4)
../../../_images/tutorials_notebooks_scNoah_scAdam_model_explainability_37_0.png

LINC00996, SMPD3, SERPINF1, GAS6, TSPAN13, APP, PLD4, and ZFAT, which have high gene importance values, can be considered marker genes for pDC.

[26]:
# Save anndata with predicted annotations
adata.write_h5ad('adata_predicted.h5ad')
[27]:
import session_info
session_info.show()
[27]:
Click to view session information
-----
anndata             0.10.8
matplotlib          3.9.2
numpy               1.25.0
pandas              2.2.3
scanpy              1.10.3
scparadise          0.3.1_beta
session_info        1.0.0
shap                0.46.0
-----
Click to view modules imported as dependencies
PIL                         10.4.0
anyio                       NA
arrow                       1.3.0
asciitree                   NA
asttokens                   NA
attr                        24.2.0
attrs                       24.2.0
awkward                     2.7.1
awkward_cpp                 NA
babel                       2.16.0
backports                   NA
certifi                     2024.08.30
cffi                        1.17.1
charset_normalizer          3.3.2
cloudpickle                 3.1.0
colorlog                    NA
comm                        0.2.2
cycler                      0.12.1
cython_runtime              NA
dask                        2024.8.0
dateutil                    2.9.0.post0
debugpy                     1.8.6
decorator                   5.1.1
defusedxml                  0.7.1
exceptiongroup              1.2.2
executing                   2.1.0
fastjsonschema              NA
fqdn                        NA
fsspec                      2023.6.0
h5py                        3.12.1
idna                        3.10
igraph                      0.11.6
imblearn                    0.12.3
importlib_metadata          NA
importlib_resources         NA
ipykernel                   6.29.5
isoduration                 NA
jaraco                      NA
jedi                        0.19.1
jinja2                      3.1.4
joblib                      1.4.2
json5                       0.9.25
jsonpointer                 3.0.0
jsonschema                  4.23.0
jsonschema_specifications   NA
jupyter_events              0.10.0
jupyter_server              2.14.2
jupyterlab_server           2.27.3
kiwisolver                  1.4.7
lazy_loader                 0.4
legacy_api_wrap             NA
leidenalg                   0.10.2
llvmlite                    0.43.0
markupsafe                  2.1.5
matplotlib_inline           0.1.7
more_itertools              10.5.0
mpl_toolkits                NA
msgpack                     1.1.0
mudata                      0.2.4
muon                        0.1.6
natsort                     8.4.0
nbformat                    5.10.4
numba                       0.60.0
numcodecs                   0.12.1
optuna                      4.0.0
overrides                   NA
packaging                   24.1
parso                       0.8.4
patsy                       0.5.6
pexpect                     4.9.0
platformdirs                4.3.6
plotly                      5.24.1
pooch                       v1.8.2
prometheus_client           NA
prompt_toolkit              3.0.48
psutil                      6.0.0
ptyprocess                  0.7.0
pure_eval                   0.2.3
pyarrow                     18.1.0
pycparser                   2.22
pydev_ipython               NA
pydevconsole                NA
pydevd                      3.1.0
pydevd_file_utils           NA
pydevd_plugins              NA
pydevd_tracing              NA
pydot                       3.0.3
pygments                    2.18.0
pynndescent                 0.5.13
pyparsing                   3.1.4
pythonjsonlogger            NA
pytorch_tabnet              NA
pytz                        2024.2
referencing                 NA
requests                    2.32.3
rfc3339_validator           0.1.4
rfc3986_validator           0.1.1
rich                        NA
rpds                        NA
scipy                       1.13.1
seaborn                     0.13.2
send2trash                  NA
setuptools                  75.1.0
setuptools_scm              NA
six                         1.16.0
skimage                     0.24.0
sklearn                     1.5.2
slicer                      NA
sniffio                     1.3.1
stack_data                  0.6.3
statsmodels                 0.14.3
tblib                       3.0.0
texttable                   1.7.0
threadpoolctl               3.5.0
tlz                         0.12.1
tomli                       2.0.1
toolz                       0.12.1
torch                       2.4.1+cu121
torchgen                    NA
tornado                     6.4.1
tqdm                        4.66.5
traitlets                   5.14.3
typing_extensions           NA
umap                        0.5.6
uri_template                NA
urllib3                     1.26.20
wcwidth                     0.2.13
webcolors                   24.8.0
websocket                   1.8.0
yaml                        6.0.2
zarr                        2.18.2
zipp                        NA
zmq                         26.2.0
zoneinfo                    NA
-----
IPython             8.18.1
jupyter_client      8.6.3
jupyter_core        5.7.2
jupyterlab          4.2.5
-----
Python 3.9.19 (main, May  6 2024, 19:43:03) [GCC 11.2.0]
Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
-----
Session information updated at 2025-01-15 22:52