This function returns prediction sets for the cell type of cells in a SingleCellExperiment objects. It implements two methods: the first one uses standard conformal inference, while the second one conformal risk control (see details). The output is either a SingleCellExperiment object with the prediction sets in the colData or a list.

getPredictionSets(
  x_query,
  x_cal,
  y_cal,
  onto = NULL,
  alpha = 0.1,
  lambdas = seq(0.001, 0.999, length.out = 100),
  follow_ontology = TRUE,
  resample = FALSE,
  labels = NULL,
  return_sc = NULL,
  pr_name = "pred.set",
  simplify = FALSE,
  BPPARAM = SerialParam()
)

Arguments

x_query

query data for which we want to build prediction sets. Could be either a SingleCellExperiment object with the estimated probabilities for each cell type in the colData, or a named matrix of dimension n x K, where n is the number of cells and K is the number of different labels. The colnames of the matrix have to correspond to the cell labels.

x_cal

calibration data. Could be either a SingleCellExperiment object with the estimated probabilities for each cell type in the colData, or a named matrix of dimension m x K, where m is the number of cells and K is the number of different labels. The colnames of the matrix have to correspond to the cell labels.

y_cal

a vector of length m with the true labels of the cells in the calibration data.

onto

the considered section of the cell ontology as an igraph object.

alpha

a number between 0 and 1 that indicates the allowed miscoverage

lambdas

a vector of possible lambda values to be considered. Necessary only when follow_ontology=TRUE.

follow_ontology

If TRUE, then the function returns hierarchical prediction sets that follow the cell ontology structure. If FALSE, it returns classical conformal prediction sets. See details.

resample

Should the calibration dataset be resampled according to the estimated relative frequencies of cell types in the query data?

labels

labels of different considered cell types. Necessary if onto=NULL, otherwise they are set equal to the leaf nodes of the provided graph.

return_sc

parameter the controls the output. If TRUE, the function returns a SingleCellExperiment. If FALSE, the function returns a list. By default, it is set to TRUE when x_query is a SingleCellExperiment or SpatialExperiment object and to FALSE when x_query is a matrix.

pr_name

name of the colData variable in the returned SingleCellExperiment object that will contain the prediction sets. The default name is pred.set.

simplify

if TRUE, the output will be the common ancestor of the labels inserted into the prediction set. If FALSE (default), the output will be the set of the leaf labels.

BPPARAM

BiocParallel instance for parallel computing. Default is SerialParam().

Value

return_sc = TRUE

the function getPredictionSets returns a SingleCellExperiment or SpatialExperiment object with the prediction sets in the colData. The name of the variable containing the prediction sets is given by the parameter pr_name

return_sc = FALSE

the function getPredictionSets returns a list of length equal to the number of cells in the test data. Each element of the list contains the prediction set for that cell.

Details

Split conformal sets

Conformal inference is a statistical framework that allows to build prediction sets for any probabilistic or machine learning model. Suppose we have a classification task with \(K\) classes. We fit a classification model \(\hat{f}\) that outputs estimated probabilities for each class: \(\hat{f}(x) \in [0,1]^K\). Split conformal inference requires to reserve a portion of the labelled training data, \((X_1,Y_1),\dots, (X_n,Y_n)\), to be used as calibration data. Given \(\hat{f}\) and the calibration data, the objective of conformal inference is to build, for a new observation \(X_{n+1},\) a prediction set \(C(X_{n+1}) \subseteq\{1,\dots,K\}\) that satisfies $$P\left(Y_{n+1}\in C(X_{n+1})\right) \geq 1-\alpha$$ for a user-chosen error rate \(\alpha\). Note that conformal inference is distribution-free and the sets provided have finite-samples validity. The only assumption is that the test data and the calibration data are exchangeable. The algorithm of split conformal inference is the following:

  1. For the data in the calibration set, \((X_1,Y_1),\dots, (X_n,Y_n)\) , obtain the conformal scores, \(s_i=1-\hat{f}(X_i)_{Y_i}, \;i=1,\dots,n\). These scores will be high when the model is assigning a small probability to the true class, and low otherwise.

  2. Obtain \(\hat{q}\), the \(\lceil(1-\alpha)(n+1)\rceil/n\) empirical quantile of the conformal scores.

  3. Finally, for a new observation \(X_{n+1}\), construct a prediction set by including all the classes for which the estimated probability is higher than \(1-\hat{q}\): $$C(X_{n+1})=\{y: \hat{f}(X_{n+1})_y\geq 1-\hat{q}\}.$$

Hierarchical conformal sets

Let \(\hat{y}(x)\) be the class with maximum estimated probability. Moreover, given a directed graph let \(\mathcal{P}(v)\) and \(\mathcal{A}(v)\) be the set on children nodes and ancestor nodes of \(v\), respectively. Finally, for each node \(v\) define a score \(g(v,x)\) as the sum of the predicted probabilities of the leaf nodes that are children of \(v\). To build the sets we propose the following algorithm: $$\mathcal{P}(v) \cup \{\mathcal{P}(a): a\in\mathcal{A}(\hat{y}(x)): g(a,x)\leq\lambda \},$$ where \(v:v\in \mathcal{A}(\hat{y}(x)), \;g(v,x)\geq\lambda,\; v=\arg\min_{u:g(u,x)\geq\lambda}g(u,x)\). In words, we start from the predicted class and we go up in the graph until we find an ancestor of \(\hat{y}(x)\) that has a score that is at least \(\lambda\) and include in the prediction sets all its children. For theoretical reasons, to this subgraph we add all the other ones that contain \(\hat{y}(x)\) for which the score is less than \(\lambda\). To choose \(\lambda\), we follow eq. (4) in Angelopoulus et al. (2023), considering the miscoverage as loss function. In this way, it is still guaranteed that $$P(Y_{n+1}\notin C_\lambda (X_{n+1})) \leq \alpha.$$

References

For an introduction to conformal prediction, see Angelopoulos, Anastasios N., and Stephen Bates. "A gentle introduction to conformal prediction and distribution-free uncertainty quantification." arXiv preprint arXiv:2107.07511 (2021). For reference on conformal risk control, see Angelopoulos, Anastasios N., et al. "Conformal risk control." arXiv preprint arXiv:2208.02814 (2023).

Examples

# random p matrix
set.seed(1040)
p <- matrix(rnorm(2000 * 4), ncol = 4)
# Normalize the matrix p to have all numbers between 0 and 1 that sum to 1
# by row
p <- exp(p - apply(p, 1, max))
p <- p / rowSums(p)
cell_types <- c("T (CD4+)", "T (CD8+)", "B", "NK")
colnames(p) <- cell_types

# Take 1000 rows as calibration and 1000 as test
p_cal <- p[1:1000, ]
p_test <- p[1001:2000, ]

# Randomly create the vector of real cell types for p_cal and p_test
y_cal <- sample(cell_types, 1000, replace = TRUE)
y_test <- sample(cell_types, 1000, replace = TRUE)

# Obtain conformal prediction sets
conf_sets <- getPredictionSets(
    x_query = p_test,
    x_cal = p_cal,
    y_cal = y_cal,
    onto = NULL,
    alpha = 0.1,
    follow_ontology = FALSE,
    resample = FALSE,
    labels = cell_types,
    return_sc = FALSE
)