添加数据处理脚本和H5数据集类
- 新增dataprocessing.py脚本,实现图像数据处理功能,包括文件读取、格式转换、低对比度筛选等 - 新增H5Dataset类,用于加载和访问H5格式的图像数据集 - 在项目中配置远程服务器部署和代码自动上传 - 添加IDE配置文件,包括项目路径、模块管理、代码检查等设置
8
.idea/.gitignore
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
8
.idea/CDDFuse.iml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="Remote Python 3.12.4 (sftp://star@192.168.50.108:22/home/star/anaconda3/bin/python)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
57
.idea/deployment.xml
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="PublishConfigData" autoUpload="Always" serverName="star@192.168.50.108:22 password (6)" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||||
|
<serverData>
|
||||||
|
<paths name="star@192.168.50.108:22 password">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="star@192.168.50.108:22 password (2)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="star@192.168.50.108:22 password (3)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="star@192.168.50.108:22 password (4)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="star@192.168.50.108:22 password (5)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="star@192.168.50.108:22 password (6)">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping deploy="/home/star/whaiDir/CDDFuse" local="$PROJECT_DIR$" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
<paths name="v100">
|
||||||
|
<serverdata>
|
||||||
|
<mappings>
|
||||||
|
<mapping local="$PROJECT_DIR$" web="/" />
|
||||||
|
</mappings>
|
||||||
|
</serverdata>
|
||||||
|
</paths>
|
||||||
|
</serverData>
|
||||||
|
<option name="myAutoUpload" value="ALWAYS" />
|
||||||
|
</component>
|
||||||
|
</project>
|
264
.idea/inspectionProfiles/Project_Default.xml
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||||
|
<Languages>
|
||||||
|
<language minSize="226" name="Python" />
|
||||||
|
</Languages>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="245">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="numba" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="tensorflow-estimator" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="greenlet" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="Babel" />
|
||||||
|
<item index="4" class="java.lang.String" itemvalue="scikit-learn" />
|
||||||
|
<item index="5" class="java.lang.String" itemvalue="testpath" />
|
||||||
|
<item index="6" class="java.lang.String" itemvalue="py" />
|
||||||
|
<item index="7" class="java.lang.String" itemvalue="gitdb" />
|
||||||
|
<item index="8" class="java.lang.String" itemvalue="torchvision" />
|
||||||
|
<item index="9" class="java.lang.String" itemvalue="patsy" />
|
||||||
|
<item index="10" class="java.lang.String" itemvalue="mccabe" />
|
||||||
|
<item index="11" class="java.lang.String" itemvalue="bleach" />
|
||||||
|
<item index="12" class="java.lang.String" itemvalue="lxml" />
|
||||||
|
<item index="13" class="java.lang.String" itemvalue="torchaudio" />
|
||||||
|
<item index="14" class="java.lang.String" itemvalue="jsonschema" />
|
||||||
|
<item index="15" class="java.lang.String" itemvalue="xlrd" />
|
||||||
|
<item index="16" class="java.lang.String" itemvalue="Werkzeug" />
|
||||||
|
<item index="17" class="java.lang.String" itemvalue="anaconda-project" />
|
||||||
|
<item index="18" class="java.lang.String" itemvalue="tensorboard-data-server" />
|
||||||
|
<item index="19" class="java.lang.String" itemvalue="typing-extensions" />
|
||||||
|
<item index="20" class="java.lang.String" itemvalue="click" />
|
||||||
|
<item index="21" class="java.lang.String" itemvalue="regex" />
|
||||||
|
<item index="22" class="java.lang.String" itemvalue="fastcache" />
|
||||||
|
<item index="23" class="java.lang.String" itemvalue="tensorboard" />
|
||||||
|
<item index="24" class="java.lang.String" itemvalue="imageio" />
|
||||||
|
<item index="25" class="java.lang.String" itemvalue="pytest-remotedata" />
|
||||||
|
<item index="26" class="java.lang.String" itemvalue="matplotlib" />
|
||||||
|
<item index="27" class="java.lang.String" itemvalue="idna" />
|
||||||
|
<item index="28" class="java.lang.String" itemvalue="Bottleneck" />
|
||||||
|
<item index="29" class="java.lang.String" itemvalue="rsa" />
|
||||||
|
<item index="30" class="java.lang.String" itemvalue="networkx" />
|
||||||
|
<item index="31" class="java.lang.String" itemvalue="pycurl" />
|
||||||
|
<item index="32" class="java.lang.String" itemvalue="smmap" />
|
||||||
|
<item index="33" class="java.lang.String" itemvalue="pluggy" />
|
||||||
|
<item index="34" class="java.lang.String" itemvalue="cffi" />
|
||||||
|
<item index="35" class="java.lang.String" itemvalue="pep8" />
|
||||||
|
<item index="36" class="java.lang.String" itemvalue="numpy" />
|
||||||
|
<item index="37" class="java.lang.String" itemvalue="jdcal" />
|
||||||
|
<item index="38" class="java.lang.String" itemvalue="alabaster" />
|
||||||
|
<item index="39" class="java.lang.String" itemvalue="jupyter" />
|
||||||
|
<item index="40" class="java.lang.String" itemvalue="pyOpenSSL" />
|
||||||
|
<item index="41" class="java.lang.String" itemvalue="PyWavelets" />
|
||||||
|
<item index="42" class="java.lang.String" itemvalue="prompt-toolkit" />
|
||||||
|
<item index="43" class="java.lang.String" itemvalue="QtAwesome" />
|
||||||
|
<item index="44" class="java.lang.String" itemvalue="msgpack-python" />
|
||||||
|
<item index="45" class="java.lang.String" itemvalue="Flask-Cors" />
|
||||||
|
<item index="46" class="java.lang.String" itemvalue="glob2" />
|
||||||
|
<item index="47" class="java.lang.String" itemvalue="Send2Trash" />
|
||||||
|
<item index="48" class="java.lang.String" itemvalue="imagesize" />
|
||||||
|
<item index="49" class="java.lang.String" itemvalue="et-xmlfile" />
|
||||||
|
<item index="50" class="java.lang.String" itemvalue="pathlib2" />
|
||||||
|
<item index="51" class="java.lang.String" itemvalue="docker-pycreds" />
|
||||||
|
<item index="52" class="java.lang.String" itemvalue="importlib-resources" />
|
||||||
|
<item index="53" class="java.lang.String" itemvalue="pathtools" />
|
||||||
|
<item index="54" class="java.lang.String" itemvalue="spyder" />
|
||||||
|
<item index="55" class="java.lang.String" itemvalue="pylint" />
|
||||||
|
<item index="56" class="java.lang.String" itemvalue="statsmodels" />
|
||||||
|
<item index="57" class="java.lang.String" itemvalue="tensorboardX" />
|
||||||
|
<item index="58" class="java.lang.String" itemvalue="isort" />
|
||||||
|
<item index="59" class="java.lang.String" itemvalue="ruamel_yaml" />
|
||||||
|
<item index="60" class="java.lang.String" itemvalue="pytz" />
|
||||||
|
<item index="61" class="java.lang.String" itemvalue="unicodecsv" />
|
||||||
|
<item index="62" class="java.lang.String" itemvalue="pytest-astropy" />
|
||||||
|
<item index="63" class="java.lang.String" itemvalue="traitlets" />
|
||||||
|
<item index="64" class="java.lang.String" itemvalue="absl-py" />
|
||||||
|
<item index="65" class="java.lang.String" itemvalue="protobuf" />
|
||||||
|
<item index="66" class="java.lang.String" itemvalue="nltk" />
|
||||||
|
<item index="67" class="java.lang.String" itemvalue="partd" />
|
||||||
|
<item index="68" class="java.lang.String" itemvalue="promise" />
|
||||||
|
<item index="69" class="java.lang.String" itemvalue="gast" />
|
||||||
|
<item index="70" class="java.lang.String" itemvalue="filelock" />
|
||||||
|
<item index="71" class="java.lang.String" itemvalue="numpydoc" />
|
||||||
|
<item index="72" class="java.lang.String" itemvalue="pyzmq" />
|
||||||
|
<item index="73" class="java.lang.String" itemvalue="oauthlib" />
|
||||||
|
<item index="74" class="java.lang.String" itemvalue="astropy" />
|
||||||
|
<item index="75" class="java.lang.String" itemvalue="keras" />
|
||||||
|
<item index="76" class="java.lang.String" itemvalue="entrypoints" />
|
||||||
|
<item index="77" class="java.lang.String" itemvalue="bkcharts" />
|
||||||
|
<item index="78" class="java.lang.String" itemvalue="pyparsing" />
|
||||||
|
<item index="79" class="java.lang.String" itemvalue="munch" />
|
||||||
|
<item index="80" class="java.lang.String" itemvalue="sphinxcontrib-websupport" />
|
||||||
|
<item index="81" class="java.lang.String" itemvalue="beautifulsoup4" />
|
||||||
|
<item index="82" class="java.lang.String" itemvalue="path.py" />
|
||||||
|
<item index="83" class="java.lang.String" itemvalue="clyent" />
|
||||||
|
<item index="84" class="java.lang.String" itemvalue="navigator-updater" />
|
||||||
|
<item index="85" class="java.lang.String" itemvalue="tifffile" />
|
||||||
|
<item index="86" class="java.lang.String" itemvalue="cryptography" />
|
||||||
|
<item index="87" class="java.lang.String" itemvalue="pygdal" />
|
||||||
|
<item index="88" class="java.lang.String" itemvalue="fastrlock" />
|
||||||
|
<item index="89" class="java.lang.String" itemvalue="widgetsnbextension" />
|
||||||
|
<item index="90" class="java.lang.String" itemvalue="multipledispatch" />
|
||||||
|
<item index="91" class="java.lang.String" itemvalue="numexpr" />
|
||||||
|
<item index="92" class="java.lang.String" itemvalue="jupyter-core" />
|
||||||
|
<item index="93" class="java.lang.String" itemvalue="ipython_genutils" />
|
||||||
|
<item index="94" class="java.lang.String" itemvalue="yapf" />
|
||||||
|
<item index="95" class="java.lang.String" itemvalue="rope" />
|
||||||
|
<item index="96" class="java.lang.String" itemvalue="wcwidth" />
|
||||||
|
<item index="97" class="java.lang.String" itemvalue="cupy-cuda110" />
|
||||||
|
<item index="98" class="java.lang.String" itemvalue="llvmlite" />
|
||||||
|
<item index="99" class="java.lang.String" itemvalue="Jinja2" />
|
||||||
|
<item index="100" class="java.lang.String" itemvalue="pycrypto" />
|
||||||
|
<item index="101" class="java.lang.String" itemvalue="Keras-Preprocessing" />
|
||||||
|
<item index="102" class="java.lang.String" itemvalue="ptflops" />
|
||||||
|
<item index="103" class="java.lang.String" itemvalue="cupy-cuda111" />
|
||||||
|
<item index="104" class="java.lang.String" itemvalue="cupy-cuda114" />
|
||||||
|
<item index="105" class="java.lang.String" itemvalue="some-package" />
|
||||||
|
<item index="106" class="java.lang.String" itemvalue="wandb" />
|
||||||
|
<item index="107" class="java.lang.String" itemvalue="netaddr" />
|
||||||
|
<item index="108" class="java.lang.String" itemvalue="sortedcollections" />
|
||||||
|
<item index="109" class="java.lang.String" itemvalue="six" />
|
||||||
|
<item index="110" class="java.lang.String" itemvalue="timm" />
|
||||||
|
<item index="111" class="java.lang.String" itemvalue="pyflakes" />
|
||||||
|
<item index="112" class="java.lang.String" itemvalue="asn1crypto" />
|
||||||
|
<item index="113" class="java.lang.String" itemvalue="parso" />
|
||||||
|
<item index="114" class="java.lang.String" itemvalue="pytest-doctestplus" />
|
||||||
|
<item index="115" class="java.lang.String" itemvalue="ipython" />
|
||||||
|
<item index="116" class="java.lang.String" itemvalue="xlwt" />
|
||||||
|
<item index="117" class="java.lang.String" itemvalue="packaging" />
|
||||||
|
<item index="118" class="java.lang.String" itemvalue="chardet" />
|
||||||
|
<item index="119" class="java.lang.String" itemvalue="jupyterlab-launcher" />
|
||||||
|
<item index="120" class="java.lang.String" itemvalue="click-plugins" />
|
||||||
|
<item index="121" class="java.lang.String" itemvalue="PyYAML" />
|
||||||
|
<item index="122" class="java.lang.String" itemvalue="pickleshare" />
|
||||||
|
<item index="123" class="java.lang.String" itemvalue="pycparser" />
|
||||||
|
<item index="124" class="java.lang.String" itemvalue="pyasn1-modules" />
|
||||||
|
<item index="125" class="java.lang.String" itemvalue="tables" />
|
||||||
|
<item index="126" class="java.lang.String" itemvalue="Pygments" />
|
||||||
|
<item index="127" class="java.lang.String" itemvalue="sentry-sdk" />
|
||||||
|
<item index="128" class="java.lang.String" itemvalue="docutils" />
|
||||||
|
<item index="129" class="java.lang.String" itemvalue="gevent" />
|
||||||
|
<item index="130" class="java.lang.String" itemvalue="shortuuid" />
|
||||||
|
<item index="131" class="java.lang.String" itemvalue="qtconsole" />
|
||||||
|
<item index="132" class="java.lang.String" itemvalue="terminado" />
|
||||||
|
<item index="133" class="java.lang.String" itemvalue="GitPython" />
|
||||||
|
<item index="134" class="java.lang.String" itemvalue="distributed" />
|
||||||
|
<item index="135" class="java.lang.String" itemvalue="jupyter-client" />
|
||||||
|
<item index="136" class="java.lang.String" itemvalue="pexpect" />
|
||||||
|
<item index="137" class="java.lang.String" itemvalue="ipykernel" />
|
||||||
|
<item index="138" class="java.lang.String" itemvalue="nbconvert" />
|
||||||
|
<item index="139" class="java.lang.String" itemvalue="attrs" />
|
||||||
|
<item index="140" class="java.lang.String" itemvalue="psutil" />
|
||||||
|
<item index="141" class="java.lang.String" itemvalue="simplejson" />
|
||||||
|
<item index="142" class="java.lang.String" itemvalue="jedi" />
|
||||||
|
<item index="143" class="java.lang.String" itemvalue="flatbuffers" />
|
||||||
|
<item index="144" class="java.lang.String" itemvalue="cytoolz" />
|
||||||
|
<item index="145" class="java.lang.String" itemvalue="odo" />
|
||||||
|
<item index="146" class="java.lang.String" itemvalue="decorator" />
|
||||||
|
<item index="147" class="java.lang.String" itemvalue="pandocfilters" />
|
||||||
|
<item index="148" class="java.lang.String" itemvalue="backports.shutil-get-terminal-size" />
|
||||||
|
<item index="149" class="java.lang.String" itemvalue="pycodestyle" />
|
||||||
|
<item index="150" class="java.lang.String" itemvalue="pycosat" />
|
||||||
|
<item index="151" class="java.lang.String" itemvalue="pyasn1" />
|
||||||
|
<item index="152" class="java.lang.String" itemvalue="requests" />
|
||||||
|
<item index="153" class="java.lang.String" itemvalue="bitarray" />
|
||||||
|
<item index="154" class="java.lang.String" itemvalue="kornia" />
|
||||||
|
<item index="155" class="java.lang.String" itemvalue="mkl-fft" />
|
||||||
|
<item index="156" class="java.lang.String" itemvalue="tensorflow" />
|
||||||
|
<item index="157" class="java.lang.String" itemvalue="XlsxWriter" />
|
||||||
|
<item index="158" class="java.lang.String" itemvalue="seaborn" />
|
||||||
|
<item index="159" class="java.lang.String" itemvalue="tensorboard-plugin-wit" />
|
||||||
|
<item index="160" class="java.lang.String" itemvalue="blaze" />
|
||||||
|
<item index="161" class="java.lang.String" itemvalue="zipp" />
|
||||||
|
<item index="162" class="java.lang.String" itemvalue="pkginfo" />
|
||||||
|
<item index="163" class="java.lang.String" itemvalue="cached-property" />
|
||||||
|
<item index="164" class="java.lang.String" itemvalue="torchstat" />
|
||||||
|
<item index="165" class="java.lang.String" itemvalue="datashape" />
|
||||||
|
<item index="166" class="java.lang.String" itemvalue="itsdangerous" />
|
||||||
|
<item index="167" class="java.lang.String" itemvalue="ipywidgets" />
|
||||||
|
<item index="168" class="java.lang.String" itemvalue="scipy" />
|
||||||
|
<item index="169" class="java.lang.String" itemvalue="thop" />
|
||||||
|
<item index="170" class="java.lang.String" itemvalue="tornado" />
|
||||||
|
<item index="171" class="java.lang.String" itemvalue="google-auth-oauthlib" />
|
||||||
|
<item index="172" class="java.lang.String" itemvalue="opencv-python" />
|
||||||
|
<item index="173" class="java.lang.String" itemvalue="torch" />
|
||||||
|
<item index="174" class="java.lang.String" itemvalue="singledispatch" />
|
||||||
|
<item index="175" class="java.lang.String" itemvalue="sortedcontainers" />
|
||||||
|
<item index="176" class="java.lang.String" itemvalue="mistune" />
|
||||||
|
<item index="177" class="java.lang.String" itemvalue="pandas" />
|
||||||
|
<item index="178" class="java.lang.String" itemvalue="termcolor" />
|
||||||
|
<item index="179" class="java.lang.String" itemvalue="clang" />
|
||||||
|
<item index="180" class="java.lang.String" itemvalue="toolz" />
|
||||||
|
<item index="181" class="java.lang.String" itemvalue="Sphinx" />
|
||||||
|
<item index="182" class="java.lang.String" itemvalue="mpmath" />
|
||||||
|
<item index="183" class="java.lang.String" itemvalue="jupyter-console" />
|
||||||
|
<item index="184" class="java.lang.String" itemvalue="bokeh" />
|
||||||
|
<item index="185" class="java.lang.String" itemvalue="cachetools" />
|
||||||
|
<item index="186" class="java.lang.String" itemvalue="gmpy2" />
|
||||||
|
<item index="187" class="java.lang.String" itemvalue="setproctitle" />
|
||||||
|
<item index="188" class="java.lang.String" itemvalue="webencodings" />
|
||||||
|
<item index="189" class="java.lang.String" itemvalue="html5lib" />
|
||||||
|
<item index="190" class="java.lang.String" itemvalue="colorlog" />
|
||||||
|
<item index="191" class="java.lang.String" itemvalue="python-dateutil" />
|
||||||
|
<item index="192" class="java.lang.String" itemvalue="QtPy" />
|
||||||
|
<item index="193" class="java.lang.String" itemvalue="astroid" />
|
||||||
|
<item index="194" class="java.lang.String" itemvalue="cycler" />
|
||||||
|
<item index="195" class="java.lang.String" itemvalue="mkl-random" />
|
||||||
|
<item index="196" class="java.lang.String" itemvalue="pytest-arraydiff" />
|
||||||
|
<item index="197" class="java.lang.String" itemvalue="locket" />
|
||||||
|
<item index="198" class="java.lang.String" itemvalue="heapdict" />
|
||||||
|
<item index="199" class="java.lang.String" itemvalue="snowballstemmer" />
|
||||||
|
<item index="200" class="java.lang.String" itemvalue="contextlib2" />
|
||||||
|
<item index="201" class="java.lang.String" itemvalue="certifi" />
|
||||||
|
<item index="202" class="java.lang.String" itemvalue="Markdown" />
|
||||||
|
<item index="203" class="java.lang.String" itemvalue="sympy" />
|
||||||
|
<item index="204" class="java.lang.String" itemvalue="notebook" />
|
||||||
|
<item index="205" class="java.lang.String" itemvalue="pyodbc" />
|
||||||
|
<item index="206" class="java.lang.String" itemvalue="boto" />
|
||||||
|
<item index="207" class="java.lang.String" itemvalue="cligj" />
|
||||||
|
<item index="208" class="java.lang.String" itemvalue="h5py" />
|
||||||
|
<item index="209" class="java.lang.String" itemvalue="wrapt" />
|
||||||
|
<item index="210" class="java.lang.String" itemvalue="kiwisolver" />
|
||||||
|
<item index="211" class="java.lang.String" itemvalue="pytest-openfiles" />
|
||||||
|
<item index="212" class="java.lang.String" itemvalue="anaconda-client" />
|
||||||
|
<item index="213" class="java.lang.String" itemvalue="backcall" />
|
||||||
|
<item index="214" class="java.lang.String" itemvalue="PySocks" />
|
||||||
|
<item index="215" class="java.lang.String" itemvalue="charset-normalizer" />
|
||||||
|
<item index="216" class="java.lang.String" itemvalue="typing" />
|
||||||
|
<item index="217" class="java.lang.String" itemvalue="dask" />
|
||||||
|
<item index="218" class="java.lang.String" itemvalue="enum34" />
|
||||||
|
<item index="219" class="java.lang.String" itemvalue="torchsummary" />
|
||||||
|
<item index="220" class="java.lang.String" itemvalue="scikit-image" />
|
||||||
|
<item index="221" class="java.lang.String" itemvalue="ptyprocess" />
|
||||||
|
<item index="222" class="java.lang.String" itemvalue="more-itertools" />
|
||||||
|
<item index="223" class="java.lang.String" itemvalue="SQLAlchemy" />
|
||||||
|
<item index="224" class="java.lang.String" itemvalue="tblib" />
|
||||||
|
<item index="225" class="java.lang.String" itemvalue="cloudpickle" />
|
||||||
|
<item index="226" class="java.lang.String" itemvalue="importlib-metadata" />
|
||||||
|
<item index="227" class="java.lang.String" itemvalue="simplegeneric" />
|
||||||
|
<item index="228" class="java.lang.String" itemvalue="zict" />
|
||||||
|
<item index="229" class="java.lang.String" itemvalue="urllib3" />
|
||||||
|
<item index="230" class="java.lang.String" itemvalue="jupyterlab" />
|
||||||
|
<item index="231" class="java.lang.String" itemvalue="Cython" />
|
||||||
|
<item index="232" class="java.lang.String" itemvalue="Flask" />
|
||||||
|
<item index="233" class="java.lang.String" itemvalue="nose" />
|
||||||
|
<item index="234" class="java.lang.String" itemvalue="pytorch-msssim" />
|
||||||
|
<item index="235" class="java.lang.String" itemvalue="pytest" />
|
||||||
|
<item index="236" class="java.lang.String" itemvalue="nbformat" />
|
||||||
|
<item index="237" class="java.lang.String" itemvalue="matmul" />
|
||||||
|
<item index="238" class="java.lang.String" itemvalue="tqdm" />
|
||||||
|
<item index="239" class="java.lang.String" itemvalue="lazy-object-proxy" />
|
||||||
|
<item index="240" class="java.lang.String" itemvalue="colorama" />
|
||||||
|
<item index="241" class="java.lang.String" itemvalue="grpcio" />
|
||||||
|
<item index="242" class="java.lang.String" itemvalue="ply" />
|
||||||
|
<item index="243" class="java.lang.String" itemvalue="google-auth" />
|
||||||
|
<item index="244" class="java.lang.String" itemvalue="openpyxl" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
13
.idea/misc.xml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="MavenImportPreferences">
|
||||||
|
<option name="generalSettings">
|
||||||
|
<MavenGeneralSettings>
|
||||||
|
<option name="localRepository" value="E:\maven\repository" />
|
||||||
|
<option name="showDialogWithAdvancedSettings" value="true" />
|
||||||
|
<option name="userSettingsFile" value="E:\maven\settings.xml" />
|
||||||
|
</MavenGeneralSettings>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.12.4 (sftp://star@192.168.50.108:22/home/star/anaconda3/bin/python)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
8
.idea/modules.xml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/CDDFuse.iml" filepath="$PROJECT_DIR$/.idea/CDDFuse.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/vcs.xml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
1
MSRS_train/readme.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it here.
|
186
README.md
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
# CDDFuse
|
||||||
|
Codes for ***CDDFuse: Correlation-Driven Dual-Branch Feature Decomposition for Multi-Modality Image Fusion. (CVPR 2023)***
|
||||||
|
|
||||||
|
[Zixiang Zhao](https://zhaozixiang1228.github.io/), [Haowen Bai](), [Jiangshe Zhang](http://gr.xjtu.edu.cn/web/jszhang), [Yulun Zhang](https://yulunzhang.com/), [Shuang Xu](https://shuangxu96.github.io/), [Zudi Lin](https://zudi-lin.github.io/), [Radu Timofte](https://www.informatik.uni-wuerzburg.de/computervision/home/) and [Luc Van Gool](https://vision.ee.ethz.ch/people-details.OTAyMzM=.TGlzdC8zMjQ4LC0xOTcxNDY1MTc4.html).
|
||||||
|
|
||||||
|
-[*[Paper]*](https://openaccess.thecvf.com/content/CVPR2023/html/Zhao_CDDFuse_Correlation-Driven_Dual-Branch_Feature_Decomposition_for_Multi-Modality_Image_Fusion_CVPR_2023_paper.html)
|
||||||
|
-[*[ArXiv]*](https://arxiv.org/abs/2104.06977)
|
||||||
|
-[*[Supplementary Materials]*](https://openaccess.thecvf.com/content/CVPR2023/supplemental/Zhao_CDDFuse_Correlation-Driven_Dual-Branch_CVPR_2023_supplemental.pdf)
|
||||||
|
|
||||||
|
|
||||||
|
## Update
|
||||||
|
- [2023/6] Training codes and config files are public available.
|
||||||
|
- [2023/4] Release inference code for infrared-visible image fusion and medical image fusion.
|
||||||
|
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
```
|
||||||
|
@InProceedings{Zhao_2023_CVPR,
|
||||||
|
author = {Zhao, Zixiang and Bai, Haowen and Zhang, Jiangshe and Zhang, Yulun and Xu, Shuang and Lin, Zudi and Timofte, Radu and Van Gool, Luc},
|
||||||
|
title = {CDDFuse: Correlation-Driven Dual-Branch Feature Decomposition for Multi-Modality Image Fusion},
|
||||||
|
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
month = {June},
|
||||||
|
year = {2023},
|
||||||
|
pages = {5906-5916}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Abstract
|
||||||
|
|
||||||
|
Multi-modality (MM) image fusion aims to render fused images that maintain the merits of different modalities, e.g., functional highlight and detailed textures. To tackle the challenge in modeling cross-modality features and decomposing desirable modality-specific and modality-shared features, we propose a novel Correlation-Driven feature Decomposition Fusion (CDDFuse) network. Firstly, CDDFuse uses Restormer blocks to extract cross-modality shallow features. We then introduce a dual-branch Transformer-CNN feature extractor with Lite Transformer (LT) blocks leveraging long-range attention to handle low-frequency global features and Invertible Neural Networks (INN) blocks focusing on extracting high-frequency local information. A correlation-driven loss is further proposed to make the low-frequency features correlated while the high-frequency features uncorrelated based on the embedded information. Then, the LT-based global fusion and INN-based local fusion layers output the fused image. Extensive experiments demonstrate that our CDDFuse achieves promising results in multiple fusion tasks, including infrared-visible image fusion and medical image fusion. We also show that CDDFuse can boost the performance in downstream infrared-visible semantic segmentation and object detection in a unified benchmark.
|
||||||
|
|
||||||
|
## 🌐 Usage
|
||||||
|
|
||||||
|
### ⚙ Network Architecture
|
||||||
|
|
||||||
|
Our CDDFuse is implemented in ``net.py``.
|
||||||
|
|
||||||
|
### 🏊 Training
|
||||||
|
**1. Virtual Environment**
|
||||||
|
```
|
||||||
|
# create virtual environment
|
||||||
|
conda create -n cddfuse python=3.8.10
|
||||||
|
conda activate cddfuse
|
||||||
|
# select pytorch version yourself
|
||||||
|
# install cddfuse requirements
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Data Preparation**
|
||||||
|
|
||||||
|
Download the MSRS dataset from [this link](https://github.com/Linfeng-Tang/MSRS) and place it in the folder ``'./MSRS_train/'``.
|
||||||
|
|
||||||
|
**3. Pre-Processing**
|
||||||
|
|
||||||
|
Run
|
||||||
|
```
|
||||||
|
python dataprocessing.py
|
||||||
|
```
|
||||||
|
and the processed training dataset is in ``'./data/MSRS_train_imgsize_128_stride_200.h5'``.
|
||||||
|
|
||||||
|
**4. CDDFuse Training**
|
||||||
|
|
||||||
|
Run
|
||||||
|
```
|
||||||
|
python train.py
|
||||||
|
```
|
||||||
|
and the trained model is available in ``'./models/'``.
|
||||||
|
|
||||||
|
### 🏄 Testing
|
||||||
|
|
||||||
|
**1. Pretrained models**
|
||||||
|
|
||||||
|
Pretrained models are available in ``'./models/CDDFuse_IVF.pth'`` and ``'./models/CDDFuse_MIF.pth'``, which are responsible for the Infrared-Visible Fusion (IVF) and Medical Image Fusion (MIF) tasks, respectively.
|
||||||
|
|
||||||
|
**2. Test datasets**
|
||||||
|
|
||||||
|
The test datasets used in the paper have been stored in ``'./test_img/RoadScene'``, ``'./test_img/TNO'`` for IVF, ``'./test_img/MRI_CT'``, ``'./test_img/MRI_PET'`` and ``'./test_img/MRI_SPECT'`` for MIF.
|
||||||
|
|
||||||
|
Unfortunately, since the size of **MSRS dataset** for IVF is 500+MB, we can not upload it for exhibition. It can be downloaded via [this link](https://github.com/Linfeng-Tang/MSRS). The other datasets contain all the test images.
|
||||||
|
|
||||||
|
**3. Results in Our Paper**
|
||||||
|
|
||||||
|
If you want to infer with our CDDFuse and obtain the fusion results in our paper, please run
|
||||||
|
```
|
||||||
|
python test_IVF.py
|
||||||
|
```
|
||||||
|
for Infrared-Visible Fusion and
|
||||||
|
```
|
||||||
|
python test_MIF.py
|
||||||
|
```
|
||||||
|
for Medical Image Fusion.
|
||||||
|
|
||||||
|
The testing results will be printed in the terminal.
|
||||||
|
|
||||||
|
The output for ``'test_IVF.py'`` is:
|
||||||
|
|
||||||
|
```
|
||||||
|
================================================================================
|
||||||
|
The test result of TNO :
|
||||||
|
EN SD SF MI SCD VIF Qabf SSIM
|
||||||
|
CDDFuse 7.12 46.0 13.15 2.19 1.76 0.77 0.54 1.03
|
||||||
|
================================================================================
|
||||||
|
|
||||||
|
================================================================================
|
||||||
|
The test result of RoadScene :
|
||||||
|
EN SD SF MI SCD VIF Qabf SSIM
|
||||||
|
CDDFuse 7.44 54.67 16.36 2.3 1.81 0.69 0.52 0.98
|
||||||
|
================================================================================
|
||||||
|
```
|
||||||
|
which can match the results in Table 1 in our original paper.
|
||||||
|
|
||||||
|
The output for ``'test_MIF.py'`` is:
|
||||||
|
|
||||||
|
```
|
||||||
|
================================================================================
|
||||||
|
The test result of MRI_CT :
|
||||||
|
EN SD SF MI SCD VIF Qabf SSIM
|
||||||
|
CDDFuse_IVF 4.83 88.59 33.83 2.24 1.74 0.5 0.59 1.31
|
||||||
|
CDDFuse_MIF 4.88 79.17 38.14 2.61 1.41 0.61 0.68 1.34
|
||||||
|
================================================================================
|
||||||
|
|
||||||
|
================================================================================
|
||||||
|
The test result of MRI_PET :
|
||||||
|
EN SD SF MI SCD VIF Qabf SSIM
|
||||||
|
CDDFuse_IVF 4.23 81.69 28.04 1.87 1.82 0.66 0.65 1.46
|
||||||
|
CDDFuse_MIF 4.22 70.74 29.57 2.03 1.69 0.71 0.71 1.49
|
||||||
|
================================================================================
|
||||||
|
|
||||||
|
================================================================================
|
||||||
|
The test result of MRI_SPECT :
|
||||||
|
EN SD SF MI SCD VIF Qabf SSIM
|
||||||
|
CDDFuse_IVF 3.91 71.81 20.66 1.9 1.87 0.65 0.68 1.45
|
||||||
|
CDDFuse_MIF 3.9 58.31 20.87 2.49 1.35 0.97 0.78 1.48
|
||||||
|
================================================================================
|
||||||
|
```
|
||||||
|
which can match the results in Table 5 in our original paper.
|
||||||
|
|
||||||
|
## 🙌 CDDFuse
|
||||||
|
|
||||||
|
### Illustration of our CDDFuse model.
|
||||||
|
|
||||||
|
<img src="image//Workflow.png" width="90%" align=center />
|
||||||
|
|
||||||
|
### Qualitative fusion results.
|
||||||
|
|
||||||
|
<img src="image//IVF1.png" width="90%" align=center />
|
||||||
|
|
||||||
|
<img src="image//IVF2.png" width="90%" align=center />
|
||||||
|
|
||||||
|
<img src="image//MIF.png" width="60%" align=center />
|
||||||
|
|
||||||
|
### Quantitative fusion results.
|
||||||
|
|
||||||
|
Infrared-Visible Image Fusion
|
||||||
|
|
||||||
|
<img src="image//Quantitative_IVF.png" width="60%" align=center />
|
||||||
|
|
||||||
|
Medical Image Fusion
|
||||||
|
|
||||||
|
<img src="image//Quantitative_MIF.png" width="60%" align=center />
|
||||||
|
|
||||||
|
MM detection
|
||||||
|
|
||||||
|
<img src="image//MMDet.png" width="60%" align=center />
|
||||||
|
|
||||||
|
MM segmentation
|
||||||
|
|
||||||
|
<img src="image//MMSeg.png" width="60%" align=center />
|
||||||
|
|
||||||
|
|
||||||
|
## 📖 Related Work
|
||||||
|
|
||||||
|
- Zixiang Zhao, Haowen Bai, Jiangshe Zhang, Yulun Zhang, Kai Zhang, Shuang Xu, Dongdong Chen, Radu Timofte, Luc Van Gool. *Equivariant Multi-Modality Image Fusion.* **CVPR 2024**, https://arxiv.org/abs/2305.11443
|
||||||
|
|
||||||
|
- Zixiang Zhao, Haowen Bai, Yuanzhi Zhu, Jiangshe Zhang, Shuang Xu, Yulun Zhang, Kai Zhang, Deyu Meng, Radu Timofte, Luc Van Gool.
|
||||||
|
*DDFM: Denoising Diffusion Model for Multi-Modality Image Fusion.* **ICCV 2023 (Oral)**, https://arxiv.org/abs/2303.06840
|
||||||
|
|
||||||
|
- Zixiang Zhao, Shuang Xu, Chunxia Zhang, Junmin Liu, Jiangshe Zhang and Pengfei Li. *DIDFuse: Deep Image Decomposition for Infrared and Visible Image Fusion.* **IJCAI 2020**, https://www.ijcai.org/Proceedings/2020/135.
|
||||||
|
|
||||||
|
- Zixiang Zhao, Shuang Xu, Jiangshe Zhang, Chengyang Liang, Chunxia Zhang and Junmin Liu. *Efficient and Model-Based Infrared and Visible Image Fusion via Algorithm Unrolling.* **IEEE Transactions on Circuits and Systems for Video Technology 2021**, https://ieeexplore.ieee.org/document/9416456.
|
||||||
|
|
||||||
|
- Zixiang Zhao, Jiangshe Zhang, Haowen Bai, Yicheng Wang, Yukun Cui, Lilun Deng, Kai Sun, Chunxia Zhang, Junmin Liu, Shuang Xu. *Deep Convolutional Sparse Coding Networks for Interpretable Image Fusion.* **CVPR Workshop 2023**. https://robustart.github.io/long_paper/26.pdf.
|
||||||
|
|
||||||
|
- Zixiang Zhao, Shuang Xu, Chunxia Zhang, Junmin Liu, Jiangshe Zhang. *Bayesian fusion for infrared and visible images.* **Signal Processing**, https://doi.org/10.1016/j.sigpro.2020.107734.
|
||||||
|
|
93
dataprocessing.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from skimage.io import imread
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_file(file_name):
|
||||||
|
imagelist = []
|
||||||
|
for parent, dirnames, filenames in os.walk(file_name):
|
||||||
|
for filename in filenames:
|
||||||
|
if filename.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff', '.npy')):
|
||||||
|
imagelist.append(os.path.join(parent, filename))
|
||||||
|
return imagelist
|
||||||
|
|
||||||
|
def rgb2y(img):
|
||||||
|
y = img[0:1, :, :] * 0.299000 + img[1:2, :, :] * 0.587000 + img[2:3, :, :] * 0.114000
|
||||||
|
return y
|
||||||
|
|
||||||
|
def Im2Patch(img, win, stride=1):
|
||||||
|
k = 0
|
||||||
|
endc = img.shape[0]
|
||||||
|
endw = img.shape[1]
|
||||||
|
endh = img.shape[2]
|
||||||
|
patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
|
||||||
|
TotalPatNum = patch.shape[1] * patch.shape[2]
|
||||||
|
Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
|
||||||
|
for i in range(win):
|
||||||
|
for j in range(win):
|
||||||
|
patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
|
||||||
|
Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
|
||||||
|
k = k + 1
|
||||||
|
return Y.reshape([endc, win, win, TotalPatNum])
|
||||||
|
|
||||||
|
def is_low_contrast(image, fraction_threshold=0.1, lower_percentile=10,
|
||||||
|
upper_percentile=90):
|
||||||
|
"""Determine if an image is low contrast."""
|
||||||
|
limits = np.percentile(image, [lower_percentile, upper_percentile])
|
||||||
|
ratio = (limits[1] - limits[0]) / limits[1]
|
||||||
|
return ratio < fraction_threshold
|
||||||
|
|
||||||
|
data_name="MSRS_train"
|
||||||
|
img_size=128 #patch size
|
||||||
|
stride=200 #patch stride
|
||||||
|
|
||||||
|
IR_files = sorted(get_img_file(r"MSRS_train/ir"))
|
||||||
|
VIS_files = sorted(get_img_file(r"MSRS_train/vi"))
|
||||||
|
|
||||||
|
assert len(IR_files) == len(VIS_files)
|
||||||
|
h5f = h5py.File(os.path.join('.\\data',
|
||||||
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),
|
||||||
|
'w')
|
||||||
|
h5_ir = h5f.create_group('ir_patchs')
|
||||||
|
h5_vis = h5f.create_group('vis_patchs')
|
||||||
|
train_num=0
|
||||||
|
for i in tqdm(range(len(IR_files))):
|
||||||
|
I_VIS = imread(VIS_files[i]).astype(np.float32).transpose(2,0,1)/255. # [3, H, W] Uint8->float32
|
||||||
|
I_VIS = rgb2y(I_VIS) # [1, H, W] Float32
|
||||||
|
I_IR = imread(IR_files[i]).astype(np.float32)[None, :, :]/255. # [1, H, W] Float32
|
||||||
|
|
||||||
|
# crop
|
||||||
|
I_IR_Patch_Group = Im2Patch(I_IR,img_size,stride)
|
||||||
|
I_VIS_Patch_Group = Im2Patch(I_VIS, img_size, stride) # (3, 256, 256, 12)
|
||||||
|
|
||||||
|
for ii in range(I_IR_Patch_Group.shape[-1]):
|
||||||
|
bad_IR = is_low_contrast(I_IR_Patch_Group[0,:,:,ii])
|
||||||
|
bad_VIS = is_low_contrast(I_VIS_Patch_Group[0,:,:,ii])
|
||||||
|
# Determine if the contrast is low
|
||||||
|
if not (bad_IR or bad_VIS):
|
||||||
|
avl_IR= I_IR_Patch_Group[0,:,:,ii] # available IR
|
||||||
|
avl_VIS= I_VIS_Patch_Group[0,:,:,ii]
|
||||||
|
avl_IR=avl_IR[None,...]
|
||||||
|
avl_VIS=avl_VIS[None,...]
|
||||||
|
|
||||||
|
h5_ir.create_dataset(str(train_num), data=avl_IR,
|
||||||
|
dtype=avl_IR.dtype, shape=avl_IR.shape)
|
||||||
|
h5_vis.create_dataset(str(train_num), data=avl_VIS,
|
||||||
|
dtype=avl_VIS.dtype, shape=avl_VIS.shape)
|
||||||
|
train_num += 1
|
||||||
|
|
||||||
|
h5f.close()
|
||||||
|
|
||||||
|
with h5py.File(os.path.join('data',
|
||||||
|
data_name+'_imgsize_'+str(img_size)+"_stride_"+str(stride)+'.h5'),"r") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
print(f[key], key, f[key].name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
BIN
image/IVF1.png
Normal file
After Width: | Height: | Size: 868 KiB |
BIN
image/IVF2.png
Normal file
After Width: | Height: | Size: 1008 KiB |
BIN
image/MIF.png
Normal file
After Width: | Height: | Size: 1.1 MiB |
BIN
image/MMDet.png
Normal file
After Width: | Height: | Size: 122 KiB |
BIN
image/MMSeg.png
Normal file
After Width: | Height: | Size: 126 KiB |
BIN
image/Quantitative_IVF.png
Normal file
After Width: | Height: | Size: 378 KiB |
BIN
image/Quantitative_MIF.png
Normal file
After Width: | Height: | Size: 417 KiB |
BIN
image/Workflow.png
Normal file
After Width: | Height: | Size: 249 KiB |
BIN
models/CDDFuse_IVF.pth
Normal file
BIN
models/CDDFuse_MIF.pth
Normal file
403
net.py
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||||
|
"""
|
||||||
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||||
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||||
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||||
|
'survival rate' as the argument.
|
||||||
|
"""
|
||||||
|
if drop_prob == 0. or not training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
# work with diff dim tensors, not just 2D ConvNets
|
||||||
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||||
|
random_tensor = keep_prob + \
|
||||||
|
torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||||
|
random_tensor.floor_() # binarize
|
||||||
|
output = x.div(keep_prob) * random_tensor
|
||||||
|
return output
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""
|
||||||
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, drop_prob=None):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBase(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
num_heads=8,
|
||||||
|
qkv_bias=False,):
|
||||||
|
super(AttentionBase, self).__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = nn.Parameter(torch.ones(num_heads, 1, 1))
|
||||||
|
self.qkv1 = nn.Conv2d(dim, dim*3, kernel_size=1, bias=qkv_bias)
|
||||||
|
self.qkv2 = nn.Conv2d(dim*3, dim*3, kernel_size=3, padding=1, bias=qkv_bias)
|
||||||
|
self.proj = nn.Conv2d(dim, dim, kernel_size=1, bias=qkv_bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# [batch_size, num_patches + 1, total_embed_dim]
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
qkv = self.qkv2(self.qkv1(x))
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
|
||||||
|
head=self.num_heads)
|
||||||
|
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
|
||||||
|
head=self.num_heads)
|
||||||
|
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
|
||||||
|
head=self.num_heads)
|
||||||
|
q = torch.nn.functional.normalize(q, dim=-1)
|
||||||
|
k = torch.nn.functional.normalize(k, dim=-1)
|
||||||
|
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
|
||||||
|
# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = (attn @ v)
|
||||||
|
|
||||||
|
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
|
||||||
|
head=self.num_heads, h=h, w=w)
|
||||||
|
|
||||||
|
out = self.proj(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
"""
|
||||||
|
MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_features,
|
||||||
|
hidden_features=None,
|
||||||
|
ffn_expansion_factor = 2,
|
||||||
|
bias = False):
|
||||||
|
super().__init__()
|
||||||
|
hidden_features = int(in_features*ffn_expansion_factor)
|
||||||
|
|
||||||
|
self.project_in = nn.Conv2d(
|
||||||
|
in_features, hidden_features*2, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3,
|
||||||
|
stride=1, padding=1, groups=hidden_features, bias=bias)
|
||||||
|
|
||||||
|
self.project_out = nn.Conv2d(
|
||||||
|
hidden_features, in_features, kernel_size=1, bias=bias)
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.project_in(x)
|
||||||
|
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||||
|
x = F.gelu(x1) * x2
|
||||||
|
x = self.project_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class BaseFeatureExtraction(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
num_heads,
|
||||||
|
ffn_expansion_factor=1.,
|
||||||
|
qkv_bias=False,):
|
||||||
|
super(BaseFeatureExtraction, self).__init__()
|
||||||
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
|
self.attn = AttentionBase(dim, num_heads=num_heads, qkv_bias=qkv_bias,)
|
||||||
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
|
self.mlp = Mlp(in_features=dim,
|
||||||
|
ffn_expansion_factor=ffn_expansion_factor,)
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.attn(self.norm1(x))
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidualBlock(nn.Module):
|
||||||
|
def __init__(self, inp, oup, expand_ratio):
|
||||||
|
super(InvertedResidualBlock, self).__init__()
|
||||||
|
hidden_dim = int(inp * expand_ratio)
|
||||||
|
self.bottleneckBlock = nn.Sequential(
|
||||||
|
# pw
|
||||||
|
nn.Conv2d(inp, hidden_dim, 1, bias=False),
|
||||||
|
# nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.ReLU6(inplace=True),
|
||||||
|
# dw
|
||||||
|
nn.ReflectionPad2d(1),
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 3, groups=hidden_dim, bias=False),
|
||||||
|
# nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.ReLU6(inplace=True),
|
||||||
|
# pw-linear
|
||||||
|
nn.Conv2d(hidden_dim, oup, 1, bias=False),
|
||||||
|
# nn.BatchNorm2d(oup),
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return self.bottleneckBlock(x)
|
||||||
|
|
||||||
|
class DetailNode(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(DetailNode, self).__init__()
|
||||||
|
# Scale is Ax + b, i.e. affine transformation
|
||||||
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
|
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
|
||||||
|
stride=1, padding=0, bias=True)
|
||||||
|
def separateFeature(self, x):
|
||||||
|
z1, z2 = x[:, :x.shape[1]//2], x[:, x.shape[1]//2:x.shape[1]]
|
||||||
|
return z1, z2
|
||||||
|
def forward(self, z1, z2):
|
||||||
|
z1, z2 = self.separateFeature(
|
||||||
|
self.shffleconv(torch.cat((z1, z2), dim=1)))
|
||||||
|
z2 = z2 + self.theta_phi(z1)
|
||||||
|
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
|
||||||
|
return z1, z2
|
||||||
|
|
||||||
|
class DetailFeatureExtraction(nn.Module):
|
||||||
|
def __init__(self, num_layers=3):
|
||||||
|
super(DetailFeatureExtraction, self).__init__()
|
||||||
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
def forward(self, x):
|
||||||
|
z1, z2 = x[:, :x.shape[1]//2], x[:, x.shape[1]//2:x.shape[1]]
|
||||||
|
for layer in self.net:
|
||||||
|
z1, z2 = layer(z1, z2)
|
||||||
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
import numbers
|
||||||
|
##########################################################################
|
||||||
|
## Layer Norm
|
||||||
|
def to_3d(x):
|
||||||
|
return rearrange(x, 'b c h w -> b (h w) c')
|
||||||
|
|
||||||
|
|
||||||
|
def to_4d(x, h, w):
|
||||||
|
return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
|
||||||
|
|
||||||
|
|
||||||
|
class BiasFree_LayerNorm(nn.Module):
|
||||||
|
def __init__(self, normalized_shape):
|
||||||
|
super(BiasFree_LayerNorm, self).__init__()
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
normalized_shape = (normalized_shape,)
|
||||||
|
normalized_shape = torch.Size(normalized_shape)
|
||||||
|
|
||||||
|
assert len(normalized_shape) == 1
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||||
|
self.normalized_shape = normalized_shape
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
sigma = x.var(-1, keepdim=True, unbiased=False)
|
||||||
|
return x / torch.sqrt(sigma+1e-5) * self.weight
|
||||||
|
|
||||||
|
|
||||||
|
class WithBias_LayerNorm(nn.Module):
|
||||||
|
def __init__(self, normalized_shape):
|
||||||
|
super(WithBias_LayerNorm, self).__init__()
|
||||||
|
if isinstance(normalized_shape, numbers.Integral):
|
||||||
|
normalized_shape = (normalized_shape,)
|
||||||
|
normalized_shape = torch.Size(normalized_shape)
|
||||||
|
|
||||||
|
assert len(normalized_shape) == 1
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||||
|
self.normalized_shape = normalized_shape
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
mu = x.mean(-1, keepdim=True)
|
||||||
|
sigma = x.var(-1, keepdim=True, unbiased=False)
|
||||||
|
return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim, LayerNorm_type):
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
if LayerNorm_type == 'BiasFree':
|
||||||
|
self.body = BiasFree_LayerNorm(dim)
|
||||||
|
else:
|
||||||
|
self.body = WithBias_LayerNorm(dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
return to_4d(self.body(to_3d(x)), h, w)
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
## Gated-Dconv Feed-Forward Network (GDFN)
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, ffn_expansion_factor, bias):
|
||||||
|
super(FeedForward, self).__init__()
|
||||||
|
|
||||||
|
hidden_features = int(dim*ffn_expansion_factor)
|
||||||
|
|
||||||
|
self.project_in = nn.Conv2d(
|
||||||
|
dim, hidden_features*2, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3,
|
||||||
|
stride=1, padding=1, groups=hidden_features*2, bias=bias)
|
||||||
|
|
||||||
|
self.project_out = nn.Conv2d(
|
||||||
|
hidden_features, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.project_in(x)
|
||||||
|
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||||
|
x = F.gelu(x1) * x2
|
||||||
|
x = self.project_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
## Multi-DConv Head Transposed Self-Attention (MDTA)
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, bias):
|
||||||
|
super(Attention, self).__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
||||||
|
|
||||||
|
self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
|
||||||
|
self.qkv_dwconv = nn.Conv2d(
|
||||||
|
dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
|
||||||
|
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, h, w = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_dwconv(self.qkv(x))
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
q = rearrange(q, 'b (head c) h w -> b head c (h w)',
|
||||||
|
head=self.num_heads)
|
||||||
|
k = rearrange(k, 'b (head c) h w -> b head c (h w)',
|
||||||
|
head=self.num_heads)
|
||||||
|
v = rearrange(v, 'b (head c) h w -> b head c (h w)',
|
||||||
|
head=self.num_heads)
|
||||||
|
|
||||||
|
q = torch.nn.functional.normalize(q, dim=-1)
|
||||||
|
k = torch.nn.functional.normalize(k, dim=-1)
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = (attn @ v)
|
||||||
|
|
||||||
|
out = rearrange(out, 'b head c (h w) -> b (head c) h w',
|
||||||
|
head=self.num_heads, h=h, w=w)
|
||||||
|
|
||||||
|
out = self.project_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
|
||||||
|
super(TransformerBlock, self).__init__()
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm(dim, LayerNorm_type)
|
||||||
|
self.attn = Attention(dim, num_heads, bias)
|
||||||
|
self.norm2 = LayerNorm(dim, LayerNorm_type)
|
||||||
|
self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.attn(self.norm1(x))
|
||||||
|
x = x + self.ffn(self.norm2(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
##########################################################################
|
||||||
|
## Overlapped image patch embedding with 3x3 Conv
|
||||||
|
class OverlapPatchEmbed(nn.Module):
|
||||||
|
def __init__(self, in_c=3, embed_dim=48, bias=False):
|
||||||
|
super(OverlapPatchEmbed, self).__init__()
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3,
|
||||||
|
stride=1, padding=1, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Restormer_Encoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
inp_channels=1,
|
||||||
|
out_channels=1,
|
||||||
|
dim=64,
|
||||||
|
num_blocks=[4, 4],
|
||||||
|
heads=[8, 8, 8],
|
||||||
|
ffn_expansion_factor=2,
|
||||||
|
bias=False,
|
||||||
|
LayerNorm_type='WithBias',
|
||||||
|
):
|
||||||
|
|
||||||
|
super(Restormer_Encoder, self).__init__()
|
||||||
|
|
||||||
|
self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
|
||||||
|
|
||||||
|
self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
||||||
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
||||||
|
self.baseFeature = BaseFeatureExtraction(dim=dim, num_heads = heads[2])
|
||||||
|
self.detailFeature = DetailFeatureExtraction()
|
||||||
|
|
||||||
|
def forward(self, inp_img):
|
||||||
|
inp_enc_level1 = self.patch_embed(inp_img)
|
||||||
|
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
||||||
|
base_feature = self.baseFeature(out_enc_level1)
|
||||||
|
detail_feature = self.detailFeature(out_enc_level1)
|
||||||
|
return base_feature, detail_feature, out_enc_level1
|
||||||
|
|
||||||
|
class Restormer_Decoder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
inp_channels=1,
|
||||||
|
out_channels=1,
|
||||||
|
dim=64,
|
||||||
|
num_blocks=[4, 4],
|
||||||
|
heads=[8, 8, 8],
|
||||||
|
ffn_expansion_factor=2,
|
||||||
|
bias=False,
|
||||||
|
LayerNorm_type='WithBias',
|
||||||
|
):
|
||||||
|
|
||||||
|
super(Restormer_Decoder, self).__init__()
|
||||||
|
self.reduce_channel = nn.Conv2d(int(dim*2), int(dim), kernel_size=1, bias=bias)
|
||||||
|
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
||||||
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
||||||
|
self.output = nn.Sequential(
|
||||||
|
nn.Conv2d(int(dim), int(dim)//2, kernel_size=3,
|
||||||
|
stride=1, padding=1, bias=bias),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.Conv2d(int(dim)//2, out_channels, kernel_size=3,
|
||||||
|
stride=1, padding=1, bias=bias),)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
def forward(self, inp_img, base_feature, detail_feature):
|
||||||
|
out_enc_level0 = torch.cat((base_feature, detail_feature), dim=1)
|
||||||
|
out_enc_level0 = self.reduce_channel(out_enc_level0)
|
||||||
|
out_enc_level1 = self.encoder_level2(out_enc_level0)
|
||||||
|
if inp_img is not None:
|
||||||
|
out_enc_level1 = self.output(out_enc_level1) + inp_img
|
||||||
|
else:
|
||||||
|
out_enc_level1 = self.output(out_enc_level1)
|
||||||
|
return self.sigmoid(out_enc_level1), out_enc_level0
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
height = 128
|
||||||
|
width = 128
|
||||||
|
window_size = 8
|
||||||
|
modelE = Restormer_Encoder().cuda()
|
||||||
|
modelD = Restormer_Decoder().cuda()
|
||||||
|
|
10
requirements.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
einops==0.4.1
|
||||||
|
kornia==0.2.0
|
||||||
|
numpy==1.21.5
|
||||||
|
opencv_python==4.5.3.56
|
||||||
|
scikit_image==0.19.2
|
||||||
|
scikit_learn==1.1.3
|
||||||
|
scipy==1.7.3
|
||||||
|
tensorboardX==2.5.1
|
||||||
|
timm==0.4.12
|
||||||
|
torch==1.8.1+cu111
|
80
test_IVF.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from utils.Evaluator import Evaluator
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from utils.img_read_save import img_save,image_read_cv2
|
||||||
|
import warnings
|
||||||
|
import logging
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
logging.basicConfig(level=logging.CRITICAL)
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
ckpt_path=r"models/CDDFuse_IVF.pth"
|
||||||
|
for dataset_name in ["TNO","RoadScene"]:
|
||||||
|
print("\n"*2+"="*80)
|
||||||
|
model_name="CDDFuse "
|
||||||
|
print("The test result of "+dataset_name+' :')
|
||||||
|
test_folder=os.path.join('test_img',dataset_name)
|
||||||
|
test_out_folder=os.path.join('test_result',dataset_name)
|
||||||
|
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
|
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
|
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
|
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||||
|
|
||||||
|
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'])
|
||||||
|
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
|
||||||
|
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
|
||||||
|
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
|
||||||
|
Encoder.eval()
|
||||||
|
Decoder.eval()
|
||||||
|
BaseFuseLayer.eval()
|
||||||
|
DetailFuseLayer.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
||||||
|
|
||||||
|
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||||
|
data_VIS = image_read_cv2(os.path.join(test_folder,"vi",img_name), mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||||
|
|
||||||
|
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
|
||||||
|
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||||
|
|
||||||
|
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
|
||||||
|
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
|
||||||
|
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
|
||||||
|
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
|
||||||
|
data_Fuse, _ = Decoder(data_VIS, feature_F_B, feature_F_D)
|
||||||
|
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
|
||||||
|
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
||||||
|
img_save(fi, img_name.split(sep='.')[0], test_out_folder)
|
||||||
|
|
||||||
|
|
||||||
|
eval_folder=test_out_folder
|
||||||
|
ori_img_folder=test_folder
|
||||||
|
|
||||||
|
metric_result = np.zeros((8))
|
||||||
|
for img_name in os.listdir(os.path.join(ori_img_folder,"ir")):
|
||||||
|
ir = image_read_cv2(os.path.join(ori_img_folder,"ir", img_name), 'GRAY')
|
||||||
|
vi = image_read_cv2(os.path.join(ori_img_folder,"vi", img_name), 'GRAY')
|
||||||
|
fi = image_read_cv2(os.path.join(eval_folder, img_name.split('.')[0]+".png"), 'GRAY')
|
||||||
|
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
|
||||||
|
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
|
||||||
|
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
|
||||||
|
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
|
||||||
|
|
||||||
|
metric_result /= len(os.listdir(eval_folder))
|
||||||
|
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
|
||||||
|
print(model_name+'\t'+str(np.round(metric_result[0], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[1], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[2], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[3], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[4], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[5], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[6], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[7], 2))
|
||||||
|
)
|
||||||
|
print("="*80)
|
86
test_MIF.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from utils.Evaluator import Evaluator
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from utils.img_read_save import img_save,image_read_cv2
|
||||||
|
import warnings
|
||||||
|
import logging
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
logging.basicConfig(level=logging.CRITICAL)
|
||||||
|
import cv2
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
CDDFuse_path=r"models/CDDFuse_IVF.pth"
|
||||||
|
CDDFuse_MIF_path=r"models/CDDFuse_MIF.pth"
|
||||||
|
for dataset_name in ["MRI_CT","MRI_PET","MRI_SPECT"]:
|
||||||
|
print("\n"*2+"="*80)
|
||||||
|
print("The test result of "+dataset_name+" :")
|
||||||
|
print("\t\t EN\t SD\t SF\t MI\tSCD\tVIF\tQabf\tSSIM")
|
||||||
|
for ckpt_path in [CDDFuse_path,CDDFuse_MIF_path]:
|
||||||
|
model_name=ckpt_path.split('/')[-1].split('.')[0]
|
||||||
|
test_folder=os.path.join('test_img',dataset_name)
|
||||||
|
test_out_folder=os.path.join('test_result',dataset_name)
|
||||||
|
|
||||||
|
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
|
Decoder = nn.DataParallel(Restormer_Decoder()).to(device)
|
||||||
|
BaseFuseLayer = nn.DataParallel(BaseFeatureExtraction(dim=64, num_heads=8)).to(device)
|
||||||
|
DetailFuseLayer = nn.DataParallel(DetailFeatureExtraction(num_layers=1)).to(device)
|
||||||
|
|
||||||
|
Encoder.load_state_dict(torch.load(ckpt_path)['DIDF_Encoder'])
|
||||||
|
Decoder.load_state_dict(torch.load(ckpt_path)['DIDF_Decoder'])
|
||||||
|
BaseFuseLayer.load_state_dict(torch.load(ckpt_path)['BaseFuseLayer'])
|
||||||
|
DetailFuseLayer.load_state_dict(torch.load(ckpt_path)['DetailFuseLayer'])
|
||||||
|
Encoder.eval()
|
||||||
|
Decoder.eval()
|
||||||
|
BaseFuseLayer.eval()
|
||||||
|
DetailFuseLayer.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for img_name in os.listdir(os.path.join(test_folder,dataset_name.split('_')[0])):
|
||||||
|
data_IR=image_read_cv2(os.path.join(test_folder,dataset_name.split('_')[1],img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||||
|
data_VIS = image_read_cv2(os.path.join(test_folder,dataset_name.split('_')[0],img_name), mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||||
|
|
||||||
|
data_IR,data_VIS = torch.FloatTensor(data_IR),torch.FloatTensor(data_VIS)
|
||||||
|
data_VIS, data_IR = data_VIS.cuda(), data_IR.cuda()
|
||||||
|
|
||||||
|
feature_V_B, feature_V_D, feature_V = Encoder(data_VIS)
|
||||||
|
feature_I_B, feature_I_D, feature_I = Encoder(data_IR)
|
||||||
|
feature_F_B = BaseFuseLayer(feature_V_B + feature_I_B)
|
||||||
|
feature_F_D = DetailFuseLayer(feature_V_D + feature_I_D)
|
||||||
|
if ckpt_path==CDDFuse_path:
|
||||||
|
data_Fuse, _ = Decoder(data_IR+data_VIS, feature_F_B, feature_F_D)
|
||||||
|
else:
|
||||||
|
data_Fuse, _ = Decoder(None, feature_F_B, feature_F_D)
|
||||||
|
data_Fuse=(data_Fuse-torch.min(data_Fuse))/(torch.max(data_Fuse)-torch.min(data_Fuse))
|
||||||
|
fi = np.squeeze((data_Fuse * 255).cpu().numpy())
|
||||||
|
img_save(fi, img_name.split(sep='.')[0], test_out_folder)
|
||||||
|
eval_folder=test_out_folder
|
||||||
|
ori_img_folder=test_folder
|
||||||
|
|
||||||
|
metric_result = np.zeros((8))
|
||||||
|
for img_name in os.listdir(os.path.join(ori_img_folder,dataset_name.split('_')[0])):
|
||||||
|
ir = image_read_cv2(os.path.join(ori_img_folder,dataset_name.split('_')[1], img_name), 'GRAY')
|
||||||
|
vi = image_read_cv2(os.path.join(ori_img_folder,dataset_name.split('_')[0], img_name), 'GRAY')
|
||||||
|
fi = image_read_cv2(os.path.join(eval_folder, img_name.split('.')[0]+".png"), 'GRAY')
|
||||||
|
metric_result += np.array([Evaluator.EN(fi), Evaluator.SD(fi)
|
||||||
|
, Evaluator.SF(fi), Evaluator.MI(fi, ir, vi)
|
||||||
|
, Evaluator.SCD(fi, ir, vi), Evaluator.VIFF(fi, ir, vi)
|
||||||
|
, Evaluator.Qabf(fi, ir, vi), Evaluator.SSIM(fi, ir, vi)])
|
||||||
|
|
||||||
|
metric_result /= len(os.listdir(eval_folder))
|
||||||
|
|
||||||
|
print(model_name+'\t'+str(np.round(metric_result[0], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[1], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[2], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[3], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[4], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[5], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[6], 2))+'\t'
|
||||||
|
+str(np.round(metric_result[7], 2))
|
||||||
|
)
|
||||||
|
print("="*80)
|
||||||
|
|
||||||
|
|
BIN
test_img/MRI_CT/CT/11.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_CT/CT/12.png
Normal file
After Width: | Height: | Size: 41 KiB |
BIN
test_img/MRI_CT/CT/13.png
Normal file
After Width: | Height: | Size: 41 KiB |
BIN
test_img/MRI_CT/CT/14.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_CT/CT/15.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_CT/CT/16.png
Normal file
After Width: | Height: | Size: 41 KiB |
BIN
test_img/MRI_CT/CT/17.png
Normal file
After Width: | Height: | Size: 39 KiB |
BIN
test_img/MRI_CT/CT/18.png
Normal file
After Width: | Height: | Size: 37 KiB |
BIN
test_img/MRI_CT/CT/19.png
Normal file
After Width: | Height: | Size: 34 KiB |
BIN
test_img/MRI_CT/CT/20.png
Normal file
After Width: | Height: | Size: 33 KiB |
BIN
test_img/MRI_CT/CT/21.png
Normal file
After Width: | Height: | Size: 38 KiB |
BIN
test_img/MRI_CT/CT/22.png
Normal file
After Width: | Height: | Size: 41 KiB |
BIN
test_img/MRI_CT/CT/23.png
Normal file
After Width: | Height: | Size: 44 KiB |
BIN
test_img/MRI_CT/CT/24.png
Normal file
After Width: | Height: | Size: 44 KiB |
BIN
test_img/MRI_CT/CT/25.png
Normal file
After Width: | Height: | Size: 43 KiB |
BIN
test_img/MRI_CT/CT/26.png
Normal file
After Width: | Height: | Size: 41 KiB |
BIN
test_img/MRI_CT/CT/27.png
Normal file
After Width: | Height: | Size: 40 KiB |
BIN
test_img/MRI_CT/CT/28.png
Normal file
After Width: | Height: | Size: 40 KiB |
BIN
test_img/MRI_CT/CT/29.png
Normal file
After Width: | Height: | Size: 40 KiB |
BIN
test_img/MRI_CT/CT/30.png
Normal file
After Width: | Height: | Size: 39 KiB |
BIN
test_img/MRI_CT/CT/31.png
Normal file
After Width: | Height: | Size: 37 KiB |
BIN
test_img/MRI_CT/MRI/11.png
Normal file
After Width: | Height: | Size: 59 KiB |
BIN
test_img/MRI_CT/MRI/12.png
Normal file
After Width: | Height: | Size: 58 KiB |
BIN
test_img/MRI_CT/MRI/13.png
Normal file
After Width: | Height: | Size: 60 KiB |
BIN
test_img/MRI_CT/MRI/14.png
Normal file
After Width: | Height: | Size: 66 KiB |
BIN
test_img/MRI_CT/MRI/15.png
Normal file
After Width: | Height: | Size: 63 KiB |
BIN
test_img/MRI_CT/MRI/16.png
Normal file
After Width: | Height: | Size: 61 KiB |
BIN
test_img/MRI_CT/MRI/17.png
Normal file
After Width: | Height: | Size: 60 KiB |
BIN
test_img/MRI_CT/MRI/18.png
Normal file
After Width: | Height: | Size: 57 KiB |
BIN
test_img/MRI_CT/MRI/19.png
Normal file
After Width: | Height: | Size: 54 KiB |
BIN
test_img/MRI_CT/MRI/20.png
Normal file
After Width: | Height: | Size: 52 KiB |
BIN
test_img/MRI_CT/MRI/21.png
Normal file
After Width: | Height: | Size: 51 KiB |
BIN
test_img/MRI_CT/MRI/22.png
Normal file
After Width: | Height: | Size: 53 KiB |
BIN
test_img/MRI_CT/MRI/23.png
Normal file
After Width: | Height: | Size: 54 KiB |
BIN
test_img/MRI_CT/MRI/24.png
Normal file
After Width: | Height: | Size: 56 KiB |
BIN
test_img/MRI_CT/MRI/25.png
Normal file
After Width: | Height: | Size: 55 KiB |
BIN
test_img/MRI_CT/MRI/26.png
Normal file
After Width: | Height: | Size: 52 KiB |
BIN
test_img/MRI_CT/MRI/27.png
Normal file
After Width: | Height: | Size: 50 KiB |
BIN
test_img/MRI_CT/MRI/28.png
Normal file
After Width: | Height: | Size: 50 KiB |
BIN
test_img/MRI_CT/MRI/29.png
Normal file
After Width: | Height: | Size: 49 KiB |
BIN
test_img/MRI_CT/MRI/30.png
Normal file
After Width: | Height: | Size: 48 KiB |
BIN
test_img/MRI_CT/MRI/31.png
Normal file
After Width: | Height: | Size: 46 KiB |
BIN
test_img/MRI_PET/MRI/11.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/12.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/13.png
Normal file
After Width: | Height: | Size: 43 KiB |
BIN
test_img/MRI_PET/MRI/14.png
Normal file
After Width: | Height: | Size: 43 KiB |
BIN
test_img/MRI_PET/MRI/15.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/16.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/17.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/18.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/19.png
Normal file
After Width: | Height: | Size: 43 KiB |
BIN
test_img/MRI_PET/MRI/20.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/21.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/22.png
Normal file
After Width: | Height: | Size: 41 KiB |
BIN
test_img/MRI_PET/MRI/23.png
Normal file
After Width: | Height: | Size: 39 KiB |
BIN
test_img/MRI_PET/MRI/24.png
Normal file
After Width: | Height: | Size: 38 KiB |
BIN
test_img/MRI_PET/MRI/25.png
Normal file
After Width: | Height: | Size: 37 KiB |
BIN
test_img/MRI_PET/MRI/26.png
Normal file
After Width: | Height: | Size: 36 KiB |
BIN
test_img/MRI_PET/MRI/27.png
Normal file
After Width: | Height: | Size: 34 KiB |
BIN
test_img/MRI_PET/MRI/28.png
Normal file
After Width: | Height: | Size: 32 KiB |
BIN
test_img/MRI_PET/MRI/29.png
Normal file
After Width: | Height: | Size: 29 KiB |
BIN
test_img/MRI_PET/MRI/30.png
Normal file
After Width: | Height: | Size: 26 KiB |
BIN
test_img/MRI_PET/MRI/31.png
Normal file
After Width: | Height: | Size: 21 KiB |
BIN
test_img/MRI_PET/MRI/32.png
Normal file
After Width: | Height: | Size: 39 KiB |
BIN
test_img/MRI_PET/MRI/33.png
Normal file
After Width: | Height: | Size: 40 KiB |
BIN
test_img/MRI_PET/MRI/34.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/35.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/36.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/37.png
Normal file
After Width: | Height: | Size: 42 KiB |
BIN
test_img/MRI_PET/MRI/38.png
Normal file
After Width: | Height: | Size: 44 KiB |
BIN
test_img/MRI_PET/MRI/39.png
Normal file
After Width: | Height: | Size: 46 KiB |
BIN
test_img/MRI_PET/MRI/40.png
Normal file
After Width: | Height: | Size: 44 KiB |
BIN
test_img/MRI_PET/MRI/41.png
Normal file
After Width: | Height: | Size: 44 KiB |
BIN
test_img/MRI_PET/MRI/42.png
Normal file
After Width: | Height: | Size: 43 KiB |
BIN
test_img/MRI_PET/MRI/43.png
Normal file
After Width: | Height: | Size: 41 KiB |