Source code for scparadise.scnoah

import warnings
warnings.filterwarnings("ignore")

from imblearn.over_sampling import RandomOverSampler
import sklearn.metrics as metrics_
from imblearn import metrics
from scipy import sparse
from scipy import stats
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import scanpy as sc
import shap
import json
import anndata
import os
from plottable.cmap import normed_cmap
import matplotlib as mpl
from plottable import ColumnDefinition
from plottable import Table
from plottable.plots import bar
from plottable.font import contrasting_font_color
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch
from scparadise import scadam
from torch.utils.data import Dataset, DataLoader, Subset

# Function for generation of synthetic cells for a cell type via PCA-based local interpolation.
def cells_generator(
    X_class,
    num_target,
    max_components = 20,
    random_state = 0
):
    """
    Generate synthetic cells for a cell type via PCA-based local interpolation.

    Parameters
    ----------
    X_class : ndarray, shape (n_cells, n_genes)
        Expression matrix for one cell type.
    num_target : int
        Desired total number of samples after augmentation (including originals).
    max_components: int (default: 20)
        Upper bound for PCA components.
    random_state: int (default: 0)
        Seed for random number generators to ensure reproducibility.

    Returns
    -------
    ndarray
        Extended matrix. If `n_cells >= num_target` or `n_cells < 2`, returns input.
    """
    rng = np.random.default_rng(random_state)
    n_current, n_genes = X_class.shape

    if n_current >= num_target or n_current < 2:
        # Nothing to do (or too few points for interpolation)
        return X_class

    n_new = num_target - n_current

    # Fit PCA on this class
    n_components = min(max_components, n_current - 1, n_genes)
    if n_components < 1:
        return X_class

    pca = PCA(n_components=n_components, random_state=random_state)
    Z = pca.fit_transform(X_class)  # (n_current, n_components)

    Z_new = []
    for _ in range(n_new):
        i, j = rng.choice(n_current, size=2, replace=False)
        zi, zj = Z[i], Z[j]
        alpha = rng.uniform(0.3, 0.7)  # stay somewhere in between
        z_new = alpha * zi + (1.0 - alpha) * zj
        Z_new.append(z_new)

    Z_new = np.vstack(Z_new)  # (n_new, n_components)
    X_new = pca.inverse_transform(Z_new)

    # Clip negatives - scRNA-seq is non-negative
    X_new[X_new < 0] = 0.0

    return np.vstack([X_class, X_new])


# Function for oversample selected cell types
[docs] def oversample( adata, celltype_keys, target_per_class = None, max_oversample_factor = 7.0, min_oversample_cells = 5, random_state = 0 ): """ Oversample some cell types in AnnData object. Returns adata_oversampled with updated matrix and adata_oversampled.obs with given celltypes levels and sample. If you give counts function returns counts. If you give normalized data function returns normalized data. Parameters ---------- adata: AnnData Input dataset to be oversampled. celltype_keys: list List of cell type annotations in adata.obs. Example: ['lineage', 'cell type', 'cell state'] target_per_class: int (default: None) Global target per cell type. If None, computed as average cells per cell type in celltype level with most cell types. max_oversample_factor: float (default: 7.0) Upper bound on how much a small class may be expanded relative to its original size. min_oversample_cells: int (default: 5) Minimal cell type size to allow substantial generation of new cells. random_state: int (default: 0) Seed for random number generators to ensure reproducibility. Returns ------- AnnData New AnnData with oversampled minor cell types. Original and synthetic cells are concatenated. """ rng = np.random.default_rng(random_state) biggest_key = None biggest_key_size = 0 for key in celltype_keys: key_size = adata.obs[key].nunique() if biggest_key_size < key_size: biggest_key = key counts = adata.obs[biggest_key].value_counts() if target_per_class is None: target_per_class = round(len(adata)/adata.obs[biggest_key].nunique()) # Convert data to dense matrix X = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X X_blocks = [X] obs_blocks = [adata.obs.copy()] for celltype, n in counts.items(): if n >= target_per_class: continue # no need to oversample this type idx = np.where(adata.obs[biggest_key].values == celltype)[0] X_class = X[idx] # Compute class-specific target with safety cap # max_oversample_factor - maximum oversample of minor cell types max_allowed = int(max_oversample_factor * n) num_target = min(target_per_class, max_allowed) if n < min_oversample_cells: num_target = min(num_target, n + (min_oversample_cells - n)) if num_target <= n: continue # nothing to do X_extended = cells_generator( X_class=X_class, num_target=num_target, random_state=random_state ) n_current = X_class.shape[0] X_new = X_extended[n_current:] n_new = X_new.shape[0] if n_new == 0: continue # For metadata, reuse obs rows of existing cells (bootstrap indices) base_idx = rng.integers(0, n_current, size=n_new) base_obs = adata.obs.iloc[idx[base_idx]].copy() X_blocks.append(X_new) obs_blocks.append(base_obs) X_oversampled = np.vstack(X_blocks) obs_bal = pd.concat(obs_blocks, axis=0) adata_oversampled = anndata.AnnData(X=sparse.csr_matrix(X_oversampled), obs=obs_bal, var=adata.var.copy()) return adata_oversampled
# Function for undersample specific cell types
[docs] def undersample( adata, celltype_keys, target_per_class = None, min_keep_frac = 0.5, random_state = 0 ): """ Undersample some cell types in AnnData object. Returns subsetted adata_undersampled object. Parameters ---------- adata: AnnData Input dataset to be undersampled. celltype_keys: list List of cell type annotations in adata.obs. Example: ['lineage', 'cell type', 'cell state'] target_per_class: int (default: None) Global target per cell type. If None, computed as average cells per cell type in celltype level with most cell types. min_keep_frac: float (default: 0.5) Lower bound on fraction of original cell type size preserved after undersampling. random_state: int (default: 0) Seed for random number generators to ensure reproducibility. Returns ------- AnnData adata_undersample containing the undersampled cell types. Cell type hierarchy is preserved. """ rng = np.random.default_rng(random_state) biggest_key = None biggest_key_size = 0 for key in celltype_keys: key_size = adata.obs[key].nunique() if biggest_key_size < key_size: biggest_key = key counts = adata.obs[biggest_key].value_counts() if target_per_class is None: target_per_class = round(len(adata)/adata.obs[biggest_key].nunique()) keep_indices_list = [] for celltype, n in counts.items(): idx = np.where(adata.obs[biggest_key].values == celltype)[0] # Minimum acceptable number of cells after undersampling min_keep = int(max(min_keep_frac * n, 1)) # Actual target: no less than min_keep and no more than n target = min(max(target_per_class, min_keep), n) if n > target: chosen = rng.choice(idx, size=target, replace=False) keep_indices_list.append(chosen) else: keep_indices_list.append(idx) keep_indices = np.concatenate(keep_indices_list) keep_indices = np.sort(keep_indices) return adata[keep_indices].copy()
# Function for balance cell types in adata
[docs] def balance( adata, celltype_keys, min_keep_frac = 0.5, max_oversample_factor = 7.0, min_oversample_cells = 5, random_state = 0 ): """ Balance cell types in AnnData object. Returns adata_balanced with updated matrix and adata_balanced.obs with given celltypes levels. If you give counts function returns counts. If you give normalized data function returns normalized data. Parameters ---------- adata: AnnData Input dataset to be balanced. celltype_keys: list List of cell type annotations in adata.obs. Example: ['lineage', 'cell type', 'cell state'] min_keep_frac: float (default: 0.5) Safety lower bound for the fraction of original cells retained in large classes. max_oversample_factor: float (default: 7.0) Upper bound on how much a small class may be expanded relative to its original size. min_oversample_cells: int (default: 5) Minimal cell type size to allow substantial generation of new cells. random_state: int (default: 0) Seed for random number generators to ensure reproducibility. Returns ------- AnnData Balanced dataset with preserved cell type hierarchy and var features. """ biggest_key = None biggest_key_size = 0 for key in celltype_keys: key_size = adata.obs[key].nunique() if biggest_key_size < key_size: biggest_key = key target_per_class = round(len(adata)/adata.obs[biggest_key].nunique()) # Step 1: undersample big classes adata_undersampled = undersample( adata=adata, celltype_keys=celltype_keys, target_per_class=target_per_class, min_keep_frac=min_keep_frac, random_state=random_state, ) # Step 2: oversample small classes adata_balanced = oversample( adata=adata_undersampled, celltype_keys=celltype_keys, target_per_class=target_per_class, max_oversample_factor=max_oversample_factor, min_oversample_cells=min_oversample_cells, random_state=random_state ) return adata_balanced
# Function for creation of full report
[docs] def report_classif_full( adata, celltype = None, pred_celltype = None, save_report = False, report_name = 'report.csv', save_path = '', ndigits = 4 ): ''' Returns metrics (precision, recall (also called sensitivity), specificity, f1-score, geometric mean, and index balanced accuracy of the geometric mean) of predicted cell types. You should use it after prediction on annotated test dataset (shows results of validation). Helps in understanding model quality. adata : AnnData Annotated data matrix. Previously annotated test dataset not used for model tuning or training. celltype : str, , (default: None) Level of cell annotation to show metrics. Key in adata.obs dataframe. pred_celltype : str, , (default: None) Predicted level of cell annotation. Key in adata.obs dataframe. save_report : bool, (default: False) Save report as csv file or not. report_name : str, (default: 'report.csv') Name of a file to save report. save_path : path object Path to a folder to save report. ndigits : int (default: 4) Round a number to a given precision in decimal digits. ''' # Create report with precision, recall/sensitivity, specificity, f1-score, geometric mean, and index balanced accuracy of the geometric mean report = metrics.classification_report_imbalanced(adata.obs[celltype].to_numpy(), adata.obs[pred_celltype].to_numpy(), output_dict = True, digits = ndigits) report = pd.DataFrame(report) del report['avg_pre'], report['avg_rec'], report['avg_spe'], report['avg_f1'], report['avg_geo'], report['avg_iba'], report['total_support'] report = report.transpose() report['sup'] = report['sup'].astype('int') # Rename columns report = report.rename(columns = {'pre': 'precision', 'rec': 'recall/sensitivity', 'spe': 'specificity', 'f1': 'f1-score', 'geo': 'geometric mean', 'iba': 'index balanced accuracy', 'sup': 'number of cells' }) # Calculate balanced accuracy and create list wih it lst_bal_acc = [round(np.mean(report['recall/sensitivity']), ndigits = ndigits)] i = 0 while i < 6: lst_bal_acc.append('') i+=1 # Add avg rows report.loc['macro avg'] = [round(np.mean(report['precision']), ndigits = ndigits), round(np.mean(report['recall/sensitivity']), ndigits = ndigits), round(np.mean(report['specificity']), ndigits = ndigits), round(np.mean(report['f1-score']), ndigits = ndigits), round(np.mean(report['geometric mean']), ndigits = ndigits), round(np.mean(report['index balanced accuracy']), ndigits = ndigits), ''] report.loc['weighted avg'] = [round(np.sum(report['precision'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits), round(np.sum(report['recall/sensitivity'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits), round(np.sum(report['specificity'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits), round(np.sum(report['f1-score'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits), round(np.sum(report['geometric mean'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits), round(np.sum(report['index balanced accuracy'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits),''] # Round data report['precision'] = round(report['precision'], ndigits = ndigits) report['recall/sensitivity'] = round(report['recall/sensitivity'], ndigits = ndigits) report['specificity'] = round(report['specificity'], ndigits = ndigits) report['f1-score'] = round(report['f1-score'], ndigits = ndigits) report['geometric mean'] = round(report['geometric mean'], ndigits = ndigits) report['index balanced accuracy'] = round(report['index balanced accuracy'], ndigits = ndigits) # Add accuracy lst_acc = [round(metrics_.accuracy_score(adata.obs[celltype].to_numpy(), adata.obs[pred_celltype].to_numpy()), ndigits = ndigits)] i = 0 while i < 6: lst_acc.append('') i+=1 report.loc['Accuracy'] = lst_acc report.loc['Balanced accuracy'] = lst_bal_acc del lst_acc, lst_bal_acc # Save report to .csv if save_report: report.to_csv(os.path.join(save_path, report_name).replace("\\","/")) print('Successfully saved report') print() return report
# Function to find predition status (correct or incorrect prediction)
[docs] def pred_status( adata, celltype = None, pred_celltype = None, key_added = 'pred_status' ): ''' Find correct and incorrect predictions. Returns prediction status in adata.obs. Parameters ---------- adata : AnnData Annotated data matrix. Function uses adata.X for oversample. celltype : str, (default: None) Cell annotation. Key in adata.obs dataframe. pred_celltype : str, (default: None) Predicted cell annotation. Key in adata.obs dataframe. key_added : str, (default: 'pred_status') Key to add in adata.obs ''' adata.obs[key_added] = adata.obs[celltype] == adata.obs[pred_celltype] adata.obs[key_added] = adata.obs[key_added].astype('str') adata.obs[key_added] = adata.obs[key_added].replace('True', 'correct') adata.obs[key_added] = adata.obs[key_added].replace('False', 'incorrect') adata.uns[key_added + '_colors'] = ['#3A3AFF', '#FF3737']
# Function for visualization of cell type prediction using confusion matrix
[docs] def conf_matrix( adata, celltype = None, pred_celltype = None, fmt = ".2f", annot = True, cmap = "Blues", ndigits_metrics = 3, grid = False, **kwargs ): ''' Compute confusion matrix to evaluate the accuracy of a classification. Parameters ---------- adata : AnnData Annotated data matrix. Function uses adata.X for oversample. celltype : str, (default: None) Cell annotation. Key in adata.obs dataframe. pred_celltype : str, (default: None) Predicted cell annotation. Key in adata.obs dataframe. fmt : str, optional String formatting code to use when adding annotations. annot : bool or rectangular dataset, optional If True, write the data value in each cell. If an array-like with the same shape as ``data``, then use this to annotate the heatmap instead of the data. Note that DataFrames will match on position, not index. cmap : matplotlib colormap name or object, or list of colors, optional The mapping from data values to color space. If not provided, the default will depend on whether ``center`` is set. ndigits_metrics : int (default: 3) Round a n accuracy and balanced accuracy to a given precision in decimal digits. grid : bool (default: False) Show or hide grid lines. **kwargs: other keyword arguments All other keyword arguments are passed to sns.heatmap ''' # Create confusion matrix cm = metrics_.confusion_matrix(adata.obs[celltype], adata.obs[pred_celltype]) # Calculate accuracy accuracy = round(np.trace(cm) / float(np.sum(cm)), ndigits = ndigits_metrics) accuracy = f"\n\nAccuracy={accuracy}" # Calculate balanced accuracy report = metrics.classification_report_imbalanced(adata.obs[celltype].to_numpy(), adata.obs[pred_celltype].to_numpy(), output_dict = True) report = pd.DataFrame(report) del report['avg_pre'], report['avg_rec'], report['avg_spe'], report['avg_f1'], report['avg_geo'], report['avg_iba'], report['total_support'] report = report.transpose() bal_accuracy = round(np.mean(report['rec']), ndigits = ndigits_metrics) bal_accuracy = f"\n\nBalanced accuracy={bal_accuracy}" del report # Convert confusion matrix to dataframe celltypes = np.unique(adata.obs[celltype]) cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] cm = pd.DataFrame(cm, index=celltypes[:cm.shape[0]], columns=celltypes[:cm.shape[1]]) plt.grid(grid) sns.heatmap(cm, annot=annot, fmt=fmt, cmap=cmap, **kwargs) plt.xlabel('Predicted' + accuracy + bal_accuracy) plt.ylabel("Observed")
# Function for calculation of sensitivity and specificity of trained model
[docs] def report_classif_sens_spec( adata, celltype = None, pred_celltype = None, save_report = False, report_name = 'report_sens_spec.csv', save_path = '', ndigits = 3 ): ''' Returns specificity and recall (also called sensitivity) metrics of predicted cell types. You should use it after prediction on annotated test dataset (shows results of validation). Helps in understanding model quality. Parameters ---------- adata : AnnData Annotated data matrix. Previously annotated test dataset not used for model tuning or training. celltype : str, , (default: None) Level of cell annotation to show metrics. Key in adata.obs dataframe. pred_celltype : str, , (default: None) Predicted level of cell annotation. Key in adata.obs dataframe. save_report : bool, (default: False) Save report as csv file or not. report_name : str, (default: 'report_sens_spec.csv') Name of a file to save report. save_path : path object Path to a folder to save report. ndigits : int (default: 3) Round a number to a given precision in decimal digits. ''' # Create report with sensitivity and specificity report = metrics.sensitivity_specificity_support(adata.obs[celltype].to_numpy(), adata.obs[pred_celltype].to_numpy(), ) # Add cell types names and column names report = pd.DataFrame(report, columns=np.unique(adata.obs[celltype]).tolist(), index = ['recall/sensitivity', 'specificity', 'number of cells']).transpose() # Round data report['recall/sensitivity'] = round(report['recall/sensitivity'], ndigits = ndigits) report['specificity'] = round(report['specificity'], ndigits = ndigits) report['number of cells'] = report['number of cells'].astype('int') # Add avg rows report.loc['macro avg'] = [round(np.mean(report['recall/sensitivity']), ndigits = ndigits), round(np.mean(report['specificity']), ndigits = ndigits), ''] report.loc['weighted avg'] = [round(np.sum(report['recall/sensitivity'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits), round(np.sum(report['specificity'][:-1] * report['number of cells'])/np.sum(report['number of cells'][:-1]), ndigits = ndigits), ''] # Save report to .csv if save_report: report.to_csv(os.path.join(save_path, report_name).replace("\\","/")) print('Successfully saved report') print() return report
# Function for calculation of regression metrics of trained model
[docs] def report_reg( adata_prot, adata_pred_prot, multioutput = 'uniform_average', save_report = False, report_name = 'report_regression.csv', save_path = '', ndigits = 3 ): ''' Returns multiple metrics of cell surface proteins prediction. Root mean squared error (RMSE), mean absolute error (MeanAE), median absolute error (MedianAE) : lower value - better prediction. Coefficient of determination (R² score), explained variance score (EVS) : higher value - better prediction Parameters ---------- adata_prot : AnnData Annotated data matrix with proteins. Test dataset not used for model tuning or training. adata_pred_prot : AnnData Annotated data matrix with predicted proteins. multioutput : {‘raw_values’, ‘uniform_average’} or array-like of shape (n_outputs,), (default=’uniform_average’) Defines aggregating of multiple output values. Array-like value defines weights used to average errors. ‘raw_values’ : Returns a full set of errors in case of multioutput input. ‘uniform_average’ : Errors of all outputs are averaged with uniform weight. save_report : bool, (default: False) Save report as csv file or not. report_name : str, (default: 'report_sens_spec.csv') Name of a file to save report. save_path : path object Path to a folder to save report. ndigits : int (default: 3) Round a number to a given precision in decimal digits. ''' # Create DataFrames of predicted and real data data_adt = pd.DataFrame(data = adata_prot.X.toarray(), columns = adata_prot.var_names) data_pred_adt = pd.DataFrame(data = adata_pred_prot.X.toarray(), columns = adata_pred_prot.var_names) # Root mean squared error report_RMSE = metrics_.root_mean_squared_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = multioutput) # Mean absolute error report_MeanAE = metrics_.mean_absolute_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = multioutput) # Median absolute error report_MedianAE = metrics_.median_absolute_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = multioutput) # Explained variance score report_EVS = metrics_.explained_variance_score(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = multioutput) # Coefficient of determination (R² score) report_r2_score = metrics_.r2_score(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = multioutput) # Create report report = pd.DataFrame() # Round data report['EVS'] = [round(report_EVS, ndigits=ndigits)] report['r2_score'] = [round(report_r2_score, ndigits=ndigits)] report['RMSE'] = [round(report_RMSE, ndigits=ndigits)] report['MedianAE'] = [round(report_MedianAE, ndigits=ndigits)] report['MeanAE'] = [round(report_MeanAE, ndigits=ndigits)] report.loc['EVS/r2_score'] = ['higher value - better prediction', '', '', '', ''] report.loc['RMSE/MedianAE/MeanAE'] = ['lower value - better prediction', '', '', '', ''] report = report.rename(index = {0: "score"}) # Save report to .csv if save_report: report.to_csv(os.path.join(save_path, report_name).replace("\\","/")) print('Successfully saved report') print() return report
# Function for defining regression status
[docs] def regres_status( adata_prot, adata_pred_prot, metric = 'RMSE' ): ''' Compute regression status of cells to visualize on UMAP. Parameters ---------- adata_prot : AnnData Annotated data matrix with proteins. Test dataset not used for model tuning or training. adata_pred_prot : AnnData Annotated data matrix with predicted proteins. metric : str (default: 'RMSE') Metric used for regression status calculation. Available metrics: RMSE, MeanAE, MedianAE, EVS, r2_score. Root mean squared error (RMSE), mean absolute error (MeanAE), median absolute error (MedianAE) : lower value - better prediction. Coefficient of determination (R² score), explained variance score (EVS) : higher value - better prediction ''' # Create DataFrames of predicted and real data data_adt = pd.DataFrame(data = adata_prot.X.toarray().transpose(), columns = adata_prot.obs_names) data_pred_adt = pd.DataFrame(data = adata_pred_prot.X.toarray().transpose(), columns = adata_pred_prot.obs_names) # Root mean squared error if metric == 'RMSE': report = metrics_.root_mean_squared_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = 'raw_values') # Mean absolute error elif metric == 'MeanAE': report = metrics_.mean_absolute_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = 'raw_values') # Median absolute error elif metric == 'MedianAE': report = metrics_.median_absolute_error(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = 'raw_values') # Explained variance score elif metric == 'EVS': report = metrics_.explained_variance_score(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = 'raw_values') # Coefficient of determination (R² score) elif metric == 'r2_score': report = metrics_.r2_score(y_true = data_adt.values, y_pred = data_pred_adt.values, multioutput = 'raw_values') del data_adt, data_pred_adt # Add metric to predicted adata adata_pred_prot.obs['regres_status_' + metric] = report.copy() del report
# Function for calculation Pearson correlation coefficient per protein
[docs] def pearson_coef( adata, adata_pred, feature, feature_pred, ndigits = 3, print_res = False ): ''' Compute Pearson correlation coefficient of predicted feature. Varies between -1 and +1. Parameters ---------- adata: AnnData Annotated data matrix with features. Test dataset not used for model tuning or training. adata_pred: AnnData Annotated data matrix with predicted features. feature: str Name of the feature in adata. feature_pred: str Name of the predicted feature in adata_pred. ndigits: int (default: 3) Round a number to a given precision in decimal digits. print_res: bool (default: False) Print results or not. Returns dictionary with Pearson correlation coefficient and p-value. Values close to 1 indicate strong positive correlation, and values close to -1 indicate strong negative correlation. ''' # Get feature and predicted feature values if hasattr(adata.X, 'toarray'): feature_values = adata[:, feature].X.T.toarray()[0] else: feature_values = adata[:, feature].X.T[0] if hasattr(adata_pred.X, 'toarray'): feature_pred_values = adata_pred[:, feature_pred].X.T.toarray()[0] else: feature_pred_values = adata_pred[:, feature_pred].X.T[0] # Calculate Pearson correlation coefficient res = stats.pearsonr(feature_values, feature_pred_values) # Create results dictionary results = {'Pearson coefficient' : round(res.correlation, ndigits = ndigits), 'p-value' : "{:.3e}".format(res.pvalue)} if print_res == True: print(f'Pearson coefficient = {round(res.correlation, ndigits = ndigits)}, p-value = {"{:.3e}".format(res.pvalue)}') return results
# Function for calculation Spearman correlation coefficient per protein
[docs] def spearman_coef( adata, adata_pred, feature, feature_pred, ndigits = 3, print_res = False ): ''' Compute Spearman correlation coefficient of predicted feature. Varies between -1 and +1. Parameters ---------- adata: AnnData Annotated data matrix with features. Test dataset not used for model tuning or training. adata_pred: AnnData Annotated data matrix with predicted features. feature: str Name of the feature in adata. feature_pred: str Name of the predicted feature in adata_pred. ndigits: int (default: 3) Round a number to a given precision in decimal digits. print_res: bool (default: False) Print results or not. Returns dictionary with Spearman correlation coefficient and p-value. Values close to 1 indicate strong positive correlation, and values close to -1 indicate strong negative correlation. ''' # Get feature and predicted feature values if hasattr(adata.X, 'toarray'): feature_values = adata[:, feature].X.T.toarray()[0] else: feature_values = adata[:, feature].X.T[0] if hasattr(adata_pred.X, 'toarray'): feature_pred_values = adata_pred[:, feature_pred].X.T.toarray()[0] else: feature_pred_values = adata_pred[:, feature_pred].X.T[0] # Calculate Spearman correlation coefficient res = stats.spearmanr(feature_values, feature_pred_values) # Create results dictionary results = {'Spearman coefficient' : round(res.correlation, ndigits = ndigits), 'p-value' : "{:.3e}".format(res.pvalue)} if print_res == True: print(f'Spearman coefficient = {round(res.correlation, ndigits = ndigits)}, p-value = {"{:.3e}".format(res.pvalue)}') return results
# Function for calculation Spearman correlation coefficient per protein
[docs] def kendalltau_coef( adata, adata_pred, feature, feature_pred, ndigits = 3, print_res = False ): ''' Compute Kendall’s tau, a correlation measure of predicted feature. Varies between -1 and +1. Parameters ---------- adata: AnnData Annotated data matrix with features. Test dataset not used for model tuning or training. adata_pred: AnnData Annotated data matrix with predicted features. feature: str Name of the feature in adata. feature_pred: str Name of the predicted feature in adata_pred. ndigits: int (default: 3) Round a number to a given precision in decimal digits. print_res: bool (default: False) Print results or not. Returns dictionary with Kendall’s tau coefficient and p-value. Values close to 1 indicate strong agreement, and values close to -1 indicate strong disagreement ''' # Get feature and predicted feature values if hasattr(adata.X, 'toarray'): feature_values = adata[:, feature].X.T.toarray()[0] else: feature_values = adata[:, feature].X.T[0] if hasattr(adata_pred.X, 'toarray'): feature_pred_values = adata_pred[:, feature_pred].X.T.toarray()[0] else: feature_pred_values = adata_pred[:, feature_pred].X.T[0] # Calculate Pearson correlation coefficient res = stats.kendalltau(feature_values, feature_pred_values) # Create results dictionary results = {'Kendall Tau' : round(res.correlation, ndigits = ndigits), 'p-value' : "{:.3e}".format(res.pvalue)} if print_res == True: print(f'Kendall Tau = {round(res.correlation, ndigits = ndigits)}, p-value = {"{:.3e}".format(res.pvalue)}') return results
# Function for count cell types in samples of integrated dataset
[docs] def cell_counter( adata, sample = None, celltype = None ): ''' Count cell types in samples. Usefull for integrated/concatenated dataset. Parameters ---------- adata : AnnData Annotated data matrix. sample : str, (default: None) Samples names key in adata.obs dataframe. celltype : str, (default: None) Level of cell annotation to show metrics. Key in adata.obs dataframe. ''' # Create dataframe to store samples cell types df = pd.DataFrame() # Add samples to dataframe for i in list(np.unique(adata.obs[sample])): adata_temp = adata[adata.obs[sample] == i] df_temp = pd.DataFrame(adata_temp.obs[celltype].value_counts()).rename(columns={'count' : i}) df = pd.concat([df, df_temp[i]], axis = 1, join='outer') del adata_temp, df_temp return df
# Function to get explanations
[docs] def explain( adata, layer = None, celltype = None, path_model = '', num_cells = 100, random_state = 0, max_evals = 2000, batch_size = 256, prefix = 'pred_', device = 'auto', verbose = True ): ''' Identify the genes that are most important for determining cell type using a model. Parameters ---------- adata: AnnData Annotated data matrix. layer: str (default: None) If specified, use adata.layers[layer] for expression values instead of adata.X. celltype: str (default: None) Specific cell type in annotation to calculate explanations. path_model: str, path object Path to the folder containing the trained scAdam model. num_cells: int (default: 100) Number of cells to make explanations. Increasing the number of cells will lead to an increase in computation time. random_state: int (default: 0) Controls the random selection of cells from dataset and explainer for reproducibility. Pass an int for reproducible output across multiple function calls. max_evals: int (default: 2000) The max_evals parameter in SHAP is a tunable setting that significantly affects both the accuracy and computational efficiency of SHAP value calculations. By adjusting this parameter, you can balance between obtaining detailed explanations and managing computational resources effectively. For a larger number of genes, it is necessary to increase max_evals. prefix: str (default: 'pred_') Prefix of predicted cell type columns in adata.obs. batch_size: int, (default: 256) Number of examples per batch. device: str (default: 'auto') Type of device to use in training model ('cpu', 'cuda'). Set 'auto' for automatic selection. verbose: bool (default: True) Show progress bar for each epoch during training. Returns ------- Explanations of specific cell type. ''' if os.path.exists(os.path.join(path_model, 'model_v2.pth')): model = scadam.load_model(path_model, device = 'auto', verbose=verbose) # Device selection device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device=='auto' else device # Check celltype key in model cell type keys lst_celltypes = [] for celltype_key in model.celltype_keys: lst_celltypes += model.celltype_encoders[celltype_key]['celltype_encoder'].classes_.tolist() if celltype not in lst_celltypes: raise ValueError('Wrong name of cell type! There is no such cell type name in any prediction level.') if not hasattr(model, 'var_names') or model.var_names is None: raise ValueError("Model does not have var_names attribute. Train the model first.") if not hasattr(model, 'celltype_keys') or model.celltype_keys is None: raise ValueError("Model does not have celltype_keys attribute. Train the model first.") # Get feature names from model model_var_names = model.var_names # Align features: create matrix with genes in the same order as training n_model_genes = len(model_var_names) # Select cell type from a predictions for celltype_key in list(reversed(model.celltype_keys)): if celltype in model.celltype_encoders[celltype_key]['celltype_encoder'].classes_: celltype_key_used = celltype_key adata_cell_type = adata[adata.obs[f"{prefix}{celltype_key}"] == celltype].copy() if verbose: print(f"Cell type '{celltype}' was successfully selected from '{prefix}{celltype_key}'") break adata_cell_type = get_frac(adata_cell_type, fraction=num_cells, random_state=random_state) # Get gene expression data if layer is not None and layer in adata_cell_type.layers: X = adata_cell_type.layers[layer] else: X = adata_cell_type.X # Convert sparse to dense if needed if hasattr(X, 'toarray'): X = X.toarray() # Get feature names from new data new_var_names = adata_cell_type.var_names.tolist() # Align features: create matrix with genes in the same order as training X_aligned = np.zeros((num_cells, n_model_genes), dtype=np.float32) # Find intersection and assign values matched_features = 0 for i, gene in enumerate(model_var_names): if gene in new_var_names: gene_idx = new_var_names.index(gene) X_aligned[:, i] = X[:, gene_idx] matched_features += 1 # Normalize the same way as during training X_aligned = np.maximum(X_aligned, 0) # Remove negative values X_norm = (X_aligned - X_aligned.min(0)) / (np.ptp(X_aligned, axis=0) + 1e-10) # Check for NaN/Inf if np.isnan(X_norm).any() or np.isinf(X_norm).any(): warnings.warn("NaN or Inf found in normalized data. Replacing with zeros.") X_norm = np.nan_to_num(X_norm, nan=0.0, posinf=0.0, neginf=0.0) # Create explainer function for loaded model def explainer_model( X_norm, model = model ): X_tensor = torch.FloatTensor(X_norm) # Create dataloader dataset = torch.utils.data.TensorDataset(X_tensor) loader = DataLoader( dataset, batch_size=batch_size, shuffle=False ) predictions = {key: [] for key in model.celltype_keys} with torch.no_grad(): for (batch_x,) in loader: batch_x = batch_x.to(device) outputs = model(batch_x) for level_idx, key in enumerate(model.celltype_keys): level_key = f'level_{level_idx}' preds = outputs[level_key]['logits'].argmax(dim=1).cpu().numpy() predictions[key].append(preds) # Concatenate predictions for key in model.celltype_keys: predictions[key] = np.concatenate(predictions[key]) # Decode labels back to original names encoder = model.celltype_encoders[key]['celltype_encoder'] arr = predictions[celltype_key_used].astype('float64') return arr # Create explainer using explainer function explainer = shap.Explainer( explainer_model, masker = X_norm, feature_names=model_var_names, seed=random_state ) # Compute SHAP values explanations = explainer( X_norm, max_evals=max_evals ) if verbose: print(f'The explanations for "{celltype}" have been completed') return explanations
# Function to get gene importance dataframe from explanations
[docs] def feature_importance_df( explanations ): ''' Get dataframe with gene importances for specific cell type. Parameters ---------- explanations : explanations Output of function explain(). path_model : str, path object Path to the folder containing the trained scAdam model. returns dataframe of gene importances for prediction of specfic cell type. ''' # Get DataFrame with gene importances for specific cell type shap_values = explanations.values vals = np.abs(shap_values).mean(0) feature_importance = pd.DataFrame(list(zip(explanations.feature_names, vals)), columns=['gene_name', 'gene_importance']) feature_importance.sort_values(by = ['gene_importance'], ascending=False, inplace=True) return feature_importance
# Function to get fraction of AnnData
[docs] def get_frac( adata = None, path = None, path_save = None, stratify = None, fraction = 0.1, shuffle = True, random_state = 0 ): ''' Get fraction of AnnData object. Specify AnnData object OR path to AnnData. The function returns a portion of the AnnData object while maintaining the ratio of cell types. Parameters ---------- adata: AnnData Annotated data matrix. If None uses path to anndata (e.g., "/Data/adata.h5ad"). path: str, path object (default: None) Path to the AnnData object if AnnData is not loaded into RAM. path_save: str, path object (default: None) Path to save fraction of AnnData stratify: str (default: None) Key in adata.obs dataframe. If specified, ensures the same key-based cell ratio as in the original adata. fraction: float or int (default: 0.1) If a float value is specified, it must be between 0.0 and 1.0 and represent the fraction of the dataset to be included in adata_fraction. If an integer is specified, it must be less than the number of cells in the dataset. This number of cells will be randomly allocated from adata. shuffle: bool (default: True) Whether or not to shuffle the data before subsetting. If shuffle = False, then stratify is not used to maintain the same ratio. random_state: int (default: 0) Controls the data shuffling and splitting. Pass an int for reproducible output across multiple function calls. Returns ------- A fraction of the AnnData object while maintaining the same ratio of cell types (if stratify is specified). This part of the AnnData object can also be saved as adata_fraction.h5ad. ''' # Get meta data from AnnData if adata is not None: obs_names = adata.obs_names elif path is not None: adata = sc.read_h5ad(path, backed = 'r') obs_names = adata.obs_names if stratify not in adata.obs.columns: assert ValueError(f"{stratify} not found in adata.obs") # Subset meta data if stratify is not None: _, obs_names = train_test_split( obs_names, test_size = fraction, stratify = adata.obs[stratify], shuffle = shuffle, random_state = random_state ) else: _, obs_names = train_test_split( obs_names, test_size = fraction, stratify = None, shuffle = shuffle, random_state = random_state ) # Get and save adata_fraction if adata.isbacked: adata_fraction = adata[obs_names].to_memory() else: adata_fraction = adata[obs_names].copy() if path_save is not None: adata_fraction.write_h5ad(os.path.join(path_save, 'adata_fraction.h5ad')) return adata_fraction
# Function to get samples from AnnData
[docs] def get_samples( adata = None, path = None, path_save = '', sample_col = None, samples = None ): ''' Get samples from AnnData object. Specify path OR AnnData object. The function returns a new AnnData with selected samples. Parameters ---------- adata : AnnData Annotated data matrix. path : str, path object Path to the AnnData object if AnnData is not loaded into RAM. path_save : str, path object Path to save fraction of AnnData sample_col : str, (default: None) Key in adata.obs dataframe with samples names. samples : list, (default: None) List of samples names in adata.obs[sample_col]. Returns ------- New AnnData with some samples from the AnnData object. This part of the AnnData object can also be saved as adata_samples.h5ad. ''' # Get meta data from AnnData if path != None: adata = sc.read_h5ad(path, backed = 'r') # Get and save adata_samples if adata.isbacked: adata_samples = adata[adata.obs[sample_col].isin(samples)].to_memory() else: adata_samples = adata[adata.obs[sample_col].isin(samples)].copy() if path_save is not None: adata_samples.write_h5ad(os.path.join(path_save, 'adata_samples.h5ad')) return adata_samples
# Function to find difference between clusters based on a specific scores
[docs] def clust_diff( adata, groupby = None, group1 = None, group2 = None, score1 = None, score2 = None, plot = True, thresh = 0.02, fill = True, alpha = 0.5, **kwargs ): ''' Calculates metrics to Integral of absolute density difference and Mutual Information between two clusterings. Each metric follows the principle that the higher the value, the better the clusters separate the selected scores. adata : AnnData or MuData Annotated data or Multimodal data. groupby : str The key of the grouping in AnnData.obs or MuData.obs. group1 : str Cluster in groupby. group2 : str Cluster in groupby. score1 : str Score 1 in AnnData.obs or MuData.obs calculated using scanpy.tl.score_genes. score2 : str Score 2 in AnnData.obs or MuData.obs calculated using scanpy.tl.score_genes. plot : bool, (default: True) Show kernel density estimate plot or not. thresh : float in [0, 1], (default: 0.02) Lowest iso-proportion level at which to draw a contour line. fill : bool or None, (default: True) If True, fill in the area under univariate density curves or between bivariate contours. alpha : float or None, (default: 0.5) Transparency of the rectangle and connector lines. kwargs Other keyword arguments are passed to seaborn.kdeplot. Returns a plot and Integral of absolute density difference and Mutual Information between two clusterings. ''' # Create DataFrame with groups from groupby df = adata.obs[adata.obs[groupby].isin([group1, group2])][[groupby, score1, score2]] df[groupby] = df[groupby].astype(object) groups = df.groupby(groupby) group1, group2 = list(groups)[:2] # Overlap metric def compute_overlap_metric(x1, y1, x2, y2, grid_size=500): xmin = min(x1.min(), x2.min()) xmax = max(x1.max(), x2.max()) ymin = min(y1.min(), y2.min()) ymax = max(y1.max(), y2.max()) X, Y = np.mgrid[xmin:xmax:grid_size*1j, ymin:ymax:grid_size*1j] positions = np.vstack([X.ravel(), Y.ravel()]) #KDE values1 = np.vstack([x1, y1]) kernel1 = stats.gaussian_kde(values1) Z1 = np.reshape(kernel1(positions).T, X.shape) values2 = np.vstack([x2, y2]) kernel2 = stats.gaussian_kde(values2) Z2 = np.reshape(kernel2(positions).T, X.shape) # Absolute difference diff = np.abs(Z1 - Z2) overlap_metric = np.trapz(np.trapz(diff, axis=1), axis=0) return overlap_metric # Mutual information score def compute_mutual_information(x, y, bins=20): c_xy = np.histogram2d(x, y, bins)[0] mi = metrics_.mutual_info_score(None, None, contingency=c_xy) return mi # Score1 and score2 x1, y1 = group1[1][[score1, score2]].values.T x2, y2 = group2[1][[score1, score2]].values.T metric = compute_overlap_metric(x1, y1, x2, y2) mi = compute_mutual_information(df[score1], df[score2]) if plot: plt.title(f"KDE Plot of {score1} vs {score2} \n" f"Integral of absolute density difference: {metric:.1f}, \nMutual Information: {mi:.3f}", fontsize=12) plt.xlabel(score1, fontsize=10) plt.ylabel(score2, fontsize=10) sns.kdeplot( data=df, x=score1, y=score2, hue=groupby, thresh=thresh, fill=fill, alpha=alpha, **kwargs ) else: print(f"Integral of absolute density difference: {metric:.1f}, \nMutual Information: {mi:.3f}") return metric, mi
# Corrected function for plotting integration metrics results def plot_results_table( bm, cmap_scores='bwr', cmap_metrics='PRGn', min_max_scale=False, show=True, save_dir=None, dpi=300 ): """ Plot the benchmarking results (scib_metrics) of data integration. Parameters ---------- bm: benchmarking results from scib_metrics. cmap_scores: str (default: 'bwr') Color map for aggregate scores ('Batch correction', 'Bio conservation', 'Total'). cmap_metrics: str (default: 'PRGn') Color map for individual metrics ('iLISI', 'KBET', ...). min_max_scale: Whether to min max scale the results. show: bool (default: True) Whether to show the plot. save_dir: path (default: None) The directory to save the plot to. If `None`, the plot is not saved. dpi: int (default: 300) Print resolution. """ _METRIC_TYPE = 'Metric Type' num_embeds = len(bm._embedding_obsm_keys) df = bm.get_results(min_max_scale=min_max_scale) # Do not want to plot what kind of metric it is plot_df = df.drop(_METRIC_TYPE, axis=0) # Sort by total '' if bm._batch_correction_metrics is not None and bm._bio_conservation_metrics is not None: sort_col = "Total" elif bm._batch_correction_metrics is not None: sort_col = "Batch correction" else: sort_col = "Bio conservation" plot_df = plot_df.sort_values(by=sort_col, ascending=False).astype(np.float64) plot_df["Method"] = plot_df.index # Split columns by metric type, using df as it doesn't have the new method col score_cols = df.columns[df.loc[_METRIC_TYPE] == 'Aggregate score'] other_cols = df.columns[df.loc[_METRIC_TYPE] != 'Aggregate score'] column_definitions = [ ColumnDefinition("Method", width=1.5, textprops={"ha": "left", "weight": "bold"}), ] # Circles for the metric values column_definitions += [ ColumnDefinition( col, title=col.replace(" ", "\n", 1), width=1, textprops={ "ha": "center", "bbox": {"boxstyle": "circle", "pad": 0.25}, }, cmap=normed_cmap(plot_df[col], cmap=mpl.colormaps.get_cmap(cmap_metrics), num_stds=2.5), group=df.loc[_METRIC_TYPE, col], formatter="{:.3f}", ) for i, col in enumerate(other_cols) ] # Bars for the aggregate scores column_definitions += [ ColumnDefinition( col, width=1, title=col.replace(" ", "\n", 1), plot_fn=bar, plot_kw={ "cmap": normed_cmap(plot_df[col], cmap=mpl.colormaps.get_cmap(cmap_scores), num_stds=2.5), "plot_bg_bar": False, "annotate": True, "height": 0.9, "formatter": "{:.3f}", "xlim": (0, 1), "textprops": {"fontsize": 9} }, group=df.loc[_METRIC_TYPE, col], border="left" if i == 0 else None, ) for i, col in enumerate(score_cols) ] # Allow to manipulate text with mpl.rc_context({"svg.fonttype": "none"}): fig, ax = plt.subplots(figsize=(len(df.columns) * 1.25, 3 + 0.3 * num_embeds)) tab = Table( plot_df, cell_kw={ "linewidth": 0, "edgecolor": "k", }, column_definitions=column_definitions, ax=ax, row_dividers=True, footer_divider=True, textprops={"fontsize": 9, "ha": "center"}, row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 5))}, col_label_divider_kw={"linewidth": 1, "linestyle": "-"}, column_border_kw={"linewidth": 1, "linestyle": "-"}, index_col="Method", ).autoset_fontcolors(colnames=plot_df.columns) if show: plt.show() if save_dir is not None: fig.savefig(os.path.join(save_dir, "scib_results.png"), bbox_inches ='tight', facecolor=ax.get_facecolor(), dpi=dpi)