-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathDataSet.py
More file actions
38 lines (25 loc) · 1.38 KB
/
Copy pathDataSet.py
File metadata and controls
38 lines (25 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
import glob
import os
import Opers
class imageDataset(Dataset):
def __init__(self, normal_dir, blurry_dir, transform=None):
working_dir = os.path.dirname(os.path.realpath(__file__))
self.path_normal_dir = os.path.join(working_dir, normal_dir)
self.path_blurry_dir = os.path.join(working_dir, blurry_dir)
self.blurry_images = [x.replace(self.path_blurry_dir, '') for x in glob.glob(self.path_blurry_dir + '/*.jpg')]
self.normal_images = [x.replace(self.path_normal_dir, '') for x in glob.glob(self.path_normal_dir + '/*.jpg')]
self.transform = transform
if self.blurry_images not in self.normal_images and len(self.blurry_images) != len(self.normal_images):
raise Exception('mismatch between the normal images and the blurry ones')
def __len__(self):
return len(self.normal_images)
def __getitem__(self, idx):
normal_image = io.imread(self.path_normal_dir + '/' + self.normal_images[idx])
blurry_image = io.imread(self.path_blurry_dir + '/' + self.normal_images[idx])
sample = {'normal': normal_image, 'blurry': blurry_image}
if self.transform:
n, b = self.transform(sample['normal']), self.transform(sample['blurry'])
sample = {'normal': n, 'blurry': b}
return sample