添加数据处理脚本和H5数据集类

- 新增dataprocessing.py脚本,实现图像数据处理功能,包括文件读取、格式转换、低对比度筛选等
- 新增H5Dataset类,用于加载和访问H5格式的图像数据集
- 在项目中配置远程服务器部署和代码自动上传
- 添加IDE配置文件,包括项目路径、模块管理、代码检查等设置
This commit is contained in:
whaifree 2024-10-05 12:58:08 +08:00
commit 82acfa83dc
452 changed files with 1850 additions and 0 deletions

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

8
.idea/CDDFuse.iml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Remote Python 3.12.4 (sftp://star@192.168.50.108:22/home/star/anaconda3/bin/python)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

57
.idea/deployment.xml Normal file
View File

@ -0,0 +1,57 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="star@192.168.50.108:22 password (6)" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="star@192.168.50.108:22 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (6)">
<serverdata>
<mappings>
<mapping deploy="/home/star/whaiDir/CDDFuse" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
<paths name="v100">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ALWAYS" />
</component>
</project>

View File

@ -0,0 +1,264 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<Languages>
<language minSize="226" name="Python" />
</Languages>
</inspection_tool>
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="245">
<item index="0" class="java.lang.String" itemvalue="numba" />
<item index="1" class="java.lang.String" itemvalue="tensorflow-estimator" />
<item index="2" class="java.lang.String" itemvalue="greenlet" />
<item index="3" class="java.lang.String" itemvalue="Babel" />
<item index="4" class="java.lang.String" itemvalue="scikit-learn" />
<item index="5" class="java.lang.String" itemvalue="testpath" />
<item index="6" class="java.lang.String" itemvalue="py" />
<item index="7" class="java.lang.String" itemvalue="gitdb" />
<item index="8" class="java.lang.String" itemvalue="torchvision" />
<item index="9" class="java.lang.String" itemvalue="patsy" />
<item index="10" class="java.lang.String" itemvalue="mccabe" />
<item index="11" class="java.lang.String" itemvalue="bleach" />
<item index="12" class="java.lang.String" itemvalue="lxml" />
<item index="13" class="java.lang.String" itemvalue="torchaudio" />
<item index="14" class="java.lang.String" itemvalue="jsonschema" />
<item index="15" class="java.lang.String" itemvalue="xlrd" />
<item index="16" class="java.lang.String" itemvalue="Werkzeug" />
<item index="17" class="java.lang.String" itemvalue="anaconda-project" />
<item index="18" class="java.lang.String" itemvalue="tensorboard-data-server" />
<item index="19" class="java.lang.String" itemvalue="typing-extensions" />
<item index="20" class="java.lang.String" itemvalue="click" />
<item index="21" class="java.lang.String" itemvalue="regex" />
<item index="22" class="java.lang.String" itemvalue="fastcache" />
<item index="23" class="java.lang.String" itemvalue="tensorboard" />
<item index="24" class="java.lang.String" itemvalue="imageio" />
<item index="25" class="java.lang.String" itemvalue="pytest-remotedata" />
<item index="26" class="java.lang.String" itemvalue="matplotlib" />
<item index="27" class="java.lang.String" itemvalue="idna" />
<item index="28" class="java.lang.String" itemvalue="Bottleneck" />
<item index="29" class="java.lang.String" itemvalue="rsa" />
<item index="30" class="java.lang.String" itemvalue="networkx" />
<item index="31" class="java.lang.String" itemvalue="pycurl" />
<item index="32" class="java.lang.String" itemvalue="smmap" />
<item index="33" class="java.lang.String" itemvalue="pluggy" />
<item index="34" class="java.lang.String" itemvalue="cffi" />
<item index="35" class="java.lang.String" itemvalue="pep8" />
<item index="36" class="java.lang.String" itemvalue="numpy" />
<item index="37" class="java.lang.String" itemvalue="jdcal" />
<item index="38" class="java.lang.String" itemvalue="alabaster" />
<item index="39" class="java.lang.String" itemvalue="jupyter" />
<item index="40" class="java.lang.String" itemvalue="pyOpenSSL" />
<item index="41" class="java.lang.String" itemvalue="PyWavelets" />
<item index="42" class="java.lang.String" itemvalue="prompt-toolkit" />
<item index="43" class="java.lang.String" itemvalue="QtAwesome" />
<item index="44" class="java.lang.String" itemvalue="msgpack-python" />
<item index="45" class="java.lang.String" itemvalue="Flask-Cors" />
<item index="46" class="java.lang.String" itemvalue="glob2" />
<item index="47" class="java.lang.String" itemvalue="Send2Trash" />
<item index="48" class="java.lang.String" itemvalue="imagesize" />
<item index="49" class="java.lang.String" itemvalue="et-xmlfile" />
<item index="50" class="java.lang.String" itemvalue="pathlib2" />
<item index="51" class="java.lang.String" itemvalue="docker-pycreds" />
<item index="52" class="java.lang.String" itemvalue="importlib-resources" />
<item index="53" class="java.lang.String" itemvalue="pathtools" />
<item index="54" class="java.lang.String" itemvalue="spyder" />
<item index="55" class="java.lang.String" itemvalue="pylint" />
<item index="56" class="java.lang.String" itemvalue="statsmodels" />
<item index="57" class="java.lang.String" itemvalue="tensorboardX" />
<item index="58" class="java.lang.String" itemvalue="isort" />
<item index="59" class="java.lang.String" itemvalue="ruamel_yaml" />
<item index="60" class="java.lang.String" itemvalue="pytz" />
<item index="61" class="java.lang.String" itemvalue="unicodecsv" />
<item index="62" class="java.lang.String" itemvalue="pytest-astropy" />
<item index="63" class="java.lang.String" itemvalue="traitlets" />
<item index="64" class="java.lang.String" itemvalue="absl-py" />
<item index="65" class="java.lang.String" itemvalue="protobuf" />
<item index="66" class="java.lang.String" itemvalue="nltk" />
<item index="67" class="java.lang.String" itemvalue="partd" />
<item index="68" class="java.lang.String" itemvalue="promise" />
<item index="69" class="java.lang.String" itemvalue="gast" />
<item index="70" class="java.lang.String" itemvalue="filelock" />
<item index="71" class="java.lang.String" itemvalue="numpydoc" />
<item index="72" class="java.lang.String" itemvalue="pyzmq" />
<item index="73" class="java.lang.String" itemvalue="oauthlib" />
<item index="74" class="java.lang.String" itemvalue="astropy" />
<item index="75" class="java.lang.String" itemvalue="keras" />
<item index="76" class="java.lang.String" itemvalue="entrypoints" />
<item index="77" class="java.lang.String" itemvalue="bkcharts" />
<item index="78" class="java.lang.String" itemvalue="pyparsing" />
<item index="79" class="java.lang.String" itemvalue="munch" />
<item index="80" class="java.lang.String" itemvalue="sphinxcontrib-websupport" />
<item index="81" class="java.lang.String" itemvalue="beautifulsoup4" />
<item index="82" class="java.lang.String" itemvalue="path.py" />
<item index="83" class="java.lang.String" itemvalue="clyent" />
<item index="84" class="java.lang.String" itemvalue="navigator-updater" />
<item index="85" class="java.lang.String" itemvalue="tifffile" />
<item index="86" class="java.lang.String" itemvalue="cryptography" />
<item index="87" class="java.lang.String" itemvalue="pygdal" />
<item index="88" class="java.lang.String" itemvalue="fastrlock" />
<item index="89" class="java.lang.String" itemvalue="widgetsnbextension" />
<item index="90" class="java.lang.String" itemvalue="multipledispatch" />
<item index="91" class="java.lang.String" itemvalue="numexpr" />
<item index="92" class="java.lang.String" itemvalue="jupyter-core" />
<item index="93" class="java.lang.String" itemvalue="ipython_genutils" />
<item index="94" class="java.lang.String" itemvalue="yapf" />
<item index="95" class="java.lang.String" itemvalue="rope" />
<item index="96" class="java.lang.String" itemvalue="wcwidth" />
<item index="97" class="java.lang.String" itemvalue="cupy-cuda110" />
<item index="98" class="java.lang.String" itemvalue="llvmlite" />
<item index="99" class="java.lang.String" itemvalue="Jinja2" />
<item index="100" class="java.lang.String" itemvalue="pycrypto" />
<item index="101" class="java.lang.String" itemvalue="Keras-Preprocessing" />
<item index="102" class="java.lang.String" itemvalue="ptflops" />
<item index="103" class="java.lang.String" itemvalue="cupy-cuda111" />
<item index="104" class="java.lang.String" itemvalue="cupy-cuda114" />
<item index="105" class="java.lang.String" itemvalue="some-package" />
<item index="106" class="java.lang.String" itemvalue="wandb" />
<item index="107" class="java.lang.String" itemvalue="netaddr" />
<item index="108" class="java.lang.String" itemvalue="sortedcollections" />
<item index="109" class="java.lang.String" itemvalue="six" />
<item index="110" class="java.lang.String" itemvalue="timm" />
<item index="111" class="java.lang.String" itemvalue="pyflakes" />
<item index="112" class="java.lang.String" itemvalue="asn1crypto" />
<item index="113" class="java.lang.String" itemvalue="parso" />
<item index="114" class="java.lang.String" itemvalue="pytest-doctestplus" />
<item index="115" class="java.lang.String" itemvalue="ipython" />
<item index="116" class="java.lang.String" itemvalue="xlwt" />
<item index="117" class="java.lang.String" itemvalue="packaging" />
<item index="118" class="java.lang.String" itemvalue="chardet" />
<item index="119" class="java.lang.String" itemvalue="jupyterlab-launcher" />
<item index="120" class="java.lang.String" itemvalue="click-plugins" />
<item index="121" class="java.lang.String" itemvalue="PyYAML" />
<item index="122" class="java.lang.String" itemvalue="pickleshare" />
<item index="123" class="java.lang.String" itemvalue="pycparser" />
<item index="124" class="java.lang.String" itemvalue="pyasn1-modules" />
<item index="125" class="java.lang.String" itemvalue="tables" />
<item index="126" class="java.lang.String" itemvalue="Pygments" />
<item index="127" class="java.lang.String" itemvalue="sentry-sdk" />
<item index="128" class="java.lang.String" itemvalue="docutils" />
<item index="129" class="java.lang.String" itemvalue="gevent" />
<item index="130" class="java.lang.String" itemvalue="shortuuid" />
<item index="131" class="java.lang.String" itemvalue="qtconsole" />
<item index="132" class="java.lang.String" itemvalue="terminado" />
<item index="133" class="java.lang.String" itemvalue="GitPython" />
<item index="134" class="java.lang.String" itemvalue="distributed" />
<item index="135" class="java.lang.String" itemvalue="jupyter-client" />
<item index="136" class="java.lang.String" itemvalue="pexpect" />
<item index="137" class="java.lang.String" itemvalue="ipykernel" />
<item index="138" class="java.lang.String" itemvalue="nbconvert" />
<item index="139" class="java.lang.String" itemvalue="attrs" />
<item index="140" class="java.lang.String" itemvalue="psutil" />
<item index="141" class="java.lang.String" itemvalue="simplejson" />
<item index="142" class="java.lang.String" itemvalue="jedi" />
<item index="143" class="java.lang.String" itemvalue="flatbuffers" />
<item index="144" class="java.lang.String" itemvalue="cytoolz" />
<item index="145" class="java.lang.String" itemvalue="odo" />
<item index="146" class="java.lang.String" itemvalue="decorator" />
<item index="147" class="java.lang.String" itemvalue="pandocfilters" />
<item index="148" class="java.lang.String" itemvalue="backports.shutil-get-terminal-size" />
<item index="149" class="java.lang.String" itemvalue="pycodestyle" />
<item index="150" class="java.lang.String" itemvalue="pycosat" />
<item index="151" class="java.lang.String" itemvalue="pyasn1" />
<item index="152" class="java.lang.String" itemvalue="requests" />
<item index="153" class="java.lang.String" itemvalue="bitarray" />
<item index="154" class="java.lang.String" itemvalue="kornia" />
<item index="155" class="java.lang.String" itemvalue="mkl-fft" />
<item index="156" class="java.lang.String" itemvalue="tensorflow" />
<item index="157" class="java.lang.String" itemvalue="XlsxWriter" />
<item index="158" class="java.lang.String" itemvalue="seaborn" />
<item index="159" class="java.lang.String" itemvalue="tensorboard-plugin-wit" />
<item index="160" class="java.lang.String" itemvalue="blaze" />
<item index="161" class="java.lang.String" itemvalue="zipp" />
<item index="162" class="java.lang.String" itemvalue="pkginfo" />
<item index="163" class="java.lang.String" itemvalue="cached-property" />
<item index="164" class="java.lang.String" itemvalue="torchstat" />
<item index="165" class="java.lang.String" itemvalue="datashape" />
<item index="166" class="java.lang.String" itemvalue="itsdangerous" />
<item index="167" class="java.lang.String" itemvalue="ipywidgets" />
<item index="168" class="java.lang.String" itemvalue="scipy" />
<item index="169" class="java.lang.String" itemvalue="thop" />
<item index="170" class="java.lang.String" itemvalue="tornado" />
<item index="171" class="java.lang.String" itemvalue="google-auth-oauthlib" />
<item index="172" class="java.lang.String" itemvalue="opencv-python" />
<item index="173" class="java.lang.String" itemvalue="torch" />
<item index="174" class="java.lang.String" itemvalue="singledispatch" />
<item index="175" class="java.lang.String" itemvalue="sortedcontainers" />
<item index="176" class="java.lang.String" itemvalue="mistune" />
<item index="177" class="java.lang.String" itemvalue="pandas" />
<item index="178" class="java.lang.String" itemvalue="termcolor" />
<item index="179" class="java.lang.String" itemvalue="clang" />
<item index="180" class="java.lang.String" itemvalue="toolz" />
<item index="181" class="java.lang.String" itemvalue="Sphinx" />
<item index="182" class="java.lang.String" itemvalue="mpmath" />
<item index="183" class="java.lang.String" itemvalue="jupyter-console" />
<item index="184" class="java.lang.String" itemvalue="bokeh" />
<item index="185" class="java.lang.String" itemvalue="cachetools" />
<item index="186" class="java.lang.String" itemvalue="gmpy2" />
<item index="187" class="java.lang.String" itemvalue="setproctitle" />
<item index="188" class="java.lang.String" itemvalue="webencodings" />
<item index="189" class="java.lang.String" itemvalue="html5lib" />
<item index="190" class="java.lang.String" itemvalue="colorlog" />
<item index="191" class="java.lang.String" itemvalue="python-dateutil" />
<item index="192" class="java.lang.String" itemvalue="QtPy" />
<item index="193" class="java.lang.String" itemvalue="astroid" />
<item index="194" class="java.lang.String" itemvalue="cycler" />
<item index="195" class="java.lang.String" itemvalue="mkl-random" />
<item index="196" class="java.lang.String" itemvalue="pytest-arraydiff" />
<item index="197" class="java.lang.String" itemvalue="locket" />
<item index="198" class="java.lang.String" itemvalue="heapdict" />
<item index="199" class="java.lang.String" itemvalue="snowballstemmer" />
<item index="200" class="java.lang.String" itemvalue="contextlib2" />
<item index="201" class="java.lang.String" itemvalue="certifi" />
<item index="202" class="java.lang.String" itemvalue="Markdown" />
<item index="203" class="java.lang.String" itemvalue="sympy" />
<item index="204" class="java.lang.String" itemvalue="notebook" />
<item index="205" class="java.lang.String" itemvalue="pyodbc" />
<item index="206" class="java.lang.String" itemvalue="boto" />
<item index="207" class="java.lang.String" itemvalue="cligj" />
<item index="208" class="java.lang.String" itemvalue="h5py" />
<item index="209" class="java.lang.String" itemvalue="wrapt" />
<item index="210" class="java.lang.String" itemvalue="kiwisolver" />
<item index="211" class="java.lang.String" itemvalue="pytest-openfiles" />
<item index="212" class="java.lang.String" itemvalue="anaconda-client" />
<item index="213" class="java.lang.String" itemvalue="backcall" />
<item index="214" class="java.lang.String" itemvalue="PySocks" />
<item index="215" class="java.lang.String" itemvalue="charset-normalizer" />
<item index="216" class="java.lang.String" itemvalue="typing" />
<item index="217" class="java.lang.String" itemvalue="dask" />
<item index="218" class="java.lang.String" itemvalue="enum34" />
<item index="219" class="java.lang.String" itemvalue="torchsummary" />
<item index="220" class="java.lang.String" itemvalue="scikit-image" />
<item index="221" class="java.lang.String" itemvalue="ptyprocess" />
<item index="222" class="java.lang.String" itemvalue="more-itertools" />
<item index="223" class="java.lang.String" itemvalue="SQLAlchemy" />
<item index="224" class="java.lang.String" itemvalue="tblib" />
<item index="225" class="java.lang.String" itemvalue="cloudpickle" />
<item index="226" class="java.lang.String" itemvalue="importlib-metadata" />
<item index="227" class="java.lang.String" itemvalue="simplegeneric" />
<item index="228" class="java.lang.String" itemvalue="zict" />
<item index="229" class="java.lang.String" itemvalue="urllib3" />
<item index="230" class="java.lang.String" itemvalue="jupyterlab" />
<item index="231" class="java.lang.String" itemvalue="Cython" />
<item index="232" class="java.lang.String" itemvalue="Flask" />
<item index="233" class="java.lang.String" itemvalue="nose" />
<item index="234" class="java.lang.String" itemvalue="pytorch-msssim" />
<item index="235" class="java.lang.String" itemvalue="pytest" />
<item index="236" class="java.lang.String" itemvalue="nbformat" />
<item index="237" class="java.lang.String" itemvalue="matmul" />
<item index="238" class="java.lang.String" itemvalue="tqdm" />
<item index="239" class="java.lang.String" itemvalue="lazy-object-proxy" />
<item index="240" class="java.lang.String" itemvalue="colorama" />
<item index="241" class="java.lang.String" itemvalue="grpcio" />
<item index="242" class="java.lang.String" itemvalue="ply" />
<item index="243" class="java.lang.String" itemvalue="google-auth" />
<item index="244" class="java.lang.String" itemvalue="openpyxl" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

13
.idea/misc.xml Normal file
View File

@ -0,0 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MavenImportPreferences">
<option name="generalSettings">
<MavenGeneralSettings>
<option name="localRepository" value="E:\maven\repository" />
<option name="showDialogWithAdvancedSettings" value="true" />
<option name="userSettingsFile" value="E:\maven\settings.xml" />
</MavenGeneralSettings>
</option>
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.12.4 (sftp://star@192.168.50.108:22/home/star/anaconda3/bin/python)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/CDDFuse.iml" filepath="$PROJECT_DIR$/.idea/CDDFuse.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

1
MSRS_train/readme.md Normal file
View File

@ -0,0 +1 @@
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it here.

186
README.md Normal file
View File

@ -0,0 +1,186 @@
# CDDFuse
Codes for ***CDDFuse: Correlation-Driven Dual-Branch Feature Decomposition for Multi-Modality Image Fusion. (CVPR 2023)***
[Zixiang Zhao](https://zhaozixiang1228.github.io/), [Haowen Bai](), [Jiangshe Zhang](http://gr.xjtu.edu.cn/web/jszhang), [Yulun Zhang](https://yulunzhang.com/), [Shuang Xu](https://shuangxu96.github.io/), [Zudi Lin](https://zudi-lin.github.io/), [Radu Timofte](https://www.informatik.uni-wuerzburg.de/computervision/home/) and [Luc Van Gool](https://vision.ee.ethz.ch/people-details.OTAyMzM=.TGlzdC8zMjQ4LC0xOTcxNDY1MTc4.html).
-[*[Paper]*](https://openaccess.thecvf.com/content/CVPR2023/html/Zhao_CDDFuse_Correlation-Driven_Dual-Branch_Feature_Decomposition_for_Multi-Modality_Image_Fusion_CVPR_2023_paper.html)
-[*[ArXiv]*](https://arxiv.org/abs/2104.06977)
-[*[Supplementary Materials]*](https://openaccess.thecvf.com/content/CVPR2023/supplemental/Zhao_CDDFuse_Correlation-Driven_Dual-Branch_CVPR_2023_supplemental.pdf)
## Update
- [2023/6] Training codes and config files are public available.
- [2023/4] Release inference code for infrared-visible image fusion and medical image fusion.
## Citation
```
@InProceedings{Zhao_2023_CVPR,
author = {Zhao, Zixiang and Bai, Haowen and Zhang, Jiangshe and Zhang, Yulun and Xu, Shuang and Lin, Zudi and Timofte, Radu and Van Gool, Luc},
title = {CDDFuse: Correlation-Driven Dual-Branch Feature Decomposition for Multi-Modality Image Fusion},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2023},
pages = {5906-5916}
}
```
## Abstract
Multi-modality (MM) image fusion aims to render fused images that maintain the merits of different modalities, e.g., functional highlight and detailed textures. To tackle the challenge in modeling cross-modality features and decomposing desirable modality-specific and modality-shared features, we propose a novel Correlation-Driven feature Decomposition Fusion (CDDFuse) network. Firstly, CDDFuse uses Restormer blocks to extract cross-modality shallow features. We then introduce a dual-branch Transformer-CNN feature extractor with Lite Transformer (LT) blocks leveraging long-range attention to handle low-frequency global features and Invertible Neural Networks (INN) blocks focusing on extracting high-frequency local information. A correlation-driven loss is further proposed to make the low-frequency features correlated while the high-frequency features uncorrelated based on the embedded information. Then, the LT-based global fusion and INN-based local fusion layers output the fused image. Extensive experiments demonstrate that our CDDFuse achieves promising results in multiple fusion tasks, including infrared-visible image fusion and medical image fusion. We also show that CDDFuse can boost the performance in downstream infrared-visible semantic segmentation and object detection in a unified benchmark.
## 🌐 Usage
### ⚙ Network Architecture
Our CDDFuse is implemented in ``net.py``.
### 🏊 Training
**1. Virtual Environment**
```
# create virtual environment
conda create -n cddfuse python=3.8.10
conda activate cddfuse
# select pytorch version yourself
# install cddfuse requirements
pip install -r requirements.txt
```
**2. Data Preparation**
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it in the folder ``'./MSRS_train/'``.
**3. Pre-Processing**
Run
```
python dataprocessing.py
```
and the processed training dataset is in ``'./data/MSRS_train_imgsize_128_stride_200.h5'``.
**4. CDDFuse Training**
Run
```
python train.py
```
and the trained model is available in ``'./models/'``.
### 🏄 Testing
**1. Pretrained models**
Pretrained models are available in ``'./models/CDDFuse_IVF.pth'`` and ``'./models/CDDFuse_MIF.pth'``, which are responsible for the Infrared-Visible Fusion (IVF) and Medical Image Fusion (MIF) tasks, respectively.
**2. Test datasets**
The test datasets used in the paper have been stored in ``'./test_img/RoadScene'``, ``'./test_img/TNO'`` for IVF, ``'./test_img/MRI_CT'``, ``'./test_img/MRI_PET'`` and ``'./test_img/MRI_SPECT'`` for MIF.
Unfortunately, since the size of **MSRS dataset** for IVF is 500+MB, we can not upload it for exhibition. It can be downloaded via [this link](https://github.com/Linfeng-Tang/MSRS). The other datasets contain all the test images.
**3. Results in Our Paper**
If you want to infer with our CDDFuse and obtain the fusion results in our paper, please run
```
python test_IVF.py
```
for Infrared-Visible Fusion and
```
python test_MIF.py
```
for Medical Image Fusion.
The testing results will be printed in the terminal.
The output for ``'test_IVF.py'`` is:
```
================================================================================
The test result of TNO :
EN SD SF MI SCD VIF Qabf SSIM
CDDFuse 7.12 46.0 13.15 2.19 1.76 0.77 0.54 1.03
================================================================================
================================================================================
The test result of RoadScene :
EN SD SF MI SCD VIF Qabf SSIM
CDDFuse 7.44 54.67 16.36 2.3 1.81 0.69 0.52 0.98
================================================================================
```
which can match the results in Table 1 in our original paper.
The output for ``'test_MIF.py'`` is:
```
================================================================================
The test result of MRI_CT :
EN SD SF MI SCD VIF Qabf SSIM
CDDFuse_IVF 4.83 88.59 33.83 2.24 1.74 0.5 0.59 1.31
CDDFuse_MIF 4.88 79.17 38.14 2.61 1.41 0.61 0.68 1.34
================================================================================
================================================================================
The test result of MRI_PET :
EN SD SF MI SCD VIF Qabf SSIM
CDDFuse_IVF 4.23 81.69 28.04 1.87 1.82 0.66 0.65 1.46
CDDFuse_MIF 4.22 70.74 29.57 2.03 1.69 0.71 0.71 1.49
================================================================================
================================================================================
The test result of MRI_SPECT :
EN SD SF MI SCD VIF Qabf SSIM
CDDFuse_IVF 3.91 71.81 20.66 1.9 1.87 0.65 0.68 1.45
CDDFuse_MIF 3.9 58.31 20.87 2.49 1.35 0.97 0.78 1.48
================================================================================
```
which can match the results in Table 5 in our original paper.
## 🙌 CDDFuse
### Illustration of our CDDFuse model.
<img src="image//Workflow.png" width="90%" align=center />
### Qualitative fusion results.
<img src="image//IVF1.png" width="90%" align=center />
<img src="image//IVF2.png" width="90%" align=center />
<img src="image//MIF.png" width="60%" align=center />
### Quantitative fusion results.
Infrared-Visible Image Fusion
<img src="image//Quantitative_IVF.png" width="60%" align=center />
Medical Image Fusion
<img src="image//Quantitative_MIF.png" width="60%" align=center />
MM detection
<img src="image//MMDet.png" width="60%" align=center />
MM segmentation
<img src="image//MMSeg.png" width="60%" align=center />
## 📖 Related Work
- Zixiang Zhao, Haowen Bai, Jiangshe Zhang, Yulun Zhang, Kai Zhang, Shuang Xu, Dongdong Chen, Radu Timofte, Luc Van Gool. *Equivariant Multi-Modality Image Fusion.* **CVPR 2024**, https://arxiv.org/abs/2305.11443
- Zixiang Zhao, Haowen Bai, Yuanzhi Zhu, Jiangshe Zhang, Shuang Xu, Yulun Zhang, Kai Zhang, Deyu Meng, Radu Timofte, Luc Van Gool.
*DDFM: Denoising Diffusion Model for Multi-Modality Image Fusion.* **ICCV 2023 (Oral)**, https://arxiv.org/abs/2303.06840
- Zixiang Zhao, Shuang Xu, Chunxia Zhang, Junmin Liu, Jiangshe Zhang and Pengfei Li. *DIDFuse: Deep Image Decomposition for Infrared and Visible Image Fusion.* **IJCAI 2020**, https://www.ijcai.org/Proceedings/2020/135.
- Zixiang Zhao, Shuang Xu, Jiangshe Zhang, Chengyang Liang, Chunxia Zhang and Junmin Liu. *Efficient and Model-Based Infrared and Visible Image Fusion via Algorithm Unrolling.* **IEEE Transactions on Circuits and Systems for Video Technology 2021**, https://ieeexplore.ieee.org/document/9416456.
- Zixiang Zhao, Jiangshe Zhang, Haowen Bai, Yicheng Wang, Yukun Cui, Lilun Deng, Kai Sun, Chunxia Zhang, Junmin Liu, Shuang Xu. *Deep Convolutional Sparse Coding Networks for Interpretable Image Fusion.* **CVPR Workshop 2023**. https://robustart.github.io/long_paper/26.pdf.
- Zixiang Zhao, Shuang Xu, Chunxia Zhang, Junmin Liu, Jiangshe Zhang. *Bayesian fusion for infrared and visible images.* **Signal Processing**, https://doi.org/10.1016/j.sigpro.2020.107734.

93
dataprocessing.py Normal file
View File

@ -0,0 +1,93 @@
import os
import h5py
import numpy as np
from tqdm import tqdm
from skimage.io import imread
def get_img_file(file_name):
imagelist = []
for parent, dirnames, filenames in os.walk(file_name):
for filename in filenames:
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
imagelist.append(os.path.join(parent, filename))
return imagelist
def rgb2y(img):
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
return y
def Im2Patch(img, win, stride=1):
k = 0
endc = img.shape[0]
endw = img.shape[1]
endh = img.shape[2]
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
TotalPatNum = patch.shape[1] * patch.shape[2]
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
for i in range(win):
for j in range(win):
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
k = k + 1
return Y.reshape([endc, win, win, TotalPatNum])
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
upper_percentile=90):
"""Determine if an image is low contrast."""
limits = np.percentile(image, [lower_percentile, upper_percentile])
ratio = (limits[1] - limits[0]) / limits[1]
return ratio < fraction_threshold
data_name="MSRS_train"
img_size=128 #patch size
stride=200 #patch stride
IR_files = sorted(get_img_file(r"MSRS_train/ir"))
VIS_files = sorted(get_img_file(r"MSRS_train/vi"))
assert len(IR_files) == len(VIS_files)
h5f = h5py.File(os.path.join('.\\data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
'w')
h5_ir = h5f.create_group('ir_patchs')
h5_vis = h5f.create_group('vis_patchs')
train_num=0
for i in tqdm(range(len(IR_files))):
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
# crop
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
for ii in range(I_IR_Patch_Group.shape[-1]):
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
# Determine if the contrast is low
if not (bad_IR or bad_VIS):
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
avl_IR=avl_IR[None,...]
avl_VIS=avl_VIS[None,...]
h5_ir.create_dataset(str(train_num), data=avl_IR,
dtype=avl_IR.dtype, shape=avl_IR.shape)
h5_vis.create_dataset(str(train_num), data=avl_VIS,
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
train_num += 1
h5f.close()
with h5py.File(os.path.join('data',
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
for key in f.keys():
print(f[key], key, f[key].name)

BIN
image/IVF1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 868 KiB

BIN
image/IVF2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1008 KiB

BIN
image/MIF.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

BIN
image/MMDet.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 122 KiB

BIN
image/MMSeg.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 126 KiB

BIN
image/Quantitative_IVF.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 378 KiB

BIN
image/Quantitative_MIF.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 417 KiB

BIN
image/Workflow.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 249 KiB

BIN
models/CDDFuse_IVF.pth Normal file

Binary file not shown.

BIN
models/CDDFuse_MIF.pth Normal file

Binary file not shown.

403
net.py Normal file
View File

@ -0,0 +1,403 @@
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class AttentionBase(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,):
super(AttentionBase, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv1 = nn.Conv2d(dim, dim*3, kernel_size=1, bias=qkv_bias)
self.qkv2 = nn.Conv2d(dim*3, dim*3, kernel_size=3, padding=1, bias=qkv_bias)
self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=qkv_bias)
def forward(self, x):
# [batch_size, num_patches + 1, total_embed_dim]
b, c, h, w = x.shape
qkv = self.qkv2(self.qkv1(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
head=self.num_heads, h=h, w=w)
out = self.proj(out)
return out
class Mlp(nn.Module):
"""
MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self,
in_features,
hidden_features=None,
ffn_expansion_factor = 2,
bias = False):
super().__init__()
hidden_features = int(in_features*ffn_expansion_factor)
self.project_in = nn.Conv2d(
in_features, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3,
stride=1, padding=1, groups=hidden_features, bias=bias)
self.project_out = nn.Conv2d(
hidden_features, in_features, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
class BaseFeatureExtraction(nn.Module):
def __init__(self,
dim,
num_heads,
ffn_expansion_factor=1.,
qkv_bias=False,):
super(BaseFeatureExtraction, self).__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
self.attn = AttentionBase(dim, num_heads=num_heads, qkv_bias=qkv_bias,)
self.norm2 = LayerNorm(dim, 'WithBias')
self.mlp = Mlp(in_features=dim,
ffn_expansion_factor=ffn_expansion_factor,)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class InvertedResidualBlock(nn.Module):
def __init__(self, inp, oup, expand_ratio):
super(InvertedResidualBlock, self).__init__()
hidden_dim = int(inp * expand_ratio)
self.bottleneckBlock = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.ReflectionPad2d(1),
nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False),
# nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, bias=False),
# nn.BatchNorm2d(oup),
)
def forward(self, x):
return self.bottleneckBlock(x)
class DetailNode(nn.Module):
def __init__(self):
super(DetailNode, self).__init__()
# Scale is Ax + b, i.e. affine transformation
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
stride=1, padding=0, bias=True)
def separateFeature(self, x):
z1, z2 = x[:, :x.shape[1]//2], x[:, x.shape[1]//2:x.shape[1]]
return z1, z2
def forward(self, z1, z2):
z1, z2 = self.separateFeature(
self.shffleconv(torch.cat((z1, z2), dim=1)))
z2 = z2 + self.theta_phi(z1)
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
return z1, z2
class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1]//2], x[:, x.shape[1]//2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
# =============================================================================
# =============================================================================
import numbers
##########################################################################
## Layer Norm
def to_3d(x):
return rearrange(x, 'b c h w -> b (h w) c')
def to_4d(x, h, w):
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma+1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
assert len(normalized_shape) == 1
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.normalized_shape = normalized_shape
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
h, w = x.shape[-2:]
return to_4d(self.body(to_3d(x)), h, w)
##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(
dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3,
stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.project_out = nn.Conv2d(
hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x1) * x2
x = self.project_out(x)
return x
##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(
dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out = (attn @ v)
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
##########################################################################
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
super(TransformerBlock, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type)
self.attn = Attention(dim, num_heads, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class OverlapPatchEmbed(nn.Module):
def __init__(self, in_c=3, embed_dim=48, bias=False):
super(OverlapPatchEmbed, self).__init__()
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
stride=1, padding=1, bias=bias)
def forward(self, x):
x = self.proj(x)
return x
class Restormer_Encoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Encoder, self).__init__()
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.baseFeature = BaseFeatureExtraction(dim=dim, num_heads = heads[2])
self.detailFeature = DetailFeatureExtraction()
def forward(self, inp_img):
inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1)
base_feature = self.baseFeature(out_enc_level1)
detail_feature = self.detailFeature(out_enc_level1)
return base_feature, detail_feature, out_enc_level1
class Restormer_Decoder(nn.Module):
def __init__(self,
inp_channels=1,
out_channels=1,
dim=64,
num_blocks=[4, 4],
heads=[8, 8, 8],
ffn_expansion_factor=2,
bias=False,
LayerNorm_type='WithBias',
):
super(Restormer_Decoder, self).__init__()
self.reduce_channel = nn.Conv2d(int(dim*2), int(dim), kernel_size=1, bias=bias)
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.output = nn.Sequential(
nn.Conv2d(int(dim), int(dim)//2, kernel_size=3,
stride=1, padding=1, bias=bias),
nn.LeakyReLU(),
nn.Conv2d(int(dim)//2, out_channels, kernel_size=3,
stride=1, padding=1, bias=bias),)
self.sigmoid = nn.Sigmoid()
def forward(self, inp_img, base_feature, detail_feature):
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
out_enc_level0 = self.reduce_channel(out_enc_level0)
out_enc_level1 = self.encoder_level2(out_enc_level0)
if inp_img is not None:
out_enc_level1 = self.output(out_enc_level1) + inp_img
else:
out_enc_level1 = self.output(out_enc_level1)
return self.sigmoid(out_enc_level1), out_enc_level0
if __name__ == '__main__':
height = 128
width = 128
window_size = 8
modelE = Restormer_Encoder().cuda()
modelD = Restormer_Decoder().cuda()

10
requirements.txt Normal file
View File

@ -0,0 +1,10 @@
einops==0.4.1
kornia==0.2.0
numpy==1.21.5
opencv_python==4.5.3.56
scikit_image==0.19.2
scikit_learn==1.1.3
scipy==1.7.3
tensorboardX==2.5.1
timm==0.4.12
torch==1.8.1+cu111

80
test_IVF.py Normal file
View File

@ -0,0 +1,80 @@
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
import numpy as np
from utils.Evaluator import Evaluator
import torch
import torch.nn as nn
from utils.img_read_save import img_save,image_read_cv2
import warnings
import logging
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path=r"models/CDDFuse_IVF.pth"
for dataset_name in ["TNO","RoadScene"]:
print("\n"*2+"="*80)
model_name="CDDFuse "
print("The test result of "+dataset_name+' :')
test_folder=os.path.join('test_img',dataset_name)
test_out_folder=os.path.join('test_result',dataset_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'])
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
Encoder.eval()
Decoder.eval()
BaseFuseLayer.eval()
DetailFuseLayer.eval()
with torch.no_grad():
for img_name in os.listdir(os.path.join(test_folder,"ir")):
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_VIS = image_read_cv2(os.path.join(test_folder,"vi",img_name), mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
img_save(fi, img_name.split(sep='.')[0], test_out_folder)
eval_folder=test_out_folder
ori_img_folder=test_folder
metric_result = np.zeros((8))
for img_name in os.listdir(os.path.join(ori_img_folder,"ir")):
ir = image_read_cv2(os.path.join(ori_img_folder,"ir", img_name), 'GRAY')
vi = image_read_cv2(os.path.join(ori_img_folder,"vi", img_name), 'GRAY')
fi = image_read_cv2(os.path.join(eval_folder, img_name.split('.')[0]+".png"), 'GRAY')
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
metric_result /= len(os.listdir(eval_folder))
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
print(model_name+'\t'+str(np.round(metric_result[0], 2))+'\t'
+str(np.round(metric_result[1], 2))+'\t'
+str(np.round(metric_result[2], 2))+'\t'
+str(np.round(metric_result[3], 2))+'\t'
+str(np.round(metric_result[4], 2))+'\t'
+str(np.round(metric_result[5], 2))+'\t'
+str(np.round(metric_result[6], 2))+'\t'
+str(np.round(metric_result[7], 2))
)
print("="*80)

86
test_MIF.py Normal file
View File

@ -0,0 +1,86 @@
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os
import numpy as np
from utils.Evaluator import Evaluator
import torch
import torch.nn as nn
from utils.img_read_save import img_save,image_read_cv2
import warnings
import logging
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL)
import cv2
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
CDDFuse_path=r"models/CDDFuse_IVF.pth"
CDDFuse_MIF_path=r"models/CDDFuse_MIF.pth"
for dataset_name in ["MRI_CT","MRI_PET","MRI_SPECT"]:
print("\n"*2+"="*80)
print("The test result of "+dataset_name+" :")
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
for ckpt_path in [CDDFuse_path,CDDFuse_MIF_path]:
model_name=ckpt_path.split('/')[-1].split('.')[0]
test_folder=os.path.join('test_img',dataset_name)
test_out_folder=os.path.join('test_result',dataset_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'])
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
Encoder.eval()
Decoder.eval()
BaseFuseLayer.eval()
DetailFuseLayer.eval()
with torch.no_grad():
for img_name in os.listdir(os.path.join(test_folder,dataset_name.split('_')[0])):
data_IR=image_read_cv2(os.path.join(test_folder,dataset_name.split('_')[1],img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_VIS = image_read_cv2(os.path.join(test_folder,dataset_name.split('_')[0],img_name), mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
if ckpt_path==CDDFuse_path:
data_Fuse, _ = Decoder(data_IR+data_VIS, feature_F_B, feature_F_D)
else:
data_Fuse, _ = Decoder(None, feature_F_B, feature_F_D)
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
img_save(fi, img_name.split(sep='.')[0], test_out_folder)
eval_folder=test_out_folder
ori_img_folder=test_folder
metric_result = np.zeros((8))
for img_name in os.listdir(os.path.join(ori_img_folder,dataset_name.split('_')[0])):
ir = image_read_cv2(os.path.join(ori_img_folder,dataset_name.split('_')[1], img_name), 'GRAY')
vi = image_read_cv2(os.path.join(ori_img_folder,dataset_name.split('_')[0], img_name), 'GRAY')
fi = image_read_cv2(os.path.join(eval_folder, img_name.split('.')[0]+".png"), 'GRAY')
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
metric_result /= len(os.listdir(eval_folder))
print(model_name+'\t'+str(np.round(metric_result[0], 2))+'\t'
+str(np.round(metric_result[1], 2))+'\t'
+str(np.round(metric_result[2], 2))+'\t'
+str(np.round(metric_result[3], 2))+'\t'
+str(np.round(metric_result[4], 2))+'\t'
+str(np.round(metric_result[5], 2))+'\t'
+str(np.round(metric_result[6], 2))+'\t'
+str(np.round(metric_result[7], 2))
)
print("="*80)

BIN
test_img/MRI_CT/CT/11.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_CT/CT/12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/14.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_CT/CT/15.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_CT/CT/16.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_CT/CT/18.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
test_img/MRI_CT/CT/19.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

BIN
test_img/MRI_CT/CT/20.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

BIN
test_img/MRI_CT/CT/21.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

BIN
test_img/MRI_CT/CT/22.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/23.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
test_img/MRI_CT/CT/24.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
test_img/MRI_CT/CT/25.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_CT/CT/26.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/27.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_CT/CT/28.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_CT/CT/29.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_CT/CT/30.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_CT/CT/31.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
test_img/MRI_CT/MRI/11.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

BIN
test_img/MRI_CT/MRI/12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

BIN
test_img/MRI_CT/MRI/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

BIN
test_img/MRI_CT/MRI/14.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

BIN
test_img/MRI_CT/MRI/15.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

BIN
test_img/MRI_CT/MRI/16.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

BIN
test_img/MRI_CT/MRI/17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

BIN
test_img/MRI_CT/MRI/18.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

BIN
test_img/MRI_CT/MRI/19.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
test_img/MRI_CT/MRI/20.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

BIN
test_img/MRI_CT/MRI/21.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

BIN
test_img/MRI_CT/MRI/22.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

BIN
test_img/MRI_CT/MRI/23.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
test_img/MRI_CT/MRI/24.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

BIN
test_img/MRI_CT/MRI/25.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

BIN
test_img/MRI_CT/MRI/26.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

BIN
test_img/MRI_CT/MRI/27.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

BIN
test_img/MRI_CT/MRI/28.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

BIN
test_img/MRI_CT/MRI/29.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

BIN
test_img/MRI_CT/MRI/30.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

BIN
test_img/MRI_CT/MRI/31.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

BIN
test_img/MRI_PET/MRI/11.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_PET/MRI/14.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_PET/MRI/15.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/16.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/18.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/19.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_PET/MRI/20.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/21.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/22.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_PET/MRI/23.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_PET/MRI/24.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

BIN
test_img/MRI_PET/MRI/25.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
test_img/MRI_PET/MRI/26.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

BIN
test_img/MRI_PET/MRI/27.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

BIN
test_img/MRI_PET/MRI/28.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

BIN
test_img/MRI_PET/MRI/29.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
test_img/MRI_PET/MRI/30.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
test_img/MRI_PET/MRI/31.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

BIN
test_img/MRI_PET/MRI/32.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_PET/MRI/33.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_PET/MRI/34.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/35.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/36.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/37.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/38.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
test_img/MRI_PET/MRI/39.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

BIN
test_img/MRI_PET/MRI/40.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
test_img/MRI_PET/MRI/41.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
test_img/MRI_PET/MRI/42.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_PET/MRI/43.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

Some files were not shown because too many files have changed in this diff Show More