Compare commits

..

16 Commits

Author SHA1 Message Date
zjut
7ad3d631ff test(PFCFuse): 更新模型权重并修改测试名称
- 更新模型权重文件路径为 whaiFusion11-16-21-39.pth
- 修改模型测试名称为 "PFCFuse Enhance 不同FusionLayer"
2024-11-17 09:55:03 +08:00
zjut
aae81d97fd feat(net): 新增基础特征融合和细节特征融合模块
- 添加了 BaseFeatureFusion 和 DetailFeatureFusion 两个新类
- 更新了 train.py 中的导入和实例化语句
2024-11-16 21:39:34 +08:00
zjut
9224f9b640 test: 更新测试模型路径
- 修改了 test_IVF.py 文件中的模型路径
- 将旧路径 whaiFusion11-15-22-09.pth 更改为新路径 whaiFusion11-16-11-20.pth
2024-11-16 21:37:02 +08:00
zjut
e8a0212bbb feat(net): 增加 SAR 图像处理支持
- 新增 BaseFeatureExtractionSAR 和 DetailFeatureExtractionSAR 模块
- 修改 DIDF_Encoder 类,支持 SAR 图像输入
- 更新测试和训练脚本,增加 SAR 图像处理相关逻辑
2024-11-16 12:59:07 +08:00
zjut
0ef5760d76 test(test_IVF.py): 更新测试模型并调整测试数据集
- 更新模型路径为 whaiFusion11-15-17-48.pth
- 将测试数据集从 ["TNO","RoadScene"] 修改为 ["sar"]- 修改模型名称为 "PFCFuse 最基本版本"
2024-11-15 21:48:32 +08:00
zjut
344de69cb2 refactor(net): 重构网络结构并移除未使用的代码
- 移除了未使用的导入语句和冗余代码
- 重构了某些类和方法,提高了代码可读性
- 删除了未使用的变量和注释掉的代码块
- 简化了部分代码结构,提高了运行效率
2024-11-15 17:48:42 +08:00
c023c0801d refactor(net): 注释掉 DetailFeatureExtraction、DetailFeatureFusion 和 DetailFeatureExtractionSAR 类中的 enhancement_module
- 在三个类中注释掉了 enhancement_module 的定义
- 该改动可能是为了暂时禁用增强模块的功能或进行调试
2024-11-15 09:14:15 +08:00
555515c2dc feat(net): 移除 WTConv2d,添加 DEConv 模块
- 删除了 WTConv2d 相关代码
- 新增了 DEConv 模块,包括多种卷积类型
- 更新了 net.py 中的相关调用,移除了 WTConv2d
2024-11-14 17:15:00 +08:00
30bbfdf86e feat(components): 添加 DEConv 和 SEBlock 组件
- 新增 DEConv 组件,用于细节增强卷积
- 新增 SEBlock组件,用于通道注意力机制
- 更新 net.py 中的 DetailNode 结构
- 调整 train.py 中的模型初始化
2024-11-14 16:59:11 +08:00
zjut
c1eed72f24 feat(net): 重构特征融合模块并添加新组件
- 新增 BaseFeatureFusion 和 DetailFeatureFusioin 类,用于特征融合
- 更新 ProjectRootManager 配置,使用本地 Python 3.8 环境
- 修改训练数据集路径
- 优化训练日志输出格式
2024-11-14 16:02:05 +08:00
b6486dbaf4 添加 .idea/ 和 status.md到 .gitignore 文件,避免个人配置和状态文件被跟踪。在测试脚本中移除了不必要的打印语句。新增了测试日志和成功运行的日志文件。 2024-10-26 18:37:15 +08:00
7d6d629786 添加 .idea/ 和 status.md到 .gitignore 文件,避免个人配置和状态文件被跟踪。在测试脚本中移除了不必要的打印语句。新增了测试日志和成功运行的日志文件。 2024-10-09 12:04:46 +08:00
15eb10b512 添加 .idea/ 和 status.md到 .gitignore 文件,避免个人配置和状态文件被跟踪。在测试脚本中移除了不必要的打印语句。新增了测试日志和成功运行的日志文件。 2024-10-09 11:57:57 +08:00
5e561ab6f7 修改代码结构,提高可读性和可维护性;调整训练输出频率。
改进 self.enhancement_module 为
        self.enhancement_module = WTConv2d(32, 32)
2024-10-09 11:35:06 +08:00
96ce7d5fda 修改代码结构,提高可读性和可维护性;调整训练输出频率。
改进 self.enhancement_module 为
        self.enhancement_module = WTConv2d(32, 32)
2024-10-08 16:50:11 +08:00
afd55abe9e 模型结构
DetailFeatureExtraction增加了一个增强残差
BaseFeatureExtraction增加了
x = self.WTConv2d(x)
2024-10-07 15:24:33 +08:00
588 changed files with 1706 additions and 142 deletions

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

12
.idea/PFCFuse.iml Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

78
.idea/deployment.xml Normal file
View File

@ -0,0 +1,78 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="star@192.168.50.108:22 password (9)" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="star@192.168.50.108:22 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (6)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (7)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (8)">
<serverdata>
<mappings>
<mapping deploy="/home/star/whaiDir/PFCFuse" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
<paths name="star@192.168.50.108:22 password (9)">
<serverdata>
<mappings>
<mapping deploy="/home/star/whaiDir/PFCFuse" local="$PROJECT_DIR$" />
</mappings>
</serverdata>
</paths>
<paths name="v100">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ALWAYS" />
</component>
</project>

15
.idea/git_toolbox_prj.xml Normal file
View File

@ -0,0 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="GitToolBoxProjectSettings">
<option name="commitMessageIssueKeyValidationOverride">
<BoolValueOverride>
<option name="enabled" value="true" />
</BoolValueOverride>
</option>
<option name="commitMessageValidationEnabledOverride">
<BoolValueOverride>
<option name="enabled" value="true" />
</BoolValueOverride>
</option>
</component>
</project>

View File

@ -0,0 +1,264 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<Languages>
<language minSize="226" name="Python" />
</Languages>
</inspection_tool>
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="245">
<item index="0" class="java.lang.String" itemvalue="numba" />
<item index="1" class="java.lang.String" itemvalue="tensorflow-estimator" />
<item index="2" class="java.lang.String" itemvalue="greenlet" />
<item index="3" class="java.lang.String" itemvalue="Babel" />
<item index="4" class="java.lang.String" itemvalue="scikit-learn" />
<item index="5" class="java.lang.String" itemvalue="testpath" />
<item index="6" class="java.lang.String" itemvalue="py" />
<item index="7" class="java.lang.String" itemvalue="gitdb" />
<item index="8" class="java.lang.String" itemvalue="torchvision" />
<item index="9" class="java.lang.String" itemvalue="patsy" />
<item index="10" class="java.lang.String" itemvalue="mccabe" />
<item index="11" class="java.lang.String" itemvalue="bleach" />
<item index="12" class="java.lang.String" itemvalue="lxml" />
<item index="13" class="java.lang.String" itemvalue="torchaudio" />
<item index="14" class="java.lang.String" itemvalue="jsonschema" />
<item index="15" class="java.lang.String" itemvalue="xlrd" />
<item index="16" class="java.lang.String" itemvalue="Werkzeug" />
<item index="17" class="java.lang.String" itemvalue="anaconda-project" />
<item index="18" class="java.lang.String" itemvalue="tensorboard-data-server" />
<item index="19" class="java.lang.String" itemvalue="typing-extensions" />
<item index="20" class="java.lang.String" itemvalue="click" />
<item index="21" class="java.lang.String" itemvalue="regex" />
<item index="22" class="java.lang.String" itemvalue="fastcache" />
<item index="23" class="java.lang.String" itemvalue="tensorboard" />
<item index="24" class="java.lang.String" itemvalue="imageio" />
<item index="25" class="java.lang.String" itemvalue="pytest-remotedata" />
<item index="26" class="java.lang.String" itemvalue="matplotlib" />
<item index="27" class="java.lang.String" itemvalue="idna" />
<item index="28" class="java.lang.String" itemvalue="Bottleneck" />
<item index="29" class="java.lang.String" itemvalue="rsa" />
<item index="30" class="java.lang.String" itemvalue="networkx" />
<item index="31" class="java.lang.String" itemvalue="pycurl" />
<item index="32" class="java.lang.String" itemvalue="smmap" />
<item index="33" class="java.lang.String" itemvalue="pluggy" />
<item index="34" class="java.lang.String" itemvalue="cffi" />
<item index="35" class="java.lang.String" itemvalue="pep8" />
<item index="36" class="java.lang.String" itemvalue="numpy" />
<item index="37" class="java.lang.String" itemvalue="jdcal" />
<item index="38" class="java.lang.String" itemvalue="alabaster" />
<item index="39" class="java.lang.String" itemvalue="jupyter" />
<item index="40" class="java.lang.String" itemvalue="pyOpenSSL" />
<item index="41" class="java.lang.String" itemvalue="PyWavelets" />
<item index="42" class="java.lang.String" itemvalue="prompt-toolkit" />
<item index="43" class="java.lang.String" itemvalue="QtAwesome" />
<item index="44" class="java.lang.String" itemvalue="msgpack-python" />
<item index="45" class="java.lang.String" itemvalue="Flask-Cors" />
<item index="46" class="java.lang.String" itemvalue="glob2" />
<item index="47" class="java.lang.String" itemvalue="Send2Trash" />
<item index="48" class="java.lang.String" itemvalue="imagesize" />
<item index="49" class="java.lang.String" itemvalue="et-xmlfile" />
<item index="50" class="java.lang.String" itemvalue="pathlib2" />
<item index="51" class="java.lang.String" itemvalue="docker-pycreds" />
<item index="52" class="java.lang.String" itemvalue="importlib-resources" />
<item index="53" class="java.lang.String" itemvalue="pathtools" />
<item index="54" class="java.lang.String" itemvalue="spyder" />
<item index="55" class="java.lang.String" itemvalue="pylint" />
<item index="56" class="java.lang.String" itemvalue="statsmodels" />
<item index="57" class="java.lang.String" itemvalue="tensorboardX" />
<item index="58" class="java.lang.String" itemvalue="isort" />
<item index="59" class="java.lang.String" itemvalue="ruamel_yaml" />
<item index="60" class="java.lang.String" itemvalue="pytz" />
<item index="61" class="java.lang.String" itemvalue="unicodecsv" />
<item index="62" class="java.lang.String" itemvalue="pytest-astropy" />
<item index="63" class="java.lang.String" itemvalue="traitlets" />
<item index="64" class="java.lang.String" itemvalue="absl-py" />
<item index="65" class="java.lang.String" itemvalue="protobuf" />
<item index="66" class="java.lang.String" itemvalue="nltk" />
<item index="67" class="java.lang.String" itemvalue="partd" />
<item index="68" class="java.lang.String" itemvalue="promise" />
<item index="69" class="java.lang.String" itemvalue="gast" />
<item index="70" class="java.lang.String" itemvalue="filelock" />
<item index="71" class="java.lang.String" itemvalue="numpydoc" />
<item index="72" class="java.lang.String" itemvalue="pyzmq" />
<item index="73" class="java.lang.String" itemvalue="oauthlib" />
<item index="74" class="java.lang.String" itemvalue="astropy" />
<item index="75" class="java.lang.String" itemvalue="keras" />
<item index="76" class="java.lang.String" itemvalue="entrypoints" />
<item index="77" class="java.lang.String" itemvalue="bkcharts" />
<item index="78" class="java.lang.String" itemvalue="pyparsing" />
<item index="79" class="java.lang.String" itemvalue="munch" />
<item index="80" class="java.lang.String" itemvalue="sphinxcontrib-websupport" />
<item index="81" class="java.lang.String" itemvalue="beautifulsoup4" />
<item index="82" class="java.lang.String" itemvalue="path.py" />
<item index="83" class="java.lang.String" itemvalue="clyent" />
<item index="84" class="java.lang.String" itemvalue="navigator-updater" />
<item index="85" class="java.lang.String" itemvalue="tifffile" />
<item index="86" class="java.lang.String" itemvalue="cryptography" />
<item index="87" class="java.lang.String" itemvalue="pygdal" />
<item index="88" class="java.lang.String" itemvalue="fastrlock" />
<item index="89" class="java.lang.String" itemvalue="widgetsnbextension" />
<item index="90" class="java.lang.String" itemvalue="multipledispatch" />
<item index="91" class="java.lang.String" itemvalue="numexpr" />
<item index="92" class="java.lang.String" itemvalue="jupyter-core" />
<item index="93" class="java.lang.String" itemvalue="ipython_genutils" />
<item index="94" class="java.lang.String" itemvalue="yapf" />
<item index="95" class="java.lang.String" itemvalue="rope" />
<item index="96" class="java.lang.String" itemvalue="wcwidth" />
<item index="97" class="java.lang.String" itemvalue="cupy-cuda110" />
<item index="98" class="java.lang.String" itemvalue="llvmlite" />
<item index="99" class="java.lang.String" itemvalue="Jinja2" />
<item index="100" class="java.lang.String" itemvalue="pycrypto" />
<item index="101" class="java.lang.String" itemvalue="Keras-Preprocessing" />
<item index="102" class="java.lang.String" itemvalue="ptflops" />
<item index="103" class="java.lang.String" itemvalue="cupy-cuda111" />
<item index="104" class="java.lang.String" itemvalue="cupy-cuda114" />
<item index="105" class="java.lang.String" itemvalue="some-package" />
<item index="106" class="java.lang.String" itemvalue="wandb" />
<item index="107" class="java.lang.String" itemvalue="netaddr" />
<item index="108" class="java.lang.String" itemvalue="sortedcollections" />
<item index="109" class="java.lang.String" itemvalue="six" />
<item index="110" class="java.lang.String" itemvalue="timm" />
<item index="111" class="java.lang.String" itemvalue="pyflakes" />
<item index="112" class="java.lang.String" itemvalue="asn1crypto" />
<item index="113" class="java.lang.String" itemvalue="parso" />
<item index="114" class="java.lang.String" itemvalue="pytest-doctestplus" />
<item index="115" class="java.lang.String" itemvalue="ipython" />
<item index="116" class="java.lang.String" itemvalue="xlwt" />
<item index="117" class="java.lang.String" itemvalue="packaging" />
<item index="118" class="java.lang.String" itemvalue="chardet" />
<item index="119" class="java.lang.String" itemvalue="jupyterlab-launcher" />
<item index="120" class="java.lang.String" itemvalue="click-plugins" />
<item index="121" class="java.lang.String" itemvalue="PyYAML" />
<item index="122" class="java.lang.String" itemvalue="pickleshare" />
<item index="123" class="java.lang.String" itemvalue="pycparser" />
<item index="124" class="java.lang.String" itemvalue="pyasn1-modules" />
<item index="125" class="java.lang.String" itemvalue="tables" />
<item index="126" class="java.lang.String" itemvalue="Pygments" />
<item index="127" class="java.lang.String" itemvalue="sentry-sdk" />
<item index="128" class="java.lang.String" itemvalue="docutils" />
<item index="129" class="java.lang.String" itemvalue="gevent" />
<item index="130" class="java.lang.String" itemvalue="shortuuid" />
<item index="131" class="java.lang.String" itemvalue="qtconsole" />
<item index="132" class="java.lang.String" itemvalue="terminado" />
<item index="133" class="java.lang.String" itemvalue="GitPython" />
<item index="134" class="java.lang.String" itemvalue="distributed" />
<item index="135" class="java.lang.String" itemvalue="jupyter-client" />
<item index="136" class="java.lang.String" itemvalue="pexpect" />
<item index="137" class="java.lang.String" itemvalue="ipykernel" />
<item index="138" class="java.lang.String" itemvalue="nbconvert" />
<item index="139" class="java.lang.String" itemvalue="attrs" />
<item index="140" class="java.lang.String" itemvalue="psutil" />
<item index="141" class="java.lang.String" itemvalue="simplejson" />
<item index="142" class="java.lang.String" itemvalue="jedi" />
<item index="143" class="java.lang.String" itemvalue="flatbuffers" />
<item index="144" class="java.lang.String" itemvalue="cytoolz" />
<item index="145" class="java.lang.String" itemvalue="odo" />
<item index="146" class="java.lang.String" itemvalue="decorator" />
<item index="147" class="java.lang.String" itemvalue="pandocfilters" />
<item index="148" class="java.lang.String" itemvalue="backports.shutil-get-terminal-size" />
<item index="149" class="java.lang.String" itemvalue="pycodestyle" />
<item index="150" class="java.lang.String" itemvalue="pycosat" />
<item index="151" class="java.lang.String" itemvalue="pyasn1" />
<item index="152" class="java.lang.String" itemvalue="requests" />
<item index="153" class="java.lang.String" itemvalue="bitarray" />
<item index="154" class="java.lang.String" itemvalue="kornia" />
<item index="155" class="java.lang.String" itemvalue="mkl-fft" />
<item index="156" class="java.lang.String" itemvalue="tensorflow" />
<item index="157" class="java.lang.String" itemvalue="XlsxWriter" />
<item index="158" class="java.lang.String" itemvalue="seaborn" />
<item index="159" class="java.lang.String" itemvalue="tensorboard-plugin-wit" />
<item index="160" class="java.lang.String" itemvalue="blaze" />
<item index="161" class="java.lang.String" itemvalue="zipp" />
<item index="162" class="java.lang.String" itemvalue="pkginfo" />
<item index="163" class="java.lang.String" itemvalue="cached-property" />
<item index="164" class="java.lang.String" itemvalue="torchstat" />
<item index="165" class="java.lang.String" itemvalue="datashape" />
<item index="166" class="java.lang.String" itemvalue="itsdangerous" />
<item index="167" class="java.lang.String" itemvalue="ipywidgets" />
<item index="168" class="java.lang.String" itemvalue="scipy" />
<item index="169" class="java.lang.String" itemvalue="thop" />
<item index="170" class="java.lang.String" itemvalue="tornado" />
<item index="171" class="java.lang.String" itemvalue="google-auth-oauthlib" />
<item index="172" class="java.lang.String" itemvalue="opencv-python" />
<item index="173" class="java.lang.String" itemvalue="torch" />
<item index="174" class="java.lang.String" itemvalue="singledispatch" />
<item index="175" class="java.lang.String" itemvalue="sortedcontainers" />
<item index="176" class="java.lang.String" itemvalue="mistune" />
<item index="177" class="java.lang.String" itemvalue="pandas" />
<item index="178" class="java.lang.String" itemvalue="termcolor" />
<item index="179" class="java.lang.String" itemvalue="clang" />
<item index="180" class="java.lang.String" itemvalue="toolz" />
<item index="181" class="java.lang.String" itemvalue="Sphinx" />
<item index="182" class="java.lang.String" itemvalue="mpmath" />
<item index="183" class="java.lang.String" itemvalue="jupyter-console" />
<item index="184" class="java.lang.String" itemvalue="bokeh" />
<item index="185" class="java.lang.String" itemvalue="cachetools" />
<item index="186" class="java.lang.String" itemvalue="gmpy2" />
<item index="187" class="java.lang.String" itemvalue="setproctitle" />
<item index="188" class="java.lang.String" itemvalue="webencodings" />
<item index="189" class="java.lang.String" itemvalue="html5lib" />
<item index="190" class="java.lang.String" itemvalue="colorlog" />
<item index="191" class="java.lang.String" itemvalue="python-dateutil" />
<item index="192" class="java.lang.String" itemvalue="QtPy" />
<item index="193" class="java.lang.String" itemvalue="astroid" />
<item index="194" class="java.lang.String" itemvalue="cycler" />
<item index="195" class="java.lang.String" itemvalue="mkl-random" />
<item index="196" class="java.lang.String" itemvalue="pytest-arraydiff" />
<item index="197" class="java.lang.String" itemvalue="locket" />
<item index="198" class="java.lang.String" itemvalue="heapdict" />
<item index="199" class="java.lang.String" itemvalue="snowballstemmer" />
<item index="200" class="java.lang.String" itemvalue="contextlib2" />
<item index="201" class="java.lang.String" itemvalue="certifi" />
<item index="202" class="java.lang.String" itemvalue="Markdown" />
<item index="203" class="java.lang.String" itemvalue="sympy" />
<item index="204" class="java.lang.String" itemvalue="notebook" />
<item index="205" class="java.lang.String" itemvalue="pyodbc" />
<item index="206" class="java.lang.String" itemvalue="boto" />
<item index="207" class="java.lang.String" itemvalue="cligj" />
<item index="208" class="java.lang.String" itemvalue="h5py" />
<item index="209" class="java.lang.String" itemvalue="wrapt" />
<item index="210" class="java.lang.String" itemvalue="kiwisolver" />
<item index="211" class="java.lang.String" itemvalue="pytest-openfiles" />
<item index="212" class="java.lang.String" itemvalue="anaconda-client" />
<item index="213" class="java.lang.String" itemvalue="backcall" />
<item index="214" class="java.lang.String" itemvalue="PySocks" />
<item index="215" class="java.lang.String" itemvalue="charset-normalizer" />
<item index="216" class="java.lang.String" itemvalue="typing" />
<item index="217" class="java.lang.String" itemvalue="dask" />
<item index="218" class="java.lang.String" itemvalue="enum34" />
<item index="219" class="java.lang.String" itemvalue="torchsummary" />
<item index="220" class="java.lang.String" itemvalue="scikit-image" />
<item index="221" class="java.lang.String" itemvalue="ptyprocess" />
<item index="222" class="java.lang.String" itemvalue="more-itertools" />
<item index="223" class="java.lang.String" itemvalue="SQLAlchemy" />
<item index="224" class="java.lang.String" itemvalue="tblib" />
<item index="225" class="java.lang.String" itemvalue="cloudpickle" />
<item index="226" class="java.lang.String" itemvalue="importlib-metadata" />
<item index="227" class="java.lang.String" itemvalue="simplegeneric" />
<item index="228" class="java.lang.String" itemvalue="zict" />
<item index="229" class="java.lang.String" itemvalue="urllib3" />
<item index="230" class="java.lang.String" itemvalue="jupyterlab" />
<item index="231" class="java.lang.String" itemvalue="Cython" />
<item index="232" class="java.lang.String" itemvalue="Flask" />
<item index="233" class="java.lang.String" itemvalue="nose" />
<item index="234" class="java.lang.String" itemvalue="pytorch-msssim" />
<item index="235" class="java.lang.String" itemvalue="pytest" />
<item index="236" class="java.lang.String" itemvalue="nbformat" />
<item index="237" class="java.lang.String" itemvalue="matmul" />
<item index="238" class="java.lang.String" itemvalue="tqdm" />
<item index="239" class="java.lang.String" itemvalue="lazy-object-proxy" />
<item index="240" class="java.lang.String" itemvalue="colorama" />
<item index="241" class="java.lang.String" itemvalue="grpcio" />
<item index="242" class="java.lang.String" itemvalue="ply" />
<item index="243" class="java.lang.String" itemvalue="google-auth" />
<item index="244" class="java.lang.String" itemvalue="openpyxl" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

16
.idea/misc.xml Normal file
View File

@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.9 (flaskTest)" />
</component>
<component name="MavenImportPreferences">
<option name="generalSettings">
<MavenGeneralSettings>
<option name="localRepository" value="E:\maven\repository" />
<option name="showDialogWithAdvancedSettings" value="true" />
<option name="userSettingsFile" value="E:\maven\settings.xml" />
</MavenGeneralSettings>
</option>
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/PFCFuse.iml" filepath="$PROJECT_DIR$/.idea/PFCFuse.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

171
componets/DEConv.py Normal file
View File

@ -0,0 +1,171 @@
# https://github.com/cecret3350/DEA-Net/blob/main/code/model/modules/deconv.py
import math
import torch
from torch import nn
from einops.layers.torch import Rearrange
"""
在深度学习和图像处理领域"vanilla" "difference" 卷积是两种不同的卷积操作它们各自有不同的特性和应用场景DEConv细节增强卷积的设计思想是结合这两种卷积的特性来增强模型的性能尤其是在图像去雾等任务中
Vanilla Convolution标准卷积
"Vanilla" 卷积是最基本的卷积类型通常仅称为卷积它是卷积神经网络CNN中最常用的组件用于提取输入数据如图像的特征
标准卷积通过在输入数据上滑动小的可学习的过滤器或称为核并计算过滤器与数据的局部区域之间的点乘来工作通过这种方式它能够捕获输入数据的局部模式和特征
Difference Convolution差分卷积
差分卷积是一种特殊类型的卷积它专注于捕捉输入数据中的局部差异信息例如边缘或纹理的变化
它通过修改标准卷积核的权重或者通过特殊的操作来实现使得网络更加关注于图像的高频信息即图像中的细节和纹理变化在图像处理任务中如图像去雾图像增强边缘检测等捕获这种高频信息非常重要因为它们往往包含了关于物体边界和结构的关键信息
重参数化技术
重参数化技术是一种参数转换方法它允许模型在不增加额外参数和计算代价的情况下实现更复杂的功能或改善性能在DEConv的上下文中重参数化技术使得将vanilla卷积和difference卷积结合起来的操作可以等价地转换成一个标准的卷积操作
这意味着DEConv可以在不增加额外参数和计算成本的情况下通过巧妙地设计卷积核权重同时利用标准卷积和差分卷积的优势从而增强模型处理图像的能力
具体来说通过重参数化可以将差分卷积的效果整合到一个卷积核中使得这个卷积核既能捕获图像的基本特征通过标准卷积部分也能强调图像的细节和差异信息通过差分卷积部分
这种方法特别适用于那些需要同时考虑全局内容和局部细节信息的任务如图像去雾其中既需要理解图像的整体结构也需要恢复由于雾导致的细节丢失
重参数化技术的关键优势在于它允许模型在维持参数数量和计算复杂度不变的前提下实现更为复杂或更为精细的功能
"""
class Conv2d_cd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_cd, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
# conv_weight_cd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_cd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_cd[:, :, :] = conv_weight[:, :, :]
conv_weight_cd[:, :, 4] = conv_weight[:, :, 4] - conv_weight[:, :, :].sum(2)
conv_weight_cd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
conv_weight_cd)
return conv_weight_cd, self.conv.bias
class Conv2d_ad(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_ad, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
conv_weight_ad = conv_weight - self.theta * conv_weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]
conv_weight_ad = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
conv_weight_ad)
return conv_weight_ad, self.conv.bias
class Conv2d_rd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=2, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_rd, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.theta = theta
def forward(self, x):
if math.fabs(self.theta - 0.0) < 1e-8:
out_normal = self.conv(x)
return out_normal
else:
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
if conv_weight.is_cuda:
conv_weight_rd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 5 * 5).fill_(0)
else:
conv_weight_rd = torch.zeros(conv_shape[0], conv_shape[1], 5 * 5)
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
conv_weight_rd[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = conv_weight[:, :, 1:]
conv_weight_rd[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -conv_weight[:, :, 1:] * self.theta
conv_weight_rd[:, :, 12] = conv_weight[:, :, 0] * (1 - self.theta)
conv_weight_rd = conv_weight_rd.view(conv_shape[0], conv_shape[1], 5, 5)
out_diff = nn.functional.conv2d(input=x, weight=conv_weight_rd, bias=self.conv.bias,
stride=self.conv.stride, padding=self.conv.padding, groups=self.conv.groups)
return out_diff
class Conv2d_hd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
super(Conv2d_hd, self).__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
# conv_weight_hd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_hd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_hd[:, :, [0, 3, 6]] = conv_weight[:, :, :]
conv_weight_hd[:, :, [2, 5, 8]] = -conv_weight[:, :, :]
conv_weight_hd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
conv_weight_hd)
return conv_weight_hd, self.conv.bias
class Conv2d_vd(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False):
super(Conv2d_vd, self).__init__()
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
def get_weight(self):
conv_weight = self.conv.weight
conv_shape = conv_weight.shape
# conv_weight_vd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_vd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
conv_weight_vd[:, :, [0, 1, 2]] = conv_weight[:, :, :]
conv_weight_vd[:, :, [6, 7, 8]] = -conv_weight[:, :, :]
conv_weight_vd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
conv_weight_vd)
return conv_weight_vd, self.conv.bias
class DEConv(nn.Module):
def __init__(self, dim):
super(DEConv, self).__init__()
self.conv1_1 = Conv2d_cd(dim, dim, 3, bias=True)
self.conv1_2 = Conv2d_hd(dim, dim, 3, bias=True)
self.conv1_3 = Conv2d_vd(dim, dim, 3, bias=True)
self.conv1_4 = Conv2d_ad(dim, dim, 3, bias=True)
self.conv1_5 = nn.Conv2d(dim, dim, 3, padding=1, bias=True)
def forward(self, x):
w1, b1 = self.conv1_1.get_weight()
w2, b2 = self.conv1_2.get_weight()
w3, b3 = self.conv1_3.get_weight()
w4, b4 = self.conv1_4.get_weight()
w5, b5 = self.conv1_5.weight, self.conv1_5.bias
w = w1 + w2 + w3 + w4 + w5
b = b1 + b2 + b3 + b4 + b5
res = nn.functional.conv2d(input=x, weight=w, bias=b, stride=1, padding=1, groups=1)
return res
if __name__ == '__main__':
# 初始化DEConv模块dim为输入和输出的通道数
block = DEConv(dim=16)
# 创建一个随机输入张量,假设输入尺寸为(batch_size, channels, height, width)
input_tensor = torch.rand(4, 16, 64, 64)
# 将输入传递给DEConv模块
output_tensor = block(input_tensor)
# 打印输入和输出张量的尺寸
print("输入尺寸:", input_tensor.size())
print("输出尺寸:", output_tensor.size())

37
componets/SEBlock.py Normal file
View File

@ -0,0 +1,37 @@
'''-------------一、SE模块-----------------------------'''
import torch
from torch import nn
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
class SE_Block(nn.Module):
def __init__(self, inchannel, ratio=16):
super(SE_Block, self).__init__()
# 全局平均池化(Fsq操作)
self.gap = nn.AdaptiveAvgPool2d((1, 1))
# 两个全连接层(Fex操作)
self.fc = nn.Sequential(
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
nn.ReLU(),
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
nn.Sigmoid()
)
def forward(self, x):
# 读取批数据图片数量及通道数
b, c, h, w = x.size()
# Fsq操作经池化后输出b*c的矩阵
y = self.gap(x).view(b, c)
# Fex操作经全连接层输出bc11矩阵
y = self.fc(y).view(b, c, 1, 1)
# Fscale操作将得到的权重乘以原来的特征图x
return x * y.expand_as(x)
if __name__ == '__main__':
input = torch.randn(1, 64, 32, 32)
seblock = SE_Block(64)
print(seblock)
output = seblock(input)
print(input.shape)
print(output.shape)

110
componets/TIAM(CV).py Normal file
View File

@ -0,0 +1,110 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
"""Elsevier2024
变化检测 (CD) 是地球观测中一种重要的监测方法尤其适用于土地利用分析城市管理和灾害损失评估然而在星座互联和空天协作时代感兴趣区域 (ROI) 的变化由于几何透视旋转和时间风格差异而导致许多错误检测
为了应对这些挑战我们引入了 CDNeXt该框架阐明了一种稳健而有效的方法用于将基于预训练主干的 Siamese 网络与用于遥感图像的创新时空交互注意模块 (TIAM) 相结合
CDNeXt 可分为四个主要组件编码器交互器解码器和检测器值得注意的是 TIAM 提供支持的交互器从编码器提取的二进制时间特征中查询和重建空间透视依赖关系和时间风格相关性以扩大 ROI 变化的差异
最后检测器集成解码器生成的分层特征随后生成二进制变化掩码
"""
class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
assert dimension in [2, ]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
self.g1 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
)
self.g2 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.W1 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.W2 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.theta = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.phi = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
def forward(self, x1, x2):
"""
:param x: (b, c, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x1.size(0)
g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
g_x12 = g_x11.permute(0, 2, 1)
g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
g_x22 = g_x21.permute(0, 2, 1)
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
theta_x2 = theta_x1.permute(0, 2, 1)
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
phi_x2 = phi_x1.permute(0, 2, 1)
energy_time_1 = torch.matmul(theta_x1, phi_x2)
energy_time_2 = energy_time_1.permute(0, 2, 1)
energy_space_1 = torch.matmul(theta_x2, phi_x1)
energy_space_2 = energy_space_1.permute(0, 2, 1)
energy_time_1s = F.softmax(energy_time_1, dim=-1)
energy_time_2s = F.softmax(energy_time_2, dim=-1)
energy_space_2s = F.softmax(energy_space_1, dim=-2)
energy_space_1s = F.softmax(energy_space_2, dim=-2)
# C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
# C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
return x1 + self.W1(y1), x2 + self.W2(y2)
if __name__ == '__main__':
in_channels = 64
batch_size = 8
height = 32
width = 32
block = SpatiotemporalAttentionFullNotWeightShared(in_channels=in_channels)
input1 = torch.rand(batch_size, in_channels, height, width)
input2 = torch.rand(batch_size, in_channels, height, width)
output1, output2 = block(input1, input2)
print(f"Input1 size: {input1.size()}")
print(f"Input2 size: {input2.size()}")
print(f"Output1 size: {output1.size()}")
print(f"Output2 size: {output2.size()}")

42
componets/whaiutil.py Normal file
View File

@ -0,0 +1,42 @@
import os
from PIL import Image
def transfer(input_path, quality=20, resize_factor=0.1):
# 打开TIFF图像
# img = Image.open(input_path)
#
# # 保存为JPEG并设置压缩质量
# img.save(output_path, 'JPEG', quality=quality)
# input_path = os.path.join(input_folder, filename)
# 获取input_path的文件名
# 使用os.path.splitext获取文件名和后缀的元组
# 使用os.path.basename获取文件名包含后缀
filename_with_extension = os.path.basename(input_path)
filename, file_extension = os.path.splitext(filename_with_extension)
# 使用os.path.dirname获取文件所在的目录路径
output_folder = os.path.dirname(input_path)
output_path = os.path.join(output_folder, filename + '.jpg')
img = Image.open(input_path)
# 将图像缩小到原来的一半
new_width = int(img.width * resize_factor)
new_height = int(img.height * resize_factor)
resized_img = img.resize((new_width, new_height))
# 保存为JPEG并设置压缩质量
# 转换为RGB模式丢弃透明通道
rgb_img = resized_img.convert('RGB')
# 保存为JPEG并设置压缩质量
# 压缩
rgb_img.save(output_path, 'JPEG', quality=quality)
print(f'{output_path} 转换完成')
return output_path

35
logs/20241005.log Normal file
View File

@ -0,0 +1,35 @@
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
# base pcffuse
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 2.39 33.82 11.32 0.81 0.8 0.12 0.07 0.11
================================================================================
Process finished with exit code 0

33
logs/20241007_whai.log Normal file
View File

@ -0,0 +1,33 @@
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.01 40.67 15.39 1.53 1.76 0.64 0.53 0.95
================================================================================

View File

@ -0,0 +1,89 @@
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.14 46.48 13.18 2.22 1.76 0.79 0.56 1.02
================================================================================
================================================================================
The test result of RoadScene :
FLIR_07206.jpg
FLIR_08202.jpg
FLIR_05893.jpg
FLIR_06974.jpg
FLIR_04424.jpg
FLIR_08284.jpg
FLIR_07786.jpg
FLIR_08021.jpg
FLIR_07968.jpg
FLIR_01130.jpg
FLIR_06993.jpg
FLIR_07190.jpg
FLIR_06570.jpg
FLIR_07809.jpg
FLIR_06430.jpg
FLIR_08592.jpg
FLIR_00211.jpg
FLIR_08721.jpg
FLIR_05955.jpg
FLIR_04688.jpg
FLIR_07732.jpg
FLIR_06392.jpg
FLIR_00977.jpg
FLIR_05105.jpg
FLIR_04269.jpg
FLIR_07970.jpg
FLIR_05005.jpg
FLIR_07209.jpg
FLIR_07555.jpg
FLIR_06325.jpg
FLIR_04943.jpg
FLIR_video_02829.jpg
FLIR_08248.jpg
FLIR_04484.jpg
FLIR_08058.jpg
FLIR_06795.jpg
FLIR_06995.jpg
FLIR_05879.jpg
FLIR_04593.jpg
FLIR_08094.jpg
FLIR_08526.jpg
FLIR_08858.jpg
FLIR_09465.jpg
FLIR_05064.jpg
FLIR_05857.jpg
FLIR_05914.jpg
FLIR_04722.jpg
FLIR_06506.jpg
FLIR_06282.jpg
FLIR_04512.jpg
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.41 52.99 15.81 2.37 1.78 0.71 0.55 0.96
================================================================================

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,24 @@
2.4.1+cu121
True
Model: PFCFuse
Number of epochs: 60
Epoch gap: 40
Learning rate: 0.0001
Weight decay: 0
Batch size: 1
GPU number: 0
Coefficient of MSE loss VF: 1.0
Coefficient of MSE loss IF: 1.0
Coefficient of RMI loss VF: 1.0
Coefficient of RMI loss IF: 1.0
Coefficient of Cosine loss VF: 1.0
Coefficient of Cosine loss IF: 1.0
Coefficient of Decomposition loss: 2.0
Coefficient of Total Variation loss: 5.0
Clip gradient norm value: 0.01
Optimization step: 20
Optimization gamma: 0.5
[Epoch 0/60] [Batch 0/6487] [loss: 10.193036] ETA: 9 days, 21 [Epoch 0/60] [Batch 1/6487] [loss: 4.166963] ETA: 11:49:56.5 [Epoch 0/60] [Batch 2/6487] [loss: 10.681509] ETA: 10:23:19.1 [Epoch 0/60] [Batch 3/6487] [loss: 6.257133] ETA: 10:31:48.3 [Epoch 0/60] [Batch 4/6487] [loss: 13.018341] ETA: 10:32:54.2 [Epoch 0/60] [Batch 5/6487] [loss: 11.268185] ETA: 10:27:32.2 [Epoch 0/60] [Batch 6/6487] [loss: 6.920656] ETA: 10:34:01.5 [Epoch 0/60] [Batch 7/6487] [loss: 4.666215] ETA: 10:32:45.3 [Epoch 0/60] [Batch 8/6487] [loss: 10.787085] ETA: 10:26:01.9 [Epoch 0/60] [Batch 9/6487] [loss: 5.754866] ETA: 10:34:34.2 [Epoch 0/60] [Batch 10/6487] [loss: 28.760792] ETA: 10:36:32.6 [Epoch 0/60] [Batch 11/6487] [loss: 8.672796] ETA: 10:25:11.9 [Epoch 0/60] [Batch 12/6487] [loss: 14.300608] ETA: 10:28:19.9 [Epoch 0/60] [Batch 13/6487] [loss: 11.821722] ETA: 10:34:18.6 [Epoch 0/60] [Batch 14/6487] [loss: 7.627745] ETA: 10:31:44.8 [Epoch 0/60] [Batch 15/6487] [loss: 5.722600] ETA: 10:34:17.4 [Epoch 0/60] [Batch 16/6487] [loss: 10.423873] ETA: 11:33:27.1 [Epoch 0/60] [Batch 17/6487] [loss: 4.454098] ETA: 9:37:13.67 [Epoch 0/60] [Batch 18/6487] [loss: 3.820719] ETA: 9:33:57.42 [Epoch 0/60] [Batch 19/6487] [loss: 6.564124] ETA: 9:41:22.09 [Epoch 0/60] [Batch 20/6487] [loss: 5.406681] ETA: 9:47:30.11 [Epoch 0/60] [Batch 21/6487] [loss: 25.275440] ETA: 9:39:29.91 [Epoch 0/60] [Batch 22/6487] [loss: 4.228334] ETA: 9:42:15.45 [Epoch 0/60] [Batch 23/6487] [loss: 22.508118] ETA: 9:38:18.10 [Epoch 0/60] [Batch 24/6487] [loss: 5.062001] ETA: 9:46:09.29 [Epoch 0/60] [Batch 25/6487] [loss: 3.157355] ETA: 9:41:30.09 [Epoch 0/60] [Batch 26/6487] [loss: 6.438435] ETA: 10:02:51.9 [Epoch 0/60] [Batch 27/6487] [loss: 7.430470] ETA: 9:18:12.94 [Epoch 0/60] [Batch 28/6487] [loss: 3.783903] ETA: 10:41:13.9 [Epoch 0/60] [Batch 29/6487] [loss: 2.954306] ETA: 9:44:25.10 [Epoch 0/60] [Batch 30/6487] [loss: 5.863827] ETA: 9:35:13.84 [Epoch 0/60] [Batch 31/6487] [loss: 6.467144] ETA: 9:46:19.80 [Epoch 0/60] [Batch 32/6487] [loss: 4.801052] ETA: 9:32:17.18 [Epoch 0/60] [Batch 33/6487] [loss: 5.658401] ETA: 9:31:10.28 [Epoch 0/60] [Batch 34/6487] [loss: 2.085633] ETA: 9:36:47.39 [Epoch 0/60] [Batch 35/6487] [loss: 15.402915] ETA: 9:40:51.43 [Epoch 0/60] [Batch 36/6487] [loss: 3.181264] ETA: 9:33:06.65 [Epoch 0/60] [Batch 37/6487] [loss: 3.883055] ETA: 9:42:29.60 [Epoch 0/60] [Batch 38/6487] [loss: 3.342676] ETA: 10:07:02.2 [Epoch 0/60] [Batch 39/6487] [loss: 2.589705] ETA: 9:36:43.32 [Epoch 0/60] [Batch 40/6487] [loss: 3.742121] ETA: 9:42:57.54 [Epoch 0/60] [Batch 41/6487] [loss: 2.732829] ETA: 9:36:54.65 [Epoch 0/60] [Batch 42/6487] [loss: 6.655626] ETA: 9:42:20.71 [Epoch 0/60] [Batch 43/6487] [loss: 1.822412] ETA: 9:38:02.02 [Epoch 0/60] [Batch 44/6487] [loss: 2.875143] ETA: 9:41:02.96 [Epoch 0/60] [Batch 45/6487] [loss: 2.319836] ETA: 9:38:16.23 [Epoch 0/60] [Batch 46/6487] [loss: 2.354790] ETA: 9:39:08.93 [Epoch 0/60] [Batch 47/6487] [loss: 1.986412] ETA: 9:52:11.40 [Epoch 0/60] [Batch 48/6487] [loss: 2.154071] ETA: 10:08:20.0 [Epoch 0/60] [Batch 49/6487] [loss: 1.425418] ETA: 9:54:04.42 [Epoch 0/60] [Batch 50/6487] [loss: 1.988360] ETA: 9:30:25.08 [Epoch 0/60] [Batch 51/6487] [loss: 4.090429] ETA: 9:43:53.52 [Epoch 0/60] [Batch 52/6487] [loss: 1.924778] ETA: 9:46:19.38 [Epoch 0/60] [Batch 53/6487] [loss: 2.191964] ETA: 9:46:59.93 [Epoch 0/60] [Batch 54/6487] [loss: 2.032799] ETA: 9:46:14.01 [Epoch 0/60] [Batch 55/6487] [loss: 1.923933] ETA: 9:44:21.65 [Epoch 0/60] [Batch 56/6487] [loss: 1.656838] ETA: 9:56:15.90 [Epoch 0/60] [Batch 57/6487] [loss: 1.656845] ETA: 10:21:26.1 [Epoch 0/60] [Batch 58/6487] [loss: 1.157820] ETA: 10:44:47.3 [Epoch 0/60] [Batch 59/6487] [loss: 1.652715] ETA: 10:46:39.9 [Epoch 0/60] [Batch 60/6487] [loss: 1.633865] ETA: 10:23:34.7 [Epoch 0/60] [Batch 61/6487] [loss: 1.290259] ETA: 9:24:06.12Traceback (most recent call last):
File "/home/star/whaiDir/PFCFuse/train.py", line 232, in <module>
loss.item(),
KeyboardInterrupt

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,38 @@
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.01 40.4 15.51 1.55 1.75 0.66 0.54 0.96
================================================================================
Process finished with exit code 0

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

247
net.py
View File

@ -6,10 +6,7 @@ import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange from einops import rearrange
from componets.WTConvCV2 import WTConv2d
# 以一定概率随机丢弃输入张量中的路径,用于正则化模型
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training: if drop_prob == 0. or not training:
return x return x
@ -35,9 +32,6 @@ class DropPath(nn.Module):
def forward(self, x): def forward(self, x):
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
# 改点使用Pooling替换AttentionBase
class Pooling(nn.Module): class Pooling(nn.Module):
def __init__(self, kernel_size=3): def __init__(self, kernel_size=3):
super().__init__() super().__init__()
@ -50,8 +44,8 @@ class Pooling(nn.Module):
class PoolMlp(nn.Module): class PoolMlp(nn.Module):
""" """
实现基于1x1卷积的MLP模块 Implementation of MLP with 1*1 convolutions.
输入形状为[B, C, H, W]的张量 Input: tensor with shape [B, C, H, W]
""" """
def __init__(self, def __init__(self,
@ -61,17 +55,6 @@ class PoolMlp(nn.Module):
act_layer=nn.GELU, act_layer=nn.GELU,
bias=False, bias=False,
drop=0.): drop=0.):
"""
初始化PoolMlp模块
参数:
in_features (int): 输入特征的数量
hidden_features (int, 可选): 隐藏层特征的数量默认为None设置为与in_features相同
out_features (int, 可选): 输出特征的数量默认为None设置为与in_features相同
act_layer (nn.Module, 可选): 使用的激活层默认为nn.GELU
bias (bool, 可选): 是否在卷积层中包含偏置项默认为False
drop (float, 可选): Dropout比率默认为0
"""
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
@ -81,15 +64,6 @@ class PoolMlp(nn.Module):
self.drop = nn.Dropout(drop) self.drop = nn.Dropout(drop)
def forward(self, x): def forward(self, x):
"""
通过PoolMlp模块的前向传播
参数:
x (torch.Tensor): 形状为[B, C, H, W]的输入张量
返回:
torch.Tensor: 形状为[B, C, H, W]的输出张量
"""
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W) x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
x = self.act(x) x = self.act(x)
x = self.drop(x) x = self.drop(x)
@ -98,55 +72,7 @@ class PoolMlp(nn.Module):
return x return x
# class BaseFeatureExtraction1(nn.Module): class BaseFeatureFusion(nn.Module):
# def __init__(self, dim, pool_size=3, mlp_ratio=4.,
# act_layer=nn.GELU,
# # norm_layer=nn.LayerNorm,
# drop=0., drop_path=0.,
# use_layer_scale=True, layer_scale_init_value=1e-5):
#
# super().__init__()
#
# self.norm1 = LayerNorm(dim, 'WithBias')
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
# self.norm2 = LayerNorm(dim, 'WithBias')
# mlp_hidden_dim = int(dim * mlp_ratio)
# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
# act_layer=act_layer, drop=drop)
#
# # The following two techniques are useful to train deep PoolFormers.
# self.drop_path = DropPath(drop_path) if drop_path > 0. \
# else nn.Identity()
# self.use_layer_scale = use_layer_scale
#
# if use_layer_scale:
# self.layer_scale_1 = nn.Parameter(
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
#
# self.layer_scale_2 = nn.Parameter(
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
#
# def forward(self, x): # 1 64 128 128
# if self.use_layer_scale:
# # self.layer_scale_1(64,)
# tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
# normal = self.norm1(x) # 1 64 128 128
# token_mix = self.token_mixer(normal) # 1 64 128 128
# x = (x +
# self.drop_path(
# tmp1 * token_mix
# )
# # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
# )
# x = x + self.drop_path(
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
# * self.poolmlp(self.norm2(x)))
# else:
# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
# x = x + self.drop_path(self.poolmlp(self.norm2(x)))
# return x
class BaseFeatureExtraction(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4., def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, act_layer=nn.GELU,
# norm_layer=nn.LayerNorm, # norm_layer=nn.LayerNorm,
@ -155,7 +81,6 @@ class BaseFeatureExtraction(nn.Module):
super().__init__() super().__init__()
self.WTConv2d = WTConv2d(dim, dim)
self.norm1 = LayerNorm(dim, 'WithBias') self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代 self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias') self.norm2 = LayerNorm(dim, 'WithBias')
@ -175,29 +100,103 @@ class BaseFeatureExtraction(nn.Module):
self.layer_scale_2 = nn.Parameter( self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value) torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x): # 1 64 128 128 def forward(self, x):
if self.use_layer_scale: if self.use_layer_scale:
# self.layer_scale_1(64,) x = x + self.drop_path(
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1 self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
normal = self.norm1(x) # 1 64 128 128 * self.token_mixer(self.norm1(x)))
token_mix = self.token_mixer(normal) # 1 64 128 128
x = self.WTConv2d(x)
x = (x +
self.drop_path(
tmp1 * token_mix
)
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
)
x = x + self.drop_path( x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x))) * self.poolmlp(self.norm2(x)))
else: else:
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.poolmlp(self.norm2(x))) x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x return x
class BaseFeatureExtraction(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class BaseFeatureExtractionSAR(nn.Module):
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU,
# norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
super().__init__()
self.norm1 = LayerNorm(dim, 'WithBias')
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msaMLPs是mlp这个用pool来替代
self.norm2 = LayerNorm(dim, 'WithBias')
mlp_hidden_dim = int(dim * mlp_ratio)
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.poolmlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
return x
class InvertedResidualBlock(nn.Module): class InvertedResidualBlock(nn.Module):
def __init__(self, inp, oup, expand_ratio): def __init__(self, inp, oup, expand_ratio):
super(InvertedResidualBlock, self).__init__() super(InvertedResidualBlock, self).__init__()
@ -216,12 +215,12 @@ class InvertedResidualBlock(nn.Module):
nn.Conv2d(hidden_dim, oup, 1, bias=False), nn.Conv2d(hidden_dim, oup, 1, bias=False),
# nn.BatchNorm2d(oup), # nn.BatchNorm2d(oup),
) )
def forward(self, x): def forward(self, x):
return self.bottleneckBlock(x) return self.bottleneckBlock(x)
class DetailNode(nn.Module):
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > ' class DetailNode(nn.Module):
def __init__(self): def __init__(self):
super(DetailNode, self).__init__() super(DetailNode, self).__init__()
@ -242,30 +241,44 @@ class DetailNode(nn.Module):
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2) z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
return z1, z2 return z1, z2
class DetailFeatureFusion(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureFusion, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
class DetailFeatureExtraction(nn.Module): class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3): def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__() super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)] INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
self.enhancement_module = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
)
def forward(self, x): # 1 64 128 128 def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128 z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
# 增强并添加残差连接
enhanced_z1 = self.enhancement_module(z1)
enhanced_z2 = self.enhancement_module(z2)
# 残差连接
z1 = z1 + enhanced_z1
z2 = z2 + enhanced_z2
for layer in self.net: for layer in self.net:
z1, z2 = layer(z1, z2) z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1) return torch.cat((z1, z2), dim=1)
class DetailFeatureExtractionSAR(nn.Module):
def __init__(self, num_layers=3):
super(DetailFeatureExtractionSAR, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules)
def forward(self, x):
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
for layer in self.net:
z1, z2 = layer(z1, z2)
return torch.cat((z1, z2), dim=1)
# ============================================================================= # =============================================================================
# ============================================================================= # =============================================================================
@ -447,14 +460,23 @@ class Restormer_Encoder(nn.Module):
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, *[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
self.baseFeature = BaseFeatureExtraction(dim=dim) self.baseFeature = BaseFeatureExtraction(dim=dim)
self.detailFeature = DetailFeatureExtraction() self.detailFeature = DetailFeatureExtraction()
def forward(self, inp_img): self.baseFeatureSar= BaseFeatureExtractionSAR(dim=dim)
self.detailFeatureSar = DetailFeatureExtractionSAR()
def forward(self, inp_img, sar_img=False):
inp_enc_level1 = self.patch_embed(inp_img) inp_enc_level1 = self.patch_embed(inp_img)
out_enc_level1 = self.encoder_level1(inp_enc_level1) out_enc_level1 = self.encoder_level1(inp_enc_level1)
if sar_img:
base_feature = self.baseFeature(out_enc_level1) base_feature = self.baseFeature(out_enc_level1)
detail_feature = self.detailFeature(out_enc_level1) detail_feature = self.detailFeature(out_enc_level1)
else:
base_feature= self.baseFeature(out_enc_level1)
detail_feature = self.detailFeature(out_enc_level1)
return base_feature, detail_feature, out_enc_level1 return base_feature, detail_feature, out_enc_level1
@ -472,7 +494,8 @@ class Restormer_Decoder(nn.Module):
super(Restormer_Decoder, self).__init__() super(Restormer_Decoder, self).__init__()
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias) self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, self.encoder_level2 = nn.Sequential(
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
self.output = nn.Sequential( self.output = nn.Sequential(
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3, nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
@ -499,5 +522,3 @@ if __name__ == '__main__':
window_size = 8 window_size = 8
modelE = Restormer_Encoder().cuda() modelE = Restormer_Encoder().cuda()
modelD = Restormer_Decoder().cuda() modelD = Restormer_Decoder().cuda()
print(modelE)
print(modelD)

136
status.md Normal file
View File

@ -0,0 +1,136 @@
PFCFuse
```angular2html
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.14 46.48 13.18 2.22 1.76 0.79 0.56 1.02
================================================================================
================================================================================
The test result of RoadScene :
FLIR_07206.jpg
FLIR_08202.jpg
FLIR_05893.jpg
FLIR_06974.jpg
FLIR_04424.jpg
FLIR_08284.jpg
FLIR_07786.jpg
FLIR_08021.jpg
FLIR_07968.jpg
FLIR_01130.jpg
FLIR_06993.jpg
FLIR_07190.jpg
FLIR_06570.jpg
FLIR_07809.jpg
FLIR_06430.jpg
FLIR_08592.jpg
FLIR_00211.jpg
FLIR_08721.jpg
FLIR_05955.jpg
FLIR_04688.jpg
FLIR_07732.jpg
FLIR_06392.jpg
FLIR_00977.jpg
FLIR_05105.jpg
FLIR_04269.jpg
FLIR_07970.jpg
FLIR_05005.jpg
FLIR_07209.jpg
FLIR_07555.jpg
FLIR_06325.jpg
FLIR_04943.jpg
FLIR_video_02829.jpg
FLIR_08248.jpg
FLIR_04484.jpg
FLIR_08058.jpg
FLIR_06795.jpg
FLIR_06995.jpg
FLIR_05879.jpg
FLIR_04593.jpg
FLIR_08094.jpg
FLIR_08526.jpg
FLIR_08858.jpg
FLIR_09465.jpg
FLIR_05064.jpg
FLIR_05857.jpg
FLIR_05914.jpg
FLIR_04722.jpg
FLIR_06506.jpg
FLIR_06282.jpg
FLIR_04512.jpg
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.41 52.99 15.81 2.37 1.78 0.71 0.55 0.96
================================================================================
```
20241008
```
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
================================================================================
The test result of TNO :
19.png
05.png
21.png
18.png
15.png
22.png
14.png
13.png
08.png
01.png
02.png
03.png
25.png
17.png
11.png
16.png
06.png
07.png
09.png
10.png
12.png
23.png
24.png
20.png
04.png
EN SD SF MI SCD VIF Qabf SSIM
PFCFuse 7.01 40.4 15.51 1.55 1.75 0.66 0.54 0.96
================================================================================
Process finished with exit code 0
```

View File

@ -1,3 +1,5 @@
import datetime
import cv2 import cv2
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
import os import os
@ -11,16 +13,18 @@ import logging
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL) logging.basicConfig(level=logging.CRITICAL)
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth" ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-16-21-39.pth"
for dataset_name in ["TNO"]: for dataset_name in ["sar"]:
print("\n"*2+"="*80) print("\n"*2+"="*80)
model_name="PFCFuse " model_name="PFCFuse Enhance 不同FusionLayer"
print("The test result of "+dataset_name+' :') print("The test result of "+dataset_name+' :')
test_folder=os.path.join('/home/star/whaiDir/CDDFuse/test_img/',dataset_name) test_folder = os.path.join('test_img', dataset_name)
test_out_folder=os.path.join('test_result',dataset_name) test_out_folder=os.path.join('test_result',current_time,dataset_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device) Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
@ -39,7 +43,6 @@ for dataset_name in ["TNO"]:
with torch.no_grad(): with torch.no_grad():
for img_name in os.listdir(os.path.join(test_folder,"ir")): for img_name in os.listdir(os.path.join(test_folder,"ir")):
print(img_name)
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0 data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0 data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0

BIN
test_img/MRI_CT/CT/11.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_CT/CT/12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/14.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_CT/CT/15.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_CT/CT/16.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_CT/CT/18.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
test_img/MRI_CT/CT/19.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

BIN
test_img/MRI_CT/CT/20.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

BIN
test_img/MRI_CT/CT/21.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

BIN
test_img/MRI_CT/CT/22.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/23.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
test_img/MRI_CT/CT/24.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

BIN
test_img/MRI_CT/CT/25.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_CT/CT/26.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_CT/CT/27.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_CT/CT/28.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_CT/CT/29.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_CT/CT/30.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_CT/CT/31.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
test_img/MRI_CT/MRI/11.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

BIN
test_img/MRI_CT/MRI/12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 58 KiB

BIN
test_img/MRI_CT/MRI/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

BIN
test_img/MRI_CT/MRI/14.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

BIN
test_img/MRI_CT/MRI/15.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

BIN
test_img/MRI_CT/MRI/16.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

BIN
test_img/MRI_CT/MRI/17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

BIN
test_img/MRI_CT/MRI/18.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

BIN
test_img/MRI_CT/MRI/19.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
test_img/MRI_CT/MRI/20.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

BIN
test_img/MRI_CT/MRI/21.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

BIN
test_img/MRI_CT/MRI/22.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

BIN
test_img/MRI_CT/MRI/23.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

BIN
test_img/MRI_CT/MRI/24.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

BIN
test_img/MRI_CT/MRI/25.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 55 KiB

BIN
test_img/MRI_CT/MRI/26.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

BIN
test_img/MRI_CT/MRI/27.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

BIN
test_img/MRI_CT/MRI/28.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

BIN
test_img/MRI_CT/MRI/29.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

BIN
test_img/MRI_CT/MRI/30.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

BIN
test_img/MRI_CT/MRI/31.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

BIN
test_img/MRI_PET/MRI/11.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_PET/MRI/14.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_PET/MRI/15.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/16.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/17.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/18.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/19.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

BIN
test_img/MRI_PET/MRI/20.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/21.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/22.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

BIN
test_img/MRI_PET/MRI/23.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_PET/MRI/24.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

BIN
test_img/MRI_PET/MRI/25.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

BIN
test_img/MRI_PET/MRI/26.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

BIN
test_img/MRI_PET/MRI/27.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

BIN
test_img/MRI_PET/MRI/28.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

BIN
test_img/MRI_PET/MRI/29.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

BIN
test_img/MRI_PET/MRI/30.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
test_img/MRI_PET/MRI/31.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

BIN
test_img/MRI_PET/MRI/32.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

BIN
test_img/MRI_PET/MRI/33.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
test_img/MRI_PET/MRI/34.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/35.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/36.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

BIN
test_img/MRI_PET/MRI/37.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Some files were not shown because too many files have changed in this diff Show More