cell2net.prediction.data.get_dataloader#

cell2net.prediction.data.get_dataloader(mdata, rna_mod='rna', atac_mod='atac', with_variants=False, idx=None, covariates=None, batch_size=128, num_workers=4, pin_memory=True, shuffle=True, drop_last=True, persistent_workers=True, **kwargs)#

Creates a PyTorch DataLoader from a MuData object.

This function converts a MuData object into a PyTorch DataLoader for training or evaluation. It allows customization of data loading parameters, such as batch size, shuffling, and number of workers.

Parameters:
  • mdata (MuData) – A MuData object containing multimodal data.

  • rna_mod (str (default: 'rna')) – The name of the RNA modality in the MuData object. Default is “rna”.

  • atac_mod (str (default: 'atac')) – The name of the ATAC modality in the MuData object. Default is “atac”.

  • idx (Sequence[int] | Sequence[str] | None (default: None)) – Indices or keys to subset the MuData object. If None, the entire dataset is used. Default is None.

  • covariates (Sequence[str] | None (default: None)) – Covariates to include from the MuData object. Default is None.

  • batch_size (int (default: 128)) – The number of samples per batch. Default is 128.

  • num_workers (int (default: 4)) – The number of worker processes for data loading. Default is 4.

  • pin_memory (bool (default: True)) – Whether to pin memory in DataLoader for faster GPU transfers. Default is True.

  • shuffle (bool (default: True)) – Whether to shuffle the dataset. Default is True.

  • drop_last (bool (default: True)) – Whether to drop the last incomplete batch. Default is True.

  • persistent_workers (bool (default: True)) – Whether to keep data loading workers alive between epochs. Default is True.

  • **kwargs – Additional keyword arguments passed to torch.utils.data.DataLoader.

Return type:

DataLoader

Returns:

A PyTorch DataLoader for the specified MuData dataset.

Examples

>>> from mudata import MuData
>>> from cell2net.pd.data import get_dataloader
>>> mdata = MuData("data.h5mu")
>>> dataloader = get_dataloader(mdata, batch_size=32, shuffle=True)
>>> for batch in dataloader:
>>>     print(batch)