Add files via upload

This commit is contained in:
HXY13 2024-06-03 19:36:29 +08:00 committed by GitHub
parent e6852193de
commit da5da74611
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 1322 additions and 0 deletions

93
dataprocessing.py Normal file
View File

@ -0,0 +1,93 @@
import os
import h5py
import numpy as np
from tqdm import tqdm
from skimage.io import imread
def get_img_file(file_name):
imagelist = []
for parent, dirnames, filenames in os.walk(file_name):
for filename in filenames:
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
imagelist.append(os.path.join(parent, filename))
return imagelist
def rgb2y(img):
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
return y
def Im2Patch(img, win, stride=1):
k = 0
endc = img.shape[0]
endw = img.shape[1]
endh = img.shape[2]
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
TotalPatNum = patch.shape[1] * patch.shape[2]
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
for i in range(win):
for j in range(win):
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
k = k + 1
return Y.reshape([endc, win, win, TotalPatNum])
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
upper_percentile=90):
"""Determine if an image is low contrast."""
limits = np.percentile(image, [lower_percentile, upper_percentile])
ratio = (limits[1] - limits[0]) / limits[1]
return ratio < fraction_threshold
data_name="MSRS_train"
img_size=128 #patch size
stride=200 #patch stride
IR_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/ir"))
VIS_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/vi"))
assert len(IR_files) == len(VIS_files)
h5f = h5py.File(os.path.join('.\\data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
'w')
h5_ir = h5f.create_group('ir_patchs')
h5_vis = h5f.create_group('vis_patchs')
train_num=0
for i in tqdm(range(len(IR_files))):
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
# crop
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
for ii in range(I_IR_Patch_Group.shape[-1]):
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
# Determine if the contrast is low
if not (bad_IR or bad_VIS):
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
avl_IR=avl_IR[None,...]
avl_VIS=avl_VIS[None,...]
h5_ir.create_dataset(str(train_num), data=avl_IR,
dtype=avl_IR.dtype, shape=avl_IR.shape)
h5_vis.create_dataset(str(train_num), data=avl_VIS,
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
train_num += 1
h5f.close()
with h5py.File(os.path.join('data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
for key in f.keys():
print(f[key], key, f[key].name)

427
net.py Normal file
View File

@ -0,0 +1,427 @@
# poolformer
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Pooling(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
self.pool = nn.AvgPool2d(
kernel_size, stride=1, padding=kernel_size // 2)
def forward(self, x):
return self.pool(x) - x
class PoolMlp(nn.Module):
"""
Implementation of MLP with 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=False,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=bias)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=bias)
self.drop = nn.Dropout(drop)
# self.apply(self._init_weights)
# def _init_weights(self, m):
# if isinstance(m, nn.Conv2D):
# trunc_normal_(m.weight)
# if m.bias is not None:
# zeros_(m.bias)
def forward(self, x):
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x) # (B, C, H, W) --> (B, C, H, W)
x = self.drop(x)
return x
class BaseFeatureExtraction(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class InvertedResidualBlock(nn.Module):
def __init__(self, inp, oup, expand_ratio):
super(InvertedResidualBlock, self).__init__()
hidden_dim = int(inp * expand_ratio)
self.bottleneckBlock = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.ReflectionPad2d(1),
nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, bias=False),
# nn.BatchNorm2d(oup),
)
def forward(self, x):
return self.bottleneckBlock(x)
class DetailNode(nn.Module):
def __init__(self):
super(DetailNode, self).__init__()
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
stride=1, padding=0, bias=True)
def separateFeature(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
return z1, z2
def forward(self, z1, z2):
z1, z2 = self.separateFeature(
self.shffleconv(torch.cat((z1, z2), dim=1)))
z2 = z2 + self.theta_phi(z1)
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
return z1, z2
class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
# =============================================================================
# =============================================================================
import numbers
##########################################################################
## Layer Norm
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(
dim, hidden_features * 2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3,
stride=1, padding=1, groups=hidden_features * 2, bias=bias)
self.project_out = nn.Conv2d(
hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
##########################################################################
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
stride=1, padding=1, bias=bias)
def forward(self, x):
x = self.proj(x)
return x
class Restormer_Encoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Encoder, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.baseFeature = BaseFeatureExtraction(dim=dim)
self.detailFeature = DetailFeatureExtraction()
def forward(self, inp_img):
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
base_feature = self.baseFeature(out_enc_level1)
detail_feature = self.detailFeature(out_enc_level1)
return base_feature, detail_feature, out_enc_level1
class Restormer_Decoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Decoder, self).__init__()
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
self.encoder_level2 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.output = nn.Sequential(
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
stride=1, padding=1, bias=bias),
nn.LeakyReLU(),
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
stride=1, padding=1, bias=bias), )
self.sigmoid = nn.Sigmoid()
def forward(self, inp_img, base_feature, detail_feature):
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
out_enc_level0 = self.reduce_channel(out_enc_level0)
out_enc_level1 = self.encoder_level2(out_enc_level0)
if inp_img is not None:
out_enc_level1 = self.output(out_enc_level1) + inp_img
else:
out_enc_level1 = self.output(out_enc_level1)
return self.sigmoid(out_enc_level1), out_enc_level0
if __name__ == '__main__':
height = 128
width = 128
window_size = 8
modelE = Restormer_Encoder().cuda()
modelD = Restormer_Decoder().cuda()

95
test_IVF.py Normal file
View File

@ -0,0 +1,95 @@
import cv2
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
import numpy as np
from utils.Evaluator import Evaluator
import torch
import torch.nn as nn
from utils.img_read_save import img_save,image_read_cv2
import warnings
import logging
# 增加
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path= r"models/PFCFuse.pth"
for dataset_name in ["MSRS","TNO","RoadScene"]:
print("\n"*2+"="*80)
model_name="PFCFuse "
print("The test result of "+dataset_name+' :')
test_folder=os.path.join('test_img',dataset_name)
test_out_folder=os.path.join('test_result',dataset_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'],strict=False)
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
Encoder.eval()
Decoder.eval()
BaseFuseLayer.eval()
DetailFuseLayer.eval()
with torch.no_grad():
for img_name in os.listdir(os.path.join(test_folder,"ir")):
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
# 改
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0
# ycrcb, uint8
data_VIS_BGR = cv2.imread(os.path.join(test_folder, "vi", img_name))
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
# 改
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
# 改
# float32 to uint8
fi = fi.astype(np.uint8)
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder)
# 改
eval_folder=test_out_folder
ori_img_folder=test_folder
metric_result = np.zeros((8))
for img_name in os.listdir(os.path.join(ori_img_folder,"ir")):
ir = image_read_cv2(os.path.join(ori_img_folder,"ir", img_name), 'GRAY')
vi = image_read_cv2(os.path.join(ori_img_folder,"vi", img_name), 'GRAY')
fi = image_read_cv2(os.path.join(eval_folder, img_name.split('.')[0]+".png"), 'GRAY')
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
metric_result /= len(os.listdir(eval_folder))
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
print(model_name+'\t'+str(np.round(metric_result[0], 2))+'\t'
+str(np.round(metric_result[1], 2))+'\t'
+str(np.round(metric_result[2], 2))+'\t'
+str(np.round(metric_result[3], 2))+'\t'
+str(np.round(metric_result[4], 2))+'\t'
+str(np.round(metric_result[5], 2))+'\t'
+str(np.round(metric_result[6], 2))+'\t'
+str(np.round(metric_result[7], 2))
)
print("="*80)

239
train.py Normal file
View File

@ -0,0 +1,239 @@
# -*- coding: utf-8 -*-
'''
------------------------------------------------------------------------------
Import packages
------------------------------------------------------------------------------
'''
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
from utils.dataset import H5Dataset
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import sys
import time
import datetime
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utils.loss import Fusionloss, cc, relative_diff_loss
import kornia
print(torch.__version__)
print(torch.cuda.is_available())
'''
------------------------------------------------------------------------------
Configure our network
------------------------------------------------------------------------------
'''
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
criteria_fusion = Fusionloss()
model_str = 'PFCFuse'
# . Set the hyper-parameters for training
num_epochs = 120 # total epoch
epoch_gap = 40 # epoches of Phase I
lr = 1e-4
weight_decay = 0
batch_size = 1
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
# Coefficients of the loss function
coeff_mse_loss_VF = 1. # alpha1
coeff_mse_loss_IF = 1.
coeff_rmi_loss_VF = 1.
coeff_rmi_loss_IF = 1.
coeff_cos_loss_VF = 1.
coeff_cos_loss_IF = 1.
coeff_decomp = 2. # alpha2 and alpha4
coeff_tv = 5.
clip_grad_norm_value = 0.01
optim_step = 20
optim_gamma = 0.5
# Model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
# optimizer, scheduler and loss function
optimizer1 = torch.optim.Adam(
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer2 = torch.optim.Adam(
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
optimizer3 = torch.optim.Adam(
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
optimizer4 = torch.optim.Adam(
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
MSELoss = nn.MSELoss()
L1Loss = nn.L1Loss()
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
HuberLoss = nn.HuberLoss()
# data loader
trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"),
batch_size=batch_size,
shuffle=True,
num_workers=0)
loader = {'train': trainloader, }
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
'''
------------------------------------------------------------------------------
Train
------------------------------------------------------------------------------
'''
step = 0
torch.backends.cudnn.benchmark = True
prev_time = time.time()
start = time.time()
for epoch in range(num_epochs):
''' train '''
for i, (data_VIS, data_IR) in enumerate(loader['train']):
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
DIDF_Encoder.train()
DIDF_Decoder.train()
BaseFuseLayer.train()
DetailFuseLayer.train()
DIDF_Encoder.zero_grad()
DIDF_Decoder.zero_grad()
BaseFuseLayer.zero_grad()
DetailFuseLayer.zero_grad()
optimizer1.zero_grad()
optimizer2.zero_grad()
optimizer3.zero_grad()
optimizer4.zero_grad()
if epoch < epoch_gap: #Phase I
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR)
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
cc_loss_B = cc(feature_V_B, feature_I_B)
cc_loss_D = cc(feature_V_D, feature_I_D)
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat)
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat)
# print("mse_loss_V", mse_loss_V)
# print("mse_loss_I", mse_loss_I)
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
kornia.filters.SpatialGradient()(data_VIS_hat))
# print("Gradient_loss", Gradient_loss)
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
# print("loss_decomp", loss_decomp)
loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat)
loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat)
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss + \
coeff_rmi_loss_IF * loss_rmi_i + coeff_rmi_loss_VF * loss_rmi_v
loss.backward()
nn.utils.clip_grad_norm_(
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
optimizer1.step()
optimizer2.step()
else: #Phase II
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
mse_loss_V = 5 * Loss_ssim(data_VIS, data_Fuse) + HuberLoss(data_VIS, data_Fuse)
mse_loss_I = 5 * Loss_ssim(data_IR, data_Fuse) + HuberLoss(data_IR, data_Fuse)
cc_loss_B = cc(feature_V_B, feature_I_B)
cc_loss_D = cc(feature_V_D, feature_I_D)
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
loss = fusionloss + coeff_decomp * loss_decomp
loss.backward()
nn.utils.clip_grad_norm_(
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
nn.utils.clip_grad_norm_(
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
optimizer1.step()
optimizer2.step()
optimizer3.step()
optimizer4.step()
# Determine approximate time left
batches_done = epoch * len(loader['train']) + i
batches_left = num_epochs * len(loader['train']) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
epoch_time = time.time() - prev_time
prev_time = time.time()
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f]"
% (
epoch,
num_epochs,
i,
len(loader['train']),
loss.item(),
)
)
# adjust the learning rate
scheduler1.step()
scheduler2.step()
if not epoch < epoch_gap:
scheduler3.step()
scheduler4.step()
if optimizer1.param_groups[0]['lr'] <= 1e-6:
optimizer1.param_groups[0]['lr'] = 1e-6
if optimizer2.param_groups[0]['lr'] <= 1e-6:
optimizer2.param_groups[0]['lr'] = 1e-6
if optimizer3.param_groups[0]['lr'] <= 1e-6:
optimizer3.param_groups[0]['lr'] = 1e-6
if optimizer4.param_groups[0]['lr'] <= 1e-6:
optimizer4.param_groups[0]['lr'] = 1e-6
if True:
checkpoint = {
'DIDF_Encoder': DIDF_Encoder.state_dict(),
'DIDF_Decoder': DIDF_Decoder.state_dict(),
'BaseFuseLayer': BaseFuseLayer.state_dict(),
'DetailFuseLayer': DetailFuseLayer.state_dict(),
}
torch.save(checkpoint, os.path.join("models/PFCFusion"+timestamp+'.pth'))

305
utils/Evaluator.py Normal file
View File

@ -0,0 +1,305 @@
import numpy as np
import cv2
import sklearn.metrics as skm
from scipy.signal import convolve2d
import math
from skimage.metrics import structural_similarity as ssim
def image_read_cv2(path, mode='RGB'):
img_BGR = cv2.imread(path).astype('float32')
assert mode == 'RGB' or mode == 'GRAY' or mode == 'YCrCb', 'mode error'
if mode == 'RGB':
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2RGB)
elif mode == 'GRAY':
img = np.round(cv2.cvtColor(img_BGR, cv2.COLOR_BGR2GRAY))
elif mode == 'YCrCb':
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2YCrCb)
return img
class Evaluator():
@classmethod
def input_check(cls, imgF, imgA=None, imgB=None):
if imgA is None:
assert type(imgF) == np.ndarray, 'type error'
assert len(imgF.shape) == 2, 'dimension error'
else:
assert type(imgF) == type(imgA) == type(imgB) == np.ndarray, 'type error'
assert imgF.shape == imgA.shape == imgB.shape, 'shape error'
assert len(imgF.shape) == 2, 'dimension error'
@classmethod
def EN(cls, img): # entropy
cls.input_check(img)
a = np.uint8(np.round(img)).flatten()
h = np.bincount(a) / a.shape[0]
return -sum(h * np.log2(h + (h == 0)))
@classmethod
def SD(cls, img):
cls.input_check(img)
return np.std(img)
@classmethod
def SF(cls, img):
cls.input_check(img)
return np.sqrt(np.mean((img[:, 1:] - img[:, :-1]) ** 2) + np.mean((img[1:, :] - img[:-1, :]) ** 2))
@classmethod
def AG(cls, img): # Average gradient
cls.input_check(img)
Gx, Gy = np.zeros_like(img), np.zeros_like(img)
Gx[:, 0] = img[:, 1] - img[:, 0]
Gx[:, -1] = img[:, -1] - img[:, -2]
Gx[:, 1:-1] = (img[:, 2:] - img[:, :-2]) / 2
Gy[0, :] = img[1, :] - img[0, :]
Gy[-1, :] = img[-1, :] - img[-2, :]
Gy[1:-1, :] = (img[2:, :] - img[:-2, :]) / 2
return np.mean(np.sqrt((Gx ** 2 + Gy ** 2) / 2))
@classmethod
def MI(cls, image_F, image_A, image_B):
cls.input_check(image_F, image_A, image_B)
return skm.mutual_info_score(image_F.flatten(), image_A.flatten()) + skm.mutual_info_score(image_F.flatten(),
image_B.flatten())
@classmethod
def MSE(cls, image_F, image_A, image_B): # MSE
cls.input_check(image_F, image_A, image_B)
return (np.mean((image_A - image_F) ** 2) + np.mean((image_B - image_F) ** 2)) / 2
@classmethod
def CC(cls, image_F, image_A, image_B):
cls.input_check(image_F, image_A, image_B)
rAF = np.sum((image_A - np.mean(image_A)) * (image_F - np.mean(image_F))) / np.sqrt(
(np.sum((image_A - np.mean(image_A)) ** 2)) * (np.sum((image_F - np.mean(image_F)) ** 2)))
rBF = np.sum((image_B - np.mean(image_B)) * (image_F - np.mean(image_F))) / np.sqrt(
(np.sum((image_B - np.mean(image_B)) ** 2)) * (np.sum((image_F - np.mean(image_F)) ** 2)))
return (rAF + rBF) / 2
@classmethod
def PSNR(cls, image_F, image_A, image_B):
cls.input_check(image_F, image_A, image_B)
return 10 * np.log10(np.max(image_F) ** 2 / cls.MSE(image_F, image_A, image_B))
@classmethod
def SCD(cls, image_F, image_A, image_B): # The sum of the correlations of differences
cls.input_check(image_F, image_A, image_B)
imgF_A = image_F - image_A
imgF_B = image_F - image_B
corr1 = np.sum((image_A - np.mean(image_A)) * (imgF_B - np.mean(imgF_B))) / np.sqrt(
(np.sum((image_A - np.mean(image_A)) ** 2)) * (np.sum((imgF_B - np.mean(imgF_B)) ** 2)))
corr2 = np.sum((image_B - np.mean(image_B)) * (imgF_A - np.mean(imgF_A))) / np.sqrt(
(np.sum((image_B - np.mean(image_B)) ** 2)) * (np.sum((imgF_A - np.mean(imgF_A)) ** 2)))
return corr1 + corr2
@classmethod
def VIFF(cls, image_F, image_A, image_B):
cls.input_check(image_F, image_A, image_B)
return cls.compare_viff(image_A, image_F)+cls.compare_viff(image_B, image_F)
@classmethod
def compare_viff(cls,ref, dist): # viff of a pair of pictures
sigma_nsq = 2
eps = 1e-10
num = 0.0
den = 0.0
for scale in range(1, 5):
N = 2 ** (4 - scale + 1) + 1
sd = N / 5.0
# Create a Gaussian kernel as MATLAB's
m, n = [(ss - 1.) / 2. for ss in (N, N)]
y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x + y * y) / (2. * sd * sd))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
win = h / sumh
if scale > 1:
ref = convolve2d(ref, np.rot90(win, 2), mode='valid')
dist = convolve2d(dist, np.rot90(win, 2), mode='valid')
ref = ref[::2, ::2]
dist = dist[::2, ::2]
mu1 = convolve2d(ref, np.rot90(win, 2), mode='valid')
mu2 = convolve2d(dist, np.rot90(win, 2), mode='valid')
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = convolve2d(ref * ref, np.rot90(win, 2), mode='valid') - mu1_sq
sigma2_sq = convolve2d(dist * dist, np.rot90(win, 2), mode='valid') - mu2_sq
sigma12 = convolve2d(ref * dist, np.rot90(win, 2), mode='valid') - mu1_mu2
sigma1_sq[sigma1_sq < 0] = 0
sigma2_sq[sigma2_sq < 0] = 0
g = sigma12 / (sigma1_sq + eps)
sv_sq = sigma2_sq - g * sigma12
g[sigma1_sq < eps] = 0
sv_sq[sigma1_sq < eps] = sigma2_sq[sigma1_sq < eps]
sigma1_sq[sigma1_sq < eps] = 0
g[sigma2_sq < eps] = 0
sv_sq[sigma2_sq < eps] = 0
sv_sq[g < 0] = sigma2_sq[g < 0]
g[g < 0] = 0
sv_sq[sv_sq <= eps] = eps
num += np.sum(np.log10(1 + g * g * sigma1_sq / (sv_sq + sigma_nsq)))
den += np.sum(np.log10(1 + sigma1_sq / sigma_nsq))
vifp = num / den
if np.isnan(vifp):
return 1.0
else:
return vifp
@classmethod
def Qabf(cls, image_F, image_A, image_B):
cls.input_check(image_F, image_A, image_B)
gA, aA = cls.Qabf_getArray(image_A)
gB, aB = cls.Qabf_getArray(image_B)
gF, aF = cls.Qabf_getArray(image_F)
QAF = cls.Qabf_getQabf(aA, gA, aF, gF)
QBF = cls.Qabf_getQabf(aB, gB, aF, gF)
# 计算QABF
deno = np.sum(gA + gB)
nume = np.sum(np.multiply(QAF, gA) + np.multiply(QBF, gB))
return nume / deno
@classmethod
def Qabf_getArray(cls,img):
# Sobel Operator Sobel
h1 = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).astype(np.float32)
h2 = np.array([[0, 1, 2], [-1, 0, 1], [-2, -1, 0]]).astype(np.float32)
h3 = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).astype(np.float32)
SAx = convolve2d(img, h3, mode='same')
SAy = convolve2d(img, h1, mode='same')
gA = np.sqrt(np.multiply(SAx, SAx) + np.multiply(SAy, SAy))
aA = np.zeros_like(img)
aA[SAx == 0] = math.pi / 2
aA[SAx != 0]=np.arctan(SAy[SAx != 0] / SAx[SAx != 0])
return gA, aA
@classmethod
def Qabf_getQabf(cls,aA, gA, aF, gF):
L = 1
Tg = 0.9994
kg = -15
Dg = 0.5
Ta = 0.9879
ka = -22
Da = 0.8
GAF,AAF,QgAF,QaAF,QAF = np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA)
GAF[gA>gF]=gF[gA>gF]/gA[gA>gF]
GAF[gA == gF] = gF[gA == gF]
GAF[gA <gF] = gA[gA<gF]/gF[gA<gF]
AAF = 1 - np.abs(aA - aF) / (math.pi / 2)
QgAF = Tg / (1 + np.exp(kg * (GAF - Dg)))
QaAF = Ta / (1 + np.exp(ka * (AAF - Da)))
QAF = QgAF* QaAF
return QAF
@classmethod
def SSIM(cls, image_F, image_A, image_B):
cls.input_check(image_F, image_A, image_B)
return ssim(image_F,image_A)+ssim(image_F,image_B)
def VIFF(image_F, image_A, image_B):
refA=image_A
refB=image_B
dist=image_F
sigma_nsq = 2
eps = 1e-10
numA = 0.0
denA = 0.0
numB = 0.0
denB = 0.0
for scale in range(1, 5):
N = 2 ** (4 - scale + 1) + 1
sd = N / 5.0
# Create a Gaussian kernel as MATLAB's
m, n = [(ss - 1.) / 2. for ss in (N, N)]
y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x + y * y) / (2. * sd * sd))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
sumh = h.sum()
if sumh != 0:
win = h / sumh
if scale > 1:
refA = convolve2d(refA, np.rot90(win, 2), mode='valid')
refB = convolve2d(refB, np.rot90(win, 2), mode='valid')
dist = convolve2d(dist, np.rot90(win, 2), mode='valid')
refA = refA[::2, ::2]
refB = refB[::2, ::2]
dist = dist[::2, ::2]
mu1A = convolve2d(refA, np.rot90(win, 2), mode='valid')
mu1B = convolve2d(refB, np.rot90(win, 2), mode='valid')
mu2 = convolve2d(dist, np.rot90(win, 2), mode='valid')
mu1_sq_A = mu1A * mu1A
mu1_sq_B = mu1B * mu1B
mu2_sq = mu2 * mu2
mu1A_mu2 = mu1A * mu2
mu1B_mu2 = mu1B * mu2
sigma1A_sq = convolve2d(refA * refA, np.rot90(win, 2), mode='valid') - mu1_sq_A
sigma1B_sq = convolve2d(refB * refB, np.rot90(win, 2), mode='valid') - mu1_sq_B
sigma2_sq = convolve2d(dist * dist, np.rot90(win, 2), mode='valid') - mu2_sq
sigma12_A = convolve2d(refA * dist, np.rot90(win, 2), mode='valid') - mu1A_mu2
sigma12_B = convolve2d(refB * dist, np.rot90(win, 2), mode='valid') - mu1B_mu2
sigma1A_sq[sigma1A_sq < 0] = 0
sigma1B_sq[sigma1B_sq < 0] = 0
sigma2_sq[sigma2_sq < 0] = 0
gA = sigma12_A / (sigma1A_sq + eps)
gB = sigma12_B / (sigma1B_sq + eps)
sv_sq_A = sigma2_sq - gA * sigma12_A
sv_sq_B = sigma2_sq - gB * sigma12_B
gA[sigma1A_sq < eps] = 0
gB[sigma1B_sq < eps] = 0
sv_sq_A[sigma1A_sq < eps] = sigma2_sq[sigma1A_sq < eps]
sv_sq_B[sigma1B_sq < eps] = sigma2_sq[sigma1B_sq < eps]
sigma1A_sq[sigma1A_sq < eps] = 0
sigma1B_sq[sigma1B_sq < eps] = 0
gA[sigma2_sq < eps] = 0
gB[sigma2_sq < eps] = 0
sv_sq_A[sigma2_sq < eps] = 0
sv_sq_B[sigma2_sq < eps] = 0
sv_sq_A[gA < 0] = sigma2_sq[gA < 0]
sv_sq_B[gB < 0] = sigma2_sq[gB < 0]
gA[gA < 0] = 0
gB[gB < 0] = 0
sv_sq_A[sv_sq_A <= eps] = eps
sv_sq_B[sv_sq_B <= eps] = eps
numA += np.sum(np.log10(1 + gA * gA * sigma1A_sq / (sv_sq_A + sigma_nsq)))
numB += np.sum(np.log10(1 + gB * gB * sigma1B_sq / (sv_sq_B + sigma_nsq)))
denA += np.sum(np.log10(1 + sigma1A_sq / sigma_nsq))
denB += np.sum(np.log10(1 + sigma1B_sq / sigma_nsq))
vifpA = numA / denA
vifpB =numB / denB
if np.isnan(vifpA):
vifpA=1
if np.isnan(vifpB):
vifpB = 1
return vifpA+vifpB

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

22
utils/dataset.py Normal file
View File

@ -0,0 +1,22 @@
import torch.utils.data as Data
import h5py
import numpy as np
import torch
class H5Dataset(Data.Dataset):
def __init__(self, h5file_path):
self.h5file_path = h5file_path
h5f = h5py.File(h5file_path, 'r')
self.keys = list(h5f['ir_patchs'].keys())
h5f.close()
def __len__(self):
return len(self.keys)
def __getitem__(self, index):
h5f = h5py.File(self.h5file_path, 'r')
key = self.keys[index]
IR = np.array(h5f['ir_patchs'][key])
VIS = np.array(h5f['vis_patchs'][key])
h5f.close()
return torch.Tensor(VIS), torch.Tensor(IR)

28
utils/img_read_save.py Normal file
View File

@ -0,0 +1,28 @@
import numpy as np
import cv2
import os
from skimage.io import imsave
def image_read_cv2(path, mode='RGB'):
img_BGR = cv2.imread(path).astype('float32')
# img_BGR = cv2.imread(path)
# print(img_BGR)
# if img_BGR is not None:
# img_BGR = img_BGR.astype('float32')
# else:
# print("处理图像加载失败的情况")
assert mode == 'RGB' or mode == 'GRAY' or mode == 'YCrCb', 'mode error'
if mode == 'RGB':
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2RGB)
elif mode == 'GRAY':
img = np.round(cv2.cvtColor(img_BGR, cv2.COLOR_BGR2GRAY))
elif mode == 'YCrCb':
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2YCrCb)
return img
def img_save(image,imagename,savepath):
if not os.path.exists(savepath):
os.makedirs(savepath)
# Gray_pic
imsave(os.path.join(savepath, "{}.png".format(imagename)),image.astype(np.uint8))

113
utils/loss.py Normal file
View File

@ -0,0 +1,113 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Fusionloss(nn.Module):
def __init__(self):
super(Fusionloss, self).__init__()
self.sobelconv=Sobelxy()
def forward(self,image_vis,image_ir,generate_img):
# 加
loss_rmi_v=relative_diff_loss(image_vis, generate_img)
loss_rmi_i=relative_diff_loss(image_ir, generate_img)
x_rmi_max=torch.max(loss_rmi_v, loss_rmi_i)
loss_rmi=F.l1_loss(x_rmi_max, generate_img)
# 加
image_y=image_vis[:,:1,:,:]
x_in_max=torch.max(image_y,image_ir)
loss_in=F.l1_loss(x_in_max,generate_img)
y_grad=self.sobelconv(image_y)
ir_grad=self.sobelconv(image_ir)
generate_img_grad=self.sobelconv(generate_img)
x_grad_joint=torch.max(y_grad,ir_grad)
loss_grad=F.l1_loss(x_grad_joint,generate_img_grad)
# loss_total=loss_in+10*loss_grad
#改
loss_total = loss_in + 10 * loss_grad + loss_rmi
return loss_total,loss_in,loss_grad
class Sobelxy(nn.Module):
def __init__(self):
super(Sobelxy, self).__init__()
kernelx = [[-1, 0, 1],
[-2,0 , 2],
[-1, 0, 1]]
kernely = [[1, 2, 1],
[0,0 , 0],
[-1, -2, -1]]
kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
def forward(self,x):
sobelx=F.conv2d(x, self.weightx, padding=1)
sobely=F.conv2d(x, self.weighty, padding=1)
return torch.abs(sobelx)+torch.abs(sobely)
def cc(img1, img2):
eps = torch.finfo(torch.float32).eps
"""Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.]."""
N, C, _, _ = img1.shape
img1 = img1.reshape(N, C, -1)
img2 = img2.reshape(N, C, -1)
img1 = img1 - img1.mean(dim=-1, keepdim=True)
img2 = img2 - img2.mean(dim=-1, keepdim=True)
cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 **
2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)))
cc = torch.clamp(cc, -1., 1.)
return cc.mean()
# def dice_coeff(img1, img2):
# smooth = 1.
# num = img1.size(0)
# m1 = img1.view(num, -1) # Flatten
# m2 = img2.view(num, -1) # Flatten
# intersection = (m1 * m2).sum()
#
# return 1 - (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
# 用来衡量图像之间的平均灰度差异
def relative_diff_loss(img1, img2):
# 计算图像的平均灰度值
mean_intensity_img1 = torch.mean(img1)
mean_intensity_img2 = torch.mean(img2)
# print("mean_intensity_img1")
# print(mean_intensity_img1)
# print("mean_intensity_img2")
# print(mean_intensity_img2)
# 计算relative_diff
epsilon = 1e-10 # 防止除零错误
relative_diff = abs((mean_intensity_img1 - mean_intensity_img2) / (mean_intensity_img1 + epsilon))
return relative_diff
# 互信息MI损失
# def mutual_information_loss(img1, img2):
# # 计算 X 和 Y 的熵
# entropy_img1 = -torch.mean(torch.sum(F.softmax(img1, dim=-1) * F.log_softmax(img1, dim=-1), dim=-1))
# entropy_img2 = -torch.mean(torch.sum(F.softmax(img2, dim=-1) * F.log_softmax(img2, dim=-1), dim=-1))
#
# # 计算 X 和 Y 的联合熵
# joint_entropy = -torch.mean(torch.sum(F.softmax(img1, dim=-1) * F.log_softmax(img2, dim=-1), dim=-1))
#
# # 计算互信息损失
# mutual_information = entropy_img1 + entropy_img2 - joint_entropy
#
# return mutual_information
# # 余弦相似度计算
# def cosine_similarity(img1, img2):
# # Flatten the tensors
# img1_flat = img1.view(-1)
# img2_flat = img2.view(-1)
#
# # Calculate cosine similarity
# similarity = Fine_similarity(img1_flat, img2_flat, dim=0)
#
# loss = torch.abs(similarity - 1)
# return loss.item() # Convert to Python float