import torch.utils.data as Data import h5py import numpy as np import torch class H5Dataset(Data.Dataset): def __init__(self, h5file_path): self.h5file_path = h5file_path h5f = h5py.File(h5file_path, 'r') self.keys = list(h5f['ir_patchs'].keys()) h5f.close() def __len__(self): return len(self.keys) def __getitem__(self, index): h5f = h5py.File(self.h5file_path, 'r') key = self.keys[index] IR = np.array(h5f['ir_patchs'][key]) VIS = np.array(h5f['vis_patchs'][key]) h5f.close() return torch.Tensor(VIS), torch.Tensor(IR)