修改代码结构,提高可读性和可维护性;调整训练输出频率。
改进 self.enhancement_module 为 self.enhancement_module = WTConv2d(32, 32)
This commit is contained in:
parent
afd55abe9e
commit
96ce7d5fda
6
net.py
6
net.py
@ -248,11 +248,7 @@ class DetailFeatureExtraction(nn.Module):
|
|||||||
super(DetailFeatureExtraction, self).__init__()
|
super(DetailFeatureExtraction, self).__init__()
|
||||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
self.net = nn.Sequential(*INNmodules)
|
self.net = nn.Sequential(*INNmodules)
|
||||||
self.enhancement_module = nn.Sequential(
|
self.enhancement_module = WTConv2d(32, 32)
|
||||||
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
|
|
||||||
nn.ReLU(inplace=True),
|
|
||||||
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x): # 1 64 128 128
|
def forward(self, x): # 1 64 128 128
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
||||||
|
2
train.py
2
train.py
@ -222,7 +222,7 @@ for epoch in range(num_epochs):
|
|||||||
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
||||||
epoch_time = time.time() - prev_time
|
epoch_time = time.time() - prev_time
|
||||||
prev_time = time.time()
|
prev_time = time.time()
|
||||||
if step % 100 == 0:
|
if i % 100 == 0:
|
||||||
sys.stdout.write(
|
sys.stdout.write(
|
||||||
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
||||||
% (
|
% (
|
||||||
|
Loading…
Reference in New Issue
Block a user