This function returns prediction sets for the cell type of cells in a SingleCellExperiment objects. It implements two methods: the first one uses standard conformal inference, while the second one conformal risk control (see details). The output is either a SingleCellExperiment object with the prediction sets in the colData or a list.
getPredictionSets(
x_query,
x_cal,
y_cal,
onto = NULL,
alpha = 0.1,
lambdas = seq(0.001, 0.999, length.out = 100),
follow_ontology = TRUE,
resample = FALSE,
labels = NULL,
return_sc = NULL,
pr_name = "pred.set",
simplify = FALSE,
BPPARAM = SerialParam()
)
query data for which we want to build prediction sets. Could
be either a SingleCellExperiment object with the estimated probabilities for
each cell type in the colData, or a named matrix of dimension n x K
,
where n
is the number of cells and K
is the number of different
labels. The colnames of the matrix have to correspond to the cell labels.
calibration data. Could be either a
SingleCellExperiment object with the estimated probabilities for each cell
type in the colData, or a named matrix of dimension m x K
, where
m
is the number of cells and K
is the number of different
labels. The colnames of the matrix have to correspond to the cell labels.
a vector of length m
with the true labels of the cells in
the calibration data.
the considered section of the cell ontology as an igraph object.
a number between 0 and 1 that indicates the allowed miscoverage
a vector of possible lambda values to be considered. Necessary
only when follow_ontology=TRUE
.
If TRUE
, then the function returns hierarchical
prediction sets that follow the cell ontology structure. If FALSE
, it
returns classical conformal prediction sets. See details.
Should the calibration dataset be resampled according to the estimated relative frequencies of cell types in the query data?
labels of different considered cell types. Necessary if
onto=NULL
, otherwise they are set equal to the leaf nodes of the
provided graph.
parameter the controls the output. If TRUE
, the
function returns a SingleCellExperiment.
If FALSE
, the function returns a list. By default,
it is set to TRUE
when x_query
is a SingleCellExperiment or
SpatialExperiment object and to FALSE
when x_query
is a matrix.
name of the colData variable in the returned
SingleCellExperiment object that will contain the prediction
sets. The default name is pred.set
.
if TRUE
, the output will be the common ancestor
of the labels inserted into the prediction set. If FALSE
(default),
the output will be the set of the leaf labels.
BiocParallel instance for parallel computing. Default is
SerialParam()
.
return_sc = TRUE
the function getPredictionSets
returns
a SingleCellExperiment or SpatialExperiment
object with the prediction sets in the colData. The name of the variable
containing the prediction sets is given by the parameter pr_name
return_sc = FALSE
the function getPredictionSets
returns
a list of length equal
to the number of cells in the test data. Each element of the list contains
the prediction set for that cell.
Conformal inference is a statistical framework that allows to build prediction sets for any probabilistic or machine learning model. Suppose we have a classification task with \(K\) classes. We fit a classification model \(\hat{f}\) that outputs estimated probabilities for each class: \(\hat{f}(x) \in [0,1]^K\). Split conformal inference requires to reserve a portion of the labelled training data, \((X_1,Y_1),\dots, (X_n,Y_n)\), to be used as calibration data. Given \(\hat{f}\) and the calibration data, the objective of conformal inference is to build, for a new observation \(X_{n+1},\) a prediction set \(C(X_{n+1}) \subseteq\{1,\dots,K\}\) that satisfies $$P\left(Y_{n+1}\in C(X_{n+1})\right) \geq 1-\alpha$$ for a user-chosen error rate \(\alpha\). Note that conformal inference is distribution-free and the sets provided have finite-samples validity. The only assumption is that the test data and the calibration data are exchangeable. The algorithm of split conformal inference is the following:
For the data in the calibration set, \((X_1,Y_1),\dots, (X_n,Y_n)\) , obtain the conformal scores, \(s_i=1-\hat{f}(X_i)_{Y_i}, \;i=1,\dots,n\). These scores will be high when the model is assigning a small probability to the true class, and low otherwise.
Obtain \(\hat{q}\), the \(\lceil(1-\alpha)(n+1)\rceil/n\) empirical quantile of the conformal scores.
Finally, for a new observation \(X_{n+1}\), construct a prediction set by including all the classes for which the estimated probability is higher than \(1-\hat{q}\): $$C(X_{n+1})=\{y: \hat{f}(X_{n+1})_y\geq 1-\hat{q}\}.$$
Let \(\hat{y}(x)\) be the class with maximum estimated probability. Moreover, given a directed graph let \(\mathcal{P}(v)\) and \(\mathcal{A}(v)\) be the set on children nodes and ancestor nodes of \(v\), respectively. Finally, for each node \(v\) define a score \(g(v,x)\) as the sum of the predicted probabilities of the leaf nodes that are children of \(v\). To build the sets we propose the following algorithm: $$\mathcal{P}(v) \cup \{\mathcal{P}(a): a\in\mathcal{A}(\hat{y}(x)): g(a,x)\leq\lambda \},$$ where \(v:v\in \mathcal{A}(\hat{y}(x)), \;g(v,x)\geq\lambda,\; v=\arg\min_{u:g(u,x)\geq\lambda}g(u,x)\). In words, we start from the predicted class and we go up in the graph until we find an ancestor of \(\hat{y}(x)\) that has a score that is at least \(\lambda\) and include in the prediction sets all its children. For theoretical reasons, to this subgraph we add all the other ones that contain \(\hat{y}(x)\) for which the score is less than \(\lambda\). To choose \(\lambda\), we follow eq. (4) in Angelopoulus et al. (2023), considering the miscoverage as loss function. In this way, it is still guaranteed that $$P(Y_{n+1}\notin C_\lambda (X_{n+1})) \leq \alpha.$$
For an introduction to conformal prediction, see Angelopoulos, Anastasios N., and Stephen Bates. "A gentle introduction to conformal prediction and distribution-free uncertainty quantification." arXiv preprint arXiv:2107.07511 (2021). For reference on conformal risk control, see Angelopoulos, Anastasios N., et al. "Conformal risk control." arXiv preprint arXiv:2208.02814 (2023).
# random p matrix
set.seed(1040)
p <- matrix(rnorm(2000 * 4), ncol = 4)
# Normalize the matrix p to have all numbers between 0 and 1 that sum to 1
# by row
p <- exp(p - apply(p, 1, max))
p <- p / rowSums(p)
cell_types <- c("T (CD4+)", "T (CD8+)", "B", "NK")
colnames(p) <- cell_types
# Take 1000 rows as calibration and 1000 as test
p_cal <- p[1:1000, ]
p_test <- p[1001:2000, ]
# Randomly create the vector of real cell types for p_cal and p_test
y_cal <- sample(cell_types, 1000, replace = TRUE)
y_test <- sample(cell_types, 1000, replace = TRUE)
# Obtain conformal prediction sets
conf_sets <- getPredictionSets(
x_query = p_test,
x_cal = p_cal,
y_cal = y_cal,
onto = NULL,
alpha = 0.1,
follow_ontology = FALSE,
resample = FALSE,
labels = cell_types,
return_sc = FALSE
)