cell2net.interpretation.peak_attr#

cell2net.interpretation.peak_attr(model, idx=None, batch_size=4, num_workers=1, n_steps=100, multiply_by_inputs=True)#

Compute feature attributions for peak accessibility using Integrated Gradients.

This function uses Integrated Gradients (IG) to compute the attributions for peak accessibility features in a Cell2Net model. Attributions indicate the importance of each feature for the model’s predictions.

Parameters:
  • model (Cell2Net) – The trained Cell2Net model for which feature attributions are computed. The model should have an attribute mdata for data, and it must be set up for gradient computations.

  • idx (Sequence[int] | Sequence[str] | None (default: None)) – Indices or identifiers for the specific observations to include in the attribution computation. If None, all observations in the dataset are used.

  • batch_size (int (default: 4)) – The number of samples per batch for the DataLoader.

  • num_workers (int (default: 1)) – The number of worker threads to use for data loading.

  • n_steps (int (default: 100)) – The number of interpolation steps for the Integrated Gradients computation.

  • multiply_by_inputs (bool (default: True)) – Whether to scale attributions by the input values as per the Integrated Gradients method.

Return type:

ndarray

Returns:

An array of attributions for peak accessibility features, with shape corresponding to the input dataset.

Notes

  • This function uses the captum library for Integrated Gradients.

  • The model is expected to have the following inputs:

    • peak_seq: One-hot encoded sequence of peaks.

    • peak_acc: Peak accessibility values (with gradients enabled).

    • peak_dist: Distances of peaks to transcription start sites.

    • tf_exp: Transcription factor expression values.

    • covariates: Additional covariates provided as arguments.

  • Baseline values are set to zero for peak_acc.

Examples

>>> model = Cell2Net(mdata, ...)
>>> attributions = compute_peak_attr(
...     model=model,
...     idx=[0, 1, 2],
...     batch_size=8,
...     num_workers=2,
...     n_steps=50,
... )
>>> print(attributions.shape)
(number_of_samples, number_of_features)