From 96ce7d5fdab41c8b68f65d7b74584a9e08c4cf9a Mon Sep 17 00:00:00 2001 From: whaifree Date: Tue, 8 Oct 2024 16:50:11 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BB=A3=E7=A0=81=E7=BB=93?= =?UTF-8?q?=E6=9E=84=EF=BC=8C=E6=8F=90=E9=AB=98=E5=8F=AF=E8=AF=BB=E6=80=A7?= =?UTF-8?q?=E5=92=8C=E5=8F=AF=E7=BB=B4=E6=8A=A4=E6=80=A7=EF=BC=9B=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E8=AE=AD=E7=BB=83=E8=BE=93=E5=87=BA=E9=A2=91=E7=8E=87?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 改进 self.enhancement_module 为 self.enhancement_module = WTConv2d(32, 32) --- net.py | 6 +----- train.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/net.py b/net.py index d9fb694..f40fb99 100644 --- a/net.py +++ b/net.py @@ -248,11 +248,7 @@ class DetailFeatureExtraction(nn.Module): super(DetailFeatureExtraction, self).__init__() INNmodules = [DetailNode() for _ in range(num_layers)] self.net = nn.Sequential(*INNmodules) - self.enhancement_module = nn.Sequential( - nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True), - nn.ReLU(inplace=True), - nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True), - ) + self.enhancement_module = WTConv2d(32, 32) def forward(self, x): # 1 64 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 diff --git a/train.py b/train.py index f5a6abd..b963b00 100644 --- a/train.py +++ b/train.py @@ -222,7 +222,7 @@ for epoch in range(num_epochs): time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) epoch_time = time.time() - prev_time prev_time = time.time() - if step % 100 == 0: + if i % 100 == 0: sys.stdout.write( "\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s" % (