修改代码实现,提高代码可读性和可维护性
This commit is contained in:
parent
5e3fc11c37
commit
c9e054e236
33
net.py
33
net.py
@ -31,7 +31,7 @@ class DropPath(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
# 改点,使用Pooling替换AttentionBase
|
||||
class Pooling(nn.Module):
|
||||
def __init__(self, kernel_size=3):
|
||||
super().__init__()
|
||||
@ -44,8 +44,8 @@ class Pooling(nn.Module):
|
||||
|
||||
class PoolMlp(nn.Module):
|
||||
"""
|
||||
Implementation of MLP with 1*1 convolutions.
|
||||
Input: tensor with shape [B, C, H, W]
|
||||
实现基于1x1卷积的MLP模块。
|
||||
输入:形状为[B, C, H, W]的张量。
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -55,6 +55,17 @@ class PoolMlp(nn.Module):
|
||||
act_layer=nn.GELU,
|
||||
bias=False,
|
||||
drop=0.):
|
||||
"""
|
||||
初始化PoolMlp模块。
|
||||
|
||||
参数:
|
||||
in_features (int): 输入特征的数量。
|
||||
hidden_features (int, 可选): 隐藏层特征的数量。默认为None,设置为与in_features相同。
|
||||
out_features (int, 可选): 输出特征的数量。默认为None,设置为与in_features相同。
|
||||
act_layer (nn.Module, 可选): 使用的激活层。默认为nn.GELU。
|
||||
bias (bool, 可选): 是否在卷积层中包含偏置项。默认为False。
|
||||
drop (float, 可选): Dropout比率。默认为0。
|
||||
"""
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
@ -64,6 +75,15 @@ class PoolMlp(nn.Module):
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
通过PoolMlp模块的前向传播。
|
||||
|
||||
参数:
|
||||
x (torch.Tensor): 形状为[B, C, H, W]的输入张量。
|
||||
|
||||
返回:
|
||||
torch.Tensor: 形状为[B, C, H, W]的输出张量。
|
||||
"""
|
||||
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
@ -71,6 +91,7 @@ class PoolMlp(nn.Module):
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class BaseFeatureExtraction(nn.Module):
|
||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
@ -108,7 +129,7 @@ class BaseFeatureExtraction(nn.Module):
|
||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||
* self.poolmlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
||||
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
@ -131,11 +152,9 @@ class InvertedResidualBlock(nn.Module):
|
||||
nn.Conv2d(hidden_dim, oup, 1, bias=False),
|
||||
# nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bottleneckBlock(x)
|
||||
|
||||
|
||||
class DetailNode(nn.Module):
|
||||
def __init__(self):
|
||||
super(DetailNode, self).__init__()
|
||||
@ -163,14 +182,12 @@ class DetailFeatureExtraction(nn.Module):
|
||||
super(DetailFeatureExtraction, self).__init__()
|
||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||
self.net = nn.Sequential(*INNmodules)
|
||||
|
||||
def forward(self, x):
|
||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||
for layer in self.net:
|
||||
z1, z2 = layer(z1, z2)
|
||||
return torch.cat((z1, z2), dim=1)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
# =============================================================================
|
||||
|
5
requirement.txt
Normal file
5
requirement.txt
Normal file
@ -0,0 +1,5 @@
|
||||
|
||||
scipy==1.9.3
|
||||
scikit-image==0.19.2
|
||||
scikit-learn==1.1.3
|
||||
tqdm==4.62.0
|
@ -13,13 +13,13 @@ logging.basicConfig(level=logging.CRITICAL)
|
||||
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
ckpt_path= r"models/PFCFuse.pth"
|
||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-18-13.pth"
|
||||
|
||||
for dataset_name in ["MSRS","TNO","RoadScene"]:
|
||||
for dataset_name in ["TNO"]:
|
||||
print("\n"*2+"="*80)
|
||||
model_name="PFCFuse "
|
||||
print("The test result of "+dataset_name+' :')
|
||||
test_folder=os.path.join('test_img',dataset_name)
|
||||
test_folder=os.path.join('/home/star/whaiDir/CDDFuse/test_img/',dataset_name)
|
||||
test_out_folder=os.path.join('test_result',dataset_name)
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
@ -39,6 +39,7 @@ for dataset_name in ["MSRS","TNO","RoadScene"]:
|
||||
|
||||
with torch.no_grad():
|
||||
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
||||
print(img_name)
|
||||
|
||||
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0
|
||||
@ -60,7 +61,7 @@ for dataset_name in ["MSRS","TNO","RoadScene"]:
|
||||
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
|
||||
img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder)
|
||||
|
||||
eval_folder=test_out_folder
|
||||
eval_folder=test_out_folder
|
||||
ori_img_folder=test_folder
|
||||
|
||||
metric_result = np.zeros((8))
|
||||
|
5
train.py
5
train.py
@ -87,7 +87,7 @@ Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
||||
HuberLoss = nn.HuberLoss()
|
||||
|
||||
# data loader
|
||||
trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"),
|
||||
trainloader = DataLoader(H5Dataset(r"/home/star/whaiDir/CDDFuse/data/MSRS_train_imgsize_128_stride_200.h5"),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0)
|
||||
@ -201,13 +201,14 @@ for epoch in range(num_epochs):
|
||||
epoch_time = time.time() - prev_time
|
||||
prev_time = time.time()
|
||||
sys.stdout.write(
|
||||
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f]"
|
||||
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f] ETA: %.10s"
|
||||
% (
|
||||
epoch,
|
||||
num_epochs,
|
||||
i,
|
||||
len(loader['train']),
|
||||
loss.item(),
|
||||
time_left,
|
||||
)
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user