cell2net.prediction.data.get_dataloader#
- cell2net.prediction.data.get_dataloader(mdata, rna_mod='rna', atac_mod='atac', with_variants=False, idx=None, covariates=None, batch_size=128, num_workers=4, pin_memory=True, shuffle=True, drop_last=True, persistent_workers=True, **kwargs)#
Creates a PyTorch DataLoader from a MuData object.
This function converts a MuData object into a PyTorch DataLoader for training or evaluation. It allows customization of data loading parameters, such as batch size, shuffling, and number of workers.
- Parameters:
mdata (
MuData
) – A MuData object containing multimodal data.rna_mod (
str
(default:'rna'
)) – The name of the RNA modality in the MuData object. Default is “rna”.atac_mod (
str
(default:'atac'
)) – The name of the ATAC modality in the MuData object. Default is “atac”.idx (
Sequence
[int
] |Sequence
[str
] |None
(default:None
)) – Indices or keys to subset the MuData object. If None, the entire dataset is used. Default is None.covariates (
Sequence
[str
] |None
(default:None
)) – Covariates to include from the MuData object. Default is None.batch_size (
int
(default:128
)) – The number of samples per batch. Default is 128.num_workers (
int
(default:4
)) – The number of worker processes for data loading. Default is 4.pin_memory (
bool
(default:True
)) – Whether to pin memory in DataLoader for faster GPU transfers. Default is True.shuffle (
bool
(default:True
)) – Whether to shuffle the dataset. Default is True.drop_last (
bool
(default:True
)) – Whether to drop the last incomplete batch. Default is True.persistent_workers (
bool
(default:True
)) – Whether to keep data loading workers alive between epochs. Default is True.**kwargs – Additional keyword arguments passed to torch.utils.data.DataLoader.
- Return type:
DataLoader
- Returns:
A PyTorch DataLoader for the specified MuData dataset.
Examples
>>> from mudata import MuData >>> from cell2net.pd.data import get_dataloader >>> mdata = MuData("data.h5mu") >>> dataloader = get_dataloader(mdata, batch_size=32, shuffle=True) >>> for batch in dataloader: >>> print(batch)