cell2net.interpretation.tf_attr#

cell2net.interpretation.tf_attr(model, idx=None, batch_size=4, num_workers=1, n_steps=100, multiply_by_inputs=True)#

Compute transcription factor (TF) attribution using Integrated Gradients

This function calculates the attribution of TF expression to the output of a Cell2Net model using the Integrated Gradients method. The attributions are computed over a specified dataset and returned as a NumPy array.

Parameters:
  • model (Cell2Net) – The trained Cell2Net model. It must have a mdata attribute for metadata and a covariates attribute for covariate information.

  • idx (list[int] | list[str] | None (default: None)) – Indices or identifiers specifying the subset of the data to compute attributions for. If None, the entire dataset is used.

  • batch_size (default: 4) – The batch size to use for data loading.

  • num_workers (default: 1) – Number of worker processes for data loading.

  • n_steps (default: 100) – The number of steps to use for the Integrated Gradients computation. Larger values provide more accurate estimates but increase computation time.

  • multiply_by_inputs (default: True) – Whether to multiply the attributions by the inputs. This is recommended to preserve implementation invariance.

Return type:

ndarray

Returns:

A NumPy array containing the attributions for TF expression. The shape of the output depends on the dataset and the number of TFs modeled.

Notes

  • This function uses the Integrated Gradients algorithm for attribution computation. The captum library is required to perform this calculation.

  • The attributions are computed for the tf_exp input (transcription factor expression) while keeping other inputs (peak sequence, accessibility, and distance) fixed.

  • The model is set to training mode (model.module.train()) during computation.

Examples

>>> model = Cell2Net(...)
>>> idx = [0, 1, 2, 3]  # Indices of samples to compute attributions for
>>> attributions = compute_tf_attr(
...     model=model,
...     idx=idx,
...     batch_size=4,
...     num_workers=2,
...     n_steps=50,
...     multiply_by_inputs=True,
... )
>>> attributions.shape
(4, num_tfs)