增加了INN部分的残差连接模块,修改了训练和测试代码以提高代码的可读性和可维护性。- 在train.py中添加了打印所有参数的代码,以方便检查和记录
This commit is contained in:
parent
faacea007c
commit
7068b627c4
6
net.py
6
net.py
@ -248,7 +248,11 @@ class DetailFeatureExtraction(nn.Module):
|
||||
super(DetailFeatureExtraction, self).__init__()
|
||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||
self.net = nn.Sequential(*INNmodules)
|
||||
|
||||
self.enhancement_module = nn.Sequential(
|
||||
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x): # 1 64 128 128
|
||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
||||
|
@ -13,7 +13,7 @@ logging.basicConfig(level=logging.CRITICAL)
|
||||
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/PFCFuse_IVF.pth"
|
||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth"
|
||||
|
||||
for dataset_name in ["TNO"]:
|
||||
print("\n"*2+"="*80)
|
||||
|
24
train.py
24
train.py
@ -34,7 +34,7 @@ criteria_fusion = Fusionloss()
|
||||
model_str = 'PFCFuse'
|
||||
|
||||
# . Set the hyper-parameters for training
|
||||
num_epochs = 120 # total epoch
|
||||
num_epochs = 60 # total epoch
|
||||
epoch_gap = 40 # epoches of Phase I
|
||||
|
||||
lr = 1e-4
|
||||
@ -57,6 +57,28 @@ clip_grad_norm_value = 0.01
|
||||
optim_step = 20
|
||||
optim_gamma = 0.5
|
||||
|
||||
# 打印所有参数
|
||||
print(f"Model: {model_str}")
|
||||
print(f"Number of epochs: {num_epochs}")
|
||||
print(f"Epoch gap: {epoch_gap}")
|
||||
print(f"Learning rate: {lr}")
|
||||
print(f"Weight decay: {weight_decay}")
|
||||
print(f"Batch size: {batch_size}")
|
||||
print(f"GPU number: {GPU_number}")
|
||||
|
||||
print(f"Coefficient of MSE loss VF: {coeff_mse_loss_VF}")
|
||||
print(f"Coefficient of MSE loss IF: {coeff_mse_loss_IF}")
|
||||
print(f"Coefficient of RMI loss VF: {coeff_rmi_loss_VF}")
|
||||
print(f"Coefficient of RMI loss IF: {coeff_rmi_loss_IF}")
|
||||
print(f"Coefficient of Cosine loss VF: {coeff_cos_loss_VF}")
|
||||
print(f"Coefficient of Cosine loss IF: {coeff_cos_loss_IF}")
|
||||
print(f"Coefficient of Decomposition loss: {coeff_decomp}")
|
||||
print(f"Coefficient of Total Variation loss: {coeff_tv}")
|
||||
|
||||
print(f"Clip gradient norm value: {clip_grad_norm_value}")
|
||||
print(f"Optimization step: {optim_step}")
|
||||
print(f"Optimization gamma: {optim_gamma}")
|
||||
|
||||
|
||||
# Model
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
Loading…
Reference in New Issue
Block a user