cell2net.interpretation.peak_attr#
- cell2net.interpretation.peak_attr(model, idx=None, batch_size=4, num_workers=1, n_steps=100, multiply_by_inputs=True)#
Compute feature attributions for peak accessibility using Integrated Gradients.
This function uses Integrated Gradients (IG) to compute the attributions for peak accessibility features in a Cell2Net model. Attributions indicate the importance of each feature for the model’s predictions.
- Parameters:
model (
Cell2Net
) – The trained Cell2Net model for which feature attributions are computed. The model should have an attribute mdata for data, and it must be set up for gradient computations.idx (
Sequence
[int
] |Sequence
[str
] |None
(default:None
)) – Indices or identifiers for the specific observations to include in the attribution computation. If None, all observations in the dataset are used.batch_size (
int
(default:4
)) – The number of samples per batch for the DataLoader.num_workers (
int
(default:1
)) – The number of worker threads to use for data loading.n_steps (
int
(default:100
)) – The number of interpolation steps for the Integrated Gradients computation.multiply_by_inputs (
bool
(default:True
)) – Whether to scale attributions by the input values as per the Integrated Gradients method.
- Return type:
- Returns:
An array of attributions for peak accessibility features, with shape corresponding to the input dataset.
Notes
This function uses the captum library for Integrated Gradients.
The model is expected to have the following inputs:
peak_seq: One-hot encoded sequence of peaks.
peak_acc: Peak accessibility values (with gradients enabled).
peak_dist: Distances of peaks to transcription start sites.
tf_exp: Transcription factor expression values.
covariates: Additional covariates provided as arguments.
Baseline values are set to zero for peak_acc.
Examples
>>> model = Cell2Net(mdata, ...) >>> attributions = compute_peak_attr( ... model=model, ... idx=[0, 1, 2], ... batch_size=8, ... num_workers=2, ... n_steps=50, ... ) >>> print(attributions.shape) (number_of_samples, number_of_features)