Compare commits
No commits in common. "base_修改SCSA分支" and "main" have entirely different histories.
base_修改SCS
...
main
8
.idea/.gitignore
vendored
@ -1,8 +0,0 @@
|
|||||||
# Default ignored files
|
|
||||||
/shelf/
|
|
||||||
/workspace.xml
|
|
||||||
# Editor-based HTTP Client requests
|
|
||||||
/httpRequests/
|
|
||||||
# Datasource local storage ignored files
|
|
||||||
/dataSources/
|
|
||||||
/dataSources.local.xml
|
|
@ -1,12 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<module type="PYTHON_MODULE" version="4">
|
|
||||||
<component name="NewModuleRootManager">
|
|
||||||
<content url="file://$MODULE_DIR$" />
|
|
||||||
<orderEntry type="jdk" jdkName="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" jdkType="Python SDK" />
|
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
|
||||||
</component>
|
|
||||||
<component name="PyDocumentationSettings">
|
|
||||||
<option name="format" value="PLAIN" />
|
|
||||||
<option name="myDocStringFormat" value="Plain" />
|
|
||||||
</component>
|
|
||||||
</module>
|
|
@ -1,78 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="PublishConfigData" autoUpload="Always" serverName="star@192.168.50.108:22 password (9)" remoteFilesAllowedToDisappearOnAutoupload="false">
|
|
||||||
<serverData>
|
|
||||||
<paths name="star@192.168.50.108:22 password">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (2)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (3)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (4)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (5)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (6)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (7)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (8)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping deploy="/home/star/whaiDir/PFCFuse" local="$PROJECT_DIR$" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (9)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping deploy="/home/star/whaiDir/PFCFuse" local="$PROJECT_DIR$" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="v100">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
</serverData>
|
|
||||||
<option name="myAutoUpload" value="ALWAYS" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -1,15 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="GitToolBoxProjectSettings">
|
|
||||||
<option name="commitMessageIssueKeyValidationOverride">
|
|
||||||
<BoolValueOverride>
|
|
||||||
<option name="enabled" value="true" />
|
|
||||||
</BoolValueOverride>
|
|
||||||
</option>
|
|
||||||
<option name="commitMessageValidationEnabledOverride">
|
|
||||||
<BoolValueOverride>
|
|
||||||
<option name="enabled" value="true" />
|
|
||||||
</BoolValueOverride>
|
|
||||||
</option>
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -1,264 +0,0 @@
|
|||||||
<component name="InspectionProjectProfileManager">
|
|
||||||
<profile version="1.0">
|
|
||||||
<option name="myName" value="Project Default" />
|
|
||||||
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
|
||||||
<Languages>
|
|
||||||
<language minSize="226" name="Python" />
|
|
||||||
</Languages>
|
|
||||||
</inspection_tool>
|
|
||||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
|
||||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
|
||||||
<option name="ignoredPackages">
|
|
||||||
<value>
|
|
||||||
<list size="245">
|
|
||||||
<item index="0" class="java.lang.String" itemvalue="numba" />
|
|
||||||
<item index="1" class="java.lang.String" itemvalue="tensorflow-estimator" />
|
|
||||||
<item index="2" class="java.lang.String" itemvalue="greenlet" />
|
|
||||||
<item index="3" class="java.lang.String" itemvalue="Babel" />
|
|
||||||
<item index="4" class="java.lang.String" itemvalue="scikit-learn" />
|
|
||||||
<item index="5" class="java.lang.String" itemvalue="testpath" />
|
|
||||||
<item index="6" class="java.lang.String" itemvalue="py" />
|
|
||||||
<item index="7" class="java.lang.String" itemvalue="gitdb" />
|
|
||||||
<item index="8" class="java.lang.String" itemvalue="torchvision" />
|
|
||||||
<item index="9" class="java.lang.String" itemvalue="patsy" />
|
|
||||||
<item index="10" class="java.lang.String" itemvalue="mccabe" />
|
|
||||||
<item index="11" class="java.lang.String" itemvalue="bleach" />
|
|
||||||
<item index="12" class="java.lang.String" itemvalue="lxml" />
|
|
||||||
<item index="13" class="java.lang.String" itemvalue="torchaudio" />
|
|
||||||
<item index="14" class="java.lang.String" itemvalue="jsonschema" />
|
|
||||||
<item index="15" class="java.lang.String" itemvalue="xlrd" />
|
|
||||||
<item index="16" class="java.lang.String" itemvalue="Werkzeug" />
|
|
||||||
<item index="17" class="java.lang.String" itemvalue="anaconda-project" />
|
|
||||||
<item index="18" class="java.lang.String" itemvalue="tensorboard-data-server" />
|
|
||||||
<item index="19" class="java.lang.String" itemvalue="typing-extensions" />
|
|
||||||
<item index="20" class="java.lang.String" itemvalue="click" />
|
|
||||||
<item index="21" class="java.lang.String" itemvalue="regex" />
|
|
||||||
<item index="22" class="java.lang.String" itemvalue="fastcache" />
|
|
||||||
<item index="23" class="java.lang.String" itemvalue="tensorboard" />
|
|
||||||
<item index="24" class="java.lang.String" itemvalue="imageio" />
|
|
||||||
<item index="25" class="java.lang.String" itemvalue="pytest-remotedata" />
|
|
||||||
<item index="26" class="java.lang.String" itemvalue="matplotlib" />
|
|
||||||
<item index="27" class="java.lang.String" itemvalue="idna" />
|
|
||||||
<item index="28" class="java.lang.String" itemvalue="Bottleneck" />
|
|
||||||
<item index="29" class="java.lang.String" itemvalue="rsa" />
|
|
||||||
<item index="30" class="java.lang.String" itemvalue="networkx" />
|
|
||||||
<item index="31" class="java.lang.String" itemvalue="pycurl" />
|
|
||||||
<item index="32" class="java.lang.String" itemvalue="smmap" />
|
|
||||||
<item index="33" class="java.lang.String" itemvalue="pluggy" />
|
|
||||||
<item index="34" class="java.lang.String" itemvalue="cffi" />
|
|
||||||
<item index="35" class="java.lang.String" itemvalue="pep8" />
|
|
||||||
<item index="36" class="java.lang.String" itemvalue="numpy" />
|
|
||||||
<item index="37" class="java.lang.String" itemvalue="jdcal" />
|
|
||||||
<item index="38" class="java.lang.String" itemvalue="alabaster" />
|
|
||||||
<item index="39" class="java.lang.String" itemvalue="jupyter" />
|
|
||||||
<item index="40" class="java.lang.String" itemvalue="pyOpenSSL" />
|
|
||||||
<item index="41" class="java.lang.String" itemvalue="PyWavelets" />
|
|
||||||
<item index="42" class="java.lang.String" itemvalue="prompt-toolkit" />
|
|
||||||
<item index="43" class="java.lang.String" itemvalue="QtAwesome" />
|
|
||||||
<item index="44" class="java.lang.String" itemvalue="msgpack-python" />
|
|
||||||
<item index="45" class="java.lang.String" itemvalue="Flask-Cors" />
|
|
||||||
<item index="46" class="java.lang.String" itemvalue="glob2" />
|
|
||||||
<item index="47" class="java.lang.String" itemvalue="Send2Trash" />
|
|
||||||
<item index="48" class="java.lang.String" itemvalue="imagesize" />
|
|
||||||
<item index="49" class="java.lang.String" itemvalue="et-xmlfile" />
|
|
||||||
<item index="50" class="java.lang.String" itemvalue="pathlib2" />
|
|
||||||
<item index="51" class="java.lang.String" itemvalue="docker-pycreds" />
|
|
||||||
<item index="52" class="java.lang.String" itemvalue="importlib-resources" />
|
|
||||||
<item index="53" class="java.lang.String" itemvalue="pathtools" />
|
|
||||||
<item index="54" class="java.lang.String" itemvalue="spyder" />
|
|
||||||
<item index="55" class="java.lang.String" itemvalue="pylint" />
|
|
||||||
<item index="56" class="java.lang.String" itemvalue="statsmodels" />
|
|
||||||
<item index="57" class="java.lang.String" itemvalue="tensorboardX" />
|
|
||||||
<item index="58" class="java.lang.String" itemvalue="isort" />
|
|
||||||
<item index="59" class="java.lang.String" itemvalue="ruamel_yaml" />
|
|
||||||
<item index="60" class="java.lang.String" itemvalue="pytz" />
|
|
||||||
<item index="61" class="java.lang.String" itemvalue="unicodecsv" />
|
|
||||||
<item index="62" class="java.lang.String" itemvalue="pytest-astropy" />
|
|
||||||
<item index="63" class="java.lang.String" itemvalue="traitlets" />
|
|
||||||
<item index="64" class="java.lang.String" itemvalue="absl-py" />
|
|
||||||
<item index="65" class="java.lang.String" itemvalue="protobuf" />
|
|
||||||
<item index="66" class="java.lang.String" itemvalue="nltk" />
|
|
||||||
<item index="67" class="java.lang.String" itemvalue="partd" />
|
|
||||||
<item index="68" class="java.lang.String" itemvalue="promise" />
|
|
||||||
<item index="69" class="java.lang.String" itemvalue="gast" />
|
|
||||||
<item index="70" class="java.lang.String" itemvalue="filelock" />
|
|
||||||
<item index="71" class="java.lang.String" itemvalue="numpydoc" />
|
|
||||||
<item index="72" class="java.lang.String" itemvalue="pyzmq" />
|
|
||||||
<item index="73" class="java.lang.String" itemvalue="oauthlib" />
|
|
||||||
<item index="74" class="java.lang.String" itemvalue="astropy" />
|
|
||||||
<item index="75" class="java.lang.String" itemvalue="keras" />
|
|
||||||
<item index="76" class="java.lang.String" itemvalue="entrypoints" />
|
|
||||||
<item index="77" class="java.lang.String" itemvalue="bkcharts" />
|
|
||||||
<item index="78" class="java.lang.String" itemvalue="pyparsing" />
|
|
||||||
<item index="79" class="java.lang.String" itemvalue="munch" />
|
|
||||||
<item index="80" class="java.lang.String" itemvalue="sphinxcontrib-websupport" />
|
|
||||||
<item index="81" class="java.lang.String" itemvalue="beautifulsoup4" />
|
|
||||||
<item index="82" class="java.lang.String" itemvalue="path.py" />
|
|
||||||
<item index="83" class="java.lang.String" itemvalue="clyent" />
|
|
||||||
<item index="84" class="java.lang.String" itemvalue="navigator-updater" />
|
|
||||||
<item index="85" class="java.lang.String" itemvalue="tifffile" />
|
|
||||||
<item index="86" class="java.lang.String" itemvalue="cryptography" />
|
|
||||||
<item index="87" class="java.lang.String" itemvalue="pygdal" />
|
|
||||||
<item index="88" class="java.lang.String" itemvalue="fastrlock" />
|
|
||||||
<item index="89" class="java.lang.String" itemvalue="widgetsnbextension" />
|
|
||||||
<item index="90" class="java.lang.String" itemvalue="multipledispatch" />
|
|
||||||
<item index="91" class="java.lang.String" itemvalue="numexpr" />
|
|
||||||
<item index="92" class="java.lang.String" itemvalue="jupyter-core" />
|
|
||||||
<item index="93" class="java.lang.String" itemvalue="ipython_genutils" />
|
|
||||||
<item index="94" class="java.lang.String" itemvalue="yapf" />
|
|
||||||
<item index="95" class="java.lang.String" itemvalue="rope" />
|
|
||||||
<item index="96" class="java.lang.String" itemvalue="wcwidth" />
|
|
||||||
<item index="97" class="java.lang.String" itemvalue="cupy-cuda110" />
|
|
||||||
<item index="98" class="java.lang.String" itemvalue="llvmlite" />
|
|
||||||
<item index="99" class="java.lang.String" itemvalue="Jinja2" />
|
|
||||||
<item index="100" class="java.lang.String" itemvalue="pycrypto" />
|
|
||||||
<item index="101" class="java.lang.String" itemvalue="Keras-Preprocessing" />
|
|
||||||
<item index="102" class="java.lang.String" itemvalue="ptflops" />
|
|
||||||
<item index="103" class="java.lang.String" itemvalue="cupy-cuda111" />
|
|
||||||
<item index="104" class="java.lang.String" itemvalue="cupy-cuda114" />
|
|
||||||
<item index="105" class="java.lang.String" itemvalue="some-package" />
|
|
||||||
<item index="106" class="java.lang.String" itemvalue="wandb" />
|
|
||||||
<item index="107" class="java.lang.String" itemvalue="netaddr" />
|
|
||||||
<item index="108" class="java.lang.String" itemvalue="sortedcollections" />
|
|
||||||
<item index="109" class="java.lang.String" itemvalue="six" />
|
|
||||||
<item index="110" class="java.lang.String" itemvalue="timm" />
|
|
||||||
<item index="111" class="java.lang.String" itemvalue="pyflakes" />
|
|
||||||
<item index="112" class="java.lang.String" itemvalue="asn1crypto" />
|
|
||||||
<item index="113" class="java.lang.String" itemvalue="parso" />
|
|
||||||
<item index="114" class="java.lang.String" itemvalue="pytest-doctestplus" />
|
|
||||||
<item index="115" class="java.lang.String" itemvalue="ipython" />
|
|
||||||
<item index="116" class="java.lang.String" itemvalue="xlwt" />
|
|
||||||
<item index="117" class="java.lang.String" itemvalue="packaging" />
|
|
||||||
<item index="118" class="java.lang.String" itemvalue="chardet" />
|
|
||||||
<item index="119" class="java.lang.String" itemvalue="jupyterlab-launcher" />
|
|
||||||
<item index="120" class="java.lang.String" itemvalue="click-plugins" />
|
|
||||||
<item index="121" class="java.lang.String" itemvalue="PyYAML" />
|
|
||||||
<item index="122" class="java.lang.String" itemvalue="pickleshare" />
|
|
||||||
<item index="123" class="java.lang.String" itemvalue="pycparser" />
|
|
||||||
<item index="124" class="java.lang.String" itemvalue="pyasn1-modules" />
|
|
||||||
<item index="125" class="java.lang.String" itemvalue="tables" />
|
|
||||||
<item index="126" class="java.lang.String" itemvalue="Pygments" />
|
|
||||||
<item index="127" class="java.lang.String" itemvalue="sentry-sdk" />
|
|
||||||
<item index="128" class="java.lang.String" itemvalue="docutils" />
|
|
||||||
<item index="129" class="java.lang.String" itemvalue="gevent" />
|
|
||||||
<item index="130" class="java.lang.String" itemvalue="shortuuid" />
|
|
||||||
<item index="131" class="java.lang.String" itemvalue="qtconsole" />
|
|
||||||
<item index="132" class="java.lang.String" itemvalue="terminado" />
|
|
||||||
<item index="133" class="java.lang.String" itemvalue="GitPython" />
|
|
||||||
<item index="134" class="java.lang.String" itemvalue="distributed" />
|
|
||||||
<item index="135" class="java.lang.String" itemvalue="jupyter-client" />
|
|
||||||
<item index="136" class="java.lang.String" itemvalue="pexpect" />
|
|
||||||
<item index="137" class="java.lang.String" itemvalue="ipykernel" />
|
|
||||||
<item index="138" class="java.lang.String" itemvalue="nbconvert" />
|
|
||||||
<item index="139" class="java.lang.String" itemvalue="attrs" />
|
|
||||||
<item index="140" class="java.lang.String" itemvalue="psutil" />
|
|
||||||
<item index="141" class="java.lang.String" itemvalue="simplejson" />
|
|
||||||
<item index="142" class="java.lang.String" itemvalue="jedi" />
|
|
||||||
<item index="143" class="java.lang.String" itemvalue="flatbuffers" />
|
|
||||||
<item index="144" class="java.lang.String" itemvalue="cytoolz" />
|
|
||||||
<item index="145" class="java.lang.String" itemvalue="odo" />
|
|
||||||
<item index="146" class="java.lang.String" itemvalue="decorator" />
|
|
||||||
<item index="147" class="java.lang.String" itemvalue="pandocfilters" />
|
|
||||||
<item index="148" class="java.lang.String" itemvalue="backports.shutil-get-terminal-size" />
|
|
||||||
<item index="149" class="java.lang.String" itemvalue="pycodestyle" />
|
|
||||||
<item index="150" class="java.lang.String" itemvalue="pycosat" />
|
|
||||||
<item index="151" class="java.lang.String" itemvalue="pyasn1" />
|
|
||||||
<item index="152" class="java.lang.String" itemvalue="requests" />
|
|
||||||
<item index="153" class="java.lang.String" itemvalue="bitarray" />
|
|
||||||
<item index="154" class="java.lang.String" itemvalue="kornia" />
|
|
||||||
<item index="155" class="java.lang.String" itemvalue="mkl-fft" />
|
|
||||||
<item index="156" class="java.lang.String" itemvalue="tensorflow" />
|
|
||||||
<item index="157" class="java.lang.String" itemvalue="XlsxWriter" />
|
|
||||||
<item index="158" class="java.lang.String" itemvalue="seaborn" />
|
|
||||||
<item index="159" class="java.lang.String" itemvalue="tensorboard-plugin-wit" />
|
|
||||||
<item index="160" class="java.lang.String" itemvalue="blaze" />
|
|
||||||
<item index="161" class="java.lang.String" itemvalue="zipp" />
|
|
||||||
<item index="162" class="java.lang.String" itemvalue="pkginfo" />
|
|
||||||
<item index="163" class="java.lang.String" itemvalue="cached-property" />
|
|
||||||
<item index="164" class="java.lang.String" itemvalue="torchstat" />
|
|
||||||
<item index="165" class="java.lang.String" itemvalue="datashape" />
|
|
||||||
<item index="166" class="java.lang.String" itemvalue="itsdangerous" />
|
|
||||||
<item index="167" class="java.lang.String" itemvalue="ipywidgets" />
|
|
||||||
<item index="168" class="java.lang.String" itemvalue="scipy" />
|
|
||||||
<item index="169" class="java.lang.String" itemvalue="thop" />
|
|
||||||
<item index="170" class="java.lang.String" itemvalue="tornado" />
|
|
||||||
<item index="171" class="java.lang.String" itemvalue="google-auth-oauthlib" />
|
|
||||||
<item index="172" class="java.lang.String" itemvalue="opencv-python" />
|
|
||||||
<item index="173" class="java.lang.String" itemvalue="torch" />
|
|
||||||
<item index="174" class="java.lang.String" itemvalue="singledispatch" />
|
|
||||||
<item index="175" class="java.lang.String" itemvalue="sortedcontainers" />
|
|
||||||
<item index="176" class="java.lang.String" itemvalue="mistune" />
|
|
||||||
<item index="177" class="java.lang.String" itemvalue="pandas" />
|
|
||||||
<item index="178" class="java.lang.String" itemvalue="termcolor" />
|
|
||||||
<item index="179" class="java.lang.String" itemvalue="clang" />
|
|
||||||
<item index="180" class="java.lang.String" itemvalue="toolz" />
|
|
||||||
<item index="181" class="java.lang.String" itemvalue="Sphinx" />
|
|
||||||
<item index="182" class="java.lang.String" itemvalue="mpmath" />
|
|
||||||
<item index="183" class="java.lang.String" itemvalue="jupyter-console" />
|
|
||||||
<item index="184" class="java.lang.String" itemvalue="bokeh" />
|
|
||||||
<item index="185" class="java.lang.String" itemvalue="cachetools" />
|
|
||||||
<item index="186" class="java.lang.String" itemvalue="gmpy2" />
|
|
||||||
<item index="187" class="java.lang.String" itemvalue="setproctitle" />
|
|
||||||
<item index="188" class="java.lang.String" itemvalue="webencodings" />
|
|
||||||
<item index="189" class="java.lang.String" itemvalue="html5lib" />
|
|
||||||
<item index="190" class="java.lang.String" itemvalue="colorlog" />
|
|
||||||
<item index="191" class="java.lang.String" itemvalue="python-dateutil" />
|
|
||||||
<item index="192" class="java.lang.String" itemvalue="QtPy" />
|
|
||||||
<item index="193" class="java.lang.String" itemvalue="astroid" />
|
|
||||||
<item index="194" class="java.lang.String" itemvalue="cycler" />
|
|
||||||
<item index="195" class="java.lang.String" itemvalue="mkl-random" />
|
|
||||||
<item index="196" class="java.lang.String" itemvalue="pytest-arraydiff" />
|
|
||||||
<item index="197" class="java.lang.String" itemvalue="locket" />
|
|
||||||
<item index="198" class="java.lang.String" itemvalue="heapdict" />
|
|
||||||
<item index="199" class="java.lang.String" itemvalue="snowballstemmer" />
|
|
||||||
<item index="200" class="java.lang.String" itemvalue="contextlib2" />
|
|
||||||
<item index="201" class="java.lang.String" itemvalue="certifi" />
|
|
||||||
<item index="202" class="java.lang.String" itemvalue="Markdown" />
|
|
||||||
<item index="203" class="java.lang.String" itemvalue="sympy" />
|
|
||||||
<item index="204" class="java.lang.String" itemvalue="notebook" />
|
|
||||||
<item index="205" class="java.lang.String" itemvalue="pyodbc" />
|
|
||||||
<item index="206" class="java.lang.String" itemvalue="boto" />
|
|
||||||
<item index="207" class="java.lang.String" itemvalue="cligj" />
|
|
||||||
<item index="208" class="java.lang.String" itemvalue="h5py" />
|
|
||||||
<item index="209" class="java.lang.String" itemvalue="wrapt" />
|
|
||||||
<item index="210" class="java.lang.String" itemvalue="kiwisolver" />
|
|
||||||
<item index="211" class="java.lang.String" itemvalue="pytest-openfiles" />
|
|
||||||
<item index="212" class="java.lang.String" itemvalue="anaconda-client" />
|
|
||||||
<item index="213" class="java.lang.String" itemvalue="backcall" />
|
|
||||||
<item index="214" class="java.lang.String" itemvalue="PySocks" />
|
|
||||||
<item index="215" class="java.lang.String" itemvalue="charset-normalizer" />
|
|
||||||
<item index="216" class="java.lang.String" itemvalue="typing" />
|
|
||||||
<item index="217" class="java.lang.String" itemvalue="dask" />
|
|
||||||
<item index="218" class="java.lang.String" itemvalue="enum34" />
|
|
||||||
<item index="219" class="java.lang.String" itemvalue="torchsummary" />
|
|
||||||
<item index="220" class="java.lang.String" itemvalue="scikit-image" />
|
|
||||||
<item index="221" class="java.lang.String" itemvalue="ptyprocess" />
|
|
||||||
<item index="222" class="java.lang.String" itemvalue="more-itertools" />
|
|
||||||
<item index="223" class="java.lang.String" itemvalue="SQLAlchemy" />
|
|
||||||
<item index="224" class="java.lang.String" itemvalue="tblib" />
|
|
||||||
<item index="225" class="java.lang.String" itemvalue="cloudpickle" />
|
|
||||||
<item index="226" class="java.lang.String" itemvalue="importlib-metadata" />
|
|
||||||
<item index="227" class="java.lang.String" itemvalue="simplegeneric" />
|
|
||||||
<item index="228" class="java.lang.String" itemvalue="zict" />
|
|
||||||
<item index="229" class="java.lang.String" itemvalue="urllib3" />
|
|
||||||
<item index="230" class="java.lang.String" itemvalue="jupyterlab" />
|
|
||||||
<item index="231" class="java.lang.String" itemvalue="Cython" />
|
|
||||||
<item index="232" class="java.lang.String" itemvalue="Flask" />
|
|
||||||
<item index="233" class="java.lang.String" itemvalue="nose" />
|
|
||||||
<item index="234" class="java.lang.String" itemvalue="pytorch-msssim" />
|
|
||||||
<item index="235" class="java.lang.String" itemvalue="pytest" />
|
|
||||||
<item index="236" class="java.lang.String" itemvalue="nbformat" />
|
|
||||||
<item index="237" class="java.lang.String" itemvalue="matmul" />
|
|
||||||
<item index="238" class="java.lang.String" itemvalue="tqdm" />
|
|
||||||
<item index="239" class="java.lang.String" itemvalue="lazy-object-proxy" />
|
|
||||||
<item index="240" class="java.lang.String" itemvalue="colorama" />
|
|
||||||
<item index="241" class="java.lang.String" itemvalue="grpcio" />
|
|
||||||
<item index="242" class="java.lang.String" itemvalue="ply" />
|
|
||||||
<item index="243" class="java.lang.String" itemvalue="google-auth" />
|
|
||||||
<item index="244" class="java.lang.String" itemvalue="openpyxl" />
|
|
||||||
</list>
|
|
||||||
</value>
|
|
||||||
</option>
|
|
||||||
</inspection_tool>
|
|
||||||
</profile>
|
|
||||||
</component>
|
|
@ -1,6 +0,0 @@
|
|||||||
<component name="InspectionProjectProfileManager">
|
|
||||||
<settings>
|
|
||||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
|
||||||
<version value="1.0" />
|
|
||||||
</settings>
|
|
||||||
</component>
|
|
@ -1,16 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="Black">
|
|
||||||
<option name="sdkName" value="Python 3.9 (flaskTest)" />
|
|
||||||
</component>
|
|
||||||
<component name="MavenImportPreferences">
|
|
||||||
<option name="generalSettings">
|
|
||||||
<MavenGeneralSettings>
|
|
||||||
<option name="localRepository" value="E:\maven\repository" />
|
|
||||||
<option name="showDialogWithAdvancedSettings" value="true" />
|
|
||||||
<option name="userSettingsFile" value="E:\maven\settings.xml" />
|
|
||||||
</MavenGeneralSettings>
|
|
||||||
</option>
|
|
||||||
</component>
|
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" project-jdk-type="Python SDK" />
|
|
||||||
</project>
|
|
@ -1,8 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="ProjectModuleManager">
|
|
||||||
<modules>
|
|
||||||
<module fileurl="file://$PROJECT_DIR$/.idea/PFCFuse.iml" filepath="$PROJECT_DIR$/.idea/PFCFuse.iml" />
|
|
||||||
</modules>
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -1,6 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="VcsDirectoryMappings">
|
|
||||||
<mapping directory="" vcs="Git" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -1,171 +0,0 @@
|
|||||||
# https://github.com/cecret3350/DEA-Net/blob/main/code/model/modules/deconv.py
|
|
||||||
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from einops.layers.torch import Rearrange
|
|
||||||
|
|
||||||
"""
|
|
||||||
在深度学习和图像处理领域,"vanilla" 和 "difference" 卷积是两种不同的卷积操作,它们各自有不同的特性和应用场景。DEConv(细节增强卷积)的设计思想是结合这两种卷积的特性来增强模型的性能,尤其是在图像去雾等任务中。
|
|
||||||
|
|
||||||
Vanilla Convolution(标准卷积)
|
|
||||||
"Vanilla" 卷积是最基本的卷积类型,通常仅称为卷积。它是卷积神经网络(CNN)中最常用的组件,用于提取输入数据(如图像)的特征。
|
|
||||||
标准卷积通过在输入数据上滑动小的、可学习的过滤器(或称为核),并计算过滤器与数据的局部区域之间的点乘,来工作。通过这种方式,它能够捕获输入数据的局部模式和特征。
|
|
||||||
|
|
||||||
Difference Convolution(差分卷积)
|
|
||||||
差分卷积是一种特殊类型的卷积,它专注于捕捉输入数据中的局部差异信息,例如边缘或纹理的变化。
|
|
||||||
它通过修改标准卷积核的权重或者通过特殊的操作来实现,使得网络更加关注于图像的高频信息,即图像中的细节和纹理变化。在图像处理任务中,如图像去雾、图像增强、边缘检测等,捕获这种高频信息非常重要,因为它们往往包含了关于物体边界和结构的关键信息。
|
|
||||||
|
|
||||||
重参数化技术
|
|
||||||
重参数化技术是一种参数转换方法,它允许模型在不增加额外参数和计算代价的情况下,实现更复杂的功能或改善性能。在DEConv的上下文中,重参数化技术使得将vanilla卷积和difference卷积结合起来的操作,可以等价地转换成一个标准的卷积操作。
|
|
||||||
这意味着DEConv可以在不增加额外参数和计算成本的情况下,通过巧妙地设计卷积核权重,同时利用标准卷积和差分卷积的优势,从而增强模型处理图像的能力。
|
|
||||||
具体来说,通过重参数化,可以将差分卷积的效果整合到一个卷积核中,使得这个卷积核既能捕获图像的基本特征(通过标准卷积部分),也能强调图像的细节和差异信息(通过差分卷积部分)。
|
|
||||||
这种方法特别适用于那些需要同时考虑全局内容和局部细节信息的任务,如图像去雾,其中既需要理解图像的整体结构,也需要恢复由于雾导致的细节丢失。
|
|
||||||
重参数化技术的关键优势在于,它允许模型在维持参数数量和计算复杂度不变的前提下,实现更为复杂或更为精细的功能。
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d_cd(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
|
||||||
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
|
|
||||||
super(Conv2d_cd, self).__init__()
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, groups=groups, bias=bias)
|
|
||||||
self.theta = theta
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
conv_weight = self.conv.weight
|
|
||||||
conv_shape = conv_weight.shape
|
|
||||||
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
|
|
||||||
# conv_weight_cd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_cd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_cd[:, :, :] = conv_weight[:, :, :]
|
|
||||||
conv_weight_cd[:, :, 4] = conv_weight[:, :, 4] - conv_weight[:, :, :].sum(2)
|
|
||||||
conv_weight_cd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
|
|
||||||
conv_weight_cd)
|
|
||||||
return conv_weight_cd, self.conv.bias
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d_ad(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
|
||||||
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
|
|
||||||
super(Conv2d_ad, self).__init__()
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, groups=groups, bias=bias)
|
|
||||||
self.theta = theta
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
conv_weight = self.conv.weight
|
|
||||||
conv_shape = conv_weight.shape
|
|
||||||
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
|
|
||||||
conv_weight_ad = conv_weight - self.theta * conv_weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]
|
|
||||||
conv_weight_ad = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
|
|
||||||
conv_weight_ad)
|
|
||||||
return conv_weight_ad, self.conv.bias
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d_rd(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
|
||||||
padding=2, dilation=1, groups=1, bias=False, theta=1.0):
|
|
||||||
|
|
||||||
super(Conv2d_rd, self).__init__()
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, groups=groups, bias=bias)
|
|
||||||
self.theta = theta
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
if math.fabs(self.theta - 0.0) < 1e-8:
|
|
||||||
out_normal = self.conv(x)
|
|
||||||
return out_normal
|
|
||||||
else:
|
|
||||||
conv_weight = self.conv.weight
|
|
||||||
conv_shape = conv_weight.shape
|
|
||||||
if conv_weight.is_cuda:
|
|
||||||
conv_weight_rd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 5 * 5).fill_(0)
|
|
||||||
else:
|
|
||||||
conv_weight_rd = torch.zeros(conv_shape[0], conv_shape[1], 5 * 5)
|
|
||||||
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
|
|
||||||
conv_weight_rd[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = conv_weight[:, :, 1:]
|
|
||||||
conv_weight_rd[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -conv_weight[:, :, 1:] * self.theta
|
|
||||||
conv_weight_rd[:, :, 12] = conv_weight[:, :, 0] * (1 - self.theta)
|
|
||||||
conv_weight_rd = conv_weight_rd.view(conv_shape[0], conv_shape[1], 5, 5)
|
|
||||||
out_diff = nn.functional.conv2d(input=x, weight=conv_weight_rd, bias=self.conv.bias,
|
|
||||||
stride=self.conv.stride, padding=self.conv.padding, groups=self.conv.groups)
|
|
||||||
|
|
||||||
return out_diff
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d_hd(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
|
||||||
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
|
|
||||||
super(Conv2d_hd, self).__init__()
|
|
||||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, groups=groups, bias=bias)
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
conv_weight = self.conv.weight
|
|
||||||
conv_shape = conv_weight.shape
|
|
||||||
# conv_weight_hd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_hd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_hd[:, :, [0, 3, 6]] = conv_weight[:, :, :]
|
|
||||||
conv_weight_hd[:, :, [2, 5, 8]] = -conv_weight[:, :, :]
|
|
||||||
conv_weight_hd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
|
|
||||||
conv_weight_hd)
|
|
||||||
return conv_weight_hd, self.conv.bias
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d_vd(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
|
||||||
padding=1, dilation=1, groups=1, bias=False):
|
|
||||||
super(Conv2d_vd, self).__init__()
|
|
||||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, groups=groups, bias=bias)
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
conv_weight = self.conv.weight
|
|
||||||
conv_shape = conv_weight.shape
|
|
||||||
# conv_weight_vd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_vd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_vd[:, :, [0, 1, 2]] = conv_weight[:, :, :]
|
|
||||||
conv_weight_vd[:, :, [6, 7, 8]] = -conv_weight[:, :, :]
|
|
||||||
conv_weight_vd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
|
|
||||||
conv_weight_vd)
|
|
||||||
return conv_weight_vd, self.conv.bias
|
|
||||||
|
|
||||||
|
|
||||||
class DEConv(nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super(DEConv, self).__init__()
|
|
||||||
self.conv1_1 = Conv2d_cd(dim, dim, 3, bias=True)
|
|
||||||
self.conv1_2 = Conv2d_hd(dim, dim, 3, bias=True)
|
|
||||||
self.conv1_3 = Conv2d_vd(dim, dim, 3, bias=True)
|
|
||||||
self.conv1_4 = Conv2d_ad(dim, dim, 3, bias=True)
|
|
||||||
self.conv1_5 = nn.Conv2d(dim, dim, 3, padding=1, bias=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
w1, b1 = self.conv1_1.get_weight()
|
|
||||||
w2, b2 = self.conv1_2.get_weight()
|
|
||||||
w3, b3 = self.conv1_3.get_weight()
|
|
||||||
w4, b4 = self.conv1_4.get_weight()
|
|
||||||
w5, b5 = self.conv1_5.weight, self.conv1_5.bias
|
|
||||||
|
|
||||||
w = w1 + w2 + w3 + w4 + w5
|
|
||||||
b = b1 + b2 + b3 + b4 + b5
|
|
||||||
res = nn.functional.conv2d(input=x, weight=w, bias=b, stride=1, padding=1, groups=1)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# 初始化DEConv模块,dim为输入和输出的通道数
|
|
||||||
block = DEConv(dim=16)
|
|
||||||
# 创建一个随机输入张量,假设输入尺寸为(batch_size, channels, height, width)
|
|
||||||
input_tensor = torch.rand(4, 16, 64, 64)
|
|
||||||
# 将输入传递给DEConv模块
|
|
||||||
output_tensor = block(input_tensor)
|
|
||||||
# 打印输入和输出张量的尺寸
|
|
||||||
print("输入尺寸:", input_tensor.size())
|
|
||||||
print("输出尺寸:", output_tensor.size())
|
|
||||||
|
|
||||||
|
|
@ -1,116 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from timm.layers.helpers import to_2tuple
|
|
||||||
|
|
||||||
"""
|
|
||||||
配备多头自注意力 (MHSA) 的模型在计算机视觉方面取得了显着的性能。它们的计算复杂度与输入特征图中的二次像素数成正比,导致处理速度缓慢,尤其是在处理高分辨率图像时。
|
|
||||||
为了规避这个问题,提出了一种新型的代币混合器作为MHSA的替代方案:基于FFT的代币混合器涉及类似于MHSA的全局操作,但计算复杂度较低。
|
|
||||||
在这里,我们提出了一种名为动态过滤器的新型令牌混合器以缩小上述差距。
|
|
||||||
DynamicFilter 模块通过频域滤波和动态调整滤波器权重,能够对图像进行复杂的增强和处理。
|
|
||||||
"""
|
|
||||||
|
|
||||||
class StarReLU(nn.Module):
|
|
||||||
"""
|
|
||||||
StarReLU: s * relu(x) ** 2 + b
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, scale_value=1.0, bias_value=0.0,
|
|
||||||
scale_learnable=True, bias_learnable=True,
|
|
||||||
mode=None, inplace=False):
|
|
||||||
super().__init__()
|
|
||||||
self.inplace = inplace
|
|
||||||
self.relu = nn.ReLU(inplace=inplace)
|
|
||||||
self.scale = nn.Parameter(scale_value * torch.ones(1),
|
|
||||||
requires_grad=scale_learnable)
|
|
||||||
self.bias = nn.Parameter(bias_value * torch.ones(1),
|
|
||||||
requires_grad=bias_learnable)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.scale * self.relu(x) ** 2 + self.bias
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
|
||||||
""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
|
|
||||||
Mostly copied from timm.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,
|
|
||||||
bias=False, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
in_features = dim
|
|
||||||
out_features = out_features or in_features
|
|
||||||
hidden_features = int(mlp_ratio * in_features)
|
|
||||||
drop_probs = to_2tuple(drop)
|
|
||||||
|
|
||||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
|
||||||
self.act = act_layer()
|
|
||||||
self.drop1 = nn.Dropout(drop_probs[0])
|
|
||||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
|
||||||
self.drop2 = nn.Dropout(drop_probs[1])
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.drop1(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = self.drop2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicFilter(nn.Module):
|
|
||||||
def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,
|
|
||||||
act1_layer=StarReLU, act2_layer=nn.Identity,
|
|
||||||
bias=False, num_filters=4, size=14, weight_resize=False,
|
|
||||||
**kwargs):
|
|
||||||
super().__init__()
|
|
||||||
size = to_2tuple(size)
|
|
||||||
self.size = size[0]
|
|
||||||
self.filter_size = size[1] // 2 + 1
|
|
||||||
self.num_filters = num_filters
|
|
||||||
self.dim = dim
|
|
||||||
self.med_channels = int(expansion_ratio * dim)
|
|
||||||
self.weight_resize = weight_resize
|
|
||||||
self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)
|
|
||||||
self.act1 = act1_layer()
|
|
||||||
self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)
|
|
||||||
self.complex_weights = nn.Parameter(
|
|
||||||
torch.randn(self.size, self.filter_size, num_filters, 2,
|
|
||||||
dtype=torch.float32) * 0.02)
|
|
||||||
self.act2 = act2_layer()
|
|
||||||
self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
B, H, W, _ = x.shape
|
|
||||||
|
|
||||||
routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,
|
|
||||||
-1).softmax(dim=1)
|
|
||||||
x = self.pwconv1(x)
|
|
||||||
x = self.act1(x)
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
|
|
||||||
|
|
||||||
if self.weight_resize:
|
|
||||||
complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],
|
|
||||||
x.shape[2])
|
|
||||||
complex_weights = torch.view_as_complex(complex_weights.contiguous())
|
|
||||||
else:
|
|
||||||
complex_weights = torch.view_as_complex(self.complex_weights)
|
|
||||||
routeing = routeing.to(torch.complex64)
|
|
||||||
weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)
|
|
||||||
if self.weight_resize:
|
|
||||||
weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)
|
|
||||||
else:
|
|
||||||
weight = weight.view(-1, self.size, self.filter_size, self.med_channels)
|
|
||||||
x = x * weight
|
|
||||||
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
|
|
||||||
|
|
||||||
x = self.act2(x)
|
|
||||||
x = self.pwconv2(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
block = DynamicFilter(32, size=64) # size==H,W
|
|
||||||
input = torch.rand(3, 64, 64, 32)
|
|
||||||
output = block(input)
|
|
||||||
print(input.size())
|
|
||||||
print(output.size())
|
|
@ -1,156 +0,0 @@
|
|||||||
import typing as t
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from einops.einops import rearrange
|
|
||||||
from mmengine.model import BaseModule
|
|
||||||
__all__ = ['SCSA']
|
|
||||||
|
|
||||||
"""SCSA:探索空间注意力和通道注意力之间的协同作用
|
|
||||||
通道和空间注意力分别在为各种下游视觉任务提取特征依赖性和空间结构关系方面带来了显着的改进。
|
|
||||||
虽然它们的结合更有利于发挥各自的优势,但通道和空间注意力之间的协同作用尚未得到充分探索,缺乏充分利用多语义信息的协同潜力来进行特征引导和缓解语义差异。
|
|
||||||
我们的研究试图在多个语义层面揭示空间和通道注意力之间的协同关系,提出了一种新颖的空间和通道协同注意力模块(SCSA)。我们的SCSA由两部分组成:可共享的多语义空间注意力(SMSA)和渐进式通道自注意力(PCSA)。
|
|
||||||
SMSA 集成多语义信息并利用渐进式压缩策略将判别性空间先验注入 PCSA 的通道自注意力中,有效地指导通道重新校准。此外,PCSA 中基于自注意力机制的稳健特征交互进一步缓解了 SMSA 中不同子特征之间多语义信息的差异。
|
|
||||||
我们在七个基准数据集上进行了广泛的实验,包括 ImageNet-1K 上的分类、MSCOCO 2017 上的对象检测、ADE20K 上的分割以及其他四个复杂场景检测数据集。我们的结果表明,我们提出的 SCSA 不仅超越了当前最先进的注意力机制,
|
|
||||||
而且在各种任务场景中表现出增强的泛化能力。
|
|
||||||
"""
|
|
||||||
|
|
||||||
class SCSA(BaseModule):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
head_num: int,
|
|
||||||
window_size: int = 7,
|
|
||||||
group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
|
|
||||||
qkv_bias: bool = False,
|
|
||||||
fuse_bn: bool = False,
|
|
||||||
norm_cfg: t.Dict = dict(type='BN'),
|
|
||||||
act_cfg: t.Dict = dict(type='ReLU'),
|
|
||||||
down_sample_mode: str = 'avg_pool',
|
|
||||||
attn_drop_ratio: float = 0.,
|
|
||||||
gate_layer: str = 'sigmoid',
|
|
||||||
):
|
|
||||||
super(SCSA, self).__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.head_num = head_num
|
|
||||||
self.head_dim = dim // head_num
|
|
||||||
self.scaler = self.head_dim ** -0.5
|
|
||||||
self.group_kernel_sizes = group_kernel_sizes
|
|
||||||
self.window_size = window_size
|
|
||||||
self.qkv_bias = qkv_bias
|
|
||||||
self.fuse_bn = fuse_bn
|
|
||||||
self.down_sample_mode = down_sample_mode
|
|
||||||
|
|
||||||
assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
|
|
||||||
self.group_chans = group_chans = self.dim // 4
|
|
||||||
|
|
||||||
self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
|
|
||||||
padding=group_kernel_sizes[0] // 2, groups=group_chans)
|
|
||||||
self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
|
|
||||||
padding=group_kernel_sizes[1] // 2, groups=group_chans)
|
|
||||||
self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
|
|
||||||
padding=group_kernel_sizes[2] // 2, groups=group_chans)
|
|
||||||
self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
|
|
||||||
padding=group_kernel_sizes[3] // 2, groups=group_chans)
|
|
||||||
self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
|
|
||||||
self.norm_h = nn.GroupNorm(4, dim)
|
|
||||||
self.norm_w = nn.GroupNorm(4, dim)
|
|
||||||
|
|
||||||
self.conv_d = nn.Identity()
|
|
||||||
self.norm = nn.GroupNorm(1, dim)
|
|
||||||
self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
|
||||||
self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
|
||||||
self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop_ratio)
|
|
||||||
self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()
|
|
||||||
|
|
||||||
if window_size == -1:
|
|
||||||
self.down_func = nn.AdaptiveAvgPool2d((1, 1))
|
|
||||||
else:
|
|
||||||
if down_sample_mode == 'recombination':
|
|
||||||
self.down_func = self.space_to_chans
|
|
||||||
# dimensionality reduction
|
|
||||||
self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
|
|
||||||
elif down_sample_mode == 'avg_pool':
|
|
||||||
self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
|
|
||||||
elif down_sample_mode == 'max_pool':
|
|
||||||
self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
The dim of x is (B, C, H, W)
|
|
||||||
"""
|
|
||||||
# Spatial attention priority calculation
|
|
||||||
b, c, h_, w_ = x.size()
|
|
||||||
# (B, C, H)
|
|
||||||
x_h = x.mean(dim=3)
|
|
||||||
l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
|
|
||||||
# (B, C, W)
|
|
||||||
x_w = x.mean(dim=2)
|
|
||||||
l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)
|
|
||||||
|
|
||||||
x_h_attn = self.sa_gate(self.norm_h(torch.cat((
|
|
||||||
self.local_dwc(l_x_h),
|
|
||||||
self.global_dwc_s(g_x_h_s),
|
|
||||||
self.global_dwc_m(g_x_h_m),
|
|
||||||
self.global_dwc_l(g_x_h_l),
|
|
||||||
), dim=1)))
|
|
||||||
x_h_attn = x_h_attn.view(b, c, h_, 1)
|
|
||||||
|
|
||||||
x_w_attn = self.sa_gate(self.norm_w(torch.cat((
|
|
||||||
self.local_dwc(l_x_w),
|
|
||||||
self.global_dwc_s(g_x_w_s),
|
|
||||||
self.global_dwc_m(g_x_w_m),
|
|
||||||
self.global_dwc_l(g_x_w_l)
|
|
||||||
), dim=1)))
|
|
||||||
x_w_attn = x_w_attn.view(b, c, 1, w_)
|
|
||||||
|
|
||||||
x = x * x_h_attn * x_w_attn
|
|
||||||
|
|
||||||
# Channel attention based on self attention
|
|
||||||
# reduce calculations
|
|
||||||
y = self.down_func(x)
|
|
||||||
y = self.conv_d(y)
|
|
||||||
_, _, h_, w_ = y.size()
|
|
||||||
|
|
||||||
# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
|
|
||||||
y = self.norm(y)
|
|
||||||
q = self.q(y)
|
|
||||||
k = self.k(y)
|
|
||||||
v = self.v(y)
|
|
||||||
# (B, C, H, W) -> (B, head_num, head_dim, N)
|
|
||||||
q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
|
||||||
head_dim=int(self.head_dim))
|
|
||||||
k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
|
||||||
head_dim=int(self.head_dim))
|
|
||||||
v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
|
||||||
head_dim=int(self.head_dim))
|
|
||||||
|
|
||||||
# (B, head_num, head_dim, head_dim)
|
|
||||||
attn = q @ k.transpose(-2, -1) * self.scaler
|
|
||||||
attn = self.attn_drop(attn.softmax(dim=-1))
|
|
||||||
# (B, head_num, head_dim, N)
|
|
||||||
attn = attn @ v
|
|
||||||
# (B, C, H_, W_)
|
|
||||||
attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
|
|
||||||
# (B, C, 1, 1)
|
|
||||||
attn = attn.mean((2, 3), keepdim=True)
|
|
||||||
attn = self.ca_gate(attn)
|
|
||||||
return attn * x
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
block = SCSA(
|
|
||||||
dim=256,
|
|
||||||
head_num=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_tensor = torch.rand(1, 256, 32, 32)
|
|
||||||
|
|
||||||
# 调用模块进行前向传播
|
|
||||||
output_tensor = block(input_tensor)
|
|
||||||
|
|
||||||
# 打印输入和输出张量的大小
|
|
||||||
print("Input size:", input_tensor.size())
|
|
||||||
print("Output size:", output_tensor.size())
|
|
@ -1,37 +0,0 @@
|
|||||||
'''-------------一、SE模块-----------------------------'''
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
|
|
||||||
class SE_Block(nn.Module):
|
|
||||||
def __init__(self, inchannel, ratio=16):
|
|
||||||
super(SE_Block, self).__init__()
|
|
||||||
# 全局平均池化(Fsq操作)
|
|
||||||
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
|
||||||
# 两个全连接层(Fex操作)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# 读取批数据图片数量及通道数
|
|
||||||
b, c, h, w = x.size()
|
|
||||||
# Fsq操作:经池化后输出b*c的矩阵
|
|
||||||
y = self.gap(x).view(b, c)
|
|
||||||
# Fex操作:经全连接层输出(b,c,1,1)矩阵
|
|
||||||
y = self.fc(y).view(b, c, 1, 1)
|
|
||||||
# Fscale操作:将得到的权重乘以原来的特征图x
|
|
||||||
return x * y.expand_as(x)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
input = torch.randn(1, 64, 32, 32)
|
|
||||||
seblock = SE_Block(64)
|
|
||||||
print(seblock)
|
|
||||||
output = seblock(input)
|
|
||||||
print(input.shape)
|
|
||||||
print(output.shape)
|
|
||||||
|
|
@ -1,65 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
"""ECCV2024(https://github.com/Zheng-MJ/SMFANet)
|
|
||||||
基于Transformer的恢复方法取得了显著的效果,因为Transformer的自注意力机制(SA)可以探索非局部信息,从而实现更好的高分辨率图像重建。然而,关键的点积自注意力需要大量的计算资源,这限制了其在低功耗设备上的应用。
|
|
||||||
此外,自注意力机制的低通滤波特性限制了其捕捉局部细节的能力,从而导致重建结果过于平滑。为了解决这些问题,我们提出了一种自调制特征聚合(SMFA)模块,协同利用局部和非局部特征交互,以实现更精确的重建。
|
|
||||||
具体而言,SMFA模块采用了高效的自注意力近似(EASA)分支来建模非局部信息,并使用局部细节估计(LDE)分支来捕捉局部细节。此外,我们还引入了基于部分卷积的前馈网络(PCFN),以进一步优化从SMFA提取的代表性特征。
|
|
||||||
大量实验表明,所提出的SMFANet系列在公共基准数据集上实现了更好的重建性能与计算效率的平衡。
|
|
||||||
特别是,与SwinIR-light的×4放大相比,SMFANet+在五个公共测试集上的平均性能提高了0.14dB,运行速度提升了约10倍,且模型复杂度(如FLOPs)仅为其约43%。
|
|
||||||
"""
|
|
||||||
|
|
||||||
class DMlp(nn.Module):
|
|
||||||
def __init__(self, dim, growth_rate=2.0):
|
|
||||||
super().__init__()
|
|
||||||
hidden_dim = int(dim * growth_rate)
|
|
||||||
self.conv_0 = nn.Sequential(
|
|
||||||
nn.Conv2d(dim, hidden_dim, 3, 1, 1, groups=dim),
|
|
||||||
nn.Conv2d(hidden_dim, hidden_dim, 1, 1, 0)
|
|
||||||
)
|
|
||||||
self.act = nn.GELU()
|
|
||||||
self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.conv_0(x)
|
|
||||||
x = self.act(x)
|
|
||||||
x = self.conv_1(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class SMFA(nn.Module):
|
|
||||||
def __init__(self, dim=36):
|
|
||||||
super(SMFA, self).__init__()
|
|
||||||
self.linear_0 = nn.Conv2d(dim, dim * 2, 1, 1, 0)
|
|
||||||
self.linear_1 = nn.Conv2d(dim, dim, 1, 1, 0)
|
|
||||||
self.linear_2 = nn.Conv2d(dim, dim, 1, 1, 0)
|
|
||||||
|
|
||||||
self.lde = DMlp(dim, 2)
|
|
||||||
|
|
||||||
self.dw_conv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)
|
|
||||||
|
|
||||||
self.gelu = nn.GELU()
|
|
||||||
self.down_scale = 8
|
|
||||||
|
|
||||||
self.alpha = nn.Parameter(torch.ones((1, dim, 1, 1)))
|
|
||||||
self.belt = nn.Parameter(torch.zeros((1, dim, 1, 1)))
|
|
||||||
|
|
||||||
def forward(self, f):
|
|
||||||
_, _, h, w = f.shape
|
|
||||||
y, x = self.linear_0(f).chunk(2, dim=1)
|
|
||||||
x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
|
|
||||||
x_v = torch.var(x, dim=(-2, -1), keepdim=True)
|
|
||||||
x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h, w),
|
|
||||||
mode='nearest')
|
|
||||||
y_d = self.lde(y)
|
|
||||||
return self.linear_2(x_l + y_d)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
block = SMFA(dim=36)
|
|
||||||
input = torch.randn(3, 36, 64, 64)
|
|
||||||
output = block(input)
|
|
||||||
print(input.size())
|
|
||||||
print(output.size())
|
|
@ -1,110 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
"""Elsevier2024
|
|
||||||
变化检测 (CD) 是地球观测中一种重要的监测方法,尤其适用于土地利用分析、城市管理和灾害损失评估。然而,在星座互联和空天协作时代,感兴趣区域 (ROI) 的变化由于几何透视旋转和时间风格差异而导致许多错误检测。
|
|
||||||
为了应对这些挑战,我们引入了 CDNeXt,该框架阐明了一种稳健而有效的方法,用于将基于预训练主干的 Siamese 网络与用于遥感图像的创新时空交互注意模块 (TIAM) 相结合。
|
|
||||||
CDNeXt 可分为四个主要组件:编码器、交互器、解码器和检测器。值得注意的是,由 TIAM 提供支持的交互器从编码器提取的二进制时间特征中查询和重建空间透视依赖关系和时间风格相关性,以扩大 ROI 变化的差异。
|
|
||||||
最后,检测器集成解码器生成的分层特征,随后生成二进制变化掩码。
|
|
||||||
"""
|
|
||||||
|
|
||||||
class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
|
|
||||||
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
|
|
||||||
super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
|
|
||||||
assert dimension in [2, ]
|
|
||||||
self.dimension = dimension
|
|
||||||
self.sub_sample = sub_sample
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.inter_channels = inter_channels
|
|
||||||
|
|
||||||
if self.inter_channels is None:
|
|
||||||
self.inter_channels = in_channels // 2
|
|
||||||
if self.inter_channels == 0:
|
|
||||||
self.inter_channels = 1
|
|
||||||
|
|
||||||
self.g1 = nn.Sequential(
|
|
||||||
nn.BatchNorm2d(self.in_channels),
|
|
||||||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0)
|
|
||||||
)
|
|
||||||
self.g2 = nn.Sequential(
|
|
||||||
nn.BatchNorm2d(self.in_channels),
|
|
||||||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.W1 = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.BatchNorm2d(self.in_channels)
|
|
||||||
)
|
|
||||||
self.W2 = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.BatchNorm2d(self.in_channels)
|
|
||||||
)
|
|
||||||
self.theta = nn.Sequential(
|
|
||||||
nn.BatchNorm2d(self.in_channels),
|
|
||||||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0),
|
|
||||||
)
|
|
||||||
self.phi = nn.Sequential(
|
|
||||||
nn.BatchNorm2d(self.in_channels),
|
|
||||||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x1, x2):
|
|
||||||
"""
|
|
||||||
:param x: (b, c, h, w)
|
|
||||||
:param return_nl_map: if True return z, nl_map, else only return z.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
batch_size = x1.size(0)
|
|
||||||
g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
|
|
||||||
g_x12 = g_x11.permute(0, 2, 1)
|
|
||||||
g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
|
|
||||||
g_x22 = g_x21.permute(0, 2, 1)
|
|
||||||
|
|
||||||
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
|
|
||||||
theta_x2 = theta_x1.permute(0, 2, 1)
|
|
||||||
|
|
||||||
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
|
|
||||||
phi_x2 = phi_x1.permute(0, 2, 1)
|
|
||||||
|
|
||||||
energy_time_1 = torch.matmul(theta_x1, phi_x2)
|
|
||||||
energy_time_2 = energy_time_1.permute(0, 2, 1)
|
|
||||||
energy_space_1 = torch.matmul(theta_x2, phi_x1)
|
|
||||||
energy_space_2 = energy_space_1.permute(0, 2, 1)
|
|
||||||
|
|
||||||
energy_time_1s = F.softmax(energy_time_1, dim=-1)
|
|
||||||
energy_time_2s = F.softmax(energy_time_2, dim=-1)
|
|
||||||
energy_space_2s = F.softmax(energy_space_1, dim=-2)
|
|
||||||
energy_space_1s = F.softmax(energy_space_2, dim=-2)
|
|
||||||
# C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
|
|
||||||
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
|
|
||||||
# C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
|
|
||||||
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
|
|
||||||
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
|
|
||||||
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
|
|
||||||
return x1 + self.W1(y1), x2 + self.W2(y2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
in_channels = 64
|
|
||||||
batch_size = 8
|
|
||||||
height = 32
|
|
||||||
width = 32
|
|
||||||
|
|
||||||
block = SpatiotemporalAttentionFullNotWeightShared(in_channels=in_channels)
|
|
||||||
|
|
||||||
input1 = torch.rand(batch_size, in_channels, height, width)
|
|
||||||
input2 = torch.rand(batch_size, in_channels, height, width)
|
|
||||||
|
|
||||||
output1, output2 = block(input1, input2)
|
|
||||||
|
|
||||||
print(f"Input1 size: {input1.size()}")
|
|
||||||
print(f"Input2 size: {input2.size()}")
|
|
||||||
print(f"Output1 size: {output1.size()}")
|
|
||||||
print(f"Output2 size: {output2.size()}")
|
|
@ -1,110 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
"""Elsevier2024
|
|
||||||
变化检测 (CD) 是地球观测中一种重要的监测方法,尤其适用于土地利用分析、城市管理和灾害损失评估。然而,在星座互联和空天协作时代,感兴趣区域 (ROI) 的变化由于几何透视旋转和时间风格差异而导致许多错误检测。
|
|
||||||
为了应对这些挑战,我们引入了 CDNeXt,该框架阐明了一种稳健而有效的方法,用于将基于预训练主干的 Siamese 网络与用于遥感图像的创新时空交互注意模块 (TIAM) 相结合。
|
|
||||||
CDNeXt 可分为四个主要组件:编码器、交互器、解码器和检测器。值得注意的是,由 TIAM 提供支持的交互器从编码器提取的二进制时间特征中查询和重建空间透视依赖关系和时间风格相关性,以扩大 ROI 变化的差异。
|
|
||||||
最后,检测器集成解码器生成的分层特征,随后生成二进制变化掩码。
|
|
||||||
"""
|
|
||||||
|
|
||||||
class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
|
|
||||||
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
|
|
||||||
super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
|
|
||||||
assert dimension in [2, ]
|
|
||||||
self.dimension = dimension
|
|
||||||
self.sub_sample = sub_sample
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.inter_channels = inter_channels
|
|
||||||
|
|
||||||
if self.inter_channels is None:
|
|
||||||
self.inter_channels = in_channels // 2
|
|
||||||
if self.inter_channels == 0:
|
|
||||||
self.inter_channels = 1
|
|
||||||
|
|
||||||
self.g1 = nn.Sequential(
|
|
||||||
nn.BatchNorm2d(self.in_channels),
|
|
||||||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0)
|
|
||||||
)
|
|
||||||
self.g2 = nn.Sequential(
|
|
||||||
nn.BatchNorm2d(self.in_channels),
|
|
||||||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.W1 = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.BatchNorm2d(self.in_channels)
|
|
||||||
)
|
|
||||||
self.W2 = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0),
|
|
||||||
nn.BatchNorm2d(self.in_channels)
|
|
||||||
)
|
|
||||||
self.theta = nn.Sequential(
|
|
||||||
nn.BatchNorm2d(self.in_channels),
|
|
||||||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0),
|
|
||||||
)
|
|
||||||
self.phi = nn.Sequential(
|
|
||||||
nn.BatchNorm2d(self.in_channels),
|
|
||||||
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
|
|
||||||
kernel_size=1, stride=1, padding=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x1, x2):
|
|
||||||
"""
|
|
||||||
:param x: (b, c, h, w)
|
|
||||||
:param return_nl_map: if True return z, nl_map, else only return z.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
batch_size = x1.size(0)
|
|
||||||
g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
|
|
||||||
g_x12 = g_x11.permute(0, 2, 1)
|
|
||||||
g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
|
|
||||||
g_x22 = g_x21.permute(0, 2, 1)
|
|
||||||
|
|
||||||
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
|
|
||||||
theta_x2 = theta_x1.permute(0, 2, 1)
|
|
||||||
|
|
||||||
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
|
|
||||||
phi_x2 = phi_x1.permute(0, 2, 1)
|
|
||||||
|
|
||||||
energy_time_1 = torch.matmul(theta_x1, phi_x2)
|
|
||||||
energy_time_2 = energy_time_1.permute(0, 2, 1)
|
|
||||||
energy_space_1 = torch.matmul(theta_x2, phi_x1)
|
|
||||||
energy_space_2 = energy_space_1.permute(0, 2, 1)
|
|
||||||
|
|
||||||
energy_time_1s = F.softmax(energy_time_1, dim=-1)
|
|
||||||
energy_time_2s = F.softmax(energy_time_2, dim=-1)
|
|
||||||
energy_space_2s = F.softmax(energy_space_1, dim=-2)
|
|
||||||
energy_space_1s = F.softmax(energy_space_2, dim=-2)
|
|
||||||
# C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
|
|
||||||
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
|
|
||||||
# C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
|
|
||||||
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
|
|
||||||
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
|
|
||||||
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
|
|
||||||
return x1 + self.W1(y1), x2 + self.W2(y2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
in_channels = 64
|
|
||||||
batch_size = 8
|
|
||||||
height = 32
|
|
||||||
width = 32
|
|
||||||
|
|
||||||
block = SpatiotemporalAttentionFullNotWeightShared(in_channels=in_channels)
|
|
||||||
|
|
||||||
input1 = torch.rand(batch_size, in_channels, height, width)
|
|
||||||
input2 = torch.rand(batch_size, in_channels, height, width)
|
|
||||||
|
|
||||||
output1, output2 = block(input1, input2)
|
|
||||||
|
|
||||||
print(f"Input1 size: {input1.size()}")
|
|
||||||
print(f"Input2 size: {input2.size()}")
|
|
||||||
print(f"Output1 size: {output1.size()}")
|
|
||||||
print(f"Output2 size: {output2.size()}")
|
|
@ -1,123 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
"""ICCV2023
|
|
||||||
最近提出的图像修复方法 LaMa 以快速傅里叶卷积 (FFC) 为基础构建了其网络,该网络最初是为图像分类等高级视觉任务而提出的。
|
|
||||||
FFC 使全卷积网络在其早期层中拥有全局感受野。得益于 FFC 模块的独特特性,LaMa 能够生成稳健的重复纹理,
|
|
||||||
这是以前的修复方法无法实现的。但是,原始 FFC 模块是否适合图像修复等低级视觉任务?
|
|
||||||
在本文中,我们分析了在图像修复中使用 FFC 的基本缺陷,即 1) 频谱偏移、2) 意外的空间激活和 3) 频率感受野有限。
|
|
||||||
这些缺陷使得基于 FFC 的修复框架难以生成复杂纹理并执行完美重建。
|
|
||||||
基于以上分析,我们提出了一种新颖的无偏快速傅里叶卷积 (UFFC) 模块,该模块通过
|
|
||||||
1) 范围变换和逆变换、2) 绝对位置嵌入、3) 动态跳过连接和 4) 自适应剪辑对原始 FFC 模块进行了修改,以克服这些缺陷,
|
|
||||||
实现更好的修复效果。在多个基准数据集上进行的大量实验证明了我们方法的有效性,在纹理捕捉能力和表现力方面均优于最先进的方法。
|
|
||||||
"""
|
|
||||||
|
|
||||||
class FourierUnit_modified(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear',
|
|
||||||
spectral_pos_encoding=False, use_se=False, ffc3d=False, fft_norm='ortho'):
|
|
||||||
# bn_layer not used
|
|
||||||
super(FourierUnit_modified, self).__init__()
|
|
||||||
self.groups = groups
|
|
||||||
|
|
||||||
self.input_shape = 32 # change!!!!!it!!!!!!manually!!!!!!
|
|
||||||
self.in_channels = in_channels
|
|
||||||
|
|
||||||
self.locMap = nn.Parameter(torch.rand(self.input_shape, self.input_shape // 2 + 1))
|
|
||||||
|
|
||||||
self.lambda_base = nn.Parameter(torch.tensor(0.), requires_grad=True)
|
|
||||||
|
|
||||||
self.conv_layer_down55 = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
|
|
||||||
out_channels=out_channels * 2,
|
|
||||||
kernel_size=1, stride=1, padding=0, dilation=1, groups=self.groups,
|
|
||||||
bias=False, padding_mode='reflect')
|
|
||||||
self.conv_layer_down55_shift = torch.nn.Conv2d(in_channels=in_channels * 2 + 1, # +1 for locmap
|
|
||||||
out_channels=out_channels * 2,
|
|
||||||
kernel_size=3, stride=1, padding=2, dilation=2,
|
|
||||||
groups=self.groups, bias=False, padding_mode='reflect')
|
|
||||||
|
|
||||||
self.norm = nn.BatchNorm2d(out_channels)
|
|
||||||
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
self.spatial_scale_factor = spatial_scale_factor
|
|
||||||
self.spatial_scale_mode = spatial_scale_mode
|
|
||||||
self.spectral_pos_encoding = spectral_pos_encoding
|
|
||||||
self.ffc3d = ffc3d
|
|
||||||
self.fft_norm = fft_norm
|
|
||||||
|
|
||||||
self.img_freq = None
|
|
||||||
self.distill = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
batch = x.shape[0]
|
|
||||||
|
|
||||||
if self.spatial_scale_factor is not None:
|
|
||||||
orig_size = x.shape[-2:]
|
|
||||||
x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode,
|
|
||||||
align_corners=False)
|
|
||||||
|
|
||||||
fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
|
|
||||||
ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
|
|
||||||
ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
|
|
||||||
ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
|
|
||||||
ffted = ffted.view((batch, -1,) + ffted.size()[3:])
|
|
||||||
|
|
||||||
locMap = self.locMap.expand_as(ffted[:, :1, :, :]) # B 1 H' W'
|
|
||||||
ffted_copy = ffted.clone()
|
|
||||||
|
|
||||||
cat_img_mask_freq = torch.cat((ffted[:, :self.in_channels, :, :],
|
|
||||||
ffted[:, self.in_channels:, :, :],
|
|
||||||
locMap), dim=1)
|
|
||||||
|
|
||||||
ffted = self.conv_layer_down55(cat_img_mask_freq)
|
|
||||||
ffted = torch.fft.fftshift(ffted, dim=-2)
|
|
||||||
|
|
||||||
ffted = self.relu(ffted)
|
|
||||||
|
|
||||||
locMap_shift = torch.fft.fftshift(locMap, dim=-2) ## ONLY IF NOT SHIFT BACK
|
|
||||||
|
|
||||||
# REPEAT CONV
|
|
||||||
cat_img_mask_freq1 = torch.cat((ffted[:, :self.in_channels, :, :],
|
|
||||||
ffted[:, self.in_channels:, :, :],
|
|
||||||
locMap_shift), dim=1)
|
|
||||||
|
|
||||||
ffted = self.conv_layer_down55_shift(cat_img_mask_freq1)
|
|
||||||
ffted = torch.fft.fftshift(ffted, dim=-2)
|
|
||||||
|
|
||||||
lambda_base = torch.sigmoid(self.lambda_base)
|
|
||||||
|
|
||||||
ffted = ffted_copy * lambda_base + ffted * (1 - lambda_base)
|
|
||||||
|
|
||||||
# irfft
|
|
||||||
ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(
|
|
||||||
0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2)
|
|
||||||
ffted = torch.complex(ffted[..., 0], ffted[..., 1])
|
|
||||||
|
|
||||||
ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
|
|
||||||
output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)
|
|
||||||
|
|
||||||
if self.spatial_scale_factor is not None:
|
|
||||||
output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False)
|
|
||||||
|
|
||||||
epsilon = 0.5
|
|
||||||
output = output - torch.mean(output) + torch.mean(x)
|
|
||||||
output = torch.clip(output, float(x.min() - epsilon), float(x.max() + epsilon))
|
|
||||||
|
|
||||||
self.distill = output # for self perc
|
|
||||||
return output
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
in_channels = 16
|
|
||||||
out_channels = 16
|
|
||||||
|
|
||||||
block = FourierUnit_modified(in_channels=in_channels, out_channels=out_channels)
|
|
||||||
|
|
||||||
input_tensor = torch.rand(8, in_channels, 32, 32)
|
|
||||||
|
|
||||||
output = block(input_tensor)
|
|
||||||
|
|
||||||
print("Input size:", input_tensor.size())
|
|
||||||
print("Output size:", output.size())
|
|
@ -1,42 +0,0 @@
|
|||||||
import os
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def transfer(input_path, quality=20, resize_factor=0.1):
|
|
||||||
# 打开TIFF图像
|
|
||||||
# img = Image.open(input_path)
|
|
||||||
#
|
|
||||||
# # 保存为JPEG,并设置压缩质量
|
|
||||||
# img.save(output_path, 'JPEG', quality=quality)
|
|
||||||
|
|
||||||
# input_path = os.path.join(input_folder, filename)
|
|
||||||
# 获取input_path的文件名
|
|
||||||
|
|
||||||
# 使用os.path.splitext获取文件名和后缀的元组
|
|
||||||
# 使用os.path.basename获取文件名(包含后缀)
|
|
||||||
filename_with_extension = os.path.basename(input_path)
|
|
||||||
filename, file_extension = os.path.splitext(filename_with_extension)
|
|
||||||
|
|
||||||
# 使用os.path.dirname获取文件所在的目录路径
|
|
||||||
output_folder = os.path.dirname(input_path)
|
|
||||||
|
|
||||||
output_path = os.path.join(output_folder, filename + '.jpg')
|
|
||||||
|
|
||||||
img = Image.open(input_path)
|
|
||||||
|
|
||||||
# 将图像缩小到原来的一半
|
|
||||||
new_width = int(img.width * resize_factor)
|
|
||||||
new_height = int(img.height * resize_factor)
|
|
||||||
resized_img = img.resize((new_width, new_height))
|
|
||||||
|
|
||||||
# 保存为JPEG,并设置压缩质量
|
|
||||||
# 转换为RGB模式,丢弃透明通道
|
|
||||||
rgb_img = resized_img.convert('RGB')
|
|
||||||
|
|
||||||
# 保存为JPEG,并设置压缩质量
|
|
||||||
# 压缩
|
|
||||||
rgb_img.save(output_path, 'JPEG', quality=quality)
|
|
||||||
|
|
||||||
print(f'{output_path} 转换完成')
|
|
||||||
|
|
||||||
return output_path
|
|
@ -1,35 +0,0 @@
|
|||||||
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
|
|
||||||
|
|
||||||
# base pcffuse
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 2.39 33.82 11.32 0.81 0.8 0.12 0.07 0.11
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
Process finished with exit code 0
|
|
@ -1,33 +0,0 @@
|
|||||||
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.01 40.67 15.39 1.53 1.76 0.64 0.53 0.95
|
|
||||||
================================================================================
|
|
@ -1,89 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.14 46.48 13.18 2.22 1.76 0.79 0.56 1.02
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of RoadScene :
|
|
||||||
FLIR_07206.jpg
|
|
||||||
FLIR_08202.jpg
|
|
||||||
FLIR_05893.jpg
|
|
||||||
FLIR_06974.jpg
|
|
||||||
FLIR_04424.jpg
|
|
||||||
FLIR_08284.jpg
|
|
||||||
FLIR_07786.jpg
|
|
||||||
FLIR_08021.jpg
|
|
||||||
FLIR_07968.jpg
|
|
||||||
FLIR_01130.jpg
|
|
||||||
FLIR_06993.jpg
|
|
||||||
FLIR_07190.jpg
|
|
||||||
FLIR_06570.jpg
|
|
||||||
FLIR_07809.jpg
|
|
||||||
FLIR_06430.jpg
|
|
||||||
FLIR_08592.jpg
|
|
||||||
FLIR_00211.jpg
|
|
||||||
FLIR_08721.jpg
|
|
||||||
FLIR_05955.jpg
|
|
||||||
FLIR_04688.jpg
|
|
||||||
FLIR_07732.jpg
|
|
||||||
FLIR_06392.jpg
|
|
||||||
FLIR_00977.jpg
|
|
||||||
FLIR_05105.jpg
|
|
||||||
FLIR_04269.jpg
|
|
||||||
FLIR_07970.jpg
|
|
||||||
FLIR_05005.jpg
|
|
||||||
FLIR_07209.jpg
|
|
||||||
FLIR_07555.jpg
|
|
||||||
FLIR_06325.jpg
|
|
||||||
FLIR_04943.jpg
|
|
||||||
FLIR_video_02829.jpg
|
|
||||||
FLIR_08248.jpg
|
|
||||||
FLIR_04484.jpg
|
|
||||||
FLIR_08058.jpg
|
|
||||||
FLIR_06795.jpg
|
|
||||||
FLIR_06995.jpg
|
|
||||||
FLIR_05879.jpg
|
|
||||||
FLIR_04593.jpg
|
|
||||||
FLIR_08094.jpg
|
|
||||||
FLIR_08526.jpg
|
|
||||||
FLIR_08858.jpg
|
|
||||||
FLIR_09465.jpg
|
|
||||||
FLIR_05064.jpg
|
|
||||||
FLIR_05857.jpg
|
|
||||||
FLIR_05914.jpg
|
|
||||||
FLIR_04722.jpg
|
|
||||||
FLIR_06506.jpg
|
|
||||||
FLIR_06282.jpg
|
|
||||||
FLIR_04512.jpg
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.41 52.99 15.81 2.37 1.78 0.71 0.55 0.96
|
|
||||||
================================================================================
|
|
@ -1,24 +0,0 @@
|
|||||||
2.4.1+cu121
|
|
||||||
True
|
|
||||||
Model: PFCFuse
|
|
||||||
Number of epochs: 60
|
|
||||||
Epoch gap: 40
|
|
||||||
Learning rate: 0.0001
|
|
||||||
Weight decay: 0
|
|
||||||
Batch size: 1
|
|
||||||
GPU number: 0
|
|
||||||
Coefficient of MSE loss VF: 1.0
|
|
||||||
Coefficient of MSE loss IF: 1.0
|
|
||||||
Coefficient of RMI loss VF: 1.0
|
|
||||||
Coefficient of RMI loss IF: 1.0
|
|
||||||
Coefficient of Cosine loss VF: 1.0
|
|
||||||
Coefficient of Cosine loss IF: 1.0
|
|
||||||
Coefficient of Decomposition loss: 2.0
|
|
||||||
Coefficient of Total Variation loss: 5.0
|
|
||||||
Clip gradient norm value: 0.01
|
|
||||||
Optimization step: 20
|
|
||||||
Optimization gamma: 0.5
|
|
||||||
[Epoch 0/60] [Batch 0/6487] [loss: 10.193036] ETA: 9 days, 21
[Epoch 0/60] [Batch 1/6487] [loss: 4.166963] ETA: 11:49:56.5
[Epoch 0/60] [Batch 2/6487] [loss: 10.681509] ETA: 10:23:19.1
[Epoch 0/60] [Batch 3/6487] [loss: 6.257133] ETA: 10:31:48.3
[Epoch 0/60] [Batch 4/6487] [loss: 13.018341] ETA: 10:32:54.2
[Epoch 0/60] [Batch 5/6487] [loss: 11.268185] ETA: 10:27:32.2
[Epoch 0/60] [Batch 6/6487] [loss: 6.920656] ETA: 10:34:01.5
[Epoch 0/60] [Batch 7/6487] [loss: 4.666215] ETA: 10:32:45.3
[Epoch 0/60] [Batch 8/6487] [loss: 10.787085] ETA: 10:26:01.9
[Epoch 0/60] [Batch 9/6487] [loss: 5.754866] ETA: 10:34:34.2
[Epoch 0/60] [Batch 10/6487] [loss: 28.760792] ETA: 10:36:32.6
[Epoch 0/60] [Batch 11/6487] [loss: 8.672796] ETA: 10:25:11.9
[Epoch 0/60] [Batch 12/6487] [loss: 14.300608] ETA: 10:28:19.9
[Epoch 0/60] [Batch 13/6487] [loss: 11.821722] ETA: 10:34:18.6
[Epoch 0/60] [Batch 14/6487] [loss: 7.627745] ETA: 10:31:44.8
[Epoch 0/60] [Batch 15/6487] [loss: 5.722600] ETA: 10:34:17.4
[Epoch 0/60] [Batch 16/6487] [loss: 10.423873] ETA: 11:33:27.1
[Epoch 0/60] [Batch 17/6487] [loss: 4.454098] ETA: 9:37:13.67
[Epoch 0/60] [Batch 18/6487] [loss: 3.820719] ETA: 9:33:57.42
[Epoch 0/60] [Batch 19/6487] [loss: 6.564124] ETA: 9:41:22.09
[Epoch 0/60] [Batch 20/6487] [loss: 5.406681] ETA: 9:47:30.11
[Epoch 0/60] [Batch 21/6487] [loss: 25.275440] ETA: 9:39:29.91
[Epoch 0/60] [Batch 22/6487] [loss: 4.228334] ETA: 9:42:15.45
[Epoch 0/60] [Batch 23/6487] [loss: 22.508118] ETA: 9:38:18.10
[Epoch 0/60] [Batch 24/6487] [loss: 5.062001] ETA: 9:46:09.29
[Epoch 0/60] [Batch 25/6487] [loss: 3.157355] ETA: 9:41:30.09
[Epoch 0/60] [Batch 26/6487] [loss: 6.438435] ETA: 10:02:51.9
[Epoch 0/60] [Batch 27/6487] [loss: 7.430470] ETA: 9:18:12.94
[Epoch 0/60] [Batch 28/6487] [loss: 3.783903] ETA: 10:41:13.9
[Epoch 0/60] [Batch 29/6487] [loss: 2.954306] ETA: 9:44:25.10
[Epoch 0/60] [Batch 30/6487] [loss: 5.863827] ETA: 9:35:13.84
[Epoch 0/60] [Batch 31/6487] [loss: 6.467144] ETA: 9:46:19.80
[Epoch 0/60] [Batch 32/6487] [loss: 4.801052] ETA: 9:32:17.18
[Epoch 0/60] [Batch 33/6487] [loss: 5.658401] ETA: 9:31:10.28
[Epoch 0/60] [Batch 34/6487] [loss: 2.085633] ETA: 9:36:47.39
[Epoch 0/60] [Batch 35/6487] [loss: 15.402915] ETA: 9:40:51.43
[Epoch 0/60] [Batch 36/6487] [loss: 3.181264] ETA: 9:33:06.65
[Epoch 0/60] [Batch 37/6487] [loss: 3.883055] ETA: 9:42:29.60
[Epoch 0/60] [Batch 38/6487] [loss: 3.342676] ETA: 10:07:02.2
[Epoch 0/60] [Batch 39/6487] [loss: 2.589705] ETA: 9:36:43.32
[Epoch 0/60] [Batch 40/6487] [loss: 3.742121] ETA: 9:42:57.54
[Epoch 0/60] [Batch 41/6487] [loss: 2.732829] ETA: 9:36:54.65
[Epoch 0/60] [Batch 42/6487] [loss: 6.655626] ETA: 9:42:20.71
[Epoch 0/60] [Batch 43/6487] [loss: 1.822412] ETA: 9:38:02.02
[Epoch 0/60] [Batch 44/6487] [loss: 2.875143] ETA: 9:41:02.96
[Epoch 0/60] [Batch 45/6487] [loss: 2.319836] ETA: 9:38:16.23
[Epoch 0/60] [Batch 46/6487] [loss: 2.354790] ETA: 9:39:08.93
[Epoch 0/60] [Batch 47/6487] [loss: 1.986412] ETA: 9:52:11.40
[Epoch 0/60] [Batch 48/6487] [loss: 2.154071] ETA: 10:08:20.0
[Epoch 0/60] [Batch 49/6487] [loss: 1.425418] ETA: 9:54:04.42
[Epoch 0/60] [Batch 50/6487] [loss: 1.988360] ETA: 9:30:25.08
[Epoch 0/60] [Batch 51/6487] [loss: 4.090429] ETA: 9:43:53.52
[Epoch 0/60] [Batch 52/6487] [loss: 1.924778] ETA: 9:46:19.38
[Epoch 0/60] [Batch 53/6487] [loss: 2.191964] ETA: 9:46:59.93
[Epoch 0/60] [Batch 54/6487] [loss: 2.032799] ETA: 9:46:14.01
[Epoch 0/60] [Batch 55/6487] [loss: 1.923933] ETA: 9:44:21.65
[Epoch 0/60] [Batch 56/6487] [loss: 1.656838] ETA: 9:56:15.90
[Epoch 0/60] [Batch 57/6487] [loss: 1.656845] ETA: 10:21:26.1
[Epoch 0/60] [Batch 58/6487] [loss: 1.157820] ETA: 10:44:47.3
[Epoch 0/60] [Batch 59/6487] [loss: 1.652715] ETA: 10:46:39.9
[Epoch 0/60] [Batch 60/6487] [loss: 1.633865] ETA: 10:23:34.7
[Epoch 0/60] [Batch 61/6487] [loss: 1.290259] ETA: 9:24:06.12Traceback (most recent call last):
|
|
||||||
File "/home/star/whaiDir/PFCFuse/train.py", line 232, in <module>
|
|
||||||
loss.item(),
|
|
||||||
KeyboardInterrupt
|
|
@ -1,38 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.01 40.4 15.51 1.55 1.75 0.66 0.54 0.96
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
Process finished with exit code 0
|
|
165
net.py
@ -6,8 +6,10 @@ import torch.utils.checkpoint as checkpoint
|
|||||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from componets.SCSA import SCSA
|
from componets.WTConvCV2 import WTConv2d
|
||||||
|
|
||||||
|
|
||||||
|
# 以一定概率随机丢弃输入张量中的路径,用于正则化模型
|
||||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||||
if drop_prob == 0. or not training:
|
if drop_prob == 0. or not training:
|
||||||
return x
|
return x
|
||||||
@ -32,6 +34,9 @@ class DropPath(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 改点,使用Pooling替换AttentionBase
|
# 改点,使用Pooling替换AttentionBase
|
||||||
class Pooling(nn.Module):
|
class Pooling(nn.Module):
|
||||||
def __init__(self, kernel_size=3):
|
def __init__(self, kernel_size=3):
|
||||||
@ -92,47 +97,54 @@ class PoolMlp(nn.Module):
|
|||||||
x = self.drop(x)
|
x = self.drop(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class BaseFeatureExtractionSAR(nn.Module):
|
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
|
||||||
act_layer=nn.GELU,
|
|
||||||
# norm_layer=nn.LayerNorm,
|
|
||||||
drop=0., drop_path=0.,
|
|
||||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
|
||||||
|
|
||||||
super().__init__()
|
# class BaseFeatureExtraction1(nn.Module):
|
||||||
|
# def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
# act_layer=nn.GELU,
|
||||||
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
# # norm_layer=nn.LayerNorm,
|
||||||
self.token_mixer = SCSA(dim=dim, head_num=8)
|
# drop=0., drop_path=0.,
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
# use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
#
|
||||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
# super().__init__()
|
||||||
act_layer=act_layer, drop=drop)
|
#
|
||||||
|
# self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
# The following two techniques are useful to train deep PoolFormers.
|
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
# self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
else nn.Identity()
|
# mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
self.use_layer_scale = use_layer_scale
|
# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
|
# act_layer=act_layer, drop=drop)
|
||||||
if use_layer_scale:
|
#
|
||||||
self.layer_scale_1 = nn.Parameter(
|
# # The following two techniques are useful to train deep PoolFormers.
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
# self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||||
|
# else nn.Identity()
|
||||||
self.layer_scale_2 = nn.Parameter(
|
# self.use_layer_scale = use_layer_scale
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
#
|
||||||
|
# if use_layer_scale:
|
||||||
def forward(self, x):
|
# self.layer_scale_1 = nn.Parameter(
|
||||||
if self.use_layer_scale:
|
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
x = x + self.drop_path(
|
#
|
||||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
# self.layer_scale_2 = nn.Parameter(
|
||||||
* self.token_mixer(self.norm1(x)))
|
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
x = x + self.drop_path(
|
#
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
# def forward(self, x): # 1 64 128 128
|
||||||
* self.poolmlp(self.norm2(x)))
|
# if self.use_layer_scale:
|
||||||
else:
|
# # self.layer_scale_1(64,)
|
||||||
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
# tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
# normal = self.norm1(x) # 1 64 128 128
|
||||||
return x
|
# token_mix = self.token_mixer(normal) # 1 64 128 128
|
||||||
|
# x = (x +
|
||||||
|
# self.drop_path(
|
||||||
|
# tmp1 * token_mix
|
||||||
|
# )
|
||||||
|
# # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||||
|
# )
|
||||||
|
# x = x + self.drop_path(
|
||||||
|
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
# * self.poolmlp(self.norm2(x)))
|
||||||
|
# else:
|
||||||
|
# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||||
|
# x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
|
# return x
|
||||||
|
|
||||||
class BaseFeatureExtraction(nn.Module):
|
class BaseFeatureExtraction(nn.Module):
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
@ -143,6 +155,7 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.WTConv2d = WTConv2d(dim, dim)
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
@ -162,11 +175,21 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
self.layer_scale_2 = nn.Parameter(
|
self.layer_scale_2 = nn.Parameter(
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x): # 1 64 128 128
|
||||||
if self.use_layer_scale:
|
if self.use_layer_scale:
|
||||||
x = x + self.drop_path(
|
# self.layer_scale_1(64,)
|
||||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
||||||
* self.token_mixer(self.norm1(x)))
|
normal = self.norm1(x) # 1 64 128 128
|
||||||
|
token_mix = self.token_mixer(normal) # 1 64 128 128
|
||||||
|
|
||||||
|
x = self.WTConv2d(x)
|
||||||
|
|
||||||
|
x = (x +
|
||||||
|
self.drop_path(
|
||||||
|
tmp1 * token_mix
|
||||||
|
)
|
||||||
|
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||||
|
)
|
||||||
x = x + self.drop_path(
|
x = x + self.drop_path(
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
* self.poolmlp(self.norm2(x)))
|
* self.poolmlp(self.norm2(x)))
|
||||||
@ -175,7 +198,6 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class InvertedResidualBlock(nn.Module):
|
class InvertedResidualBlock(nn.Module):
|
||||||
def __init__(self, inp, oup, expand_ratio):
|
def __init__(self, inp, oup, expand_ratio):
|
||||||
super(InvertedResidualBlock, self).__init__()
|
super(InvertedResidualBlock, self).__init__()
|
||||||
@ -198,6 +220,8 @@ class InvertedResidualBlock(nn.Module):
|
|||||||
return self.bottleneckBlock(x)
|
return self.bottleneckBlock(x)
|
||||||
|
|
||||||
class DetailNode(nn.Module):
|
class DetailNode(nn.Module):
|
||||||
|
|
||||||
|
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(DetailNode, self).__init__()
|
super(DetailNode, self).__init__()
|
||||||
|
|
||||||
@ -219,24 +243,25 @@ class DetailNode(nn.Module):
|
|||||||
return z1, z2
|
return z1, z2
|
||||||
|
|
||||||
|
|
||||||
class DetailFeatureExtractionSAR(nn.Module):
|
|
||||||
def __init__(self, num_layers=3):
|
|
||||||
super(DetailFeatureExtractionSAR, self).__init__()
|
|
||||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
|
||||||
self.net = nn.Sequential(*INNmodules)
|
|
||||||
def forward(self, x):
|
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
|
||||||
for layer in self.net:
|
|
||||||
z1, z2 = layer(z1, z2)
|
|
||||||
return torch.cat((z1, z2), dim=1)
|
|
||||||
|
|
||||||
class DetailFeatureExtraction(nn.Module):
|
class DetailFeatureExtraction(nn.Module):
|
||||||
def __init__(self, num_layers=3):
|
def __init__(self, num_layers=3):
|
||||||
super(DetailFeatureExtraction, self).__init__()
|
super(DetailFeatureExtraction, self).__init__()
|
||||||
INNmodules = [DetailNode() for _ in range(num_layers)]
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
self.net = nn.Sequential(*INNmodules)
|
self.net = nn.Sequential(*INNmodules)
|
||||||
def forward(self, x):
|
self.enhancement_module = nn.Sequential(
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x): # 1 64 128 128
|
||||||
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
||||||
|
# 增强并添加残差连接
|
||||||
|
enhanced_z1 = self.enhancement_module(z1)
|
||||||
|
enhanced_z2 = self.enhancement_module(z2)
|
||||||
|
# 残差连接
|
||||||
|
z1 = z1 + enhanced_z1
|
||||||
|
z2 = z2 + enhanced_z2
|
||||||
for layer in self.net:
|
for layer in self.net:
|
||||||
z1, z2 = layer(z1, z2)
|
z1, z2 = layer(z1, z2)
|
||||||
return torch.cat((z1, z2), dim=1)
|
return torch.cat((z1, z2), dim=1)
|
||||||
@ -425,21 +450,12 @@ class Restormer_Encoder(nn.Module):
|
|||||||
|
|
||||||
self.detailFeature = DetailFeatureExtraction()
|
self.detailFeature = DetailFeatureExtraction()
|
||||||
|
|
||||||
self.baseFeature_sar = BaseFeatureExtractionSAR(dim=dim)
|
def forward(self, inp_img):
|
||||||
self.detailFeature_sar = DetailFeatureExtractionSAR()
|
|
||||||
|
|
||||||
def forward(self, inp_img,is_sar=False):
|
|
||||||
inp_enc_level1 = self.patch_embed(inp_img)
|
inp_enc_level1 = self.patch_embed(inp_img)
|
||||||
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
||||||
if is_sar:
|
base_feature = self.baseFeature(out_enc_level1)
|
||||||
base_feature = self.baseFeature_sar(out_enc_level1) # 1 64 128 128
|
detail_feature = self.detailFeature(out_enc_level1)
|
||||||
detail_feature = self.detailFeature_sar(out_enc_level1) # 1 64 128 128
|
return base_feature, detail_feature, out_enc_level1
|
||||||
return base_feature, detail_feature, out_enc_level1 # 1 64 128 128
|
|
||||||
|
|
||||||
else:
|
|
||||||
base_feature = self.baseFeature(out_enc_level1) # 1 64 128 128
|
|
||||||
detail_feature = self.detailFeature(out_enc_level1) # 1 64 128 128
|
|
||||||
return base_feature, detail_feature, out_enc_level1 # 1 64 128 128
|
|
||||||
|
|
||||||
|
|
||||||
class Restormer_Decoder(nn.Module):
|
class Restormer_Decoder(nn.Module):
|
||||||
@ -456,8 +472,7 @@ class Restormer_Decoder(nn.Module):
|
|||||||
|
|
||||||
super(Restormer_Decoder, self).__init__()
|
super(Restormer_Decoder, self).__init__()
|
||||||
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
|
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
|
||||||
self.encoder_level2 = nn.Sequential(
|
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
||||||
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
|
||||||
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
||||||
self.output = nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
|
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
|
||||||
@ -484,3 +499,5 @@ if __name__ == '__main__':
|
|||||||
window_size = 8
|
window_size = 8
|
||||||
modelE = Restormer_Encoder().cuda()
|
modelE = Restormer_Encoder().cuda()
|
||||||
modelD = Restormer_Decoder().cuda()
|
modelD = Restormer_Decoder().cuda()
|
||||||
|
print(modelE)
|
||||||
|
print(modelD)
|
||||||
|
136
status.md
@ -1,136 +0,0 @@
|
|||||||
PFCFuse
|
|
||||||
|
|
||||||
```angular2html
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.14 46.48 13.18 2.22 1.76 0.79 0.56 1.02
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of RoadScene :
|
|
||||||
FLIR_07206.jpg
|
|
||||||
FLIR_08202.jpg
|
|
||||||
FLIR_05893.jpg
|
|
||||||
FLIR_06974.jpg
|
|
||||||
FLIR_04424.jpg
|
|
||||||
FLIR_08284.jpg
|
|
||||||
FLIR_07786.jpg
|
|
||||||
FLIR_08021.jpg
|
|
||||||
FLIR_07968.jpg
|
|
||||||
FLIR_01130.jpg
|
|
||||||
FLIR_06993.jpg
|
|
||||||
FLIR_07190.jpg
|
|
||||||
FLIR_06570.jpg
|
|
||||||
FLIR_07809.jpg
|
|
||||||
FLIR_06430.jpg
|
|
||||||
FLIR_08592.jpg
|
|
||||||
FLIR_00211.jpg
|
|
||||||
FLIR_08721.jpg
|
|
||||||
FLIR_05955.jpg
|
|
||||||
FLIR_04688.jpg
|
|
||||||
FLIR_07732.jpg
|
|
||||||
FLIR_06392.jpg
|
|
||||||
FLIR_00977.jpg
|
|
||||||
FLIR_05105.jpg
|
|
||||||
FLIR_04269.jpg
|
|
||||||
FLIR_07970.jpg
|
|
||||||
FLIR_05005.jpg
|
|
||||||
FLIR_07209.jpg
|
|
||||||
FLIR_07555.jpg
|
|
||||||
FLIR_06325.jpg
|
|
||||||
FLIR_04943.jpg
|
|
||||||
FLIR_video_02829.jpg
|
|
||||||
FLIR_08248.jpg
|
|
||||||
FLIR_04484.jpg
|
|
||||||
FLIR_08058.jpg
|
|
||||||
FLIR_06795.jpg
|
|
||||||
FLIR_06995.jpg
|
|
||||||
FLIR_05879.jpg
|
|
||||||
FLIR_04593.jpg
|
|
||||||
FLIR_08094.jpg
|
|
||||||
FLIR_08526.jpg
|
|
||||||
FLIR_08858.jpg
|
|
||||||
FLIR_09465.jpg
|
|
||||||
FLIR_05064.jpg
|
|
||||||
FLIR_05857.jpg
|
|
||||||
FLIR_05914.jpg
|
|
||||||
FLIR_04722.jpg
|
|
||||||
FLIR_06506.jpg
|
|
||||||
FLIR_06282.jpg
|
|
||||||
FLIR_04512.jpg
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.41 52.99 15.81 2.37 1.78 0.71 0.55 0.96
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
20241008
|
|
||||||
```
|
|
||||||
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.01 40.4 15.51 1.55 1.75 0.66 0.54 0.96
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
Process finished with exit code 0
|
|
||||||
|
|
||||||
```
|
|
15
test_IVF.py
@ -1,5 +1,3 @@
|
|||||||
import datetime
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
import os
|
import os
|
||||||
@ -13,18 +11,16 @@ import logging
|
|||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
logging.basicConfig(level=logging.CRITICAL)
|
logging.basicConfig(level=logging.CRITICAL)
|
||||||
|
|
||||||
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
|
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-15-11-11.pth"
|
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth"
|
||||||
|
|
||||||
for dataset_name in ["sar"]:
|
for dataset_name in ["TNO"]:
|
||||||
print("\n"*2+"="*80)
|
print("\n"*2+"="*80)
|
||||||
model_name="whai 修改SCSA分支 "
|
model_name="PFCFuse "
|
||||||
print("The test result of "+dataset_name+' :')
|
print("The test result of "+dataset_name+' :')
|
||||||
test_folder = os.path.join('test_img', dataset_name)
|
test_folder=os.path.join('/home/star/whaiDir/CDDFuse/test_img/',dataset_name)
|
||||||
test_out_folder=os.path.join('test_result',current_time,dataset_name)
|
test_out_folder=os.path.join('test_result',dataset_name)
|
||||||
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
@ -43,6 +39,7 @@ for dataset_name in ["sar"]:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
||||||
|
print(img_name)
|
||||||
|
|
||||||
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||||
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0
|
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0
|
||||||
|
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 39 KiB |
Before Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 34 KiB |
Before Width: | Height: | Size: 33 KiB |
Before Width: | Height: | Size: 38 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 44 KiB |
Before Width: | Height: | Size: 44 KiB |
Before Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 40 KiB |
Before Width: | Height: | Size: 40 KiB |
Before Width: | Height: | Size: 40 KiB |
Before Width: | Height: | Size: 39 KiB |
Before Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 59 KiB |
Before Width: | Height: | Size: 58 KiB |
Before Width: | Height: | Size: 60 KiB |
Before Width: | Height: | Size: 66 KiB |
Before Width: | Height: | Size: 63 KiB |
Before Width: | Height: | Size: 61 KiB |
Before Width: | Height: | Size: 60 KiB |
Before Width: | Height: | Size: 57 KiB |
Before Width: | Height: | Size: 54 KiB |
Before Width: | Height: | Size: 52 KiB |
Before Width: | Height: | Size: 51 KiB |
Before Width: | Height: | Size: 53 KiB |
Before Width: | Height: | Size: 54 KiB |
Before Width: | Height: | Size: 56 KiB |
Before Width: | Height: | Size: 55 KiB |
Before Width: | Height: | Size: 52 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 49 KiB |
Before Width: | Height: | Size: 48 KiB |
Before Width: | Height: | Size: 46 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 39 KiB |
Before Width: | Height: | Size: 38 KiB |
Before Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 36 KiB |
Before Width: | Height: | Size: 34 KiB |
Before Width: | Height: | Size: 32 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 26 KiB |