Add files via upload
This commit is contained in:
parent
e6852193de
commit
da5da74611
93
dataprocessing.py
Normal file
93
dataprocessing.py
Normal file
@ -0,0 +1,93 @@
|
||||
import os
|
||||
import h5py
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from skimage.io import imread
|
||||
|
||||
|
||||
def get_img_file(file_name):
|
||||
imagelist = []
|
||||
for parent, dirnames, filenames in os.walk(file_name):
|
||||
for filename in filenames:
|
||||
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
|
||||
imagelist.append(os.path.join(parent, filename))
|
||||
return imagelist
|
||||
|
||||
def rgb2y(img):
|
||||
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
|
||||
return y
|
||||
|
||||
def Im2Patch(img, win, stride=1):
|
||||
k = 0
|
||||
endc = img.shape[0]
|
||||
endw = img.shape[1]
|
||||
endh = img.shape[2]
|
||||
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
|
||||
TotalPatNum = patch.shape[1] * patch.shape[2]
|
||||
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
|
||||
for i in range(win):
|
||||
for j in range(win):
|
||||
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
|
||||
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
|
||||
k = k + 1
|
||||
return Y.reshape([endc, win, win, TotalPatNum])
|
||||
|
||||
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
|
||||
upper_percentile=90):
|
||||
"""Determine if an image is low contrast."""
|
||||
limits = np.percentile(image, [lower_percentile, upper_percentile])
|
||||
ratio = (limits[1] - limits[0]) / limits[1]
|
||||
return ratio < fraction_threshold
|
||||
|
||||
data_name="MSRS_train"
|
||||
img_size=128 #patch size
|
||||
stride=200 #patch stride
|
||||
|
||||
IR_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/ir"))
|
||||
VIS_files = sorted(get_img_file(r"MSRS_train/MSRS-main/train/vi"))
|
||||
|
||||
assert len(IR_files) == len(VIS_files)
|
||||
h5f = h5py.File(os.path.join('.\\data',
|
||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
||||
'w')
|
||||
h5_ir = h5f.create_group('ir_patchs')
|
||||
h5_vis = h5f.create_group('vis_patchs')
|
||||
train_num=0
|
||||
for i in tqdm(range(len(IR_files))):
|
||||
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
|
||||
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
|
||||
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
|
||||
|
||||
# crop
|
||||
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
|
||||
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
|
||||
|
||||
for ii in range(I_IR_Patch_Group.shape[-1]):
|
||||
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
|
||||
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
|
||||
# Determine if the contrast is low
|
||||
if not (bad_IR or bad_VIS):
|
||||
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
|
||||
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
|
||||
avl_IR=avl_IR[None,...]
|
||||
avl_VIS=avl_VIS[None,...]
|
||||
|
||||
h5_ir.create_dataset(str(train_num), data=avl_IR,
|
||||
dtype=avl_IR.dtype, shape=avl_IR.shape)
|
||||
h5_vis.create_dataset(str(train_num), data=avl_VIS,
|
||||
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
|
||||
train_num += 1
|
||||
|
||||
h5f.close()
|
||||
|
||||
with h5py.File(os.path.join('data',
|
||||
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
||||
for key in f.keys():
|
||||
print(f[key], key, f[key].name)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
427
net.py
Normal file
427
net.py
Normal file
@ -0,0 +1,427 @@
|
||||
# poolformer
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
"""
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
# work with diff dim tensors, not just 2D ConvNets
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = keep_prob + \
|
||||
torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
|
||||
class Pooling(nn.Module):
|
||||
def __init__(self, kernel_size=3):
|
||||
super().__init__()
|
||||
self.pool = nn.AvgPool2d(
|
||||
kernel_size, stride=1, padding=kernel_size // 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pool(x) - x
|
||||
|
||||
|
||||
class PoolMlp(nn.Module):
|
||||
"""
|
||||
Implementation of MLP with 1*1 convolutions.
|
||||
Input: tensor with shape [B, C, H, W]
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
bias=False,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=bias)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=bias)
|
||||
self.drop = nn.Dropout(drop)
|
||||
# self.apply(self._init_weights)
|
||||
|
||||
# def _init_weights(self, m):
|
||||
# if isinstance(m, nn.Conv2D):
|
||||
# trunc_normal_(m.weight)
|
||||
# if m.bias is not None:
|
||||
# zeros_(m.bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x) # (B, C, H, W) --> (B, C, H, W)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class BaseFeatureExtraction(nn.Module):
|
||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
# norm_layer=nn.LayerNorm,
|
||||
drop=0., drop_path=0.,
|
||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer, drop=drop)
|
||||
|
||||
# The following two techniques are useful to train deep PoolFormers.
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||
else nn.Identity()
|
||||
self.use_layer_scale = use_layer_scale
|
||||
|
||||
if use_layer_scale:
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_layer_scale:
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
||||
* self.token_mixer(self.norm1(x)))
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||
* self.poolmlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class InvertedResidualBlock(nn.Module):
|
||||
def __init__(self, inp, oup, expand_ratio):
|
||||
super(InvertedResidualBlock, self).__init__()
|
||||
hidden_dim = int(inp * expand_ratio)
|
||||
self.bottleneckBlock = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, bias=False),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# dw
|
||||
nn.ReflectionPad2d(1),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, bias=False),
|
||||
# nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.bottleneckBlock(x)
|
||||
|
||||
|
||||
class DetailNode(nn.Module):
|
||||
def __init__(self):
|
||||
super(DetailNode, self).__init__()
|
||||
|
||||
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
|
||||
stride=1, padding=0, bias=True)
|
||||
|
||||
def separateFeature(self, x):
|
||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||
return z1, z2
|
||||
|
||||
def forward(self, z1, z2):
|
||||
z1, z2 = self.separateFeature(
|
||||
self.shffleconv(torch.cat((z1, z2), dim=1)))
|
||||
z2 = z2 + self.theta_phi(z1)
|
||||
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
|
||||
return z1, z2
|
||||
|
||||
|
||||
class DetailFeatureExtraction(nn.Module):
|
||||
def __init__(self, num_layers=3):
|
||||
super(DetailFeatureExtraction, self).__init__()
|
||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||
self.net = nn.Sequential(*INNmodules)
|
||||
|
||||
def forward(self, x):
|
||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
||||
for layer in self.net:
|
||||
z1, z2 = layer(z1, z2)
|
||||
return torch.cat((z1, z2), dim=1)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
# =============================================================================
|
||||
import numbers
|
||||
|
||||
|
||||
##########################################################################
|
||||
## Layer Norm
|
||||
def to_3d(x):
|
||||
return rearrange(x, 'b c h w -> b (h w) c')
|
||||
|
||||
|
||||
def to_4d(x, h, w):
|
||||
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||
|
||||
|
||||
class BiasFree_LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape):
|
||||
super(BiasFree_LayerNorm, self).__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
normalized_shape = torch.Size(normalized_shape)
|
||||
|
||||
assert len(normalized_shape) == 1
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.normalized_shape = normalized_shape
|
||||
|
||||
def forward(self, x):
|
||||
sigma = x.var(-1, keepdim=True, unbiased=False)
|
||||
return x / torch.sqrt(sigma + 1e-5) * self.weight
|
||||
|
||||
|
||||
class WithBias_LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape):
|
||||
super(WithBias_LayerNorm, self).__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
normalized_shape = torch.Size(normalized_shape)
|
||||
|
||||
assert len(normalized_shape) == 1
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.normalized_shape = normalized_shape
|
||||
|
||||
def forward(self, x):
|
||||
mu = x.mean(-1, keepdim=True)
|
||||
sigma = x.var(-1, keepdim=True, unbiased=False)
|
||||
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, LayerNorm_type):
|
||||
super(LayerNorm, self).__init__()
|
||||
if LayerNorm_type == 'BiasFree':
|
||||
self.body = BiasFree_LayerNorm(dim)
|
||||
else:
|
||||
self.body = WithBias_LayerNorm(dim)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
return to_4d(self.body(to_3d(x)), h, w)
|
||||
|
||||
|
||||
##########################################################################
|
||||
## Gated-Dconv Feed-Forward Network (GDFN)
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, ffn_expansion_factor, bias):
|
||||
super(FeedForward, self).__init__()
|
||||
|
||||
hidden_features = int(dim * ffn_expansion_factor)
|
||||
|
||||
self.project_in = nn.Conv2d(
|
||||
dim, hidden_features * 2, kernel_size=1, bias=bias)
|
||||
|
||||
self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3,
|
||||
stride=1, padding=1, groups=hidden_features * 2, bias=bias)
|
||||
|
||||
self.project_out = nn.Conv2d(
|
||||
hidden_features, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.project_in(x)
|
||||
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||
x = F.gelu(x1) * x2
|
||||
x = self.project_out(x)
|
||||
return x
|
||||
|
||||
|
||||
##########################################################################
|
||||
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads, bias):
|
||||
super(Attention, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
|
||||
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
|
||||
head=self.num_heads)
|
||||
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
|
||||
head=self.num_heads)
|
||||
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
|
||||
head=self.num_heads)
|
||||
|
||||
q = torch.nn.functional.normalize(q, dim=-1)
|
||||
k = torch.nn.functional.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
out = (attn @ v)
|
||||
|
||||
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
|
||||
head=self.num_heads, h=h, w=w)
|
||||
|
||||
out = self.project_out(out)
|
||||
return out
|
||||
|
||||
|
||||
##########################################################################
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
||||
super(TransformerBlock, self).__init__()
|
||||
|
||||
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
||||
self.attn = Attention(dim, num_heads, bias)
|
||||
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
||||
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.ffn(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
##########################################################################
|
||||
## Overlapped image patch embedding with 3x3 Conv
|
||||
class OverlapPatchEmbed(nn.Module):
|
||||
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
||||
super(OverlapPatchEmbed, self).__init__()
|
||||
|
||||
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
|
||||
stride=1, padding=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Restormer_Encoder(nn.Module):
|
||||
def __init__(self,
|
||||
inp_channels=1,
|
||||
out_channels=1,
|
||||
dim=64,
|
||||
num_blocks=[4, 4],
|
||||
heads=[8, 8, 8],
|
||||
ffn_expansion_factor=2,
|
||||
bias=False,
|
||||
LayerNorm_type='WithBias',
|
||||
):
|
||||
super(Restormer_Encoder, self).__init__()
|
||||
|
||||
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
||||
|
||||
self.encoder_level1 = nn.Sequential(
|
||||
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
||||
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
||||
|
||||
self.baseFeature = BaseFeatureExtraction(dim=dim)
|
||||
|
||||
self.detailFeature = DetailFeatureExtraction()
|
||||
|
||||
def forward(self, inp_img):
|
||||
inp_enc_level1 = self.patch_embed(inp_img)
|
||||
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
||||
base_feature = self.baseFeature(out_enc_level1)
|
||||
detail_feature = self.detailFeature(out_enc_level1)
|
||||
return base_feature, detail_feature, out_enc_level1
|
||||
|
||||
|
||||
class Restormer_Decoder(nn.Module):
|
||||
def __init__(self,
|
||||
inp_channels=1,
|
||||
out_channels=1,
|
||||
dim=64,
|
||||
num_blocks=[4, 4],
|
||||
heads=[8, 8, 8],
|
||||
ffn_expansion_factor=2,
|
||||
bias=False,
|
||||
LayerNorm_type='WithBias',
|
||||
):
|
||||
|
||||
super(Restormer_Decoder, self).__init__()
|
||||
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
|
||||
self.encoder_level2 = nn.Sequential(
|
||||
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
||||
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
||||
self.output = nn.Sequential(
|
||||
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
|
||||
stride=1, padding=1, bias=bias),
|
||||
nn.LeakyReLU(),
|
||||
nn.Conv2d(int(dim) // 2, out_channels, kernel_size=3,
|
||||
stride=1, padding=1, bias=bias), )
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, inp_img, base_feature, detail_feature):
|
||||
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
|
||||
out_enc_level0 = self.reduce_channel(out_enc_level0)
|
||||
out_enc_level1 = self.encoder_level2(out_enc_level0)
|
||||
if inp_img is not None:
|
||||
out_enc_level1 = self.output(out_enc_level1) + inp_img
|
||||
else:
|
||||
out_enc_level1 = self.output(out_enc_level1)
|
||||
return self.sigmoid(out_enc_level1), out_enc_level0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
height = 128
|
||||
width = 128
|
||||
window_size = 8
|
||||
modelE = Restormer_Encoder().cuda()
|
||||
modelD = Restormer_Decoder().cuda()
|
||||
|
95
test_IVF.py
Normal file
95
test_IVF.py
Normal file
@ -0,0 +1,95 @@
|
||||
import cv2
|
||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||
import os
|
||||
import numpy as np
|
||||
from utils.Evaluator import Evaluator
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from utils.img_read_save import img_save,image_read_cv2
|
||||
import warnings
|
||||
import logging
|
||||
# 增加
|
||||
warnings.filterwarnings("ignore")
|
||||
logging.basicConfig(level=logging.CRITICAL)
|
||||
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
ckpt_path= r"models/PFCFuse.pth"
|
||||
|
||||
for dataset_name in ["MSRS","TNO","RoadScene"]:
|
||||
print("\n"*2+"="*80)
|
||||
model_name="PFCFuse "
|
||||
print("The test result of "+dataset_name+' :')
|
||||
test_folder=os.path.join('test_img',dataset_name)
|
||||
test_out_folder=os.path.join('test_result',dataset_name)
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||
|
||||
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'],strict=False)
|
||||
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
|
||||
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
|
||||
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
|
||||
Encoder.eval()
|
||||
Decoder.eval()
|
||||
BaseFuseLayer.eval()
|
||||
DetailFuseLayer.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
||||
|
||||
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||
# 改
|
||||
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0
|
||||
# ycrcb, uint8
|
||||
data_VIS_BGR = cv2.imread(os.path.join(test_folder, "vi", img_name))
|
||||
_, data_VIS_Cr, data_VIS_Cb = cv2.split(cv2.cvtColor(data_VIS_BGR, cv2.COLOR_BGR2YCrCb))
|
||||
# 改
|
||||
|
||||
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
|
||||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||
|
||||
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
|
||||
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
|
||||
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
|
||||
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
|
||||
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
|
||||
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
||||
# 改
|
||||
# float32 to uint8
|
||||
fi = fi.astype(np.uint8)
|
||||
ycrcb_fi = np.dstack((fi, data_VIS_Cr, data_VIS_Cb))
|
||||
rgb_fi = cv2.cvtColor(ycrcb_fi, cv2.COLOR_YCrCb2RGB)
|
||||
img_save(rgb_fi, img_name.split(sep='.')[0], test_out_folder)
|
||||
# 改
|
||||
|
||||
eval_folder=test_out_folder
|
||||
ori_img_folder=test_folder
|
||||
|
||||
metric_result = np.zeros((8))
|
||||
for img_name in os.listdir(os.path.join(ori_img_folder,"ir")):
|
||||
ir = image_read_cv2(os.path.join(ori_img_folder,"ir", img_name), 'GRAY')
|
||||
vi = image_read_cv2(os.path.join(ori_img_folder,"vi", img_name), 'GRAY')
|
||||
fi = image_read_cv2(os.path.join(eval_folder, img_name.split('.')[0]+".png"), 'GRAY')
|
||||
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
|
||||
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
|
||||
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
|
||||
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
|
||||
|
||||
metric_result /= len(os.listdir(eval_folder))
|
||||
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
|
||||
print(model_name+'\t'+str(np.round(metric_result[0], 2))+'\t'
|
||||
+str(np.round(metric_result[1], 2))+'\t'
|
||||
+str(np.round(metric_result[2], 2))+'\t'
|
||||
+str(np.round(metric_result[3], 2))+'\t'
|
||||
+str(np.round(metric_result[4], 2))+'\t'
|
||||
+str(np.round(metric_result[5], 2))+'\t'
|
||||
+str(np.round(metric_result[6], 2))+'\t'
|
||||
+str(np.round(metric_result[7], 2))
|
||||
)
|
||||
print("="*80)
|
239
train.py
Normal file
239
train.py
Normal file
@ -0,0 +1,239 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
'''
|
||||
------------------------------------------------------------------------------
|
||||
Import packages
|
||||
------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||
from utils.dataset import H5Dataset
|
||||
import os
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||
import sys
|
||||
import time
|
||||
import datetime
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from utils.loss import Fusionloss, cc, relative_diff_loss
|
||||
import kornia
|
||||
|
||||
print(torch.__version__)
|
||||
print(torch.cuda.is_available())
|
||||
|
||||
'''
|
||||
------------------------------------------------------------------------------
|
||||
Configure our network
|
||||
------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
criteria_fusion = Fusionloss()
|
||||
model_str = 'PFCFuse'
|
||||
|
||||
# . Set the hyper-parameters for training
|
||||
num_epochs = 120 # total epoch
|
||||
epoch_gap = 40 # epoches of Phase I
|
||||
|
||||
lr = 1e-4
|
||||
weight_decay = 0
|
||||
batch_size = 1
|
||||
GPU_number = os.environ['CUDA_VISIBLE_DEVICES']
|
||||
# Coefficients of the loss function
|
||||
coeff_mse_loss_VF = 1. # alpha1
|
||||
coeff_mse_loss_IF = 1.
|
||||
|
||||
coeff_rmi_loss_VF = 1.
|
||||
coeff_rmi_loss_IF = 1.
|
||||
|
||||
coeff_cos_loss_VF = 1.
|
||||
coeff_cos_loss_IF = 1.
|
||||
coeff_decomp = 2. # alpha2 and alpha4
|
||||
coeff_tv = 5.
|
||||
|
||||
clip_grad_norm_value = 0.01
|
||||
optim_step = 20
|
||||
optim_gamma = 0.5
|
||||
|
||||
|
||||
# Model
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
DIDF_Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||
DIDF_Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||
# BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64)).to(device)
|
||||
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||
|
||||
# optimizer, scheduler and loss function
|
||||
optimizer1 = torch.optim.Adam(
|
||||
DIDF_Encoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
optimizer2 = torch.optim.Adam(
|
||||
DIDF_Decoder.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
optimizer3 = torch.optim.Adam(
|
||||
BaseFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
optimizer4 = torch.optim.Adam(
|
||||
DetailFuseLayer.parameters(), lr=lr, weight_decay=weight_decay)
|
||||
|
||||
scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=optim_step, gamma=optim_gamma)
|
||||
scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=optim_step, gamma=optim_gamma)
|
||||
scheduler3 = torch.optim.lr_scheduler.StepLR(optimizer3, step_size=optim_step, gamma=optim_gamma)
|
||||
scheduler4 = torch.optim.lr_scheduler.StepLR(optimizer4, step_size=optim_step, gamma=optim_gamma)
|
||||
|
||||
MSELoss = nn.MSELoss()
|
||||
L1Loss = nn.L1Loss()
|
||||
Loss_ssim = kornia.losses.SSIM(11, reduction='mean')
|
||||
HuberLoss = nn.HuberLoss()
|
||||
|
||||
# data loader
|
||||
trainloader = DataLoader(H5Dataset(r"data/MSRS_train_imgsize_128_stride_200.h5"),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=0)
|
||||
|
||||
loader = {'train': trainloader, }
|
||||
timestamp = datetime.datetime.now().strftime("%m-%d-%H-%M")
|
||||
|
||||
'''
|
||||
------------------------------------------------------------------------------
|
||||
Train
|
||||
------------------------------------------------------------------------------
|
||||
'''
|
||||
|
||||
step = 0
|
||||
torch.backends.cudnn.benchmark = True
|
||||
prev_time = time.time()
|
||||
|
||||
start = time.time()
|
||||
for epoch in range(num_epochs):
|
||||
''' train '''
|
||||
for i, (data_VIS, data_IR) in enumerate(loader['train']):
|
||||
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||
DIDF_Encoder.train()
|
||||
DIDF_Decoder.train()
|
||||
BaseFuseLayer.train()
|
||||
DetailFuseLayer.train()
|
||||
|
||||
DIDF_Encoder.zero_grad()
|
||||
DIDF_Decoder.zero_grad()
|
||||
BaseFuseLayer.zero_grad()
|
||||
DetailFuseLayer.zero_grad()
|
||||
|
||||
optimizer1.zero_grad()
|
||||
optimizer2.zero_grad()
|
||||
optimizer3.zero_grad()
|
||||
optimizer4.zero_grad()
|
||||
|
||||
if epoch < epoch_gap: #Phase I
|
||||
feature_V_B, feature_V_D, _ = DIDF_Encoder(data_VIS)
|
||||
feature_I_B, feature_I_D, _ = DIDF_Encoder(data_IR)
|
||||
data_VIS_hat, _ = DIDF_Decoder(data_VIS, feature_V_B, feature_V_D)
|
||||
data_IR_hat, _ = DIDF_Decoder(data_IR, feature_I_B, feature_I_D)
|
||||
|
||||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||
|
||||
mse_loss_V = 5 * Loss_ssim(data_VIS, data_VIS_hat) + HuberLoss(data_VIS, data_VIS_hat)
|
||||
mse_loss_I = 5 * Loss_ssim(data_IR, data_IR_hat) + HuberLoss(data_IR, data_IR_hat)
|
||||
|
||||
# print("mse_loss_V", mse_loss_V)
|
||||
# print("mse_loss_I", mse_loss_I)
|
||||
|
||||
Gradient_loss = L1Loss(kornia.filters.SpatialGradient()(data_VIS),
|
||||
kornia.filters.SpatialGradient()(data_VIS_hat))
|
||||
# print("Gradient_loss", Gradient_loss)
|
||||
|
||||
loss_decomp = (cc_loss_D) ** 2/ (1.01 + cc_loss_B)
|
||||
# print("loss_decomp", loss_decomp)
|
||||
|
||||
|
||||
loss_rmi_v = relative_diff_loss(data_VIS, data_VIS_hat)
|
||||
loss_rmi_i = relative_diff_loss(data_IR, data_IR_hat)
|
||||
|
||||
loss = coeff_mse_loss_VF * mse_loss_V + coeff_mse_loss_IF * \
|
||||
mse_loss_I + coeff_decomp * loss_decomp + coeff_tv * Gradient_loss + \
|
||||
coeff_rmi_loss_IF * loss_rmi_i + coeff_rmi_loss_VF * loss_rmi_v
|
||||
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
optimizer1.step()
|
||||
optimizer2.step()
|
||||
else: #Phase II
|
||||
feature_V_B, feature_V_D, feature_V = DIDF_Encoder(data_VIS)
|
||||
feature_I_B, feature_I_D, feature_I = DIDF_Encoder(data_IR)
|
||||
feature_F_B = BaseFuseLayer(feature_I_B+feature_V_B)
|
||||
feature_F_D = DetailFuseLayer(feature_I_D+feature_V_D)
|
||||
data_Fuse, feature_F = DIDF_Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||
|
||||
|
||||
mse_loss_V = 5 * Loss_ssim(data_VIS, data_Fuse) + HuberLoss(data_VIS, data_Fuse)
|
||||
mse_loss_I = 5 * Loss_ssim(data_IR, data_Fuse) + HuberLoss(data_IR, data_Fuse)
|
||||
|
||||
cc_loss_B = cc(feature_V_B, feature_I_B)
|
||||
cc_loss_D = cc(feature_V_D, feature_I_D)
|
||||
loss_decomp = (cc_loss_D) ** 2 / (1.01 + cc_loss_B)
|
||||
|
||||
fusionloss, _,_ = criteria_fusion(data_VIS, data_IR, data_Fuse)
|
||||
|
||||
loss = fusionloss + coeff_decomp * loss_decomp
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Encoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
DIDF_Decoder.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
BaseFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
nn.utils.clip_grad_norm_(
|
||||
DetailFuseLayer.parameters(), max_norm=clip_grad_norm_value, norm_type=2)
|
||||
optimizer1.step()
|
||||
optimizer2.step()
|
||||
optimizer3.step()
|
||||
optimizer4.step()
|
||||
|
||||
# Determine approximate time left
|
||||
batches_done = epoch * len(loader['train']) + i
|
||||
batches_left = num_epochs * len(loader['train']) - batches_done
|
||||
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
|
||||
epoch_time = time.time() - prev_time
|
||||
prev_time = time.time()
|
||||
sys.stdout.write(
|
||||
"\r[Epoch %d/%d] [Batch %d/%d] [loss: %f]"
|
||||
% (
|
||||
epoch,
|
||||
num_epochs,
|
||||
i,
|
||||
len(loader['train']),
|
||||
loss.item(),
|
||||
)
|
||||
)
|
||||
|
||||
# adjust the learning rate
|
||||
|
||||
scheduler1.step()
|
||||
scheduler2.step()
|
||||
if not epoch < epoch_gap:
|
||||
scheduler3.step()
|
||||
scheduler4.step()
|
||||
|
||||
if optimizer1.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer1.param_groups[0]['lr'] = 1e-6
|
||||
if optimizer2.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer2.param_groups[0]['lr'] = 1e-6
|
||||
if optimizer3.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer3.param_groups[0]['lr'] = 1e-6
|
||||
if optimizer4.param_groups[0]['lr'] <= 1e-6:
|
||||
optimizer4.param_groups[0]['lr'] = 1e-6
|
||||
|
||||
if True:
|
||||
checkpoint = {
|
||||
'DIDF_Encoder': DIDF_Encoder.state_dict(),
|
||||
'DIDF_Decoder': DIDF_Decoder.state_dict(),
|
||||
'BaseFuseLayer': BaseFuseLayer.state_dict(),
|
||||
'DetailFuseLayer': DetailFuseLayer.state_dict(),
|
||||
}
|
||||
torch.save(checkpoint, os.path.join("models/PFCFusion"+timestamp+'.pth'))
|
||||
|
305
utils/Evaluator.py
Normal file
305
utils/Evaluator.py
Normal file
@ -0,0 +1,305 @@
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import sklearn.metrics as skm
|
||||
from scipy.signal import convolve2d
|
||||
import math
|
||||
from skimage.metrics import structural_similarity as ssim
|
||||
|
||||
def image_read_cv2(path, mode='RGB'):
|
||||
img_BGR = cv2.imread(path).astype('float32')
|
||||
assert mode == 'RGB' or mode == 'GRAY' or mode == 'YCrCb', 'mode error'
|
||||
if mode == 'RGB':
|
||||
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2RGB)
|
||||
elif mode == 'GRAY':
|
||||
img = np.round(cv2.cvtColor(img_BGR, cv2.COLOR_BGR2GRAY))
|
||||
elif mode == 'YCrCb':
|
||||
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2YCrCb)
|
||||
return img
|
||||
|
||||
class Evaluator():
|
||||
@classmethod
|
||||
def input_check(cls, imgF, imgA=None, imgB=None):
|
||||
if imgA is None:
|
||||
assert type(imgF) == np.ndarray, 'type error'
|
||||
assert len(imgF.shape) == 2, 'dimension error'
|
||||
else:
|
||||
assert type(imgF) == type(imgA) == type(imgB) == np.ndarray, 'type error'
|
||||
assert imgF.shape == imgA.shape == imgB.shape, 'shape error'
|
||||
assert len(imgF.shape) == 2, 'dimension error'
|
||||
|
||||
@classmethod
|
||||
def EN(cls, img): # entropy
|
||||
cls.input_check(img)
|
||||
a = np.uint8(np.round(img)).flatten()
|
||||
h = np.bincount(a) / a.shape[0]
|
||||
return -sum(h * np.log2(h + (h == 0)))
|
||||
|
||||
@classmethod
|
||||
def SD(cls, img):
|
||||
cls.input_check(img)
|
||||
return np.std(img)
|
||||
|
||||
@classmethod
|
||||
def SF(cls, img):
|
||||
cls.input_check(img)
|
||||
return np.sqrt(np.mean((img[:, 1:] - img[:, :-1]) ** 2) + np.mean((img[1:, :] - img[:-1, :]) ** 2))
|
||||
|
||||
@classmethod
|
||||
def AG(cls, img): # Average gradient
|
||||
cls.input_check(img)
|
||||
Gx, Gy = np.zeros_like(img), np.zeros_like(img)
|
||||
|
||||
Gx[:, 0] = img[:, 1] - img[:, 0]
|
||||
Gx[:, -1] = img[:, -1] - img[:, -2]
|
||||
Gx[:, 1:-1] = (img[:, 2:] - img[:, :-2]) / 2
|
||||
|
||||
Gy[0, :] = img[1, :] - img[0, :]
|
||||
Gy[-1, :] = img[-1, :] - img[-2, :]
|
||||
Gy[1:-1, :] = (img[2:, :] - img[:-2, :]) / 2
|
||||
return np.mean(np.sqrt((Gx ** 2 + Gy ** 2) / 2))
|
||||
|
||||
@classmethod
|
||||
def MI(cls, image_F, image_A, image_B):
|
||||
cls.input_check(image_F, image_A, image_B)
|
||||
return skm.mutual_info_score(image_F.flatten(), image_A.flatten()) + skm.mutual_info_score(image_F.flatten(),
|
||||
image_B.flatten())
|
||||
|
||||
@classmethod
|
||||
def MSE(cls, image_F, image_A, image_B): # MSE
|
||||
cls.input_check(image_F, image_A, image_B)
|
||||
return (np.mean((image_A - image_F) ** 2) + np.mean((image_B - image_F) ** 2)) / 2
|
||||
|
||||
@classmethod
|
||||
def CC(cls, image_F, image_A, image_B):
|
||||
cls.input_check(image_F, image_A, image_B)
|
||||
rAF = np.sum((image_A - np.mean(image_A)) * (image_F - np.mean(image_F))) / np.sqrt(
|
||||
(np.sum((image_A - np.mean(image_A)) ** 2)) * (np.sum((image_F - np.mean(image_F)) ** 2)))
|
||||
rBF = np.sum((image_B - np.mean(image_B)) * (image_F - np.mean(image_F))) / np.sqrt(
|
||||
(np.sum((image_B - np.mean(image_B)) ** 2)) * (np.sum((image_F - np.mean(image_F)) ** 2)))
|
||||
return (rAF + rBF) / 2
|
||||
|
||||
@classmethod
|
||||
def PSNR(cls, image_F, image_A, image_B):
|
||||
cls.input_check(image_F, image_A, image_B)
|
||||
return 10 * np.log10(np.max(image_F) ** 2 / cls.MSE(image_F, image_A, image_B))
|
||||
|
||||
@classmethod
|
||||
def SCD(cls, image_F, image_A, image_B): # The sum of the correlations of differences
|
||||
cls.input_check(image_F, image_A, image_B)
|
||||
imgF_A = image_F - image_A
|
||||
imgF_B = image_F - image_B
|
||||
corr1 = np.sum((image_A - np.mean(image_A)) * (imgF_B - np.mean(imgF_B))) / np.sqrt(
|
||||
(np.sum((image_A - np.mean(image_A)) ** 2)) * (np.sum((imgF_B - np.mean(imgF_B)) ** 2)))
|
||||
corr2 = np.sum((image_B - np.mean(image_B)) * (imgF_A - np.mean(imgF_A))) / np.sqrt(
|
||||
(np.sum((image_B - np.mean(image_B)) ** 2)) * (np.sum((imgF_A - np.mean(imgF_A)) ** 2)))
|
||||
return corr1 + corr2
|
||||
|
||||
@classmethod
|
||||
def VIFF(cls, image_F, image_A, image_B):
|
||||
cls.input_check(image_F, image_A, image_B)
|
||||
return cls.compare_viff(image_A, image_F)+cls.compare_viff(image_B, image_F)
|
||||
|
||||
@classmethod
|
||||
def compare_viff(cls,ref, dist): # viff of a pair of pictures
|
||||
sigma_nsq = 2
|
||||
eps = 1e-10
|
||||
|
||||
num = 0.0
|
||||
den = 0.0
|
||||
for scale in range(1, 5):
|
||||
|
||||
N = 2 ** (4 - scale + 1) + 1
|
||||
sd = N / 5.0
|
||||
|
||||
# Create a Gaussian kernel as MATLAB's
|
||||
m, n = [(ss - 1.) / 2. for ss in (N, N)]
|
||||
y, x = np.ogrid[-m:m + 1, -n:n + 1]
|
||||
h = np.exp(-(x * x + y * y) / (2. * sd * sd))
|
||||
h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
win = h / sumh
|
||||
|
||||
if scale > 1:
|
||||
ref = convolve2d(ref, np.rot90(win, 2), mode='valid')
|
||||
dist = convolve2d(dist, np.rot90(win, 2), mode='valid')
|
||||
ref = ref[::2, ::2]
|
||||
dist = dist[::2, ::2]
|
||||
|
||||
mu1 = convolve2d(ref, np.rot90(win, 2), mode='valid')
|
||||
mu2 = convolve2d(dist, np.rot90(win, 2), mode='valid')
|
||||
mu1_sq = mu1 * mu1
|
||||
mu2_sq = mu2 * mu2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
sigma1_sq = convolve2d(ref * ref, np.rot90(win, 2), mode='valid') - mu1_sq
|
||||
sigma2_sq = convolve2d(dist * dist, np.rot90(win, 2), mode='valid') - mu2_sq
|
||||
sigma12 = convolve2d(ref * dist, np.rot90(win, 2), mode='valid') - mu1_mu2
|
||||
|
||||
sigma1_sq[sigma1_sq < 0] = 0
|
||||
sigma2_sq[sigma2_sq < 0] = 0
|
||||
|
||||
g = sigma12 / (sigma1_sq + eps)
|
||||
sv_sq = sigma2_sq - g * sigma12
|
||||
|
||||
g[sigma1_sq < eps] = 0
|
||||
sv_sq[sigma1_sq < eps] = sigma2_sq[sigma1_sq < eps]
|
||||
sigma1_sq[sigma1_sq < eps] = 0
|
||||
|
||||
g[sigma2_sq < eps] = 0
|
||||
sv_sq[sigma2_sq < eps] = 0
|
||||
|
||||
sv_sq[g < 0] = sigma2_sq[g < 0]
|
||||
g[g < 0] = 0
|
||||
sv_sq[sv_sq <= eps] = eps
|
||||
|
||||
num += np.sum(np.log10(1 + g * g * sigma1_sq / (sv_sq + sigma_nsq)))
|
||||
den += np.sum(np.log10(1 + sigma1_sq / sigma_nsq))
|
||||
|
||||
vifp = num / den
|
||||
|
||||
if np.isnan(vifp):
|
||||
return 1.0
|
||||
else:
|
||||
return vifp
|
||||
|
||||
@classmethod
|
||||
def Qabf(cls, image_F, image_A, image_B):
|
||||
cls.input_check(image_F, image_A, image_B)
|
||||
gA, aA = cls.Qabf_getArray(image_A)
|
||||
gB, aB = cls.Qabf_getArray(image_B)
|
||||
gF, aF = cls.Qabf_getArray(image_F)
|
||||
QAF = cls.Qabf_getQabf(aA, gA, aF, gF)
|
||||
QBF = cls.Qabf_getQabf(aB, gB, aF, gF)
|
||||
|
||||
# 计算QABF
|
||||
deno = np.sum(gA + gB)
|
||||
nume = np.sum(np.multiply(QAF, gA) + np.multiply(QBF, gB))
|
||||
return nume / deno
|
||||
|
||||
@classmethod
|
||||
def Qabf_getArray(cls,img):
|
||||
# Sobel Operator Sobel
|
||||
h1 = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).astype(np.float32)
|
||||
h2 = np.array([[0, 1, 2], [-1, 0, 1], [-2, -1, 0]]).astype(np.float32)
|
||||
h3 = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).astype(np.float32)
|
||||
|
||||
SAx = convolve2d(img, h3, mode='same')
|
||||
SAy = convolve2d(img, h1, mode='same')
|
||||
gA = np.sqrt(np.multiply(SAx, SAx) + np.multiply(SAy, SAy))
|
||||
aA = np.zeros_like(img)
|
||||
aA[SAx == 0] = math.pi / 2
|
||||
aA[SAx != 0]=np.arctan(SAy[SAx != 0] / SAx[SAx != 0])
|
||||
return gA, aA
|
||||
|
||||
@classmethod
|
||||
def Qabf_getQabf(cls,aA, gA, aF, gF):
|
||||
L = 1
|
||||
Tg = 0.9994
|
||||
kg = -15
|
||||
Dg = 0.5
|
||||
Ta = 0.9879
|
||||
ka = -22
|
||||
Da = 0.8
|
||||
GAF,AAF,QgAF,QaAF,QAF = np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA),np.zeros_like(aA)
|
||||
GAF[gA>gF]=gF[gA>gF]/gA[gA>gF]
|
||||
GAF[gA == gF] = gF[gA == gF]
|
||||
GAF[gA <gF] = gA[gA<gF]/gF[gA<gF]
|
||||
AAF = 1 - np.abs(aA - aF) / (math.pi / 2)
|
||||
QgAF = Tg / (1 + np.exp(kg * (GAF - Dg)))
|
||||
QaAF = Ta / (1 + np.exp(ka * (AAF - Da)))
|
||||
QAF = QgAF* QaAF
|
||||
return QAF
|
||||
|
||||
@classmethod
|
||||
def SSIM(cls, image_F, image_A, image_B):
|
||||
cls.input_check(image_F, image_A, image_B)
|
||||
return ssim(image_F,image_A)+ssim(image_F,image_B)
|
||||
|
||||
|
||||
def VIFF(image_F, image_A, image_B):
|
||||
refA=image_A
|
||||
refB=image_B
|
||||
dist=image_F
|
||||
|
||||
sigma_nsq = 2
|
||||
eps = 1e-10
|
||||
numA = 0.0
|
||||
denA = 0.0
|
||||
numB = 0.0
|
||||
denB = 0.0
|
||||
for scale in range(1, 5):
|
||||
N = 2 ** (4 - scale + 1) + 1
|
||||
sd = N / 5.0
|
||||
# Create a Gaussian kernel as MATLAB's
|
||||
m, n = [(ss - 1.) / 2. for ss in (N, N)]
|
||||
y, x = np.ogrid[-m:m + 1, -n:n + 1]
|
||||
h = np.exp(-(x * x + y * y) / (2. * sd * sd))
|
||||
h[h < np.finfo(h.dtype).eps * h.max()] = 0
|
||||
sumh = h.sum()
|
||||
if sumh != 0:
|
||||
win = h / sumh
|
||||
|
||||
if scale > 1:
|
||||
refA = convolve2d(refA, np.rot90(win, 2), mode='valid')
|
||||
refB = convolve2d(refB, np.rot90(win, 2), mode='valid')
|
||||
dist = convolve2d(dist, np.rot90(win, 2), mode='valid')
|
||||
refA = refA[::2, ::2]
|
||||
refB = refB[::2, ::2]
|
||||
dist = dist[::2, ::2]
|
||||
|
||||
mu1A = convolve2d(refA, np.rot90(win, 2), mode='valid')
|
||||
mu1B = convolve2d(refB, np.rot90(win, 2), mode='valid')
|
||||
mu2 = convolve2d(dist, np.rot90(win, 2), mode='valid')
|
||||
mu1_sq_A = mu1A * mu1A
|
||||
mu1_sq_B = mu1B * mu1B
|
||||
mu2_sq = mu2 * mu2
|
||||
mu1A_mu2 = mu1A * mu2
|
||||
mu1B_mu2 = mu1B * mu2
|
||||
sigma1A_sq = convolve2d(refA * refA, np.rot90(win, 2), mode='valid') - mu1_sq_A
|
||||
sigma1B_sq = convolve2d(refB * refB, np.rot90(win, 2), mode='valid') - mu1_sq_B
|
||||
sigma2_sq = convolve2d(dist * dist, np.rot90(win, 2), mode='valid') - mu2_sq
|
||||
sigma12_A = convolve2d(refA * dist, np.rot90(win, 2), mode='valid') - mu1A_mu2
|
||||
sigma12_B = convolve2d(refB * dist, np.rot90(win, 2), mode='valid') - mu1B_mu2
|
||||
|
||||
sigma1A_sq[sigma1A_sq < 0] = 0
|
||||
sigma1B_sq[sigma1B_sq < 0] = 0
|
||||
sigma2_sq[sigma2_sq < 0] = 0
|
||||
|
||||
gA = sigma12_A / (sigma1A_sq + eps)
|
||||
gB = sigma12_B / (sigma1B_sq + eps)
|
||||
sv_sq_A = sigma2_sq - gA * sigma12_A
|
||||
sv_sq_B = sigma2_sq - gB * sigma12_B
|
||||
|
||||
gA[sigma1A_sq < eps] = 0
|
||||
gB[sigma1B_sq < eps] = 0
|
||||
sv_sq_A[sigma1A_sq < eps] = sigma2_sq[sigma1A_sq < eps]
|
||||
sv_sq_B[sigma1B_sq < eps] = sigma2_sq[sigma1B_sq < eps]
|
||||
sigma1A_sq[sigma1A_sq < eps] = 0
|
||||
sigma1B_sq[sigma1B_sq < eps] = 0
|
||||
|
||||
gA[sigma2_sq < eps] = 0
|
||||
gB[sigma2_sq < eps] = 0
|
||||
sv_sq_A[sigma2_sq < eps] = 0
|
||||
sv_sq_B[sigma2_sq < eps] = 0
|
||||
|
||||
sv_sq_A[gA < 0] = sigma2_sq[gA < 0]
|
||||
sv_sq_B[gB < 0] = sigma2_sq[gB < 0]
|
||||
gA[gA < 0] = 0
|
||||
gB[gB < 0] = 0
|
||||
sv_sq_A[sv_sq_A <= eps] = eps
|
||||
sv_sq_B[sv_sq_B <= eps] = eps
|
||||
|
||||
numA += np.sum(np.log10(1 + gA * gA * sigma1A_sq / (sv_sq_A + sigma_nsq)))
|
||||
numB += np.sum(np.log10(1 + gB * gB * sigma1B_sq / (sv_sq_B + sigma_nsq)))
|
||||
denA += np.sum(np.log10(1 + sigma1A_sq / sigma_nsq))
|
||||
denB += np.sum(np.log10(1 + sigma1B_sq / sigma_nsq))
|
||||
|
||||
vifpA = numA / denA
|
||||
vifpB =numB / denB
|
||||
|
||||
if np.isnan(vifpA):
|
||||
vifpA=1
|
||||
if np.isnan(vifpB):
|
||||
vifpB = 1
|
||||
return vifpA+vifpB
|
BIN
utils/__pycache__/Evaluator.cpython-38.pyc
Normal file
BIN
utils/__pycache__/Evaluator.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/dataset.cpython-38.pyc
Normal file
BIN
utils/__pycache__/dataset.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/dataset_MIF1.cpython-38.pyc
Normal file
BIN
utils/__pycache__/dataset_MIF1.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/dataset_MIF4.cpython-38.pyc
Normal file
BIN
utils/__pycache__/dataset_MIF4.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/img_read_save.cpython-38.pyc
Normal file
BIN
utils/__pycache__/img_read_save.cpython-38.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/loss.cpython-38.pyc
Normal file
BIN
utils/__pycache__/loss.cpython-38.pyc
Normal file
Binary file not shown.
22
utils/dataset.py
Normal file
22
utils/dataset.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch.utils.data as Data
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class H5Dataset(Data.Dataset):
|
||||
def __init__(self, h5file_path):
|
||||
self.h5file_path = h5file_path
|
||||
h5f = h5py.File(h5file_path, 'r')
|
||||
self.keys = list(h5f['ir_patchs'].keys())
|
||||
h5f.close()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keys)
|
||||
|
||||
def __getitem__(self, index):
|
||||
h5f = h5py.File(self.h5file_path, 'r')
|
||||
key = self.keys[index]
|
||||
IR = np.array(h5f['ir_patchs'][key])
|
||||
VIS = np.array(h5f['vis_patchs'][key])
|
||||
h5f.close()
|
||||
return torch.Tensor(VIS), torch.Tensor(IR)
|
28
utils/img_read_save.py
Normal file
28
utils/img_read_save.py
Normal file
@ -0,0 +1,28 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
from skimage.io import imsave
|
||||
|
||||
def image_read_cv2(path, mode='RGB'):
|
||||
img_BGR = cv2.imread(path).astype('float32')
|
||||
# img_BGR = cv2.imread(path)
|
||||
# print(img_BGR)
|
||||
# if img_BGR is not None:
|
||||
# img_BGR = img_BGR.astype('float32')
|
||||
# else:
|
||||
# print("处理图像加载失败的情况")
|
||||
assert mode == 'RGB' or mode == 'GRAY' or mode == 'YCrCb', 'mode error'
|
||||
if mode == 'RGB':
|
||||
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2RGB)
|
||||
elif mode == 'GRAY':
|
||||
img = np.round(cv2.cvtColor(img_BGR, cv2.COLOR_BGR2GRAY))
|
||||
elif mode == 'YCrCb':
|
||||
img = cv2.cvtColor(img_BGR, cv2.COLOR_BGR2YCrCb)
|
||||
return img
|
||||
|
||||
def img_save(image,imagename,savepath):
|
||||
if not os.path.exists(savepath):
|
||||
os.makedirs(savepath)
|
||||
# Gray_pic
|
||||
imsave(os.path.join(savepath, "{}.png".format(imagename)),image.astype(np.uint8))
|
||||
|
113
utils/loss.py
Normal file
113
utils/loss.py
Normal file
@ -0,0 +1,113 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Fusionloss(nn.Module):
|
||||
def __init__(self):
|
||||
super(Fusionloss, self).__init__()
|
||||
self.sobelconv=Sobelxy()
|
||||
|
||||
def forward(self,image_vis,image_ir,generate_img):
|
||||
# 加
|
||||
loss_rmi_v=relative_diff_loss(image_vis, generate_img)
|
||||
loss_rmi_i=relative_diff_loss(image_ir, generate_img)
|
||||
x_rmi_max=torch.max(loss_rmi_v, loss_rmi_i)
|
||||
loss_rmi=F.l1_loss(x_rmi_max, generate_img)
|
||||
# 加
|
||||
image_y=image_vis[:,:1,:,:]
|
||||
x_in_max=torch.max(image_y,image_ir)
|
||||
loss_in=F.l1_loss(x_in_max,generate_img)
|
||||
y_grad=self.sobelconv(image_y)
|
||||
ir_grad=self.sobelconv(image_ir)
|
||||
generate_img_grad=self.sobelconv(generate_img)
|
||||
x_grad_joint=torch.max(y_grad,ir_grad)
|
||||
loss_grad=F.l1_loss(x_grad_joint,generate_img_grad)
|
||||
# loss_total=loss_in+10*loss_grad
|
||||
#改
|
||||
loss_total = loss_in + 10 * loss_grad + loss_rmi
|
||||
return loss_total,loss_in,loss_grad
|
||||
|
||||
class Sobelxy(nn.Module):
|
||||
def __init__(self):
|
||||
super(Sobelxy, self).__init__()
|
||||
kernelx = [[-1, 0, 1],
|
||||
[-2,0 , 2],
|
||||
[-1, 0, 1]]
|
||||
kernely = [[1, 2, 1],
|
||||
[0,0 , 0],
|
||||
[-1, -2, -1]]
|
||||
kernelx = torch.FloatTensor(kernelx).unsqueeze(0).unsqueeze(0)
|
||||
kernely = torch.FloatTensor(kernely).unsqueeze(0).unsqueeze(0)
|
||||
self.weightx = nn.Parameter(data=kernelx, requires_grad=False).cuda()
|
||||
self.weighty = nn.Parameter(data=kernely, requires_grad=False).cuda()
|
||||
def forward(self,x):
|
||||
sobelx=F.conv2d(x, self.weightx, padding=1)
|
||||
sobely=F.conv2d(x, self.weighty, padding=1)
|
||||
return torch.abs(sobelx)+torch.abs(sobely)
|
||||
|
||||
|
||||
def cc(img1, img2):
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
"""Correlation coefficient for (N, C, H, W) image; torch.float32 [0.,1.]."""
|
||||
N, C, _, _ = img1.shape
|
||||
img1 = img1.reshape(N, C, -1)
|
||||
img2 = img2.reshape(N, C, -1)
|
||||
img1 = img1 - img1.mean(dim=-1, keepdim=True)
|
||||
img2 = img2 - img2.mean(dim=-1, keepdim=True)
|
||||
cc = torch.sum(img1 * img2, dim=-1) / (eps + torch.sqrt(torch.sum(img1 **
|
||||
2, dim=-1)) * torch.sqrt(torch.sum(img2**2, dim=-1)))
|
||||
cc = torch.clamp(cc, -1., 1.)
|
||||
return cc.mean()
|
||||
|
||||
|
||||
# def dice_coeff(img1, img2):
|
||||
# smooth = 1.
|
||||
# num = img1.size(0)
|
||||
# m1 = img1.view(num, -1) # Flatten
|
||||
# m2 = img2.view(num, -1) # Flatten
|
||||
# intersection = (m1 * m2).sum()
|
||||
#
|
||||
# return 1 - (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
|
||||
|
||||
# 用来衡量图像之间的平均灰度差异
|
||||
def relative_diff_loss(img1, img2):
|
||||
# 计算图像的平均灰度值
|
||||
mean_intensity_img1 = torch.mean(img1)
|
||||
mean_intensity_img2 = torch.mean(img2)
|
||||
# print("mean_intensity_img1")
|
||||
# print(mean_intensity_img1)
|
||||
# print("mean_intensity_img2")
|
||||
# print(mean_intensity_img2)
|
||||
|
||||
# 计算relative_diff
|
||||
epsilon = 1e-10 # 防止除零错误
|
||||
relative_diff = abs((mean_intensity_img1 - mean_intensity_img2) / (mean_intensity_img1 + epsilon))
|
||||
|
||||
return relative_diff
|
||||
|
||||
# 互信息MI损失
|
||||
# def mutual_information_loss(img1, img2):
|
||||
# # 计算 X 和 Y 的熵
|
||||
# entropy_img1 = -torch.mean(torch.sum(F.softmax(img1, dim=-1) * F.log_softmax(img1, dim=-1), dim=-1))
|
||||
# entropy_img2 = -torch.mean(torch.sum(F.softmax(img2, dim=-1) * F.log_softmax(img2, dim=-1), dim=-1))
|
||||
#
|
||||
# # 计算 X 和 Y 的联合熵
|
||||
# joint_entropy = -torch.mean(torch.sum(F.softmax(img1, dim=-1) * F.log_softmax(img2, dim=-1), dim=-1))
|
||||
#
|
||||
# # 计算互信息损失
|
||||
# mutual_information = entropy_img1 + entropy_img2 - joint_entropy
|
||||
#
|
||||
# return mutual_information
|
||||
|
||||
|
||||
# # 余弦相似度计算
|
||||
# def cosine_similarity(img1, img2):
|
||||
# # Flatten the tensors
|
||||
# img1_flat = img1.view(-1)
|
||||
# img2_flat = img2.view(-1)
|
||||
#
|
||||
# # Calculate cosine similarity
|
||||
# similarity = Fine_similarity(img1_flat, img2_flat, dim=0)
|
||||
#
|
||||
# loss = torch.abs(similarity - 1)
|
||||
# return loss.item() # Convert to Python float
|
Loading…
Reference in New Issue
Block a user