Combinatorial prediction of therapeutic perturbations using causally inspired neural networks

Preliminaries

A calligraphic letter \({\mathcal{X}}\) indicates a set, an italic uppercase letter X denotes a graph, uppercase X denotes a matrix, lowercase x denotes a vector, and a monospaced letter X indicates a tuple. Uppercase letter \({\mathsf{X}}\) indicates a random variable, and lowercase letter \({\mathsf{x}}\) indicates its corresponding value; bold uppercase \({\boldsymbol{\mathsf{X}}}\) denotes a set of random variables, and lowercase letter \({\boldsymbol{\mathsf{x}}}\) indicates its corresponding values. We denote \(P({\boldsymbol{\mathsf{X}}})\) as a probability distribution over a set of random variables \({\boldsymbol{\mathsf{X}}}\) and \(P({\boldsymbol{\mathsf{X}}}={\boldsymbol{\mathsf{x}}})\) as the probability of \({\boldsymbol{\mathsf{X}}}\) that is equal to the value of \({\boldsymbol{\mathsf{x}}}\) under the distribution \(P({\boldsymbol{\mathsf{X}}})\). For simplicity, \(P({\boldsymbol{\mathsf{X}}}={\boldsymbol{\mathsf{x}}})\) is abbreviated as \(P({\boldsymbol{\mathsf{x}}})\). This section uses terminology and concepts from the framework of casual inference67.

Problem formulation for combinatorial prediction of targets

Intuitively, given a diseased cell line sample, we would like to predict the set of therapeutic genes that need to be targeted to reverse the effects of disease, that is, the genes that need to be perturbed to shift the cell gene expression state as close as possible to the healthy state. Next, we formalize our problem formulation. Let \({\mathtt{M}}= < {\boldsymbol{\mathsf{E}}},{\boldsymbol{\mathsf{V}}},{\mathcal{F}},P(\boldsymbol{{\mathsf{E}}}) >\) be a structural causal model (SCM; see the description of related works in Supplementary Note 4) associated with causal graph G, where \({\boldsymbol{\mathsf{E}}}\) is a set of exogenous variables affecting the system, \({\boldsymbol{\mathsf{V}}}\) are the system variables, \({\mathcal{F}}\) are structural equations encoding causal relations between variables and \(P({\boldsymbol{\mathsf{E}}})\) is a probability distribution over exogenous variables. Let \({\mathcal{T}}=\{{{\mathtt{T}}}_{1},\ldots ,{{\mathtt{T}}}_{m}\}\) be a dataset of paired healthy and diseased samples (namely, disease intervention data), where each element is a triplet \({\mathtt{T}}= < {{\boldsymbol{\mathsf{v}}}}^{\rm{h}},{\boldsymbol{\mathsf{U}}},{{\boldsymbol{\mathsf{v}}}}^{\rm{d}} >\) with \({{\boldsymbol{\mathsf{v}}}}^{\rm{h}}\in {[0,1]}^{N}\) being normalized gene expression values of a healthy cell line (variable states before perturbation), \({{\boldsymbol{\mathsf{V}}}}_{{\boldsymbol{\mathsf{U}}}}\) being the disease-causing perturbed variable (gene) set in \({\boldsymbol{\mathsf{V}}}\), and \({{\boldsymbol{\mathsf{v}}}}^{\rm{d}}\in {[0,1]}^{N}\) being gene expression values of a diseased cell line (variable states after perturbation). Our goal is to find, for each sample \({\mathtt{T}}= < {{\boldsymbol{\mathsf{v}}}}^{\rm{h}},{\boldsymbol{\mathsf{U}}},{{\boldsymbol{\mathsf{v}}}}^{\rm{d}} >\), the variable set \({{\boldsymbol{\mathsf{U}}}}^{{\prime} }\) with the highest likelihood of shifting variable states from diseased \({{\boldsymbol{\mathsf{v}}}}^{\rm{d}}\) to healthy \({{\boldsymbol{\mathsf{v}}}}^{\rm{h}}\) state. To increase generality, we refer to the desired variable states as treated (\({{\boldsymbol{\mathsf{v}}}}^{\rm{t}}\)). Our goal can then be expressed as:

$${\rm{argmax}}_{{{\boldsymbol{\mathsf{U}}}}^{{\prime} }}{P}^{\;{G}^{{\boldsymbol{\mathsf{U}}}}}({\boldsymbol{\mathsf{V}}}={{\boldsymbol{\mathsf{v}}}}^{\rm{t}}\,| \,{\rm{do}}({{\boldsymbol{\mathsf{U}}}}^{{\prime} })),$$

(1)

where \({P}^{\;{G}^{{\boldsymbol{\mathsf{U}}}}}\) represents the probability on the graph G mutilated by perturbations in variables in \({\boldsymbol{\mathsf{U}}}\). Under the assumption of no unobserved confounders, the above interventional probability can be expressed as a conditional probability on the mutilated graph \({G}^{{{\boldsymbol{\mathsf{U}}}}^{{\prime} }}\):

$${\rm{argmax}}_{{{\boldsymbol{\mathsf{U}}}}^{{\prime} }}{P}^{\;{G}^{{{\boldsymbol{\mathsf{U}}}}^{{\prime} }}}({\boldsymbol{\mathsf{V}}}={{\boldsymbol{\mathsf{v}}}}^{\rm{t}}\,| \,{{\boldsymbol{\mathsf{U}}}}^{{\prime} }),$$

(2)

which under the causal Markov condition is:

$${\rm{argmax}}_{{{\boldsymbol{\mathsf{U}}}}^{{\prime} }}\prod _{i}P\left({{{\mathsf{V}}}}_{i}={{{\mathsf{v}}}}_{i}^{\rm{t}}\,| \,{{\boldsymbol{\mathsf{Pa}}}}_{{{{\mathsf{V}}}}_{i}}\right),$$

(3)

where \({{\boldsymbol{\mathsf{Pa}}}}_{{{\boldsymbol{\mathsf{V}}}}_{i}}\) represents parents of variable \({{{\mathsf{V}}}}_{i}\) according to graph \({G}^{{{\boldsymbol{\mathsf{U}}}}^{{\prime} }}\) (that is, the mutilated graph upon intervening on variables in \({{\boldsymbol{\mathsf{U}}}}^{{\prime} }\)). Here state of a variable \({{\mathsf{V}}}_{j}\in {{\boldsymbol{\mathsf{Pa}}}}_{{{{\mathsf{V}}}}_{i}}\) will be equal to an arbitrary value \({{\mathsf{v}}}_{j}^{{\prime} }\) if \({{\mathsf{V}}}_{j}\in {{\boldsymbol{\mathsf{U}}}}^{{\prime} }\). Therefore, intervening on the variable set \({{\boldsymbol{\mathsf{U}}}}^{{\prime} }\) modifies the graph used to obtain conditional probabilities and determine the state of variables in \({{\boldsymbol{\mathsf{U}}}}^{{\prime} }\).

Problem formulation from a representation learning perspective

In the previous section, we drew on the SCM framework to introduce a generic formulation for the task of combinatorial prediction of therapeutic targets. Instead of approaching the problem from a purely causal inference perspective, we draw upon representation learning to approximate the queries of interest to address the limiting real-world setting of a noisy and incomplete causal graph. Formulating our problem using the SCM framework allows for explicit modelling of interventions and formulation of interventional queries (see the description of related works in Supplementary Note 4). Inspired by this principled problem formulation, we next introduce the problem formulation using a representation learning paradigm.

We let \(G=({\mathcal{V}},{\mathcal{E}})\) denote a graph with \(| {\mathcal{V}}| =n\) nodes and \(| {\mathcal{E}}|\) edges, which contains partial information on causal relationships between nodes in \({\mathcal{V}}\) and some noisy relationships. We refer to this graph as a proxy causal graph. Let \({\mathcal{T}}=\{{{\mathtt{T}}}_{1},\ldots ,{{\mathtt{T}}}_{{\mathsf{M}}}\}\) be a dataset with an individual sample being a triplet \({\mathtt{T}}= < {{\bf{x}}}^{\rm{h}},{\mathcal{U}},{{\bf{x}}}^{\rm{d}} >\) with xh [0, 1]n being the node states (attributes) of a healthy cell sample (before perturbation), \({\mathcal{U}}\) being the set of disease-causing perturbed nodes in \({\mathcal{V}}\), and xd [0, 1]n being the node states (attributes) of a diseased cell sample (after perturbation). We denote by \({G}^{{\mathcal{U}}}=({\mathcal{V}},{{\mathcal{E}}}^{{\mathcal{U}}})\) the graph resulting from the mutilation of edges in G as a result of perturbing nodes in \({\mathcal{U}}\) (one graph per perturbagen; we avoid using superindices for simplicity). Here again we refer to the desired variable states as treated (xt). Our goal is then to learn a function:

$$f:{G}^{{{\mathcal{U}}}^{{\prime} }},\,{{\bf{x}}}^{\rm{d}},{{\bf{x}}}^{\rm{t}}\to {\rm{argmax}}_{{{\mathcal{U}}}^{{\prime} }}{P}^{\;{G}^{{{\mathcal{U}}}^{{\prime} }}}({\bf{x}}={{\bf{x}}}^{\rm{t}}| {{\bf{x}}}^{\rm{d}},{{\mathcal{U}}}^{{\prime} }).$$

(4)

That, given the graph \({G}^{{{\mathcal{U}}}^{{\prime} }}\), the diseased node states xd and treated node states xt, predicts the combinatorial set of nodes \({{\mathcal{U}}}^{{\prime} }\) that if perturbed have the highest chance of shifting the node states to the treated state xt. We note here that \({P}^{{G}^{{{\mathcal{U}}}^{{\prime} }}}\) represents probabilities over graph \({G}^{{\mathcal{U}}}\) mutilated upon perturbations in nodes in \({{\mathcal{U}}}^{{\prime} }\). Under causal Markov condition, we can factorize \({P}^{{G}^{{{\mathcal{U}}}^{{\prime} }}}\) over graph \({G}^{{{\mathcal{U}}}^{{\prime} }}\):

$$f:{G}^{{{\mathcal{U}}}^{{\prime} }},\,{{\bf{x}}}^{\rm{d}},{{\bf{x}}}^{\rm{t}}\to {\rm{argmax}}_{{{\mathcal{U}}}^{{\prime} }}\prod _{i}P({{\bf{x}}}_{i}={{\bf{x}}}_{i}^{\rm{t}}| {{\bf{x}}}_{{{\mathcal{PA}}}_{i}}),$$

(5)

that is, the probability of each node i depending only on its parents \({{\mathcal{PA}}}_{i}\) in graph \({G}^{{{\mathcal{U}}}^{{\prime} }}\).

We assume (1) real-valued node states, (2) G is fixed and given, and (3) atomic and non-atomic perturbagens (intervening on individual nodes or sets of nodes). Given that the value of each node should depend only on its parents in the graph \({G}^{{{\mathcal{U}}}^{{\prime} }}\), a message-passing framework appears especially suited to compute the factorized probabilities P.

In the SCM framework, the conditional probabilities in equation (3) are computed recursively on the graph, each being an expectation over exogenous variables \(\boldsymbol{\mathsf{E}}\). Therefore, node states of the previous time point are not necessary. To translate this query into the representation learning realm, we discard the existence of noise variables and directly try to learn a function encoding the transition from an initial state to a desired state. An exhaustive approach to solving equation (5) would be to search the space of all potential sets of therapeutic targets \({{\mathcal{U}}}^{{\prime} }\) and score how effective each is in achieving the desired treated state. This is, how many cell response prediction approaches can be used for perturbagen discovery22,23,68. However, with moderately sized graphs, this is highly computationally expensive, if not intractable. Instead, we propose to search for potential perturbagens efficiently with a perturbagen discovery module (ƒp) and a way to score each potential perturbagen with a response prediction module (ƒr).

Relationship to conventional graph prediction tasks

Given that the prediction for each variable is dependent only on its parents in a graph, GNNs appear especially suited for this problem. We can formulate the query of interest under a graph representation learning paradigm as follows: given a graph \(G=({\mathcal{V}},{\mathcal{E}})\), paired sets of node attributes \({\mathcal{X}}=\{{{\bf{X}}}_{1},{{\bf{X}}}_{2},\ldots ,{{\bf{X}}}_{m}\}\) and node labels \({\mathcal{Y}}=\{{{\bf{Y}}}_{1},{{\bf{Y}}}_{2},\ldots ,{{\bf{Y}}}_{m}\}\), where each Y = {y1, …, yn}, with yi [0, 1], we aim at training a neural message-passing architecture that given node attributes Xi predicts the corresponding node labels Yi. There are, however, differences between our problem formulation and the conventional graph prediction tasks, namely, graph and node classification (summarized in Supplementary Table 13).

In node classification, a single graph G is paired with node attributes X, and the task is to predict the node labels Y. Our formulation differs in that we have m paired sets of node attributes \({\mathcal{X}}\) and labels \({\mathcal{Y}}\) instead of a single set, yet they are similar in that there is a single graph in which GNNs operate. In graph classification, a set of graphs \({\mathcal{G}}=\{{G}_{1},\ldots ,{G}_{m}\}\) is paired with a set of node attributes \({\mathcal{X}}=\{{{\bf{X}}}_{1},{{\bf{X}}}_{2},\ldots ,{{\bf{X}}}_{m}\}\) and the task is to predict a label for each graph Y = {y1, …, ym}. Here graphs have a varying structure, and both the topological information and node attributes predict graph labels. In our formulation, a single graph is combined with each node attribute Xi, and the goal is to predict a label for each node, not for the whole graph.

PDGrapher model

PDGrapher is an approach for combinatorial prediction of therapeutic targets composed of two modules. First, a perturbagen discovery module ƒp searches the space of potential gene sets to predict a suitable candidate \({{\mathcal{U}}}^{{\prime} }\). Next, a response prediction module ƒr checks the goodness of the predicted set \({{\mathcal{U}}}^{{\prime} }\), that is, how effective intervening on variables in \({{\mathcal{U}}}^{{\prime} }\) is to shift node states to the desired treated state xt.

$$(1)\,\,{{\bf{x}}}^{\rm{d}},{{\bf{x}}}^{\rm{t}}\mathop{\rightarrow}\limits^{ {f}_{\rm{p}}}\hat{{{\mathcal{U}}}^{{\prime} }}$$

$$(2)\,\,{{\bf{x}}}^{\rm{d}},\hat{{{\mathcal{U}}}^{{\prime} }}\mathop{\rightarrow}\limits^{ {f}_{\rm{r}}}\hat{{{\bf{x}}}^{\rm{t}}}.$$

Model optimization

We optimize our response prediction module ƒr using cross-entropy (CE) loss on known triplets of disease intervention \(< {{\bf{x}}}^{\rm{h}},{\mathcal{U}},{{\bf{x}}}^{\rm{d}} >\) and treatment intervention \(< {{\bf{x}}}^{\rm{d}},{{\mathcal{U}}}^{{\prime} },{{\bf{x}}}^{\rm{t}} >\):

$${{\mathcal{L}}}_{{f}_{\rm{r}}}={\rm{CE}}\left({{\bf{x}}}^{\rm{d}},\,{f}_{\rm{r}}({{\bf{x}}}^{\rm{h}},\,{\mathcal{U}})\right)+{\rm{CE}}\left({{\bf{x}}}^{\rm{t}},\,{f}_{\rm{r}}({{\bf{x}}}^{\rm{d}},\,{{\mathcal{U}}}^{{\prime} })\right).$$

(6)

We optimize our intervention discovery module ƒp using a cycle loss, ensuring that the response to the predicted intervention set \({{\mathcal{U}}}^{{\prime} }\) closely matches the desired treated state (the first part of equation (7)). In addition, we provide a supervisory signal for predicting \({{\mathcal{U}}}^{{\prime} }\) in the form of cross-entropy loss (the second part of equation (7)). So, the total loss is defined as:

$${{\mathcal{L}}}_{{f}_{\rm{p}}}={\rm{CE}}\left({{\bf{x}}}^{\rm{t}},\,{f}_{\rm{r}}({{\bf{x}}}^{\rm{d}},\,{f}_{\rm{p}}({{\bf{x}}}^{\rm{d}},\,{{\bf{x}}}^{\rm{t}}))\right)+{\rm{CE}}\left({{\mathcal{U}}}^{{\prime} },{f}_{\rm{p}}({{\bf{x}}}^{\rm{d}},\,{{\bf{x}}}^{\rm{t}})\right)\,\,({\rm{with}}\,{f}_{\rm{r}}\,{\rm{frozen}}).$$

(7)

We train ƒp and ƒr in parallel and implement early stopping separately (see ‘Experimental set-up’ for more details). Trained module ƒp is then used to predict, for each diseased cell sample, which nodes should be perturbed (\({{\mathcal{U}}}^{{\prime} }\)) to achieve a desired treated state (Fig. 1a).

Response prediction module

Our response prediction module ƒr should learn to map pre-perturbagen node values to post-perturbagen node values through learning relationships between connected nodes (equivalent to learning structural equations in SCMs) and propagating the effects of perturbations downstream in the graph (analogous to the recursive nature of query computations in SCMs).

Given a disease intervention triplet \(< {{\bf{x}}}^{\rm{h}},{\mathcal{U}},{{\bf{x}}}^{\rm{d}} >\), we propose a neural model operating on a mutilated graph, \({G}^{{\mathcal{U}}}\), where the node attributes are the concatenation of xh and \({{\bf{x}}}_{{\mathcal{U}}}^{{\prime} }\), predicting diseased node values xd. The first element is its gene expression value \({{\bf{x}}}_{i}^{\rm{h}}\) and the second is a perturbation flag, a binary label indicating whether a perturbation occurs at node i. So, each node i has a two-dimensional attribute vector \({{\bf{d}}}_{i}=[{{\bf{x}}}_{i}^{\rm{h}}\,| | \,{{\bf{x}}}_{{\mathcal{U}}}^{{\prime} }]\). In practice, we embed each node feature into a high-dimensional continuous space by assigning learnable embeddings to each node based on the value of each input feature dimension. Specifically, for each node, we use the binary perturbation flag to assign a d-dimensional learnable embedding, which is different between nodes but shared across samples for each node. To embed the gene expression value \({{\bf{x}}}_{i}^{\rm{h}}\in [0,1]\), we first calculate thresholds using quantiles to assign the gene expression value into one of the B bins. We use the bin index to assign a d-dimensional learnable embedding, which is different between nodes but shared across samples for each node. To increase our model’s representation power, we concatenate a d-dimensional positional embedding (a d-dimensional vector initialized randomly following a normal distribution). Concatenating these three embeddings results in an input node representation of dimensionality 3d. For each node \(i\in {\mathcal{V}}\), an embedding zi is computed using a GNN operating on the node’s neighbours’ attributes. The most general formulation of a GNN layer is:

$${{\bf{h}}}_{i}^{{\prime} }=\phi \left({{\bf{h}}}_{i},\mathop{\bigoplus }\limits_{j\in {{\mathcal{N}}}^{i}}\psi ({{\bf{h}}}_{i},{{\bf{h}}}_{j})\right),$$

(8)

where \({{\bf{h}}}_{i}^{{\prime} }\) represents the updated information of node i, and hi represents the information of node i in the previous layer, with embedded di being the input to the first layer. ψ is a message function, a permutation-invariant aggregate function, and ϕ is an update function. We obtain an embedding zi for node i by stacking K GNN layers. The node embedding \({{\bf{z}}}_{i}\in {\mathbb{R}}\) is then passed to a multilayer feedforward neural network to obtain an estimate of the values of the post-perturbation nodes xd.

Perturbation discovery module

Our perturbagen prediction module ƒp should learn the nodes in the graph that should be perturbed to shift the node states (attributes) from diseased xd to the desired treated state xt. Given a triplet \(< {{\bf{x}}}^{\rm{d}},{{\mathcal{U}}}^{{\prime} },{{\bf{x}}}^{\rm{t}} >\), we propose a neural model operating on graph \({G}^{{{\mathcal{U}}}^{{\prime} }}\) with node features xd and xt that predicts a ranking for each node, where the top P ranked nodes should be predicted as the nodes in \({{\mathcal{U}}}^{{\prime} }\). Each node i has a two-dimensional attribute vector: \({{\bf{d}}}_{i}=[{{\bf{x}}}_{i}^{\rm{d}}\,| | \,{{\bf{x}}}_{i}^{\rm{t}}]\). In practice, we represent these binary features in a continuous space using the same approach as described for our response prediction module ƒr.

For each node \(i\in {\mathcal{V}}\), an embedding zi is computed using a GNN operating on the node’s neighbours’ attributes. We obtain an embedding zi for node i by stacking K GNN layers. The node embedding \({{\bf{z}}}_{i}\in {\mathbb{R}}\) is then passed to a multilayer feedforward neural network to predict a real-valued number for node i.

Model implementation and training

We implement PDGrapher using PyTorch 1.10.1 (ref. 69) and the Torch Geometric 2.0.4 Library70. The implemented architecture yields a neural network with the following hyperparameters: number of GNN layers and number of prediction layers. We set the number of prediction layers to two and performed a grid search over the number of GNN layers (one to three layers). We train our model using a 5-fold cross-validation strategy and report PDGrapher’s performance resulting from the best-performing hyperparameter setting.

Further details on statistical analysis

We next outline the evaluation set-up, baseline models and statistical tests used to evaluate PDGrapher. We evaluate the performance of PDGrapher against the following existing methods:

  • Random reference: Given a sample \({\mathtt{T}}= < {{\bf{x}}}^{\rm{d}},{{\mathcal{U}}}^{{\prime} },{{\bf{x}}}^{\rm{t}} >\), the random reference baseline returns N random genes as the prediction of target genes in \({{\mathcal{U}}}^{{\prime} }\), where N is the number of genes in \({{\mathcal{U}}}^{{\prime} }\).

  • Cancer genes: Given a sample \({\mathtt{T}}= < {{\bf{x}}}^{\rm{d}},{{\mathcal{U}}}^{{\prime} },{{\bf{x}}}^{\rm{t}} >\), the cancer genes baseline returns the top N genes from an ordered list where the first M genes are disease associated (cancer-driver genes). The remaining genes are ranked randomly, and N is the number of genes in \({{\mathcal{U}}}^{{\prime} }\). The processing of cancer genes is described in ‘Disease-genes information’ in Supplementary Note 3.

  • Cancer drug targets: Given a sample \({\mathtt{T}}= < {{\bf{x}}}^{\rm{d}},{{\mathcal{U}}}^{{\prime} },{{\bf{x}}}^{\rm{t}} >\), the cancer targets baseline returns the top N genes from an ordered list where the first M genes are cancer drug targets and the remaining genes are ranked randomly, and N is the number of genes in \({{\mathcal{U}}}^{{\prime} }\). The processing of drug target information is described in ‘Drug-targets information’ and ‘Cancer drug and target information’ in Supplementary Note 3.

  • Perturbed genes: Given a sample \({\mathtt{T}}= < {{\bf{x}}}^{\rm{d}},{{\mathcal{U}}}^{{\prime} },{{\bf{x}}}^{\rm{t}} >\), the perturbed genes baseline returns the top N genes from an ordered list where the first M genes are all perturbed genes in the training set and the remaining genes are ranked randomly, and N is the number of genes in \({{\mathcal{U}}}^{{\prime} }\).

  • scGen22: scGen is a widely used gold-standard latent variable model for response prediction71,72,73,74. Given a set of observed cell types in control and perturbed states, scGen predicts the response of a new cell type to the perturbagen seen in training. To use scGen as a baseline, we first fit it to our LINCS gene expression data for each dataset type to predict response to perturbagens, training one model per perturbagen (chemical or genetic). Then, given a sample of paired diseased-treated cell line states, \({\mathtt{T}}= < {{\bf{x}}}^{\rm{d}},{{\mathcal{U}}}^{{\prime} },{{\bf{x}}}^{\rm{t}} >\), we compute the response of the cell line with gene expression \({{\bf{x}}}^{{\rm{d}}^{{\prime} }}\) to all perturbagens. The predicted perturbagen is that whose predicted response is closest to xt in R2 score, which quantifies the proportion of variance in treated state explained by the prediction. As scGen trains one model per perturbagen, it needs an exhaustively long training time for datasets with a large number of perturbagens, especially in the leave-cell-out setting. Therefore, we set the maximum training epochs to 100 and only conducted leave-cell-out tests for one split of data for scGen.

  • Biolord33: Biolord can predict perturbagen response for both chemical and genetic datasets. We followed the official tutorial from the Biolord GitHub repository (https://github.com/nitzanlab/biolord), using the recommended parameters. To prevent memory and quota errors, we implemented two filtering steps: (1) instead of storing the entire response gene expression (rGEX) matrix of all input (control) cells for each perturbagen, we only store a vector of the averaged rGEX of the input cells per perturbagen, which is necessary for calculating R2 for evaluation; and (2) during prediction, if the number of control cells exceeds 10,000, we randomly downsample the control cells to 10,000. Similar to scGen, we predict the responses gene expression \({{\bf{x}}}^{{\rm{d}}^{{\prime} }}\) for all perturbagens and use them to calculate R2 to get the rank of predicted perturbagens.

  • ChemCPA23: ChemCPA is specifically designed for chemical perturbation. We followed the official tutorials on GitHub for running this model (https://github.com/theislab/chemCPA), with all parameters set following the authors’ recommendations. Data processing was also conducted using the provided scripts. We constructed drug embedding using RDKit with canonical SMILES sequences, as this is the default setting in the model and the tutorial. As the original ChemCPA model lacks functionality to obtain the predicted rGEX for each drug (averaging over the dosages), we developed a custom script to perform this task. These predictions, \({{\bf{x}}}^{{\rm{d}}^{{\prime} }}\), were subsequently used for calculating R2 to get the rank of predicted perturbagens.

  • GEARS34: GEARS is capable of predicting perturbagen responses for genetic perturbation datasets, specifically for predicting the rGEX to unseen perturbagens. However, it is limited to predicting only those genes that are present in the gene network used as prior knowledge for model training. In addition, GEARS cannot process perturbagens with only one sample, so we filtered the data accordingly. We followed the official tutorial from the GEARS GitHub repository (https://github.com/snap-stanford/GEARS), using the recommended parameters. After confirming with the authors, we established that GEARS is suitable only for within-cell-line prediction. Consequently, our experiments with GEARS were conducted exclusively within this scenario.

  • CellOT27: CellOT is capable of working with both chemical and genetic datasets. We ran this model by following the official tutorial from GitHub (https://github.com/bunnech/cellot), ensuring that all parameters were set according to the provided guidelines. Due to CellOT’s limitation in processing perturbagens with small sample sizes, we filtered the data to retain only those perturbagens with more than five samples or cells. We then used the predicted rGEX \({{\bf{x}}}^{{\rm{d}}^{{\prime} }}\) to calculate R2 and the predicted perturbagen ranks. Similar to scGen, CellOT trains one model per perturbagen, which results in an exhaustively long training time for datasets with a large number of perturbagens. This issue becomes even more pronounced when doing leave-cell-out evaluations. Therefore, for this method, we set the maximum training epochs to 100 and only conduct one split in leave-cell-out tests.

Dataset splits and evaluation settings

We evaluate PDGrapher and competing methods on two different settings.

Systematic random dataset splits

For each cell line, the dataset is split randomly into train and test sets to measure our model performance in an independent and identically distributed setting.

Leave-cell-out dataset splits

To measure model performance on unseen cell lines, we train our model with random splits on one cell line and test on a new cell line. Specifically, for chemical perturbation data, we train a model for each random split per cell line and test it on the entire dataset of the remaining eight cell lines. For genetic data, we train a model for each random split per cell line and test it on the entire dataset of the remaining nine cell lines. For example, with nine cell lines with chemical perturbation (A549, MDAMB231, BT20, VCAP, MCF7, PC3, A375, HT29 and HELA), we conducted an experiment where each split of cell line A549 was used as the training set, and the trained model was tested on the remaining eight cell lines (MDAMB231, BT20, VCAP, MCF7, PC3, A375, HT29 and HELA). Similarly, for cell line MDAMB231, we trained the model on each split of it and tested the model on the other eight cell lines (A549, BT20, VCAP, MCF7, PC3, A375, HT29 and HELA). This process was repeated for all cell lines, providing a comprehensive evaluation of PDGrapher and all competing methods.

Evaluation set-up

For all dataset split settings, our model is trained using 5-fold cross-validation, and metrics are reported as the average on the test set. Within each fold, we further split the training set into training and validation sets (8:2) to perform early stopping. We train the model on the training set until the validation loss has not decreased at least 105 for 15 continuous epochs.

Evaluation metrics

We report average sample-wise R2 score and average perturbagen-wise R2 score to measure performance in the prediction of xt. The sample-wise R2 score is computed as the square of the Pearson correlation between the predicted sample \({\hat{{\bf{x}}}}^{\rm{t}}\in {{\mathbb{R}}}^{N}\) and real sample \({{\bf{x}}}^{\rm{t}}\in {{\mathbb{R}}}^{N}\). The perturbagen-wise R2 score is adopted from scGen. It is computed as the square of the Pearson correlation of a linear least-squares regression between a set of predicted treated samples \({\hat{{\bf{X}}}}^{\rm{t}}\in {{\mathbb{R}}}^{N\times S}\) and a set of real treated samples \({{\bf{X}}}^{\rm{t}}\in {{\mathbb{R}}}^{N\times S}\) for the same perturbagen. Here, S indicates the size of the sets. Higher values indicate better performance in predicting the treated sample xt given the diseased sample xd and predicted perturbagen. This is used for evaluating the performance of response prediction. For evaluating perturbagen discovery, when the competing methods cannot predict perturbagen ranks for chemical perturbation data, we first calculate the rank of drugs based on the R2 score. We then build a target gene rank from the drug rank by substituting the drugs with their target genes acquired from DrugBank75 (accessed in November 2022; see details in Supplementary Note 3). A single drug can have multiple target genes, which we place in the rank in random order. As some methods cannot predict unseen drugs, their predicted target gene lists are often short, introducing bias in evaluation. To address this, we shuffle the missing target genes and attach them to the predicted ranks to create a complete rank. For genetic perturbation data, we directly obtain the target gene rank from the results and then attach the shuffled missing genes to the rank.

To evaluate the performance of our model in ranking predicted therapeutic targets, we use the nDCG, a widely used metric in information retrieval adapted for our setting. The raw DCG score is computed by summing the relevance of each correct target based on its rank in the predicted list, with relevance weighted by a logarithmic discount factor to prioritize higher-ranked interventions. The gain function is defined as 1 − ranking/N, ensuring that the score reflects the quality of the ranking relative to the total number of nodes in the system. To ensure comparability across datasets or experiments with different numbers of correct interventions, DCG is normalized by the ideal DCG, which represents the maximum possible score for a perfect ranking. This results in nDCG values in the range [0, 1], where higher values indicate better ranking performance and alignment with the ground truth. This metric is particularly suited for our task as it emphasizes the accuracy of top-ranked interventions while accounting for the diminishing importance of lower-ranked predictions.

In addition, we report the proportion of test samples for which the predicted therapeutic targets set has at least one overlapping gene with the ground-truth therapeutic targets set (denoted as the percentage of accurately predicted samples). We also calculated the ratio of correct therapeutic targets that appeared in the top 1, top 10 and top 100 predicted therapeutic targets in the predicted rank, denoted as recall@1, recall@10 and recall@100, respectively.

To assess the overall performance across all experiments and metrics, we calculated an aggregated metric, averaging all metric values for each method.

Statistical tests

In the benchmarking experiments, we performed a one-tailed pairwise t-test to evaluate whether PDGrapher significantly outperforms the competing methods. For other experiments, such as ablation studies, we used a two-tailed t-test to determine whether there is a significant difference in performance between the two models. A significance threshold of 0.05 was used for all tests. P values of perturbagen discovery and response prediction tests are presented in the Source data.

Ablation studies

In the ablation study, we evaluated PDGrapher by optimizing it with only the supervision loss (PDGrapher-Super) and with only the cycle loss (PDGrapher-Cycle) across all chemical datasets. We then compared the perturbagen prediction performance of these submodels with that of PDGrapher (PDGrapher-SuperCycle). To train PDGrapher-Super and PDGrapher-Cycle, for each cell line, we set the number of layers to that which was found optimal for the validation set in the random splitting setting for PDGrapher-SuperCycle.

Sensitivity studies

To test the sensitivity of PDGrapher on PPI networks, we used data from STRING (string-db.org), which provides a confidence score for each edge. The method for acquiring and preprocessing the PPI networks from STRING is detailed in Supplementary Note 3. For the sensitivity tests, we selected two cell lines: the chemical dataset MDAMB231 and the genetic dataset MCF7. For each cell line, we processed the data using the five PPI networks described in Supplementary Note 3. We optimized PDGrapher using 5-fold cross-validation as described in ‘Evaluation set-up’ and optimized the number of GNN layers using the validation set in each split.

Synthetic datasets

We generated three synthetic datasets:

  1. 1.

    Dataset with missing components removing bridge edges: this dataset is generated by progressively removing bridge edges from the existing PPI network. Bridge edges are those whose removal disconnects parts of the network. We vary the fraction of bridge edges removed in increments (from zero to one) and, for each fraction, we create a new edge list representing the modified network (Supplementary Table 5). This process ensures that different levels of network sparsity are introduced, affecting the overall structure and connectivity. We pair these networks with gene expression data from chemical-PPI-breast-MDAMB231.

  2. 2.

    Dataset with missing components removing random edges: this dataset is generated by progressively removing random edges from the existing PPI network. We vary the fraction of bridge edges removed in increments ([0, 0.1, … 0.6]) and, for each fraction, we create a new edge list representing the modified network. The number of remaining directed edges in the network upon random edge removal are 273,319, 242,912, 212,525, 182,177, 151,811 and 121,472.

  3. 3.

    Dataset with latent confounder noise: our starting point is the chemical-PPI-breast-MDAMB231 dataset. The synthetic datasets were constructed with varying levels of confounding bias introduced into the gene expression data. To simulate latent confounder effects, Gaussian noise with distinct means and variances was progressively added to random subsets of genes. Genes were grouped into 50 predefined subsets, each representing a latent confounder group. For each group, a Gaussian distribution was defined, with the mean drawn randomly from a uniform distribution in the range [0.5, 0.5] and the standard deviation [0.1, 0.5]. A fraction ([0.2, 0.4, 0.6, 0.8, 1]) of these subsets was randomly selected for perturbation and, for each gene in these subsets, its expression value was incremented by a value sampled from the respective Gaussian distribution. The perturbed gene expression values were then clamped between zero and one to ensure validity. This strategy ensures that different latent biases are introduced globally to gene expression patterns while maintaining controlled variability. We pair the noisy version of the gene expression data with the global unperturbed PPI network.

Network proximity between predicted and true perturbagens

Let \({\mathcal{P}}\) be the set of predicted therapeutic targets, \({\mathcal{R}}\) be the set of ground-truth therapeutic targets, and spd(p, r) be the shortest-path distance between nodes in \({\mathcal{P}}\) and \({\mathcal{R}}\). We measure the closest distance between \({\mathcal{P}}\) and \({\mathcal{R}}\) as:

$$d({\mathcal{P}},{\mathcal{R}})=\frac{1}{| {\mathcal{R}}| | {\mathcal{P}}| }\sum _{r\in {\mathcal{R}}}\sum _{p\in {\mathcal{P}}}{\rm{spd}}(p,r).$$

(9)

As part of our performance analyses, we measure the network proximity of PDGrapher and competing methods. We compared the distributions of network proximity values using a Mann–Whitney U-test, along with a rank-biserial correlation to measure effect size. To assess the uncertainty of effect sizes, we performed bootstrapping with 1,000 resamples to estimate 95% CIs.

Reporting summary

Further information on research design is available in the Nature Portfolio Reporting Summary linked to this article.


Source link

Leave a Reply

Your email address will not be published. Required fields are marked *