cell2net.interpretation.tf_attr#
- cell2net.interpretation.tf_attr(model, idx=None, batch_size=4, num_workers=1, n_steps=100, multiply_by_inputs=True)#
Compute transcription factor (TF) attribution using Integrated Gradients
This function calculates the attribution of TF expression to the output of a Cell2Net model using the Integrated Gradients method. The attributions are computed over a specified dataset and returned as a NumPy array.
- Parameters:
model (
Cell2Net
) – The trained Cell2Net model. It must have a mdata attribute for metadata and a covariates attribute for covariate information.idx (
list
[int
] |list
[str
] |None
(default:None
)) – Indices or identifiers specifying the subset of the data to compute attributions for. If None, the entire dataset is used.batch_size (default:
4
) – The batch size to use for data loading.num_workers (default:
1
) – Number of worker processes for data loading.n_steps (default:
100
) – The number of steps to use for the Integrated Gradients computation. Larger values provide more accurate estimates but increase computation time.multiply_by_inputs (default:
True
) – Whether to multiply the attributions by the inputs. This is recommended to preserve implementation invariance.
- Return type:
- Returns:
A NumPy array containing the attributions for TF expression. The shape of the output depends on the dataset and the number of TFs modeled.
Notes
This function uses the Integrated Gradients algorithm for attribution computation. The captum library is required to perform this calculation.
The attributions are computed for the tf_exp input (transcription factor expression) while keeping other inputs (peak sequence, accessibility, and distance) fixed.
The model is set to training mode (model.module.train()) during computation.
Examples
>>> model = Cell2Net(...) >>> idx = [0, 1, 2, 3] # Indices of samples to compute attributions for >>> attributions = compute_tf_attr( ... model=model, ... idx=idx, ... batch_size=4, ... num_workers=2, ... n_steps=50, ... multiply_by_inputs=True, ... ) >>> attributions.shape (4, num_tfs)