22 lines
629 B
Python
22 lines
629 B
Python
|
import torch.utils.data as Data
|
||
|
import h5py
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
class H5Dataset(Data.Dataset):
|
||
|
def __init__(self, h5file_path):
|
||
|
self.h5file_path = h5file_path
|
||
|
h5f = h5py.File(h5file_path, 'r')
|
||
|
self.keys = list(h5f['ir_patchs'].keys())
|
||
|
h5f.close()
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.keys)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
h5f = h5py.File(self.h5file_path, 'r')
|
||
|
key = self.keys[index]
|
||
|
IR = np.array(h5f['ir_patchs'][key])
|
||
|
VIS = np.array(h5f['vis_patchs'][key])
|
||
|
h5f.close()
|
||
|
return torch.Tensor(VIS), torch.Tensor(IR)
|