114 lines
4.0 KiB
Python
114 lines
4.0 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
class Fusionloss(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(Fusionloss, self).__init__()
|
||
|
self.sobelconv=Sobelxy()
|
||
|
|
||
|
def forward(self,image_vis,image_ir,generate_img):
|
||
|
# 加
|
||
|
loss_rmi_v=relative_diff_loss(image_vis, generate_img)
|
||
|
loss_rmi_i=relative_diff_loss(image_ir, generate_img)
|
||
|
x_rmi_max=torch.max(loss_rmi_v, loss_rmi_i)
|
||
|
loss_rmi=F.l1_loss(x_rmi_max, generate_img)
|
||
|
# 加
|
||
|
image_y=image_vis[:,:1,:,:]
|
||
|
x_in_max=torch.max(image_y,image_ir)
|
||
|
loss_in=F.l1_loss(x_in_max,generate_img)
|
||
|
y_grad=self.sobelconv(image_y)
|
||
|
ir_grad=self.sobelconv(image_ir)
|
||
|
generate_img_grad=self.sobelconv(generate_img)
|
||
|
x_grad_joint=torch.max(y_grad,ir_grad)
|
||
|
loss_grad=F.l1_loss(x_grad_joint,generate_img_grad)
|
||
|
# loss_total=loss_in+10*loss_grad
|
||
|
#改
|
||
|
loss_total = loss_in + 10 * loss_grad + loss_rmi
|
||
|
return loss_total,loss_in,loss_grad
|
||
|
|
||
|
class Sobelxy(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(Sobelxy, self).__init__()
|
||
|
kernelx = [[-1, 0, 1],
|
||
|
[-2,0 , 2],
|
||
|
[-1, 0, 1]]
|
||
|
kernely = [[1, 2, 1],
|
||
|
[0,0 , 0],
|
||
|
[-1, -2, -1]]
|
||
|
kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
|
||
|
kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
|
||
|
self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
|
||
|
self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
|
||
|
def forward(self,x):
|
||
|
sobelx=F.conv2d(x, self.weightx, padding=1)
|
||
|
sobely=F.conv2d(x, self.weighty, padding=1)
|
||
|
return torch.abs(sobelx)+torch.abs(sobely)
|
||
|
|
||
|
|
||
|
def cc(img1, img2):
|
||
|
eps = torch.finfo(torch.float32).eps
|
||
|
"""Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.]."""
|
||
|
N, C, _, _ = img1.shape
|
||
|
img1 = img1.reshape(N, C, -1)
|
||
|
img2 = img2.reshape(N, C, -1)
|
||
|
img1 = img1 - img1.mean(dim=-1, keepdim=True)
|
||
|
img2 = img2 - img2.mean(dim=-1, keepdim=True)
|
||
|
cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 **
|
||
|
2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)))
|
||
|
cc = torch.clamp(cc, -1., 1.)
|
||
|
return cc.mean()
|
||
|
|
||
|
|
||
|
# def dice_coeff(img1, img2):
|
||
|
# smooth = 1.
|
||
|
# num = img1.size(0)
|
||
|
# m1 = img1.view(num, -1) # Flatten
|
||
|
# m2 = img2.view(num, -1) # Flatten
|
||
|
# intersection = (m1 * m2).sum()
|
||
|
#
|
||
|
# return 1 - (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
|
||
|
|
||
|
# 用来衡量图像之间的平均灰度差异
|
||
|
def relative_diff_loss(img1, img2):
|
||
|
# 计算图像的平均灰度值
|
||
|
mean_intensity_img1 = torch.mean(img1)
|
||
|
mean_intensity_img2 = torch.mean(img2)
|
||
|
# print("mean_intensity_img1")
|
||
|
# print(mean_intensity_img1)
|
||
|
# print("mean_intensity_img2")
|
||
|
# print(mean_intensity_img2)
|
||
|
|
||
|
# 计算relative_diff
|
||
|
epsilon = 1e-10 # 防止除零错误
|
||
|
relative_diff = abs((mean_intensity_img1 - mean_intensity_img2) / (mean_intensity_img1 + epsilon))
|
||
|
|
||
|
return relative_diff
|
||
|
|
||
|
# 互信息MI损失
|
||
|
# def mutual_information_loss(img1, img2):
|
||
|
# # 计算 X 和 Y 的熵
|
||
|
# entropy_img1 = -torch.mean(torch.sum(F.softmax(img1, dim=-1) * F.log_softmax(img1, dim=-1), dim=-1))
|
||
|
# entropy_img2 = -torch.mean(torch.sum(F.softmax(img2, dim=-1) * F.log_softmax(img2, dim=-1), dim=-1))
|
||
|
#
|
||
|
# # 计算 X 和 Y 的联合熵
|
||
|
# joint_entropy = -torch.mean(torch.sum(F.softmax(img1, dim=-1) * F.log_softmax(img2, dim=-1), dim=-1))
|
||
|
#
|
||
|
# # 计算互信息损失
|
||
|
# mutual_information = entropy_img1 + entropy_img2 - joint_entropy
|
||
|
#
|
||
|
# return mutual_information
|
||
|
|
||
|
|
||
|
# # 余弦相似度计算
|
||
|
# def cosine_similarity(img1, img2):
|
||
|
# # Flatten the tensors
|
||
|
# img1_flat = img1.view(-1)
|
||
|
# img2_flat = img2.view(-1)
|
||
|
#
|
||
|
# # Calculate cosine similarity
|
||
|
# similarity = Fine_similarity(img1_flat, img2_flat, dim=0)
|
||
|
#
|
||
|
# loss = torch.abs(similarity - 1)
|
||
|
# return loss.item() # Convert to Python float
|