import warnings
warnings.filterwarnings("ignore")
from imblearn.over_sampling import RandomOverSampler
import sklearn.metrics as metrics_
from imblearn import metrics
from scipy import sparse
from scipy import stats
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import scanpy as sc
import shap
import json
import anndata
import os
from plottable.cmap import normed_cmap
import matplotlib as mpl
from plottable import ColumnDefinition
from plottable import Table
from plottable.plots import bar
from plottable.font import contrasting_font_color
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch
from scparadise import scadam
from torch.utils.data import Dataset, DataLoader, Subset
# Function for generation of synthetic cells for a cell type via PCA-based local interpolation.
def cells_generator(
X_class,
num_target,
max_components = 20,
random_state = 0
):
"""
Generate synthetic cells for a cell type via PCA-based local interpolation.
Parameters
----------
X_class : ndarray, shape (n_cells, n_genes)
Expression matrix for one cell type.
num_target : int
Desired total number of samples after augmentation (including originals).
max_components: int (default: 20)
Upper bound for PCA components.
random_state: int (default: 0)
Seed for random number generators to ensure reproducibility.
Returns
-------
ndarray
Extended matrix. If `n_cells >= num_target` or `n_cells < 2`, returns input.
"""
rng = np.random.default_rng(random_state)
n_current, n_genes = X_class.shape
if n_current >= num_target or n_current < 2:
# Nothing to do (or too few points for interpolation)
return X_class
n_new = num_target - n_current
# Fit PCA on this class
n_components = min(max_components, n_current - 1, n_genes)
if n_components < 1:
return X_class
pca = PCA(n_components=n_components, random_state=random_state)
Z = pca.fit_transform(X_class) # (n_current, n_components)
Z_new = []
for _ in range(n_new):
i, j = rng.choice(n_current, size=2, replace=False)
zi, zj = Z[i], Z[j]
alpha = rng.uniform(0.3, 0.7) # stay somewhere in between
z_new = alpha * zi + (1.0 - alpha) * zj
Z_new.append(z_new)
Z_new = np.vstack(Z_new) # (n_new, n_components)
X_new = pca.inverse_transform(Z_new)
# Clip negatives - scRNA-seq is non-negative
X_new[X_new < 0] = 0.0
return np.vstack([X_class, X_new])
# Function for oversample selected cell types
[docs]
def oversample(
adata,
celltype_keys,
target_per_class = None,
max_oversample_factor = 7.0,
min_oversample_cells = 5,
random_state = 0
):
"""
Oversample some cell types in AnnData object.
Returns adata_oversampled with updated matrix and adata_oversampled.obs with given celltypes levels and sample.
If you give counts function returns counts.
If you give normalized data function returns normalized data.
Parameters
----------
adata: AnnData
Input dataset to be oversampled.
celltype_keys: list
List of cell type annotations in adata.obs.
Example: ['lineage', 'cell type', 'cell state']
target_per_class: int (default: None)
Global target per cell type. If None, computed as average cells per cell type in celltype level with most cell types.
max_oversample_factor: float (default: 7.0)
Upper bound on how much a small class may be expanded relative to its original size.
min_oversample_cells: int (default: 5)
Minimal cell type size to allow substantial generation of new cells.
random_state: int (default: 0)
Seed for random number generators to ensure reproducibility.
Returns
-------
AnnData
New AnnData with oversampled minor cell types. Original and synthetic cells are concatenated.
"""
rng = np.random.default_rng(random_state)
biggest_key = None
biggest_key_size = 0
for key in celltype_keys:
key_size = adata.obs[key].nunique()
if biggest_key_size < key_size:
biggest_key = key
counts = adata.obs[biggest_key].value_counts()
if target_per_class is None:
target_per_class = round(len(adata)/adata.obs[biggest_key].nunique())
# Convert data to dense matrix
X = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X
X_blocks = [X]
obs_blocks = [adata.obs.copy()]
for celltype, n in counts.items():
if n >= target_per_class:
continue # no need to oversample this type
idx = np.where(adata.obs[biggest_key].values == celltype)[0]
X_class = X[idx]
# Compute class-specific target with safety cap
# max_oversample_factor - maximum oversample of minor cell types
max_allowed = int(max_oversample_factor * n)
num_target = min(target_per_class, max_allowed)
if n < min_oversample_cells:
num_target = min(num_target, n + (min_oversample_cells - n))
if num_target <= n:
continue # nothing to do
X_extended = cells_generator(
X_class=X_class,
num_target=num_target,
random_state=random_state
)
n_current = X_class.shape[0]
X_new = X_extended[n_current:]
n_new = X_new.shape[0]
if n_new == 0:
continue
# For metadata, reuse obs rows of existing cells (bootstrap indices)
base_idx = rng.integers(0, n_current, size=n_new)
base_obs = adata.obs.iloc[idx[base_idx]].copy()
X_blocks.append(X_new)
obs_blocks.append(base_obs)
X_oversampled = np.vstack(X_blocks)
obs_bal = pd.concat(obs_blocks, axis=0)
adata_oversampled = anndata.AnnData(X=sparse.csr_matrix(X_oversampled), obs=obs_bal, var=adata.var.copy())
return adata_oversampled
# Function for undersample specific cell types
[docs]
def undersample(
adata,
celltype_keys,
target_per_class = None,
min_keep_frac = 0.5,
random_state = 0
):
"""
Undersample some cell types in AnnData object.
Returns subsetted adata_undersampled object.
Parameters
----------
adata: AnnData
Input dataset to be undersampled.
celltype_keys: list
List of cell type annotations in adata.obs.
Example: ['lineage', 'cell type', 'cell state']
target_per_class: int (default: None)
Global target per cell type. If None, computed as average cells per cell type in celltype level with most cell types.
min_keep_frac: float (default: 0.5)
Lower bound on fraction of original cell type size preserved after undersampling.
random_state: int (default: 0)
Seed for random number generators to ensure reproducibility.
Returns
-------
AnnData
adata_undersample containing the undersampled cell types. Cell type hierarchy is preserved.
"""
rng = np.random.default_rng(random_state)
biggest_key = None
biggest_key_size = 0
for key in celltype_keys:
key_size = adata.obs[key].nunique()
if biggest_key_size < key_size:
biggest_key = key
counts = adata.obs[biggest_key].value_counts()
if target_per_class is None:
target_per_class = round(len(adata)/adata.obs[biggest_key].nunique())
keep_indices_list = []
for celltype, n in counts.items():
idx = np.where(adata.obs[biggest_key].values == celltype)[0]
# Minimum acceptable number of cells after undersampling
min_keep = int(max(min_keep_frac * n, 1))
# Actual target: no less than min_keep and no more than n
target = min(max(target_per_class, min_keep), n)
if n > target:
chosen = rng.choice(idx, size=target, replace=False)
keep_indices_list.append(chosen)
else:
keep_indices_list.append(idx)
keep_indices = np.concatenate(keep_indices_list)
keep_indices = np.sort(keep_indices)
return adata[keep_indices].copy()
# Function for balance cell types in adata
[docs]
def balance(
adata,
celltype_keys,
min_keep_frac = 0.5,
max_oversample_factor = 7.0,
min_oversample_cells = 5,
random_state = 0
):
"""
Balance cell types in AnnData object.
Returns adata_balanced with updated matrix and adata_balanced.obs with given celltypes levels.
If you give counts function returns counts.
If you give normalized data function returns normalized data.
Parameters
----------
adata: AnnData
Input dataset to be balanced.
celltype_keys: list
List of cell type annotations in adata.obs.
Example: ['lineage', 'cell type', 'cell state']
min_keep_frac: float (default: 0.5)
Safety lower bound for the fraction of original cells retained in large classes.
max_oversample_factor: float (default: 7.0)
Upper bound on how much a small class may be expanded relative to its original size.
min_oversample_cells: int (default: 5)
Minimal cell type size to allow substantial generation of new cells.
random_state: int (default: 0)
Seed for random number generators to ensure reproducibility.
Returns
-------
AnnData
Balanced dataset with preserved cell type hierarchy and var features.
"""
biggest_key = None
biggest_key_size = 0
for key in celltype_keys:
key_size = adata.obs[key].nunique()
if biggest_key_size < key_size:
biggest_key = key
target_per_class = round(len(adata)/adata.obs[biggest_key].nunique())
# Step 1: undersample big classes
adata_undersampled = undersample(
adata=adata,
celltype_keys=celltype_keys,
target_per_class=target_per_class,
min_keep_frac=min_keep_frac,
random_state=random_state,
)
# Step 2: oversample small classes
adata_balanced = oversample(
adata=adata_undersampled,
celltype_keys=celltype_keys,
target_per_class=target_per_class,
max_oversample_factor=max_oversample_factor,
min_oversample_cells=min_oversample_cells,
random_state=random_state
)
return adata_balanced
# Function for creation of full report
[docs]
def report_classif_full(
adata,
celltype = None,
pred_celltype = None,
save_report = False,
report_name = 'report.csv',
save_path = '',
ndigits = 4
):
'''
Returns metrics (precision, recall (also called sensitivity), specificity, f1-score, geometric mean, and index balanced accuracy of the geometric mean) of predicted cell types.
You should use it after prediction on annotated test dataset (shows results of validation).
Helps in understanding model quality.
adata : AnnData
Annotated data matrix. Previously annotated test dataset not used for model tuning or training.
celltype : str, , (default: None)
Level of cell annotation to show metrics. Key in adata.obs dataframe.
pred_celltype : str, , (default: None)
Predicted level of cell annotation. Key in adata.obs dataframe.
save_report : bool, (default: False)
Save report as csv file or not.
report_name : str, (default: 'report.csv')
Name of a file to save report.
save_path : path object
Path to a folder to save report.
ndigits : int (default: 4)
Round a number to a given precision in decimal digits.
'''
# Create report with precision, recall/sensitivity, specificity, f1-score, geometric mean, and index balanced accuracy of the geometric mean
report = metrics.classification_report_imbalanced(adata.obs[celltype].to_numpy(), adata.obs[pred_celltype].to_numpy(), output_dict = True, digits = ndigits)
report = pd.DataFrame(report)
del report['avg_pre'], report['avg_rec'], report['avg_spe'], report['avg_f1'], report['avg_geo'], report['avg_iba'], report['total_support']
report = report.transpose()
report['sup'] = report['sup'].astype('int')
# Rename columns
report = report.rename(columns = {'pre': 'precision',
'rec': 'recall/sensitivity',
'spe': 'specificity',
'f1': 'f1-score',
'geo': 'geometric mean',
'iba': 'index balanced accuracy',
'sup': 'number of cells'
})
# Calculate balanced accuracy and create list wih it
lst_bal_acc = [round(np.mean(report['recall/sensitivity']), ndigits = ndigits)]
i = 0
while i < 6:
lst_bal_acc.append('')
i+=1
# Add avg rows
report.loc['macro avg'] = [round(np.mean(report['precision']), ndigits = ndigits),
round(np.mean(report['recall/sensitivity']), ndigits = ndigits),
round(np.mean(report['specificity']), ndigits = ndigits),
round(np.mean(report['f1-score']), ndigits = ndigits),
round(np.mean(report['geometric mean']), ndigits = ndigits),
round(np.mean(report['index balanced accuracy']), ndigits = ndigits), '']
report.loc['weighted avg'] = [round(np.sum(report['precision'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits),
round(np.sum(report['recall/sensitivity'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits),
round(np.sum(report['specificity'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits),
round(np.sum(report['f1-score'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits),
round(np.sum(report['geometric mean'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits),
round(np.sum(report['index balanced accuracy'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits),'']
# Round data
report['precision'] = round(report['precision'], ndigits = ndigits)
report['recall/sensitivity'] = round(report['recall/sensitivity'], ndigits = ndigits)
report['specificity'] = round(report['specificity'], ndigits = ndigits)
report['f1-score'] = round(report['f1-score'], ndigits = ndigits)
report['geometric mean'] = round(report['geometric mean'], ndigits = ndigits)
report['index balanced accuracy'] = round(report['index balanced accuracy'], ndigits = ndigits)
# Add accuracy
lst_acc = [round(metrics_.accuracy_score(adata.obs[celltype].to_numpy(), adata.obs[pred_celltype].to_numpy()), ndigits = ndigits)]
i = 0
while i < 6:
lst_acc.append('')
i+=1
report.loc['Accuracy'] = lst_acc
report.loc['Balanced accuracy'] = lst_bal_acc
del lst_acc, lst_bal_acc
# Save report to .csv
if save_report:
report.to_csv(os.path.join(save_path, report_name).replace("\\","/"))
print('Successfully saved report')
print()
return report
# Function to find predition status (correct or incorrect prediction)
[docs]
def pred_status(
adata,
celltype = None,
pred_celltype = None,
key_added = 'pred_status'
):
'''
Find correct and incorrect predictions.
Returns prediction status in adata.obs.
Parameters
----------
adata : AnnData
Annotated data matrix. Function uses adata.X for oversample.
celltype : str, (default: None)
Cell annotation. Key in adata.obs dataframe.
pred_celltype : str, (default: None)
Predicted cell annotation. Key in adata.obs dataframe.
key_added : str, (default: 'pred_status')
Key to add in adata.obs
'''
adata.obs[key_added] = adata.obs[celltype] == adata.obs[pred_celltype]
adata.obs[key_added] = adata.obs[key_added].astype('str')
adata.obs[key_added] = adata.obs[key_added].replace('True', 'correct')
adata.obs[key_added] = adata.obs[key_added].replace('False', 'incorrect')
adata.uns[key_added + '_colors'] = ['#3A3AFF', '#FF3737']
# Function for visualization of cell type prediction using confusion matrix
[docs]
def conf_matrix(
adata,
celltype = None,
pred_celltype = None,
fmt = ".2f",
annot = True,
cmap = "Blues",
ndigits_metrics = 3,
grid = False,
**kwargs
):
'''
Compute confusion matrix to evaluate the accuracy of a classification.
Parameters
----------
adata : AnnData
Annotated data matrix. Function uses adata.X for oversample.
celltype : str, (default: None)
Cell annotation. Key in adata.obs dataframe.
pred_celltype : str, (default: None)
Predicted cell annotation. Key in adata.obs dataframe.
fmt : str, optional
String formatting code to use when adding annotations.
annot : bool or rectangular dataset, optional
If True, write the data value in each cell. If an array-like with the
same shape as ``data``, then use this to annotate the heatmap instead
of the data. Note that DataFrames will match on position, not index.
cmap : matplotlib colormap name or object, or list of colors, optional
The mapping from data values to color space. If not provided, the
default will depend on whether ``center`` is set.
ndigits_metrics : int (default: 3)
Round a n accuracy and balanced accuracy to a given precision in decimal digits.
grid : bool (default: False)
Show or hide grid lines.
**kwargs: other keyword arguments
All other keyword arguments are passed to
sns.heatmap
'''
# Create confusion matrix
cm = metrics_.confusion_matrix(adata.obs[celltype], adata.obs[pred_celltype])
# Calculate accuracy
accuracy = round(np.trace(cm) / float(np.sum(cm)), ndigits = ndigits_metrics)
accuracy = f"\n\nAccuracy={accuracy}"
# Calculate balanced accuracy
report = metrics.classification_report_imbalanced(adata.obs[celltype].to_numpy(), adata.obs[pred_celltype].to_numpy(), output_dict = True)
report = pd.DataFrame(report)
del report['avg_pre'], report['avg_rec'], report['avg_spe'], report['avg_f1'], report['avg_geo'], report['avg_iba'], report['total_support']
report = report.transpose()
bal_accuracy = round(np.mean(report['rec']), ndigits = ndigits_metrics)
bal_accuracy = f"\n\nBalanced accuracy={bal_accuracy}"
del report
# Convert confusion matrix to dataframe
celltypes = np.unique(adata.obs[celltype])
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
cm = pd.DataFrame(cm, index=celltypes[:cm.shape[0]], columns=celltypes[:cm.shape[1]])
plt.grid(grid)
sns.heatmap(cm, annot=annot, fmt=fmt, cmap=cmap, **kwargs)
plt.xlabel('Predicted' + accuracy + bal_accuracy)
plt.ylabel("Observed")
# Function for calculation of sensitivity and specificity of trained model
[docs]
def report_classif_sens_spec(
adata,
celltype = None,
pred_celltype = None,
save_report = False,
report_name = 'report_sens_spec.csv',
save_path = '',
ndigits = 3
):
'''
Returns specificity and recall (also called sensitivity) metrics of predicted cell types.
You should use it after prediction on annotated test dataset (shows results of validation).
Helps in understanding model quality.
Parameters
----------
adata : AnnData
Annotated data matrix. Previously annotated test dataset not used for model tuning or training.
celltype : str, , (default: None)
Level of cell annotation to show metrics. Key in adata.obs dataframe.
pred_celltype : str, , (default: None)
Predicted level of cell annotation. Key in adata.obs dataframe.
save_report : bool, (default: False)
Save report as csv file or not.
report_name : str, (default: 'report_sens_spec.csv')
Name of a file to save report.
save_path : path object
Path to a folder to save report.
ndigits : int (default: 3)
Round a number to a given precision in decimal digits.
'''
# Create report with sensitivity and specificity
report = metrics.sensitivity_specificity_support(adata.obs[celltype].to_numpy(), adata.obs[pred_celltype].to_numpy(), )
# Add cell types names and column names
report = pd.DataFrame(report, columns=np.unique(adata.obs[celltype]).tolist(), index = ['recall/sensitivity', 'specificity', 'number of cells']).transpose()
# Round data
report['recall/sensitivity'] = round(report['recall/sensitivity'], ndigits = ndigits)
report['specificity'] = round(report['specificity'], ndigits = ndigits)
report['number of cells'] = report['number of cells'].astype('int')
# Add avg rows
report.loc['macro avg'] = [round(np.mean(report['recall/sensitivity']), ndigits = ndigits),
round(np.mean(report['specificity']), ndigits = ndigits), '']
report.loc['weighted avg'] = [round(np.sum(report['recall/sensitivity'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits),
round(np.sum(report['specificity'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits), '']
# Save report to .csv
if save_report:
report.to_csv(os.path.join(save_path, report_name).replace("\\","/"))
print('Successfully saved report')
print()
return report
# Function for calculation of regression metrics of trained model
[docs]
def report_reg(
adata_prot,
adata_pred_prot,
multioutput = 'uniform_average',
save_report = False,
report_name = 'report_regression.csv',
save_path = '',
ndigits = 3
):
'''
Returns multiple metrics of cell surface proteins prediction.
Root mean squared error (RMSE), mean absolute error (MeanAE), median absolute error (MedianAE) : lower value - better prediction.
Coefficient of determination (R² score), explained variance score (EVS) : higher value - better prediction
Parameters
----------
adata_prot : AnnData
Annotated data matrix with proteins. Test dataset not used for model tuning or training.
adata_pred_prot : AnnData
Annotated data matrix with predicted proteins.
multioutput : {‘raw_values’, ‘uniform_average’} or array-like of shape (n_outputs,), (default=’uniform_average’)
Defines aggregating of multiple output values. Array-like value defines weights used to average errors.
‘raw_values’ : Returns a full set of errors in case of multioutput input.
‘uniform_average’ : Errors of all outputs are averaged with uniform weight.
save_report : bool, (default: False)
Save report as csv file or not.
report_name : str, (default: 'report_sens_spec.csv')
Name of a file to save report.
save_path : path object
Path to a folder to save report.
ndigits : int (default: 3)
Round a number to a given precision in decimal digits.
'''
# Create DataFrames of predicted and real data
data_adt = pd.DataFrame(data = adata_prot.X.toarray(), columns = adata_prot.var_names)
data_pred_adt = pd.DataFrame(data = adata_pred_prot.X.toarray(), columns = adata_pred_prot.var_names)
# Root mean squared error
report_RMSE = metrics_.root_mean_squared_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = multioutput)
# Mean absolute error
report_MeanAE = metrics_.mean_absolute_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = multioutput)
# Median absolute error
report_MedianAE = metrics_.median_absolute_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = multioutput)
# Explained variance score
report_EVS = metrics_.explained_variance_score(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = multioutput)
# Coefficient of determination (R² score)
report_r2_score = metrics_.r2_score(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = multioutput)
# Create report
report = pd.DataFrame()
# Round data
report['EVS'] = [round(report_EVS, ndigits=ndigits)]
report['r2_score'] = [round(report_r2_score, ndigits=ndigits)]
report['RMSE'] = [round(report_RMSE, ndigits=ndigits)]
report['MedianAE'] = [round(report_MedianAE, ndigits=ndigits)]
report['MeanAE'] = [round(report_MeanAE, ndigits=ndigits)]
report.loc['EVS/r2_score'] = ['higher value - better prediction', '', '', '', '']
report.loc['RMSE/MedianAE/MeanAE'] = ['lower value - better prediction', '', '', '', '']
report = report.rename(index = {0: "score"})
# Save report to .csv
if save_report:
report.to_csv(os.path.join(save_path, report_name).replace("\\","/"))
print('Successfully saved report')
print()
return report
# Function for defining regression status
[docs]
def regres_status(
adata_prot,
adata_pred_prot,
metric = 'RMSE'
):
'''
Compute regression status of cells to visualize on UMAP.
Parameters
----------
adata_prot : AnnData
Annotated data matrix with proteins. Test dataset not used for model tuning or training.
adata_pred_prot : AnnData
Annotated data matrix with predicted proteins.
metric : str (default: 'RMSE')
Metric used for regression status calculation.
Available metrics: RMSE, MeanAE, MedianAE, EVS, r2_score.
Root mean squared error (RMSE), mean absolute error (MeanAE), median absolute error (MedianAE) : lower value - better prediction.
Coefficient of determination (R² score), explained variance score (EVS) : higher value - better prediction
'''
# Create DataFrames of predicted and real data
data_adt = pd.DataFrame(data = adata_prot.X.toarray().transpose(), columns = adata_prot.obs_names)
data_pred_adt = pd.DataFrame(data = adata_pred_prot.X.toarray().transpose(), columns = adata_pred_prot.obs_names)
# Root mean squared error
if metric == 'RMSE':
report = metrics_.root_mean_squared_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = 'raw_values')
# Mean absolute error
elif metric == 'MeanAE':
report = metrics_.mean_absolute_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = 'raw_values')
# Median absolute error
elif metric == 'MedianAE':
report = metrics_.median_absolute_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = 'raw_values')
# Explained variance score
elif metric == 'EVS':
report = metrics_.explained_variance_score(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = 'raw_values')
# Coefficient of determination (R² score)
elif metric == 'r2_score':
report = metrics_.r2_score(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = 'raw_values')
del data_adt, data_pred_adt
# Add metric to predicted adata
adata_pred_prot.obs['regres_status_' + metric] = report.copy()
del report
# Function for calculation Pearson correlation coefficient per protein
[docs]
def pearson_coef(
adata,
adata_pred,
feature,
feature_pred,
ndigits = 3,
print_res = False
):
'''
Compute Pearson correlation coefficient of predicted feature.
Varies between -1 and +1.
Parameters
----------
adata: AnnData
Annotated data matrix with features. Test dataset not used for model tuning or training.
adata_pred: AnnData
Annotated data matrix with predicted features.
feature: str
Name of the feature in adata.
feature_pred: str
Name of the predicted feature in adata_pred.
ndigits: int (default: 3)
Round a number to a given precision in decimal digits.
print_res: bool (default: False)
Print results or not.
Returns dictionary with Pearson correlation coefficient and p-value.
Values close to 1 indicate strong positive correlation, and values close to -1 indicate strong negative correlation.
'''
# Get feature and predicted feature values
if hasattr(adata.X, 'toarray'):
feature_values = adata[:, feature].X.T.toarray()[0]
else:
feature_values = adata[:, feature].X.T[0]
if hasattr(adata_pred.X, 'toarray'):
feature_pred_values = adata_pred[:, feature_pred].X.T.toarray()[0]
else:
feature_pred_values = adata_pred[:, feature_pred].X.T[0]
# Calculate Pearson correlation coefficient
res = stats.pearsonr(feature_values, feature_pred_values)
# Create results dictionary
results = {'Pearson coefficient' : round(res.correlation, ndigits = ndigits),
'p-value' : "{:.3e}".format(res.pvalue)}
if print_res == True:
print(f'Pearson coefficient = {round(res.correlation, ndigits = ndigits)}, p-value = {"{:.3e}".format(res.pvalue)}')
return results
# Function for calculation Spearman correlation coefficient per protein
[docs]
def spearman_coef(
adata,
adata_pred,
feature,
feature_pred,
ndigits = 3,
print_res = False
):
'''
Compute Spearman correlation coefficient of predicted feature.
Varies between -1 and +1.
Parameters
----------
adata: AnnData
Annotated data matrix with features. Test dataset not used for model tuning or training.
adata_pred: AnnData
Annotated data matrix with predicted features.
feature: str
Name of the feature in adata.
feature_pred: str
Name of the predicted feature in adata_pred.
ndigits: int (default: 3)
Round a number to a given precision in decimal digits.
print_res: bool (default: False)
Print results or not.
Returns dictionary with Spearman correlation coefficient and p-value.
Values close to 1 indicate strong positive correlation, and values close to -1 indicate strong negative correlation.
'''
# Get feature and predicted feature values
if hasattr(adata.X, 'toarray'):
feature_values = adata[:, feature].X.T.toarray()[0]
else:
feature_values = adata[:, feature].X.T[0]
if hasattr(adata_pred.X, 'toarray'):
feature_pred_values = adata_pred[:, feature_pred].X.T.toarray()[0]
else:
feature_pred_values = adata_pred[:, feature_pred].X.T[0]
# Calculate Spearman correlation coefficient
res = stats.spearmanr(feature_values, feature_pred_values)
# Create results dictionary
results = {'Spearman coefficient' : round(res.correlation, ndigits = ndigits),
'p-value' : "{:.3e}".format(res.pvalue)}
if print_res == True:
print(f'Spearman coefficient = {round(res.correlation, ndigits = ndigits)}, p-value = {"{:.3e}".format(res.pvalue)}')
return results
# Function for calculation Spearman correlation coefficient per protein
[docs]
def kendalltau_coef(
adata,
adata_pred,
feature,
feature_pred,
ndigits = 3,
print_res = False
):
'''
Compute Kendall’s tau, a correlation measure of predicted feature.
Varies between -1 and +1.
Parameters
----------
adata: AnnData
Annotated data matrix with features. Test dataset not used for model tuning or training.
adata_pred: AnnData
Annotated data matrix with predicted features.
feature: str
Name of the feature in adata.
feature_pred: str
Name of the predicted feature in adata_pred.
ndigits: int (default: 3)
Round a number to a given precision in decimal digits.
print_res: bool (default: False)
Print results or not.
Returns dictionary with Kendall’s tau coefficient and p-value.
Values close to 1 indicate strong agreement, and values close to -1 indicate strong disagreement
'''
# Get feature and predicted feature values
if hasattr(adata.X, 'toarray'):
feature_values = adata[:, feature].X.T.toarray()[0]
else:
feature_values = adata[:, feature].X.T[0]
if hasattr(adata_pred.X, 'toarray'):
feature_pred_values = adata_pred[:, feature_pred].X.T.toarray()[0]
else:
feature_pred_values = adata_pred[:, feature_pred].X.T[0]
# Calculate Pearson correlation coefficient
res = stats.kendalltau(feature_values, feature_pred_values)
# Create results dictionary
results = {'Kendall Tau' : round(res.correlation, ndigits = ndigits),
'p-value' : "{:.3e}".format(res.pvalue)}
if print_res == True:
print(f'Kendall Tau = {round(res.correlation, ndigits = ndigits)}, p-value = {"{:.3e}".format(res.pvalue)}')
return results
# Function for count cell types in samples of integrated dataset
[docs]
def cell_counter(
adata,
sample = None,
celltype = None
):
'''
Count cell types in samples. Usefull for integrated/concatenated dataset.
Parameters
----------
adata : AnnData
Annotated data matrix.
sample : str, (default: None)
Samples names key in adata.obs dataframe.
celltype : str, (default: None)
Level of cell annotation to show metrics. Key in adata.obs dataframe.
'''
# Create dataframe to store samples cell types
df = pd.DataFrame()
# Add samples to dataframe
for i in list(np.unique(adata.obs[sample])):
adata_temp = adata[adata.obs[sample] == i]
df_temp = pd.DataFrame(adata_temp.obs[celltype].value_counts()).rename(columns={'count' : i})
df = pd.concat([df, df_temp[i]], axis = 1, join='outer')
del adata_temp, df_temp
return df
# Function to get explanations
[docs]
def explain(
adata,
layer = None,
celltype = None,
path_model = '',
num_cells = 100,
random_state = 0,
max_evals = 2000,
batch_size = 256,
prefix = 'pred_',
device = 'auto',
verbose = True
):
'''
Identify the genes that are most important for determining cell type using a model.
Parameters
----------
adata: AnnData
Annotated data matrix.
layer: str (default: None)
If specified, use adata.layers[layer] for expression values instead of adata.X.
celltype: str (default: None)
Specific cell type in annotation to calculate explanations.
path_model: str, path object
Path to the folder containing the trained scAdam model.
num_cells: int (default: 100)
Number of cells to make explanations.
Increasing the number of cells will lead to an increase in computation time.
random_state: int (default: 0)
Controls the random selection of cells from dataset and explainer for reproducibility.
Pass an int for reproducible output across multiple function calls.
max_evals: int (default: 2000)
The max_evals parameter in SHAP is a tunable setting that significantly affects both
the accuracy and computational efficiency of SHAP value calculations.
By adjusting this parameter, you can balance between obtaining detailed explanations
and managing computational resources effectively.
For a larger number of genes, it is necessary to increase max_evals.
prefix: str (default: 'pred_')
Prefix of predicted cell type columns in adata.obs.
batch_size: int, (default: 256)
Number of examples per batch.
device: str (default: 'auto')
Type of device to use in training model ('cpu', 'cuda'). Set 'auto' for automatic selection.
verbose: bool (default: True)
Show progress bar for each epoch during training.
Returns
-------
Explanations of specific cell type.
'''
if os.path.exists(os.path.join(path_model, 'model_v2.pth')):
model = scadam.load_model(path_model, device = 'auto', verbose=verbose)
# Device selection
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device
# Check celltype key in model cell type keys
lst_celltypes = []
for celltype_key in model.celltype_keys:
lst_celltypes += model.celltype_encoders[celltype_key]['celltype_encoder'].classes_.tolist()
if celltype not in lst_celltypes:
raise ValueError('Wrong name of cell type! There is no such cell type name in any prediction level.')
if not hasattr(model, 'var_names') or model.var_names is None:
raise ValueError("Model does not have var_names attribute. Train the model first.")
if not hasattr(model, 'celltype_keys') or model.celltype_keys is None:
raise ValueError("Model does not have celltype_keys attribute. Train the model first.")
# Get feature names from model
model_var_names = model.var_names
# Align features: create matrix with genes in the same order as training
n_model_genes = len(model_var_names)
# Select cell type from a predictions
for celltype_key in list(reversed(model.celltype_keys)):
if celltype in model.celltype_encoders[celltype_key]['celltype_encoder'].classes_:
celltype_key_used = celltype_key
adata_cell_type = adata[adata.obs[f"{prefix}{celltype_key}"] == celltype].copy()
if verbose:
print(f"Cell type '{celltype}' was successfully selected from '{prefix}{celltype_key}'")
break
adata_cell_type = get_frac(adata_cell_type, fraction=num_cells, random_state=random_state)
# Get gene expression data
if layer is not None and layer in adata_cell_type.layers:
X = adata_cell_type.layers[layer]
else:
X = adata_cell_type.X
# Convert sparse to dense if needed
if hasattr(X, 'toarray'):
X = X.toarray()
# Get feature names from new data
new_var_names = adata_cell_type.var_names.tolist()
# Align features: create matrix with genes in the same order as training
X_aligned = np.zeros((num_cells, n_model_genes), dtype=np.float32)
# Find intersection and assign values
matched_features = 0
for i, gene in enumerate(model_var_names):
if gene in new_var_names:
gene_idx = new_var_names.index(gene)
X_aligned[:, i] = X[:, gene_idx]
matched_features += 1
# Normalize the same way as during training
X_aligned = np.maximum(X_aligned, 0) # Remove negative values
X_norm = (X_aligned - X_aligned.min(0)) / (np.ptp(X_aligned, axis=0) + 1e-10)
# Check for NaN/Inf
if np.isnan(X_norm).any() or np.isinf(X_norm).any():
warnings.warn("NaN or Inf found in normalized data. Replacing with zeros.")
X_norm = np.nan_to_num(X_norm, nan=0.0, posinf=0.0, neginf=0.0)
# Create explainer function for loaded model
def explainer_model(
X_norm,
model = model
):
X_tensor = torch.FloatTensor(X_norm)
# Create dataloader
dataset = torch.utils.data.TensorDataset(X_tensor)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False
)
predictions = {key: [] for key in model.celltype_keys}
with torch.no_grad():
for (batch_x,) in loader:
batch_x = batch_x.to(device)
outputs = model(batch_x)
for level_idx, key in enumerate(model.celltype_keys):
level_key = f'level_{level_idx}'
preds = outputs[level_key]['logits'].argmax(dim=1).cpu().numpy()
predictions[key].append(preds)
# Concatenate predictions
for key in model.celltype_keys:
predictions[key] = np.concatenate(predictions[key])
# Decode labels back to original names
encoder = model.celltype_encoders[key]['celltype_encoder']
arr = predictions[celltype_key_used].astype('float64')
return arr
# Create explainer using explainer function
explainer = shap.Explainer(
explainer_model,
masker = X_norm,
feature_names=model_var_names,
seed=random_state
)
# Compute SHAP values
explanations = explainer(
X_norm,
max_evals=max_evals
)
if verbose:
print(f'The explanations for "{celltype}" have been completed')
return explanations
# Function to get gene importance dataframe from explanations
[docs]
def feature_importance_df(
explanations
):
'''
Get dataframe with gene importances for specific cell type.
Parameters
----------
explanations : explanations
Output of function explain().
path_model : str, path object
Path to the folder containing the trained scAdam model.
returns dataframe of gene importances for prediction of specfic cell type.
'''
# Get DataFrame with gene importances for specific cell type
shap_values = explanations.values
vals = np.abs(shap_values).mean(0)
feature_importance = pd.DataFrame(list(zip(explanations.feature_names, vals)), columns=['gene_name', 'gene_importance'])
feature_importance.sort_values(by = ['gene_importance'], ascending=False, inplace=True)
return feature_importance
# Function to get fraction of AnnData
[docs]
def get_frac(
adata = None,
path = None,
path_save = None,
stratify = None,
fraction = 0.1,
shuffle = True,
random_state = 0
):
'''
Get fraction of AnnData object. Specify AnnData object OR path to AnnData.
The function returns a portion of the AnnData object while maintaining the ratio of cell types.
Parameters
----------
adata: AnnData
Annotated data matrix. If None uses path to anndata (e.g., "/Data/adata.h5ad").
path: str, path object (default: None)
Path to the AnnData object if AnnData is not loaded into RAM.
path_save: str, path object (default: None)
Path to save fraction of AnnData
stratify: str (default: None)
Key in adata.obs dataframe.
If specified, ensures the same key-based cell ratio as in the original adata.
fraction: float or int (default: 0.1)
If a float value is specified, it must be between 0.0 and 1.0 and represent the fraction of the dataset to be included in adata_fraction.
If an integer is specified, it must be less than the number of cells in the dataset. This number of cells will be randomly allocated from adata.
shuffle: bool (default: True)
Whether or not to shuffle the data before subsetting.
If shuffle = False, then stratify is not used to maintain the same ratio.
random_state: int (default: 0)
Controls the data shuffling and splitting.
Pass an int for reproducible output across multiple function calls.
Returns
-------
A fraction of the AnnData object while maintaining the same ratio of cell types (if stratify is specified).
This part of the AnnData object can also be saved as adata_fraction.h5ad.
'''
# Get meta data from AnnData
if adata is not None:
obs_names = adata.obs_names
elif path is not None:
adata = sc.read_h5ad(path, backed = 'r')
obs_names = adata.obs_names
if stratify not in adata.obs.columns:
assert ValueError(f"{stratify} not found in adata.obs")
# Subset meta data
if stratify is not None:
_, obs_names = train_test_split(
obs_names,
test_size = fraction,
stratify = adata.obs[stratify],
shuffle = shuffle,
random_state = random_state
)
else:
_, obs_names = train_test_split(
obs_names,
test_size = fraction,
stratify = None,
shuffle = shuffle,
random_state = random_state
)
# Get and save adata_fraction
if adata.isbacked:
adata_fraction = adata[obs_names].to_memory()
else:
adata_fraction = adata[obs_names].copy()
if path_save is not None:
adata_fraction.write_h5ad(os.path.join(path_save, 'adata_fraction.h5ad'))
return adata_fraction
# Function to get samples from AnnData
[docs]
def get_samples(
adata = None,
path = None,
path_save = '',
sample_col = None,
samples = None
):
'''
Get samples from AnnData object. Specify path OR AnnData object.
The function returns a new AnnData with selected samples.
Parameters
----------
adata : AnnData
Annotated data matrix.
path : str, path object
Path to the AnnData object if AnnData is not loaded into RAM.
path_save : str, path object
Path to save fraction of AnnData
sample_col : str, (default: None)
Key in adata.obs dataframe with samples names.
samples : list, (default: None)
List of samples names in adata.obs[sample_col].
Returns
-------
New AnnData with some samples from the AnnData object.
This part of the AnnData object can also be saved as adata_samples.h5ad.
'''
# Get meta data from AnnData
if path != None:
adata = sc.read_h5ad(path, backed = 'r')
# Get and save adata_samples
if adata.isbacked:
adata_samples = adata[adata.obs[sample_col].isin(samples)].to_memory()
else:
adata_samples = adata[adata.obs[sample_col].isin(samples)].copy()
if path_save is not None:
adata_samples.write_h5ad(os.path.join(path_save, 'adata_samples.h5ad'))
return adata_samples
# Function to find difference between clusters based on a specific scores
[docs]
def clust_diff(
adata,
groupby = None,
group1 = None,
group2 = None,
score1 = None,
score2 = None,
plot = True,
thresh = 0.02,
fill = True,
alpha = 0.5,
**kwargs
):
'''
Calculates metrics to Integral of absolute density difference and Mutual Information between two clusterings.
Each metric follows the principle that the higher the value, the better the clusters separate the selected scores.
adata : AnnData or MuData
Annotated data or Multimodal data.
groupby : str
The key of the grouping in AnnData.obs or MuData.obs.
group1 : str
Cluster in groupby.
group2 : str
Cluster in groupby.
score1 : str
Score 1 in AnnData.obs or MuData.obs calculated using scanpy.tl.score_genes.
score2 : str
Score 2 in AnnData.obs or MuData.obs calculated using scanpy.tl.score_genes.
plot : bool, (default: True)
Show kernel density estimate plot or not.
thresh : float in [0, 1], (default: 0.02)
Lowest iso-proportion level at which to draw a contour line.
fill : bool or None, (default: True)
If True, fill in the area under univariate density curves or between bivariate contours.
alpha : float or None, (default: 0.5)
Transparency of the rectangle and connector lines.
kwargs
Other keyword arguments are passed to seaborn.kdeplot.
Returns a plot and Integral of absolute density difference and Mutual Information between two clusterings.
'''
# Create DataFrame with groups from groupby
df = adata.obs[adata.obs[groupby].isin([group1, group2])][[groupby, score1, score2]]
df[groupby] = df[groupby].astype(object)
groups = df.groupby(groupby)
group1, group2 = list(groups)[:2]
# Overlap metric
def compute_overlap_metric(x1, y1, x2, y2, grid_size=500):
xmin = min(x1.min(), x2.min())
xmax = max(x1.max(), x2.max())
ymin = min(y1.min(), y2.min())
ymax = max(y1.max(), y2.max())
X, Y = np.mgrid[xmin:xmax:grid_size*1j, ymin:ymax:grid_size*1j]
positions = np.vstack([X.ravel(), Y.ravel()])
#KDE
values1 = np.vstack([x1, y1])
kernel1 = stats.gaussian_kde(values1)
Z1 = np.reshape(kernel1(positions).T, X.shape)
values2 = np.vstack([x2, y2])
kernel2 = stats.gaussian_kde(values2)
Z2 = np.reshape(kernel2(positions).T, X.shape)
# Absolute difference
diff = np.abs(Z1 - Z2)
overlap_metric = np.trapz(np.trapz(diff, axis=1), axis=0)
return overlap_metric
# Mutual information score
def compute_mutual_information(x, y, bins=20):
c_xy = np.histogram2d(x, y, bins)[0]
mi = metrics_.mutual_info_score(None, None, contingency=c_xy)
return mi
# Score1 and score2
x1, y1 = group1[1][[score1, score2]].values.T
x2, y2 = group2[1][[score1, score2]].values.T
metric = compute_overlap_metric(x1, y1, x2, y2)
mi = compute_mutual_information(df[score1], df[score2])
if plot:
plt.title(f"KDE Plot of {score1} vs {score2} \n"
f"Integral of absolute density difference: {metric:.1f}, \nMutual Information: {mi:.3f}", fontsize=12)
plt.xlabel(score1, fontsize=10)
plt.ylabel(score2, fontsize=10)
sns.kdeplot(
data=df,
x=score1,
y=score2,
hue=groupby,
thresh=thresh,
fill=fill,
alpha=alpha,
**kwargs
)
else:
print(f"Integral of absolute density difference: {metric:.1f}, \nMutual Information: {mi:.3f}")
return metric, mi
# Corrected function for plotting integration metrics results
def plot_results_table(
bm,
cmap_scores='bwr',
cmap_metrics='PRGn',
min_max_scale=False,
show=True,
save_dir=None,
dpi=300
):
"""
Plot the benchmarking results (scib_metrics) of data integration.
Parameters
----------
bm: benchmarking results from scib_metrics.
cmap_scores: str (default: 'bwr')
Color map for aggregate scores ('Batch correction', 'Bio conservation', 'Total').
cmap_metrics: str (default: 'PRGn')
Color map for individual metrics ('iLISI', 'KBET', ...).
min_max_scale:
Whether to min max scale the results.
show: bool (default: True)
Whether to show the plot.
save_dir: path (default: None)
The directory to save the plot to. If `None`, the plot is not saved.
dpi: int (default: 300)
Print resolution.
"""
_METRIC_TYPE = 'Metric Type'
num_embeds = len(bm._embedding_obsm_keys)
df = bm.get_results(min_max_scale=min_max_scale)
# Do not want to plot what kind of metric it is
plot_df = df.drop(_METRIC_TYPE, axis=0)
# Sort by total ''
if bm._batch_correction_metrics is not None and bm._bio_conservation_metrics is not None:
sort_col = "Total"
elif bm._batch_correction_metrics is not None:
sort_col = "Batch correction"
else:
sort_col = "Bio conservation"
plot_df = plot_df.sort_values(by=sort_col, ascending=False).astype(np.float64)
plot_df["Method"] = plot_df.index
# Split columns by metric type, using df as it doesn't have the new method col
score_cols = df.columns[df.loc[_METRIC_TYPE] == 'Aggregate score']
other_cols = df.columns[df.loc[_METRIC_TYPE] != 'Aggregate score']
column_definitions = [
ColumnDefinition("Method", width=1.5, textprops={"ha": "left", "weight": "bold"}),
]
# Circles for the metric values
column_definitions += [
ColumnDefinition(
col,
title=col.replace(" ", "\n", 1),
width=1,
textprops={
"ha": "center",
"bbox": {"boxstyle": "circle", "pad": 0.25},
},
cmap=normed_cmap(plot_df[col], cmap=mpl.colormaps.get_cmap(cmap_metrics), num_stds=2.5),
group=df.loc[_METRIC_TYPE, col],
formatter="{:.3f}",
)
for i, col in enumerate(other_cols)
]
# Bars for the aggregate scores
column_definitions += [
ColumnDefinition(
col,
width=1,
title=col.replace(" ", "\n", 1),
plot_fn=bar,
plot_kw={
"cmap": normed_cmap(plot_df[col], cmap=mpl.colormaps.get_cmap(cmap_scores), num_stds=2.5),
"plot_bg_bar": False,
"annotate": True,
"height": 0.9,
"formatter": "{:.3f}",
"xlim": (0, 1),
"textprops": {"fontsize": 9}
},
group=df.loc[_METRIC_TYPE, col],
border="left" if i == 0 else None,
)
for i, col in enumerate(score_cols)
]
# Allow to manipulate text
with mpl.rc_context({"svg.fonttype": "none"}):
fig, ax = plt.subplots(figsize=(len(df.columns) * 1.25, 3 + 0.3 * num_embeds))
tab = Table(
plot_df,
cell_kw={
"linewidth": 0,
"edgecolor": "k",
},
column_definitions=column_definitions,
ax=ax,
row_dividers=True,
footer_divider=True,
textprops={"fontsize": 9, "ha": "center"},
row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 5))},
col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
column_border_kw={"linewidth": 1, "linestyle": "-"},
index_col="Method",
).autoset_fontcolors(colnames=plot_df.columns)
if show:
plt.show()
if save_dir is not None:
fig.savefig(os.path.join(save_dir, "scib_results.png"), bbox_inches ='tight', facecolor=ax.get_facecolor(), dpi=dpi)