Compare commits
No commits in common. "base_vi(inn)+sar(wtconv)" and "main" have entirely different histories.
base_vi(in
...
main
8
.idea/.gitignore
vendored
@ -1,8 +0,0 @@
|
|||||||
# Default ignored files
|
|
||||||
/shelf/
|
|
||||||
/workspace.xml
|
|
||||||
# Editor-based HTTP Client requests
|
|
||||||
/httpRequests/
|
|
||||||
# Datasource local storage ignored files
|
|
||||||
/dataSources/
|
|
||||||
/dataSources.local.xml
|
|
@ -1,12 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<module type="PYTHON_MODULE" version="4">
|
|
||||||
<component name="NewModuleRootManager">
|
|
||||||
<content url="file://$MODULE_DIR$" />
|
|
||||||
<orderEntry type="inheritedJdk" />
|
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
|
||||||
</component>
|
|
||||||
<component name="PyDocumentationSettings">
|
|
||||||
<option name="format" value="PLAIN" />
|
|
||||||
<option name="myDocStringFormat" value="Plain" />
|
|
||||||
</component>
|
|
||||||
</module>
|
|
@ -1,78 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="PublishConfigData" autoUpload="Always" serverName="star@192.168.50.108:22 password (9)" remoteFilesAllowedToDisappearOnAutoupload="false">
|
|
||||||
<serverData>
|
|
||||||
<paths name="star@192.168.50.108:22 password">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (2)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (3)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (4)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (5)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (6)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (7)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (8)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping deploy="/home/star/whaiDir/PFCFuse" local="$PROJECT_DIR$" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="star@192.168.50.108:22 password (9)">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping deploy="/home/star/whaiDir/PFCFuse" local="$PROJECT_DIR$" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
<paths name="v100">
|
|
||||||
<serverdata>
|
|
||||||
<mappings>
|
|
||||||
<mapping local="$PROJECT_DIR$" web="/" />
|
|
||||||
</mappings>
|
|
||||||
</serverdata>
|
|
||||||
</paths>
|
|
||||||
</serverData>
|
|
||||||
<option name="myAutoUpload" value="ALWAYS" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -1,15 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="GitToolBoxProjectSettings">
|
|
||||||
<option name="commitMessageIssueKeyValidationOverride">
|
|
||||||
<BoolValueOverride>
|
|
||||||
<option name="enabled" value="true" />
|
|
||||||
</BoolValueOverride>
|
|
||||||
</option>
|
|
||||||
<option name="commitMessageValidationEnabledOverride">
|
|
||||||
<BoolValueOverride>
|
|
||||||
<option name="enabled" value="true" />
|
|
||||||
</BoolValueOverride>
|
|
||||||
</option>
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -1,264 +0,0 @@
|
|||||||
<component name="InspectionProjectProfileManager">
|
|
||||||
<profile version="1.0">
|
|
||||||
<option name="myName" value="Project Default" />
|
|
||||||
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
|
||||||
<Languages>
|
|
||||||
<language minSize="226" name="Python" />
|
|
||||||
</Languages>
|
|
||||||
</inspection_tool>
|
|
||||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
|
||||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
|
||||||
<option name="ignoredPackages">
|
|
||||||
<value>
|
|
||||||
<list size="245">
|
|
||||||
<item index="0" class="java.lang.String" itemvalue="numba" />
|
|
||||||
<item index="1" class="java.lang.String" itemvalue="tensorflow-estimator" />
|
|
||||||
<item index="2" class="java.lang.String" itemvalue="greenlet" />
|
|
||||||
<item index="3" class="java.lang.String" itemvalue="Babel" />
|
|
||||||
<item index="4" class="java.lang.String" itemvalue="scikit-learn" />
|
|
||||||
<item index="5" class="java.lang.String" itemvalue="testpath" />
|
|
||||||
<item index="6" class="java.lang.String" itemvalue="py" />
|
|
||||||
<item index="7" class="java.lang.String" itemvalue="gitdb" />
|
|
||||||
<item index="8" class="java.lang.String" itemvalue="torchvision" />
|
|
||||||
<item index="9" class="java.lang.String" itemvalue="patsy" />
|
|
||||||
<item index="10" class="java.lang.String" itemvalue="mccabe" />
|
|
||||||
<item index="11" class="java.lang.String" itemvalue="bleach" />
|
|
||||||
<item index="12" class="java.lang.String" itemvalue="lxml" />
|
|
||||||
<item index="13" class="java.lang.String" itemvalue="torchaudio" />
|
|
||||||
<item index="14" class="java.lang.String" itemvalue="jsonschema" />
|
|
||||||
<item index="15" class="java.lang.String" itemvalue="xlrd" />
|
|
||||||
<item index="16" class="java.lang.String" itemvalue="Werkzeug" />
|
|
||||||
<item index="17" class="java.lang.String" itemvalue="anaconda-project" />
|
|
||||||
<item index="18" class="java.lang.String" itemvalue="tensorboard-data-server" />
|
|
||||||
<item index="19" class="java.lang.String" itemvalue="typing-extensions" />
|
|
||||||
<item index="20" class="java.lang.String" itemvalue="click" />
|
|
||||||
<item index="21" class="java.lang.String" itemvalue="regex" />
|
|
||||||
<item index="22" class="java.lang.String" itemvalue="fastcache" />
|
|
||||||
<item index="23" class="java.lang.String" itemvalue="tensorboard" />
|
|
||||||
<item index="24" class="java.lang.String" itemvalue="imageio" />
|
|
||||||
<item index="25" class="java.lang.String" itemvalue="pytest-remotedata" />
|
|
||||||
<item index="26" class="java.lang.String" itemvalue="matplotlib" />
|
|
||||||
<item index="27" class="java.lang.String" itemvalue="idna" />
|
|
||||||
<item index="28" class="java.lang.String" itemvalue="Bottleneck" />
|
|
||||||
<item index="29" class="java.lang.String" itemvalue="rsa" />
|
|
||||||
<item index="30" class="java.lang.String" itemvalue="networkx" />
|
|
||||||
<item index="31" class="java.lang.String" itemvalue="pycurl" />
|
|
||||||
<item index="32" class="java.lang.String" itemvalue="smmap" />
|
|
||||||
<item index="33" class="java.lang.String" itemvalue="pluggy" />
|
|
||||||
<item index="34" class="java.lang.String" itemvalue="cffi" />
|
|
||||||
<item index="35" class="java.lang.String" itemvalue="pep8" />
|
|
||||||
<item index="36" class="java.lang.String" itemvalue="numpy" />
|
|
||||||
<item index="37" class="java.lang.String" itemvalue="jdcal" />
|
|
||||||
<item index="38" class="java.lang.String" itemvalue="alabaster" />
|
|
||||||
<item index="39" class="java.lang.String" itemvalue="jupyter" />
|
|
||||||
<item index="40" class="java.lang.String" itemvalue="pyOpenSSL" />
|
|
||||||
<item index="41" class="java.lang.String" itemvalue="PyWavelets" />
|
|
||||||
<item index="42" class="java.lang.String" itemvalue="prompt-toolkit" />
|
|
||||||
<item index="43" class="java.lang.String" itemvalue="QtAwesome" />
|
|
||||||
<item index="44" class="java.lang.String" itemvalue="msgpack-python" />
|
|
||||||
<item index="45" class="java.lang.String" itemvalue="Flask-Cors" />
|
|
||||||
<item index="46" class="java.lang.String" itemvalue="glob2" />
|
|
||||||
<item index="47" class="java.lang.String" itemvalue="Send2Trash" />
|
|
||||||
<item index="48" class="java.lang.String" itemvalue="imagesize" />
|
|
||||||
<item index="49" class="java.lang.String" itemvalue="et-xmlfile" />
|
|
||||||
<item index="50" class="java.lang.String" itemvalue="pathlib2" />
|
|
||||||
<item index="51" class="java.lang.String" itemvalue="docker-pycreds" />
|
|
||||||
<item index="52" class="java.lang.String" itemvalue="importlib-resources" />
|
|
||||||
<item index="53" class="java.lang.String" itemvalue="pathtools" />
|
|
||||||
<item index="54" class="java.lang.String" itemvalue="spyder" />
|
|
||||||
<item index="55" class="java.lang.String" itemvalue="pylint" />
|
|
||||||
<item index="56" class="java.lang.String" itemvalue="statsmodels" />
|
|
||||||
<item index="57" class="java.lang.String" itemvalue="tensorboardX" />
|
|
||||||
<item index="58" class="java.lang.String" itemvalue="isort" />
|
|
||||||
<item index="59" class="java.lang.String" itemvalue="ruamel_yaml" />
|
|
||||||
<item index="60" class="java.lang.String" itemvalue="pytz" />
|
|
||||||
<item index="61" class="java.lang.String" itemvalue="unicodecsv" />
|
|
||||||
<item index="62" class="java.lang.String" itemvalue="pytest-astropy" />
|
|
||||||
<item index="63" class="java.lang.String" itemvalue="traitlets" />
|
|
||||||
<item index="64" class="java.lang.String" itemvalue="absl-py" />
|
|
||||||
<item index="65" class="java.lang.String" itemvalue="protobuf" />
|
|
||||||
<item index="66" class="java.lang.String" itemvalue="nltk" />
|
|
||||||
<item index="67" class="java.lang.String" itemvalue="partd" />
|
|
||||||
<item index="68" class="java.lang.String" itemvalue="promise" />
|
|
||||||
<item index="69" class="java.lang.String" itemvalue="gast" />
|
|
||||||
<item index="70" class="java.lang.String" itemvalue="filelock" />
|
|
||||||
<item index="71" class="java.lang.String" itemvalue="numpydoc" />
|
|
||||||
<item index="72" class="java.lang.String" itemvalue="pyzmq" />
|
|
||||||
<item index="73" class="java.lang.String" itemvalue="oauthlib" />
|
|
||||||
<item index="74" class="java.lang.String" itemvalue="astropy" />
|
|
||||||
<item index="75" class="java.lang.String" itemvalue="keras" />
|
|
||||||
<item index="76" class="java.lang.String" itemvalue="entrypoints" />
|
|
||||||
<item index="77" class="java.lang.String" itemvalue="bkcharts" />
|
|
||||||
<item index="78" class="java.lang.String" itemvalue="pyparsing" />
|
|
||||||
<item index="79" class="java.lang.String" itemvalue="munch" />
|
|
||||||
<item index="80" class="java.lang.String" itemvalue="sphinxcontrib-websupport" />
|
|
||||||
<item index="81" class="java.lang.String" itemvalue="beautifulsoup4" />
|
|
||||||
<item index="82" class="java.lang.String" itemvalue="path.py" />
|
|
||||||
<item index="83" class="java.lang.String" itemvalue="clyent" />
|
|
||||||
<item index="84" class="java.lang.String" itemvalue="navigator-updater" />
|
|
||||||
<item index="85" class="java.lang.String" itemvalue="tifffile" />
|
|
||||||
<item index="86" class="java.lang.String" itemvalue="cryptography" />
|
|
||||||
<item index="87" class="java.lang.String" itemvalue="pygdal" />
|
|
||||||
<item index="88" class="java.lang.String" itemvalue="fastrlock" />
|
|
||||||
<item index="89" class="java.lang.String" itemvalue="widgetsnbextension" />
|
|
||||||
<item index="90" class="java.lang.String" itemvalue="multipledispatch" />
|
|
||||||
<item index="91" class="java.lang.String" itemvalue="numexpr" />
|
|
||||||
<item index="92" class="java.lang.String" itemvalue="jupyter-core" />
|
|
||||||
<item index="93" class="java.lang.String" itemvalue="ipython_genutils" />
|
|
||||||
<item index="94" class="java.lang.String" itemvalue="yapf" />
|
|
||||||
<item index="95" class="java.lang.String" itemvalue="rope" />
|
|
||||||
<item index="96" class="java.lang.String" itemvalue="wcwidth" />
|
|
||||||
<item index="97" class="java.lang.String" itemvalue="cupy-cuda110" />
|
|
||||||
<item index="98" class="java.lang.String" itemvalue="llvmlite" />
|
|
||||||
<item index="99" class="java.lang.String" itemvalue="Jinja2" />
|
|
||||||
<item index="100" class="java.lang.String" itemvalue="pycrypto" />
|
|
||||||
<item index="101" class="java.lang.String" itemvalue="Keras-Preprocessing" />
|
|
||||||
<item index="102" class="java.lang.String" itemvalue="ptflops" />
|
|
||||||
<item index="103" class="java.lang.String" itemvalue="cupy-cuda111" />
|
|
||||||
<item index="104" class="java.lang.String" itemvalue="cupy-cuda114" />
|
|
||||||
<item index="105" class="java.lang.String" itemvalue="some-package" />
|
|
||||||
<item index="106" class="java.lang.String" itemvalue="wandb" />
|
|
||||||
<item index="107" class="java.lang.String" itemvalue="netaddr" />
|
|
||||||
<item index="108" class="java.lang.String" itemvalue="sortedcollections" />
|
|
||||||
<item index="109" class="java.lang.String" itemvalue="six" />
|
|
||||||
<item index="110" class="java.lang.String" itemvalue="timm" />
|
|
||||||
<item index="111" class="java.lang.String" itemvalue="pyflakes" />
|
|
||||||
<item index="112" class="java.lang.String" itemvalue="asn1crypto" />
|
|
||||||
<item index="113" class="java.lang.String" itemvalue="parso" />
|
|
||||||
<item index="114" class="java.lang.String" itemvalue="pytest-doctestplus" />
|
|
||||||
<item index="115" class="java.lang.String" itemvalue="ipython" />
|
|
||||||
<item index="116" class="java.lang.String" itemvalue="xlwt" />
|
|
||||||
<item index="117" class="java.lang.String" itemvalue="packaging" />
|
|
||||||
<item index="118" class="java.lang.String" itemvalue="chardet" />
|
|
||||||
<item index="119" class="java.lang.String" itemvalue="jupyterlab-launcher" />
|
|
||||||
<item index="120" class="java.lang.String" itemvalue="click-plugins" />
|
|
||||||
<item index="121" class="java.lang.String" itemvalue="PyYAML" />
|
|
||||||
<item index="122" class="java.lang.String" itemvalue="pickleshare" />
|
|
||||||
<item index="123" class="java.lang.String" itemvalue="pycparser" />
|
|
||||||
<item index="124" class="java.lang.String" itemvalue="pyasn1-modules" />
|
|
||||||
<item index="125" class="java.lang.String" itemvalue="tables" />
|
|
||||||
<item index="126" class="java.lang.String" itemvalue="Pygments" />
|
|
||||||
<item index="127" class="java.lang.String" itemvalue="sentry-sdk" />
|
|
||||||
<item index="128" class="java.lang.String" itemvalue="docutils" />
|
|
||||||
<item index="129" class="java.lang.String" itemvalue="gevent" />
|
|
||||||
<item index="130" class="java.lang.String" itemvalue="shortuuid" />
|
|
||||||
<item index="131" class="java.lang.String" itemvalue="qtconsole" />
|
|
||||||
<item index="132" class="java.lang.String" itemvalue="terminado" />
|
|
||||||
<item index="133" class="java.lang.String" itemvalue="GitPython" />
|
|
||||||
<item index="134" class="java.lang.String" itemvalue="distributed" />
|
|
||||||
<item index="135" class="java.lang.String" itemvalue="jupyter-client" />
|
|
||||||
<item index="136" class="java.lang.String" itemvalue="pexpect" />
|
|
||||||
<item index="137" class="java.lang.String" itemvalue="ipykernel" />
|
|
||||||
<item index="138" class="java.lang.String" itemvalue="nbconvert" />
|
|
||||||
<item index="139" class="java.lang.String" itemvalue="attrs" />
|
|
||||||
<item index="140" class="java.lang.String" itemvalue="psutil" />
|
|
||||||
<item index="141" class="java.lang.String" itemvalue="simplejson" />
|
|
||||||
<item index="142" class="java.lang.String" itemvalue="jedi" />
|
|
||||||
<item index="143" class="java.lang.String" itemvalue="flatbuffers" />
|
|
||||||
<item index="144" class="java.lang.String" itemvalue="cytoolz" />
|
|
||||||
<item index="145" class="java.lang.String" itemvalue="odo" />
|
|
||||||
<item index="146" class="java.lang.String" itemvalue="decorator" />
|
|
||||||
<item index="147" class="java.lang.String" itemvalue="pandocfilters" />
|
|
||||||
<item index="148" class="java.lang.String" itemvalue="backports.shutil-get-terminal-size" />
|
|
||||||
<item index="149" class="java.lang.String" itemvalue="pycodestyle" />
|
|
||||||
<item index="150" class="java.lang.String" itemvalue="pycosat" />
|
|
||||||
<item index="151" class="java.lang.String" itemvalue="pyasn1" />
|
|
||||||
<item index="152" class="java.lang.String" itemvalue="requests" />
|
|
||||||
<item index="153" class="java.lang.String" itemvalue="bitarray" />
|
|
||||||
<item index="154" class="java.lang.String" itemvalue="kornia" />
|
|
||||||
<item index="155" class="java.lang.String" itemvalue="mkl-fft" />
|
|
||||||
<item index="156" class="java.lang.String" itemvalue="tensorflow" />
|
|
||||||
<item index="157" class="java.lang.String" itemvalue="XlsxWriter" />
|
|
||||||
<item index="158" class="java.lang.String" itemvalue="seaborn" />
|
|
||||||
<item index="159" class="java.lang.String" itemvalue="tensorboard-plugin-wit" />
|
|
||||||
<item index="160" class="java.lang.String" itemvalue="blaze" />
|
|
||||||
<item index="161" class="java.lang.String" itemvalue="zipp" />
|
|
||||||
<item index="162" class="java.lang.String" itemvalue="pkginfo" />
|
|
||||||
<item index="163" class="java.lang.String" itemvalue="cached-property" />
|
|
||||||
<item index="164" class="java.lang.String" itemvalue="torchstat" />
|
|
||||||
<item index="165" class="java.lang.String" itemvalue="datashape" />
|
|
||||||
<item index="166" class="java.lang.String" itemvalue="itsdangerous" />
|
|
||||||
<item index="167" class="java.lang.String" itemvalue="ipywidgets" />
|
|
||||||
<item index="168" class="java.lang.String" itemvalue="scipy" />
|
|
||||||
<item index="169" class="java.lang.String" itemvalue="thop" />
|
|
||||||
<item index="170" class="java.lang.String" itemvalue="tornado" />
|
|
||||||
<item index="171" class="java.lang.String" itemvalue="google-auth-oauthlib" />
|
|
||||||
<item index="172" class="java.lang.String" itemvalue="opencv-python" />
|
|
||||||
<item index="173" class="java.lang.String" itemvalue="torch" />
|
|
||||||
<item index="174" class="java.lang.String" itemvalue="singledispatch" />
|
|
||||||
<item index="175" class="java.lang.String" itemvalue="sortedcontainers" />
|
|
||||||
<item index="176" class="java.lang.String" itemvalue="mistune" />
|
|
||||||
<item index="177" class="java.lang.String" itemvalue="pandas" />
|
|
||||||
<item index="178" class="java.lang.String" itemvalue="termcolor" />
|
|
||||||
<item index="179" class="java.lang.String" itemvalue="clang" />
|
|
||||||
<item index="180" class="java.lang.String" itemvalue="toolz" />
|
|
||||||
<item index="181" class="java.lang.String" itemvalue="Sphinx" />
|
|
||||||
<item index="182" class="java.lang.String" itemvalue="mpmath" />
|
|
||||||
<item index="183" class="java.lang.String" itemvalue="jupyter-console" />
|
|
||||||
<item index="184" class="java.lang.String" itemvalue="bokeh" />
|
|
||||||
<item index="185" class="java.lang.String" itemvalue="cachetools" />
|
|
||||||
<item index="186" class="java.lang.String" itemvalue="gmpy2" />
|
|
||||||
<item index="187" class="java.lang.String" itemvalue="setproctitle" />
|
|
||||||
<item index="188" class="java.lang.String" itemvalue="webencodings" />
|
|
||||||
<item index="189" class="java.lang.String" itemvalue="html5lib" />
|
|
||||||
<item index="190" class="java.lang.String" itemvalue="colorlog" />
|
|
||||||
<item index="191" class="java.lang.String" itemvalue="python-dateutil" />
|
|
||||||
<item index="192" class="java.lang.String" itemvalue="QtPy" />
|
|
||||||
<item index="193" class="java.lang.String" itemvalue="astroid" />
|
|
||||||
<item index="194" class="java.lang.String" itemvalue="cycler" />
|
|
||||||
<item index="195" class="java.lang.String" itemvalue="mkl-random" />
|
|
||||||
<item index="196" class="java.lang.String" itemvalue="pytest-arraydiff" />
|
|
||||||
<item index="197" class="java.lang.String" itemvalue="locket" />
|
|
||||||
<item index="198" class="java.lang.String" itemvalue="heapdict" />
|
|
||||||
<item index="199" class="java.lang.String" itemvalue="snowballstemmer" />
|
|
||||||
<item index="200" class="java.lang.String" itemvalue="contextlib2" />
|
|
||||||
<item index="201" class="java.lang.String" itemvalue="certifi" />
|
|
||||||
<item index="202" class="java.lang.String" itemvalue="Markdown" />
|
|
||||||
<item index="203" class="java.lang.String" itemvalue="sympy" />
|
|
||||||
<item index="204" class="java.lang.String" itemvalue="notebook" />
|
|
||||||
<item index="205" class="java.lang.String" itemvalue="pyodbc" />
|
|
||||||
<item index="206" class="java.lang.String" itemvalue="boto" />
|
|
||||||
<item index="207" class="java.lang.String" itemvalue="cligj" />
|
|
||||||
<item index="208" class="java.lang.String" itemvalue="h5py" />
|
|
||||||
<item index="209" class="java.lang.String" itemvalue="wrapt" />
|
|
||||||
<item index="210" class="java.lang.String" itemvalue="kiwisolver" />
|
|
||||||
<item index="211" class="java.lang.String" itemvalue="pytest-openfiles" />
|
|
||||||
<item index="212" class="java.lang.String" itemvalue="anaconda-client" />
|
|
||||||
<item index="213" class="java.lang.String" itemvalue="backcall" />
|
|
||||||
<item index="214" class="java.lang.String" itemvalue="PySocks" />
|
|
||||||
<item index="215" class="java.lang.String" itemvalue="charset-normalizer" />
|
|
||||||
<item index="216" class="java.lang.String" itemvalue="typing" />
|
|
||||||
<item index="217" class="java.lang.String" itemvalue="dask" />
|
|
||||||
<item index="218" class="java.lang.String" itemvalue="enum34" />
|
|
||||||
<item index="219" class="java.lang.String" itemvalue="torchsummary" />
|
|
||||||
<item index="220" class="java.lang.String" itemvalue="scikit-image" />
|
|
||||||
<item index="221" class="java.lang.String" itemvalue="ptyprocess" />
|
|
||||||
<item index="222" class="java.lang.String" itemvalue="more-itertools" />
|
|
||||||
<item index="223" class="java.lang.String" itemvalue="SQLAlchemy" />
|
|
||||||
<item index="224" class="java.lang.String" itemvalue="tblib" />
|
|
||||||
<item index="225" class="java.lang.String" itemvalue="cloudpickle" />
|
|
||||||
<item index="226" class="java.lang.String" itemvalue="importlib-metadata" />
|
|
||||||
<item index="227" class="java.lang.String" itemvalue="simplegeneric" />
|
|
||||||
<item index="228" class="java.lang.String" itemvalue="zict" />
|
|
||||||
<item index="229" class="java.lang.String" itemvalue="urllib3" />
|
|
||||||
<item index="230" class="java.lang.String" itemvalue="jupyterlab" />
|
|
||||||
<item index="231" class="java.lang.String" itemvalue="Cython" />
|
|
||||||
<item index="232" class="java.lang.String" itemvalue="Flask" />
|
|
||||||
<item index="233" class="java.lang.String" itemvalue="nose" />
|
|
||||||
<item index="234" class="java.lang.String" itemvalue="pytorch-msssim" />
|
|
||||||
<item index="235" class="java.lang.String" itemvalue="pytest" />
|
|
||||||
<item index="236" class="java.lang.String" itemvalue="nbformat" />
|
|
||||||
<item index="237" class="java.lang.String" itemvalue="matmul" />
|
|
||||||
<item index="238" class="java.lang.String" itemvalue="tqdm" />
|
|
||||||
<item index="239" class="java.lang.String" itemvalue="lazy-object-proxy" />
|
|
||||||
<item index="240" class="java.lang.String" itemvalue="colorama" />
|
|
||||||
<item index="241" class="java.lang.String" itemvalue="grpcio" />
|
|
||||||
<item index="242" class="java.lang.String" itemvalue="ply" />
|
|
||||||
<item index="243" class="java.lang.String" itemvalue="google-auth" />
|
|
||||||
<item index="244" class="java.lang.String" itemvalue="openpyxl" />
|
|
||||||
</list>
|
|
||||||
</value>
|
|
||||||
</option>
|
|
||||||
</inspection_tool>
|
|
||||||
</profile>
|
|
||||||
</component>
|
|
@ -1,6 +0,0 @@
|
|||||||
<component name="InspectionProjectProfileManager">
|
|
||||||
<settings>
|
|
||||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
|
||||||
<version value="1.0" />
|
|
||||||
</settings>
|
|
||||||
</component>
|
|
@ -1,16 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="Black">
|
|
||||||
<option name="sdkName" value="Python 3.9 (flaskTest)" />
|
|
||||||
</component>
|
|
||||||
<component name="MavenImportPreferences">
|
|
||||||
<option name="generalSettings">
|
|
||||||
<MavenGeneralSettings>
|
|
||||||
<option name="localRepository" value="E:\maven\repository" />
|
|
||||||
<option name="showDialogWithAdvancedSettings" value="true" />
|
|
||||||
<option name="userSettingsFile" value="E:\maven\settings.xml" />
|
|
||||||
</MavenGeneralSettings>
|
|
||||||
</option>
|
|
||||||
</component>
|
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pfcfuse)" project-jdk-type="Python SDK" />
|
|
||||||
</project>
|
|
@ -1,8 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="ProjectModuleManager">
|
|
||||||
<modules>
|
|
||||||
<module fileurl="file://$PROJECT_DIR$/.idea/PFCFuse.iml" filepath="$PROJECT_DIR$/.idea/PFCFuse.iml" />
|
|
||||||
</modules>
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -1,6 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="VcsDirectoryMappings">
|
|
||||||
<mapping directory="" vcs="Git" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
@ -1,171 +0,0 @@
|
|||||||
# https://github.com/cecret3350/DEA-Net/blob/main/code/model/modules/deconv.py
|
|
||||||
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from einops.layers.torch import Rearrange
|
|
||||||
|
|
||||||
"""
|
|
||||||
在深度学习和图像处理领域,"vanilla" 和 "difference" 卷积是两种不同的卷积操作,它们各自有不同的特性和应用场景。DEConv(细节增强卷积)的设计思想是结合这两种卷积的特性来增强模型的性能,尤其是在图像去雾等任务中。
|
|
||||||
|
|
||||||
Vanilla Convolution(标准卷积)
|
|
||||||
"Vanilla" 卷积是最基本的卷积类型,通常仅称为卷积。它是卷积神经网络(CNN)中最常用的组件,用于提取输入数据(如图像)的特征。
|
|
||||||
标准卷积通过在输入数据上滑动小的、可学习的过滤器(或称为核),并计算过滤器与数据的局部区域之间的点乘,来工作。通过这种方式,它能够捕获输入数据的局部模式和特征。
|
|
||||||
|
|
||||||
Difference Convolution(差分卷积)
|
|
||||||
差分卷积是一种特殊类型的卷积,它专注于捕捉输入数据中的局部差异信息,例如边缘或纹理的变化。
|
|
||||||
它通过修改标准卷积核的权重或者通过特殊的操作来实现,使得网络更加关注于图像的高频信息,即图像中的细节和纹理变化。在图像处理任务中,如图像去雾、图像增强、边缘检测等,捕获这种高频信息非常重要,因为它们往往包含了关于物体边界和结构的关键信息。
|
|
||||||
|
|
||||||
重参数化技术
|
|
||||||
重参数化技术是一种参数转换方法,它允许模型在不增加额外参数和计算代价的情况下,实现更复杂的功能或改善性能。在DEConv的上下文中,重参数化技术使得将vanilla卷积和difference卷积结合起来的操作,可以等价地转换成一个标准的卷积操作。
|
|
||||||
这意味着DEConv可以在不增加额外参数和计算成本的情况下,通过巧妙地设计卷积核权重,同时利用标准卷积和差分卷积的优势,从而增强模型处理图像的能力。
|
|
||||||
具体来说,通过重参数化,可以将差分卷积的效果整合到一个卷积核中,使得这个卷积核既能捕获图像的基本特征(通过标准卷积部分),也能强调图像的细节和差异信息(通过差分卷积部分)。
|
|
||||||
这种方法特别适用于那些需要同时考虑全局内容和局部细节信息的任务,如图像去雾,其中既需要理解图像的整体结构,也需要恢复由于雾导致的细节丢失。
|
|
||||||
重参数化技术的关键优势在于,它允许模型在维持参数数量和计算复杂度不变的前提下,实现更为复杂或更为精细的功能。
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d_cd(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
|
||||||
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
|
|
||||||
super(Conv2d_cd, self).__init__()
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, groups=groups, bias=bias)
|
|
||||||
self.theta = theta
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
conv_weight = self.conv.weight
|
|
||||||
conv_shape = conv_weight.shape
|
|
||||||
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
|
|
||||||
# conv_weight_cd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_cd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_cd[:, :, :] = conv_weight[:, :, :]
|
|
||||||
conv_weight_cd[:, :, 4] = conv_weight[:, :, 4] - conv_weight[:, :, :].sum(2)
|
|
||||||
conv_weight_cd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
|
|
||||||
conv_weight_cd)
|
|
||||||
return conv_weight_cd, self.conv.bias
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d_ad(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
|
||||||
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
|
|
||||||
super(Conv2d_ad, self).__init__()
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, groups=groups, bias=bias)
|
|
||||||
self.theta = theta
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
conv_weight = self.conv.weight
|
|
||||||
conv_shape = conv_weight.shape
|
|
||||||
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
|
|
||||||
conv_weight_ad = conv_weight - self.theta * conv_weight[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]
|
|
||||||
conv_weight_ad = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[3])(
|
|
||||||
conv_weight_ad)
|
|
||||||
return conv_weight_ad, self.conv.bias
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d_rd(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
|
||||||
padding=2, dilation=1, groups=1, bias=False, theta=1.0):
|
|
||||||
|
|
||||||
super(Conv2d_rd, self).__init__()
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, groups=groups, bias=bias)
|
|
||||||
self.theta = theta
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
|
|
||||||
if math.fabs(self.theta - 0.0) < 1e-8:
|
|
||||||
out_normal = self.conv(x)
|
|
||||||
return out_normal
|
|
||||||
else:
|
|
||||||
conv_weight = self.conv.weight
|
|
||||||
conv_shape = conv_weight.shape
|
|
||||||
if conv_weight.is_cuda:
|
|
||||||
conv_weight_rd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 5 * 5).fill_(0)
|
|
||||||
else:
|
|
||||||
conv_weight_rd = torch.zeros(conv_shape[0], conv_shape[1], 5 * 5)
|
|
||||||
conv_weight = Rearrange('c_in c_out k1 k2 -> c_in c_out (k1 k2)')(conv_weight)
|
|
||||||
conv_weight_rd[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = conv_weight[:, :, 1:]
|
|
||||||
conv_weight_rd[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -conv_weight[:, :, 1:] * self.theta
|
|
||||||
conv_weight_rd[:, :, 12] = conv_weight[:, :, 0] * (1 - self.theta)
|
|
||||||
conv_weight_rd = conv_weight_rd.view(conv_shape[0], conv_shape[1], 5, 5)
|
|
||||||
out_diff = nn.functional.conv2d(input=x, weight=conv_weight_rd, bias=self.conv.bias,
|
|
||||||
stride=self.conv.stride, padding=self.conv.padding, groups=self.conv.groups)
|
|
||||||
|
|
||||||
return out_diff
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d_hd(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
|
||||||
padding=1, dilation=1, groups=1, bias=False, theta=1.0):
|
|
||||||
super(Conv2d_hd, self).__init__()
|
|
||||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, groups=groups, bias=bias)
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
conv_weight = self.conv.weight
|
|
||||||
conv_shape = conv_weight.shape
|
|
||||||
# conv_weight_hd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_hd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_hd[:, :, [0, 3, 6]] = conv_weight[:, :, :]
|
|
||||||
conv_weight_hd[:, :, [2, 5, 8]] = -conv_weight[:, :, :]
|
|
||||||
conv_weight_hd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
|
|
||||||
conv_weight_hd)
|
|
||||||
return conv_weight_hd, self.conv.bias
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2d_vd(nn.Module):
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
|
|
||||||
padding=1, dilation=1, groups=1, bias=False):
|
|
||||||
super(Conv2d_vd, self).__init__()
|
|
||||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
dilation=dilation, groups=groups, bias=bias)
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
conv_weight = self.conv.weight
|
|
||||||
conv_shape = conv_weight.shape
|
|
||||||
# conv_weight_vd = torch.cuda.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_vd = torch.FloatTensor(conv_shape[0], conv_shape[1], 3 * 3).fill_(0)
|
|
||||||
conv_weight_vd[:, :, [0, 1, 2]] = conv_weight[:, :, :]
|
|
||||||
conv_weight_vd[:, :, [6, 7, 8]] = -conv_weight[:, :, :]
|
|
||||||
conv_weight_vd = Rearrange('c_in c_out (k1 k2) -> c_in c_out k1 k2', k1=conv_shape[2], k2=conv_shape[2])(
|
|
||||||
conv_weight_vd)
|
|
||||||
return conv_weight_vd, self.conv.bias
|
|
||||||
|
|
||||||
|
|
||||||
class DEConv(nn.Module):
|
|
||||||
def __init__(self, dim):
|
|
||||||
super(DEConv, self).__init__()
|
|
||||||
self.conv1_1 = Conv2d_cd(dim, dim, 3, bias=True)
|
|
||||||
self.conv1_2 = Conv2d_hd(dim, dim, 3, bias=True)
|
|
||||||
self.conv1_3 = Conv2d_vd(dim, dim, 3, bias=True)
|
|
||||||
self.conv1_4 = Conv2d_ad(dim, dim, 3, bias=True)
|
|
||||||
self.conv1_5 = nn.Conv2d(dim, dim, 3, padding=1, bias=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
w1, b1 = self.conv1_1.get_weight()
|
|
||||||
w2, b2 = self.conv1_2.get_weight()
|
|
||||||
w3, b3 = self.conv1_3.get_weight()
|
|
||||||
w4, b4 = self.conv1_4.get_weight()
|
|
||||||
w5, b5 = self.conv1_5.weight, self.conv1_5.bias
|
|
||||||
|
|
||||||
w = w1 + w2 + w3 + w4 + w5
|
|
||||||
b = b1 + b2 + b3 + b4 + b5
|
|
||||||
res = nn.functional.conv2d(input=x, weight=w, bias=b, stride=1, padding=1, groups=1)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# 初始化DEConv模块,dim为输入和输出的通道数
|
|
||||||
block = DEConv(dim=16)
|
|
||||||
# 创建一个随机输入张量,假设输入尺寸为(batch_size, channels, height, width)
|
|
||||||
input_tensor = torch.rand(4, 16, 64, 64)
|
|
||||||
# 将输入传递给DEConv模块
|
|
||||||
output_tensor = block(input_tensor)
|
|
||||||
# 打印输入和输出张量的尺寸
|
|
||||||
print("输入尺寸:", input_tensor.size())
|
|
||||||
print("输出尺寸:", output_tensor.size())
|
|
||||||
|
|
||||||
|
|
@ -1,156 +0,0 @@
|
|||||||
import typing as t
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from einops.einops import rearrange
|
|
||||||
from mmengine.model import BaseModule
|
|
||||||
__all__ = ['SCSA']
|
|
||||||
|
|
||||||
"""SCSA:探索空间注意力和通道注意力之间的协同作用
|
|
||||||
通道和空间注意力分别在为各种下游视觉任务提取特征依赖性和空间结构关系方面带来了显着的改进。
|
|
||||||
虽然它们的结合更有利于发挥各自的优势,但通道和空间注意力之间的协同作用尚未得到充分探索,缺乏充分利用多语义信息的协同潜力来进行特征引导和缓解语义差异。
|
|
||||||
我们的研究试图在多个语义层面揭示空间和通道注意力之间的协同关系,提出了一种新颖的空间和通道协同注意力模块(SCSA)。我们的SCSA由两部分组成:可共享的多语义空间注意力(SMSA)和渐进式通道自注意力(PCSA)。
|
|
||||||
SMSA 集成多语义信息并利用渐进式压缩策略将判别性空间先验注入 PCSA 的通道自注意力中,有效地指导通道重新校准。此外,PCSA 中基于自注意力机制的稳健特征交互进一步缓解了 SMSA 中不同子特征之间多语义信息的差异。
|
|
||||||
我们在七个基准数据集上进行了广泛的实验,包括 ImageNet-1K 上的分类、MSCOCO 2017 上的对象检测、ADE20K 上的分割以及其他四个复杂场景检测数据集。我们的结果表明,我们提出的 SCSA 不仅超越了当前最先进的注意力机制,
|
|
||||||
而且在各种任务场景中表现出增强的泛化能力。
|
|
||||||
"""
|
|
||||||
|
|
||||||
class SCSA(BaseModule):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
head_num: int,
|
|
||||||
window_size: int = 7,
|
|
||||||
group_kernel_sizes: t.List[int] = [3, 5, 7, 9],
|
|
||||||
qkv_bias: bool = False,
|
|
||||||
fuse_bn: bool = False,
|
|
||||||
norm_cfg: t.Dict = dict(type='BN'),
|
|
||||||
act_cfg: t.Dict = dict(type='ReLU'),
|
|
||||||
down_sample_mode: str = 'avg_pool',
|
|
||||||
attn_drop_ratio: float = 0.,
|
|
||||||
gate_layer: str = 'sigmoid',
|
|
||||||
):
|
|
||||||
super(SCSA, self).__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.head_num = head_num
|
|
||||||
self.head_dim = dim // head_num
|
|
||||||
self.scaler = self.head_dim ** -0.5
|
|
||||||
self.group_kernel_sizes = group_kernel_sizes
|
|
||||||
self.window_size = window_size
|
|
||||||
self.qkv_bias = qkv_bias
|
|
||||||
self.fuse_bn = fuse_bn
|
|
||||||
self.down_sample_mode = down_sample_mode
|
|
||||||
|
|
||||||
assert self.dim // 4, 'The dimension of input feature should be divisible by 4.'
|
|
||||||
self.group_chans = group_chans = self.dim // 4
|
|
||||||
|
|
||||||
self.local_dwc = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[0],
|
|
||||||
padding=group_kernel_sizes[0] // 2, groups=group_chans)
|
|
||||||
self.global_dwc_s = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[1],
|
|
||||||
padding=group_kernel_sizes[1] // 2, groups=group_chans)
|
|
||||||
self.global_dwc_m = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[2],
|
|
||||||
padding=group_kernel_sizes[2] // 2, groups=group_chans)
|
|
||||||
self.global_dwc_l = nn.Conv1d(group_chans, group_chans, kernel_size=group_kernel_sizes[3],
|
|
||||||
padding=group_kernel_sizes[3] // 2, groups=group_chans)
|
|
||||||
self.sa_gate = nn.Softmax(dim=2) if gate_layer == 'softmax' else nn.Sigmoid()
|
|
||||||
self.norm_h = nn.GroupNorm(4, dim)
|
|
||||||
self.norm_w = nn.GroupNorm(4, dim)
|
|
||||||
|
|
||||||
self.conv_d = nn.Identity()
|
|
||||||
self.norm = nn.GroupNorm(1, dim)
|
|
||||||
self.q = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
|
||||||
self.k = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
|
||||||
self.v = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=1, bias=qkv_bias, groups=dim)
|
|
||||||
self.attn_drop = nn.Dropout(attn_drop_ratio)
|
|
||||||
self.ca_gate = nn.Softmax(dim=1) if gate_layer == 'softmax' else nn.Sigmoid()
|
|
||||||
|
|
||||||
if window_size == -1:
|
|
||||||
self.down_func = nn.AdaptiveAvgPool2d((1, 1))
|
|
||||||
else:
|
|
||||||
if down_sample_mode == 'recombination':
|
|
||||||
self.down_func = self.space_to_chans
|
|
||||||
# dimensionality reduction
|
|
||||||
self.conv_d = nn.Conv2d(in_channels=dim * window_size ** 2, out_channels=dim, kernel_size=1, bias=False)
|
|
||||||
elif down_sample_mode == 'avg_pool':
|
|
||||||
self.down_func = nn.AvgPool2d(kernel_size=(window_size, window_size), stride=window_size)
|
|
||||||
elif down_sample_mode == 'max_pool':
|
|
||||||
self.down_func = nn.MaxPool2d(kernel_size=(window_size, window_size), stride=window_size)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
The dim of x is (B, C, H, W)
|
|
||||||
"""
|
|
||||||
# Spatial attention priority calculation
|
|
||||||
b, c, h_, w_ = x.size()
|
|
||||||
# (B, C, H)
|
|
||||||
x_h = x.mean(dim=3)
|
|
||||||
l_x_h, g_x_h_s, g_x_h_m, g_x_h_l = torch.split(x_h, self.group_chans, dim=1)
|
|
||||||
# (B, C, W)
|
|
||||||
x_w = x.mean(dim=2)
|
|
||||||
l_x_w, g_x_w_s, g_x_w_m, g_x_w_l = torch.split(x_w, self.group_chans, dim=1)
|
|
||||||
|
|
||||||
x_h_attn = self.sa_gate(self.norm_h(torch.cat((
|
|
||||||
self.local_dwc(l_x_h),
|
|
||||||
self.global_dwc_s(g_x_h_s),
|
|
||||||
self.global_dwc_m(g_x_h_m),
|
|
||||||
self.global_dwc_l(g_x_h_l),
|
|
||||||
), dim=1)))
|
|
||||||
x_h_attn = x_h_attn.view(b, c, h_, 1)
|
|
||||||
|
|
||||||
x_w_attn = self.sa_gate(self.norm_w(torch.cat((
|
|
||||||
self.local_dwc(l_x_w),
|
|
||||||
self.global_dwc_s(g_x_w_s),
|
|
||||||
self.global_dwc_m(g_x_w_m),
|
|
||||||
self.global_dwc_l(g_x_w_l)
|
|
||||||
), dim=1)))
|
|
||||||
x_w_attn = x_w_attn.view(b, c, 1, w_)
|
|
||||||
|
|
||||||
x = x * x_h_attn * x_w_attn
|
|
||||||
|
|
||||||
# Channel attention based on self attention
|
|
||||||
# reduce calculations
|
|
||||||
y = self.down_func(x)
|
|
||||||
y = self.conv_d(y)
|
|
||||||
_, _, h_, w_ = y.size()
|
|
||||||
|
|
||||||
# normalization first, then reshape -> (B, H, W, C) -> (B, C, H * W) and generate q, k and v
|
|
||||||
y = self.norm(y)
|
|
||||||
q = self.q(y)
|
|
||||||
k = self.k(y)
|
|
||||||
v = self.v(y)
|
|
||||||
# (B, C, H, W) -> (B, head_num, head_dim, N)
|
|
||||||
q = rearrange(q, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
|
||||||
head_dim=int(self.head_dim))
|
|
||||||
k = rearrange(k, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
|
||||||
head_dim=int(self.head_dim))
|
|
||||||
v = rearrange(v, 'b (head_num head_dim) h w -> b head_num head_dim (h w)', head_num=int(self.head_num),
|
|
||||||
head_dim=int(self.head_dim))
|
|
||||||
|
|
||||||
# (B, head_num, head_dim, head_dim)
|
|
||||||
attn = q @ k.transpose(-2, -1) * self.scaler
|
|
||||||
attn = self.attn_drop(attn.softmax(dim=-1))
|
|
||||||
# (B, head_num, head_dim, N)
|
|
||||||
attn = attn @ v
|
|
||||||
# (B, C, H_, W_)
|
|
||||||
attn = rearrange(attn, 'b head_num head_dim (h w) -> b (head_num head_dim) h w', h=int(h_), w=int(w_))
|
|
||||||
# (B, C, 1, 1)
|
|
||||||
attn = attn.mean((2, 3), keepdim=True)
|
|
||||||
attn = self.ca_gate(attn)
|
|
||||||
return attn * x
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
block = SCSA(
|
|
||||||
dim=256,
|
|
||||||
head_num=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
input_tensor = torch.rand(1, 256, 32, 32)
|
|
||||||
|
|
||||||
# 调用模块进行前向传播
|
|
||||||
output_tensor = block(input_tensor)
|
|
||||||
|
|
||||||
# 打印输入和输出张量的大小
|
|
||||||
print("Input size:", input_tensor.size())
|
|
||||||
print("Output size:", output_tensor.size())
|
|
@ -1,37 +0,0 @@
|
|||||||
'''-------------一、SE模块-----------------------------'''
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
# 全局平均池化+1*1卷积核+ReLu+1*1卷积核+Sigmoid
|
|
||||||
class SE_Block(nn.Module):
|
|
||||||
def __init__(self, inchannel, ratio=16):
|
|
||||||
super(SE_Block, self).__init__()
|
|
||||||
# 全局平均池化(Fsq操作)
|
|
||||||
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
|
||||||
# 两个全连接层(Fex操作)
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(inchannel, inchannel // ratio, bias=False), # 从 c -> c/r
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(inchannel // ratio, inchannel, bias=False), # 从 c/r -> c
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# 读取批数据图片数量及通道数
|
|
||||||
b, c, h, w = x.size()
|
|
||||||
# Fsq操作:经池化后输出b*c的矩阵
|
|
||||||
y = self.gap(x).view(b, c)
|
|
||||||
# Fex操作:经全连接层输出(b,c,1,1)矩阵
|
|
||||||
y = self.fc(y).view(b, c, 1, 1)
|
|
||||||
# Fscale操作:将得到的权重乘以原来的特征图x
|
|
||||||
return x * y.expand_as(x)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
input = torch.randn(1, 64, 32, 32)
|
|
||||||
seblock = SE_Block(64)
|
|
||||||
print(seblock)
|
|
||||||
output = seblock(input)
|
|
||||||
print(input.shape)
|
|
||||||
print(output.shape)
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
|||||||
import os
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def transfer(input_path, quality=20, resize_factor=0.1):
|
|
||||||
# 打开TIFF图像
|
|
||||||
# img = Image.open(input_path)
|
|
||||||
#
|
|
||||||
# # 保存为JPEG,并设置压缩质量
|
|
||||||
# img.save(output_path, 'JPEG', quality=quality)
|
|
||||||
|
|
||||||
# input_path = os.path.join(input_folder, filename)
|
|
||||||
# 获取input_path的文件名
|
|
||||||
|
|
||||||
# 使用os.path.splitext获取文件名和后缀的元组
|
|
||||||
# 使用os.path.basename获取文件名(包含后缀)
|
|
||||||
filename_with_extension = os.path.basename(input_path)
|
|
||||||
filename, file_extension = os.path.splitext(filename_with_extension)
|
|
||||||
|
|
||||||
# 使用os.path.dirname获取文件所在的目录路径
|
|
||||||
output_folder = os.path.dirname(input_path)
|
|
||||||
|
|
||||||
output_path = os.path.join(output_folder, filename + '.jpg')
|
|
||||||
|
|
||||||
img = Image.open(input_path)
|
|
||||||
|
|
||||||
# 将图像缩小到原来的一半
|
|
||||||
new_width = int(img.width * resize_factor)
|
|
||||||
new_height = int(img.height * resize_factor)
|
|
||||||
resized_img = img.resize((new_width, new_height))
|
|
||||||
|
|
||||||
# 保存为JPEG,并设置压缩质量
|
|
||||||
# 转换为RGB模式,丢弃透明通道
|
|
||||||
rgb_img = resized_img.convert('RGB')
|
|
||||||
|
|
||||||
# 保存为JPEG,并设置压缩质量
|
|
||||||
# 压缩
|
|
||||||
rgb_img.save(output_path, 'JPEG', quality=quality)
|
|
||||||
|
|
||||||
print(f'{output_path} 转换完成')
|
|
||||||
|
|
||||||
return output_path
|
|
@ -1,35 +0,0 @@
|
|||||||
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
|
|
||||||
|
|
||||||
# base pcffuse
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 2.39 33.82 11.32 0.81 0.8 0.12 0.07 0.11
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
Process finished with exit code 0
|
|
@ -1,33 +0,0 @@
|
|||||||
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.01 40.67 15.39 1.53 1.76 0.64 0.53 0.95
|
|
||||||
================================================================================
|
|
@ -1,89 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.14 46.48 13.18 2.22 1.76 0.79 0.56 1.02
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of RoadScene :
|
|
||||||
FLIR_07206.jpg
|
|
||||||
FLIR_08202.jpg
|
|
||||||
FLIR_05893.jpg
|
|
||||||
FLIR_06974.jpg
|
|
||||||
FLIR_04424.jpg
|
|
||||||
FLIR_08284.jpg
|
|
||||||
FLIR_07786.jpg
|
|
||||||
FLIR_08021.jpg
|
|
||||||
FLIR_07968.jpg
|
|
||||||
FLIR_01130.jpg
|
|
||||||
FLIR_06993.jpg
|
|
||||||
FLIR_07190.jpg
|
|
||||||
FLIR_06570.jpg
|
|
||||||
FLIR_07809.jpg
|
|
||||||
FLIR_06430.jpg
|
|
||||||
FLIR_08592.jpg
|
|
||||||
FLIR_00211.jpg
|
|
||||||
FLIR_08721.jpg
|
|
||||||
FLIR_05955.jpg
|
|
||||||
FLIR_04688.jpg
|
|
||||||
FLIR_07732.jpg
|
|
||||||
FLIR_06392.jpg
|
|
||||||
FLIR_00977.jpg
|
|
||||||
FLIR_05105.jpg
|
|
||||||
FLIR_04269.jpg
|
|
||||||
FLIR_07970.jpg
|
|
||||||
FLIR_05005.jpg
|
|
||||||
FLIR_07209.jpg
|
|
||||||
FLIR_07555.jpg
|
|
||||||
FLIR_06325.jpg
|
|
||||||
FLIR_04943.jpg
|
|
||||||
FLIR_video_02829.jpg
|
|
||||||
FLIR_08248.jpg
|
|
||||||
FLIR_04484.jpg
|
|
||||||
FLIR_08058.jpg
|
|
||||||
FLIR_06795.jpg
|
|
||||||
FLIR_06995.jpg
|
|
||||||
FLIR_05879.jpg
|
|
||||||
FLIR_04593.jpg
|
|
||||||
FLIR_08094.jpg
|
|
||||||
FLIR_08526.jpg
|
|
||||||
FLIR_08858.jpg
|
|
||||||
FLIR_09465.jpg
|
|
||||||
FLIR_05064.jpg
|
|
||||||
FLIR_05857.jpg
|
|
||||||
FLIR_05914.jpg
|
|
||||||
FLIR_04722.jpg
|
|
||||||
FLIR_06506.jpg
|
|
||||||
FLIR_06282.jpg
|
|
||||||
FLIR_04512.jpg
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.41 52.99 15.81 2.37 1.78 0.71 0.55 0.96
|
|
||||||
================================================================================
|
|
@ -1,24 +0,0 @@
|
|||||||
2.4.1+cu121
|
|
||||||
True
|
|
||||||
Model: PFCFuse
|
|
||||||
Number of epochs: 60
|
|
||||||
Epoch gap: 40
|
|
||||||
Learning rate: 0.0001
|
|
||||||
Weight decay: 0
|
|
||||||
Batch size: 1
|
|
||||||
GPU number: 0
|
|
||||||
Coefficient of MSE loss VF: 1.0
|
|
||||||
Coefficient of MSE loss IF: 1.0
|
|
||||||
Coefficient of RMI loss VF: 1.0
|
|
||||||
Coefficient of RMI loss IF: 1.0
|
|
||||||
Coefficient of Cosine loss VF: 1.0
|
|
||||||
Coefficient of Cosine loss IF: 1.0
|
|
||||||
Coefficient of Decomposition loss: 2.0
|
|
||||||
Coefficient of Total Variation loss: 5.0
|
|
||||||
Clip gradient norm value: 0.01
|
|
||||||
Optimization step: 20
|
|
||||||
Optimization gamma: 0.5
|
|
||||||
[Epoch 0/60] [Batch 0/6487] [loss: 10.193036] ETA: 9 days, 21
[Epoch 0/60] [Batch 1/6487] [loss: 4.166963] ETA: 11:49:56.5
[Epoch 0/60] [Batch 2/6487] [loss: 10.681509] ETA: 10:23:19.1
[Epoch 0/60] [Batch 3/6487] [loss: 6.257133] ETA: 10:31:48.3
[Epoch 0/60] [Batch 4/6487] [loss: 13.018341] ETA: 10:32:54.2
[Epoch 0/60] [Batch 5/6487] [loss: 11.268185] ETA: 10:27:32.2
[Epoch 0/60] [Batch 6/6487] [loss: 6.920656] ETA: 10:34:01.5
[Epoch 0/60] [Batch 7/6487] [loss: 4.666215] ETA: 10:32:45.3
[Epoch 0/60] [Batch 8/6487] [loss: 10.787085] ETA: 10:26:01.9
[Epoch 0/60] [Batch 9/6487] [loss: 5.754866] ETA: 10:34:34.2
[Epoch 0/60] [Batch 10/6487] [loss: 28.760792] ETA: 10:36:32.6
[Epoch 0/60] [Batch 11/6487] [loss: 8.672796] ETA: 10:25:11.9
[Epoch 0/60] [Batch 12/6487] [loss: 14.300608] ETA: 10:28:19.9
[Epoch 0/60] [Batch 13/6487] [loss: 11.821722] ETA: 10:34:18.6
[Epoch 0/60] [Batch 14/6487] [loss: 7.627745] ETA: 10:31:44.8
[Epoch 0/60] [Batch 15/6487] [loss: 5.722600] ETA: 10:34:17.4
[Epoch 0/60] [Batch 16/6487] [loss: 10.423873] ETA: 11:33:27.1
[Epoch 0/60] [Batch 17/6487] [loss: 4.454098] ETA: 9:37:13.67
[Epoch 0/60] [Batch 18/6487] [loss: 3.820719] ETA: 9:33:57.42
[Epoch 0/60] [Batch 19/6487] [loss: 6.564124] ETA: 9:41:22.09
[Epoch 0/60] [Batch 20/6487] [loss: 5.406681] ETA: 9:47:30.11
[Epoch 0/60] [Batch 21/6487] [loss: 25.275440] ETA: 9:39:29.91
[Epoch 0/60] [Batch 22/6487] [loss: 4.228334] ETA: 9:42:15.45
[Epoch 0/60] [Batch 23/6487] [loss: 22.508118] ETA: 9:38:18.10
[Epoch 0/60] [Batch 24/6487] [loss: 5.062001] ETA: 9:46:09.29
[Epoch 0/60] [Batch 25/6487] [loss: 3.157355] ETA: 9:41:30.09
[Epoch 0/60] [Batch 26/6487] [loss: 6.438435] ETA: 10:02:51.9
[Epoch 0/60] [Batch 27/6487] [loss: 7.430470] ETA: 9:18:12.94
[Epoch 0/60] [Batch 28/6487] [loss: 3.783903] ETA: 10:41:13.9
[Epoch 0/60] [Batch 29/6487] [loss: 2.954306] ETA: 9:44:25.10
[Epoch 0/60] [Batch 30/6487] [loss: 5.863827] ETA: 9:35:13.84
[Epoch 0/60] [Batch 31/6487] [loss: 6.467144] ETA: 9:46:19.80
[Epoch 0/60] [Batch 32/6487] [loss: 4.801052] ETA: 9:32:17.18
[Epoch 0/60] [Batch 33/6487] [loss: 5.658401] ETA: 9:31:10.28
[Epoch 0/60] [Batch 34/6487] [loss: 2.085633] ETA: 9:36:47.39
[Epoch 0/60] [Batch 35/6487] [loss: 15.402915] ETA: 9:40:51.43
[Epoch 0/60] [Batch 36/6487] [loss: 3.181264] ETA: 9:33:06.65
[Epoch 0/60] [Batch 37/6487] [loss: 3.883055] ETA: 9:42:29.60
[Epoch 0/60] [Batch 38/6487] [loss: 3.342676] ETA: 10:07:02.2
[Epoch 0/60] [Batch 39/6487] [loss: 2.589705] ETA: 9:36:43.32
[Epoch 0/60] [Batch 40/6487] [loss: 3.742121] ETA: 9:42:57.54
[Epoch 0/60] [Batch 41/6487] [loss: 2.732829] ETA: 9:36:54.65
[Epoch 0/60] [Batch 42/6487] [loss: 6.655626] ETA: 9:42:20.71
[Epoch 0/60] [Batch 43/6487] [loss: 1.822412] ETA: 9:38:02.02
[Epoch 0/60] [Batch 44/6487] [loss: 2.875143] ETA: 9:41:02.96
[Epoch 0/60] [Batch 45/6487] [loss: 2.319836] ETA: 9:38:16.23
[Epoch 0/60] [Batch 46/6487] [loss: 2.354790] ETA: 9:39:08.93
[Epoch 0/60] [Batch 47/6487] [loss: 1.986412] ETA: 9:52:11.40
[Epoch 0/60] [Batch 48/6487] [loss: 2.154071] ETA: 10:08:20.0
[Epoch 0/60] [Batch 49/6487] [loss: 1.425418] ETA: 9:54:04.42
[Epoch 0/60] [Batch 50/6487] [loss: 1.988360] ETA: 9:30:25.08
[Epoch 0/60] [Batch 51/6487] [loss: 4.090429] ETA: 9:43:53.52
[Epoch 0/60] [Batch 52/6487] [loss: 1.924778] ETA: 9:46:19.38
[Epoch 0/60] [Batch 53/6487] [loss: 2.191964] ETA: 9:46:59.93
[Epoch 0/60] [Batch 54/6487] [loss: 2.032799] ETA: 9:46:14.01
[Epoch 0/60] [Batch 55/6487] [loss: 1.923933] ETA: 9:44:21.65
[Epoch 0/60] [Batch 56/6487] [loss: 1.656838] ETA: 9:56:15.90
[Epoch 0/60] [Batch 57/6487] [loss: 1.656845] ETA: 10:21:26.1
[Epoch 0/60] [Batch 58/6487] [loss: 1.157820] ETA: 10:44:47.3
[Epoch 0/60] [Batch 59/6487] [loss: 1.652715] ETA: 10:46:39.9
[Epoch 0/60] [Batch 60/6487] [loss: 1.633865] ETA: 10:23:34.7
[Epoch 0/60] [Batch 61/6487] [loss: 1.290259] ETA: 9:24:06.12Traceback (most recent call last):
|
|
||||||
File "/home/star/whaiDir/PFCFuse/train.py", line 232, in <module>
|
|
||||||
loss.item(),
|
|
||||||
KeyboardInterrupt
|
|
@ -1,38 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.01 40.4 15.51 1.55 1.75 0.66 0.54 0.96
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
Process finished with exit code 0
|
|
285
net.py
@ -6,10 +6,10 @@ import torch.utils.checkpoint as checkpoint
|
|||||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from componets.SCSA import SCSA
|
|
||||||
from componets.WTConvCV2 import WTConv2d
|
from componets.WTConvCV2 import WTConv2d
|
||||||
|
|
||||||
|
|
||||||
|
# 以一定概率随机丢弃输入张量中的路径,用于正则化模型
|
||||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||||
if drop_prob == 0. or not training:
|
if drop_prob == 0. or not training:
|
||||||
return x
|
return x
|
||||||
@ -35,6 +35,9 @@ class DropPath(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return drop_path(x, self.drop_prob, self.training)
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 改点,使用Pooling替换AttentionBase
|
||||||
class Pooling(nn.Module):
|
class Pooling(nn.Module):
|
||||||
def __init__(self, kernel_size=3):
|
def __init__(self, kernel_size=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -47,8 +50,8 @@ class Pooling(nn.Module):
|
|||||||
|
|
||||||
class PoolMlp(nn.Module):
|
class PoolMlp(nn.Module):
|
||||||
"""
|
"""
|
||||||
Implementation of MLP with 1*1 convolutions.
|
实现基于1x1卷积的MLP模块。
|
||||||
Input: tensor with shape [B, C, H, W]
|
输入:形状为[B, C, H, W]的张量。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -58,6 +61,17 @@ class PoolMlp(nn.Module):
|
|||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
bias=False,
|
bias=False,
|
||||||
drop=0.):
|
drop=0.):
|
||||||
|
"""
|
||||||
|
初始化PoolMlp模块。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
in_features (int): 输入特征的数量。
|
||||||
|
hidden_features (int, 可选): 隐藏层特征的数量。默认为None,设置为与in_features相同。
|
||||||
|
out_features (int, 可选): 输出特征的数量。默认为None,设置为与in_features相同。
|
||||||
|
act_layer (nn.Module, 可选): 使用的激活层。默认为nn.GELU。
|
||||||
|
bias (bool, 可选): 是否在卷积层中包含偏置项。默认为False。
|
||||||
|
drop (float, 可选): Dropout比率。默认为0。
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_features = out_features or in_features
|
out_features = out_features or in_features
|
||||||
hidden_features = hidden_features or in_features
|
hidden_features = hidden_features or in_features
|
||||||
@ -67,6 +81,15 @@ class PoolMlp(nn.Module):
|
|||||||
self.drop = nn.Dropout(drop)
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
通过PoolMlp模块的前向传播。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
x (torch.Tensor): 形状为[B, C, H, W]的输入张量。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
torch.Tensor: 形状为[B, C, H, W]的输出张量。
|
||||||
|
"""
|
||||||
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
|
x = self.fc1(x) # (B, C, H, W) --> (B, C, H, W)
|
||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
x = self.drop(x)
|
x = self.drop(x)
|
||||||
@ -75,46 +98,53 @@ class PoolMlp(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class BaseFeatureFusion(nn.Module):
|
# class BaseFeatureExtraction1(nn.Module):
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
# def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
act_layer=nn.GELU,
|
# act_layer=nn.GELU,
|
||||||
# norm_layer=nn.LayerNorm,
|
# # norm_layer=nn.LayerNorm,
|
||||||
drop=0., drop_path=0.,
|
# drop=0., drop_path=0.,
|
||||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
# use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||||
|
#
|
||||||
super().__init__()
|
# super().__init__()
|
||||||
|
#
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
# self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
# self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
# mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
# self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
||||||
act_layer=act_layer, drop=drop)
|
# act_layer=act_layer, drop=drop)
|
||||||
|
#
|
||||||
# The following two techniques are useful to train deep PoolFormers.
|
# # The following two techniques are useful to train deep PoolFormers.
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
# self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||||
else nn.Identity()
|
# else nn.Identity()
|
||||||
self.use_layer_scale = use_layer_scale
|
# self.use_layer_scale = use_layer_scale
|
||||||
|
#
|
||||||
if use_layer_scale:
|
# if use_layer_scale:
|
||||||
self.layer_scale_1 = nn.Parameter(
|
# self.layer_scale_1 = nn.Parameter(
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
#
|
||||||
self.layer_scale_2 = nn.Parameter(
|
# self.layer_scale_2 = nn.Parameter(
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
# torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
#
|
||||||
def forward(self, x):
|
# def forward(self, x): # 1 64 128 128
|
||||||
if self.use_layer_scale:
|
# if self.use_layer_scale:
|
||||||
x = x + self.drop_path(
|
# # self.layer_scale_1(64,)
|
||||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
# tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
||||||
* self.token_mixer(self.norm1(x)))
|
# normal = self.norm1(x) # 1 64 128 128
|
||||||
x = x + self.drop_path(
|
# token_mix = self.token_mixer(normal) # 1 64 128 128
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
# x = (x +
|
||||||
* self.poolmlp(self.norm2(x)))
|
# self.drop_path(
|
||||||
else:
|
# tmp1 * token_mix
|
||||||
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
# )
|
||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
# # 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||||
return x
|
# )
|
||||||
|
# x = x + self.drop_path(
|
||||||
|
# self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
# * self.poolmlp(self.norm2(x)))
|
||||||
|
# else:
|
||||||
|
# x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||||
|
# x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
|
# return x
|
||||||
|
|
||||||
class BaseFeatureExtraction(nn.Module):
|
class BaseFeatureExtraction(nn.Module):
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
||||||
@ -125,6 +155,7 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.WTConv2d = WTConv2d(dim, dim)
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
self.norm1 = LayerNorm(dim, 'WithBias')
|
||||||
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
self.norm2 = LayerNorm(dim, 'WithBias')
|
||||||
@ -144,63 +175,29 @@ class BaseFeatureExtraction(nn.Module):
|
|||||||
self.layer_scale_2 = nn.Parameter(
|
self.layer_scale_2 = nn.Parameter(
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x): # 1 64 128 128
|
||||||
if self.use_layer_scale:
|
if self.use_layer_scale:
|
||||||
x = x + self.drop_path(
|
# self.layer_scale_1(64,)
|
||||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
tmp1 = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) # 64 1 1
|
||||||
* self.token_mixer(self.norm1(x)))
|
normal = self.norm1(x) # 1 64 128 128
|
||||||
|
token_mix = self.token_mixer(normal) # 1 64 128 128
|
||||||
|
|
||||||
|
x = self.WTConv2d(x)
|
||||||
|
|
||||||
|
x = (x +
|
||||||
|
self.drop_path(
|
||||||
|
tmp1 * token_mix
|
||||||
|
)
|
||||||
|
# 该表达式将 self.layer_scale_1 这个一维张量(或变量)在维度末尾添加两个新的维度,使其从一维变为三维。这通常用于使其能够与三维的特征图进行广播操作,如元素相乘。具体用途可能包括调整卷积层或注意力机制中的权重。
|
||||||
|
)
|
||||||
x = x + self.drop_path(
|
x = x + self.drop_path(
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
||||||
* self.poolmlp(self.norm2(x)))
|
* self.poolmlp(self.norm2(x)))
|
||||||
else:
|
else:
|
||||||
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
x = x + self.drop_path(self.token_mixer(self.norm1(x))) # 匹配cddfuse
|
||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class BaseFeatureExtractionSAR(nn.Module):
|
|
||||||
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
|
|
||||||
act_layer=nn.GELU,
|
|
||||||
# norm_layer=nn.LayerNorm,
|
|
||||||
drop=0., drop_path=0.,
|
|
||||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.norm1 = LayerNorm(dim, 'WithBias')
|
|
||||||
self.token_mixer = SCSA(dim=dim,head_num=8)
|
|
||||||
# self.token_mixer = Pooling(kernel_size=pool_size) # vits是msa,MLPs是mlp,这个用pool来替代
|
|
||||||
self.norm2 = LayerNorm(dim, 'WithBias')
|
|
||||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
||||||
self.poolmlp = PoolMlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
|
||||||
act_layer=act_layer, drop=drop)
|
|
||||||
|
|
||||||
# The following two techniques are useful to train deep PoolFormers.
|
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
|
||||||
else nn.Identity()
|
|
||||||
self.use_layer_scale = use_layer_scale
|
|
||||||
|
|
||||||
if use_layer_scale:
|
|
||||||
self.layer_scale_1 = nn.Parameter(
|
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
||||||
|
|
||||||
self.layer_scale_2 = nn.Parameter(
|
|
||||||
torch.ones(dim, dtype=torch.float32) * layer_scale_init_value)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.use_layer_scale:
|
|
||||||
x = x + self.drop_path(
|
|
||||||
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
* self.token_mixer(self.norm1(x)))
|
|
||||||
x = x + self.drop_path(
|
|
||||||
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
* self.poolmlp(self.norm2(x)))
|
|
||||||
else:
|
|
||||||
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
|
|
||||||
x = x + self.drop_path(self.poolmlp(self.norm2(x)))
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class InvertedResidualBlock(nn.Module):
|
class InvertedResidualBlock(nn.Module):
|
||||||
def __init__(self, inp, oup, expand_ratio):
|
def __init__(self, inp, oup, expand_ratio):
|
||||||
super(InvertedResidualBlock, self).__init__()
|
super(InvertedResidualBlock, self).__init__()
|
||||||
@ -219,44 +216,18 @@ class InvertedResidualBlock(nn.Module):
|
|||||||
nn.Conv2d(hidden_dim, oup, 1, bias=False),
|
nn.Conv2d(hidden_dim, oup, 1, bias=False),
|
||||||
# nn.BatchNorm2d(oup),
|
# nn.BatchNorm2d(oup),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.bottleneckBlock(x)
|
return self.bottleneckBlock(x)
|
||||||
|
|
||||||
class DepthwiseSeparableConvBlock(nn.Module):
|
|
||||||
def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1):
|
|
||||||
super(DepthwiseSeparableConvBlock, self).__init__()
|
|
||||||
self.depthwise = nn.Conv2d(inp, inp, kernel_size, stride, padding, groups=inp, bias=False)
|
|
||||||
self.pointwise = nn.Conv2d(inp, oup, 1, bias=False)
|
|
||||||
self.bn = nn.BatchNorm2d(oup)
|
|
||||||
self.relu = nn.ReLU(inplace=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.depthwise(x)
|
|
||||||
x = self.pointwise(x)
|
|
||||||
x = self.bn(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class DetailNode(nn.Module):
|
class DetailNode(nn.Module):
|
||||||
def __init__(self,useBlock=0):
|
|
||||||
|
# <img src = "http://42.192.130.83:9000/picgo/imgs/小绿鲸英文文献阅读器_ELTITYqm5G.png" / > '
|
||||||
|
def __init__(self):
|
||||||
super(DetailNode, self).__init__()
|
super(DetailNode, self).__init__()
|
||||||
if useBlock == 0:
|
|
||||||
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_rho = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_eta = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
elif useBlock == 1:
|
|
||||||
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
||||||
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
||||||
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
||||||
elif useBlock == 2:
|
|
||||||
self.theta_phi = WTConv2d(in_channels=32, out_channels=32)
|
|
||||||
self.theta_rho = WTConv2d(in_channels=32, out_channels=32)
|
|
||||||
self.theta_eta = WTConv2d(in_channels=32, out_channels=32)
|
|
||||||
else:
|
|
||||||
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
||||||
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
||||||
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
|
||||||
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
|
self.shffleconv = nn.Conv2d(64, 64, kernel_size=1,
|
||||||
stride=1, padding=0, bias=True)
|
stride=1, padding=0, bias=True)
|
||||||
|
|
||||||
@ -271,44 +242,30 @@ class DetailNode(nn.Module):
|
|||||||
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
|
z1 = z1 * torch.exp(self.theta_rho(z2)) + self.theta_eta(z2)
|
||||||
return z1, z2
|
return z1, z2
|
||||||
|
|
||||||
class DetailFeatureFusion(nn.Module):
|
|
||||||
def __init__(self, num_layers=3):
|
|
||||||
super(DetailFeatureFusion, self).__init__()
|
|
||||||
INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
|
|
||||||
self.net = nn.Sequential(*INNmodules)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
|
||||||
for layer in self.net:
|
|
||||||
z1, z2 = layer(z1, z2)
|
|
||||||
return torch.cat((z1, z2), dim=1)
|
|
||||||
|
|
||||||
class DetailFeatureExtraction(nn.Module):
|
class DetailFeatureExtraction(nn.Module):
|
||||||
def __init__(self, num_layers=3):
|
def __init__(self, num_layers=3):
|
||||||
super(DetailFeatureExtraction, self).__init__()
|
super(DetailFeatureExtraction, self).__init__()
|
||||||
INNmodules = [DetailNode(useBlock=1) for _ in range(num_layers)]
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
self.net = nn.Sequential(*INNmodules)
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
self.enhancement_module = nn.Sequential(
|
||||||
|
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=True),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x): # 1 64 128 128
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]] # 1 32 128 128
|
||||||
|
# 增强并添加残差连接
|
||||||
|
enhanced_z1 = self.enhancement_module(z1)
|
||||||
|
enhanced_z2 = self.enhancement_module(z2)
|
||||||
|
# 残差连接
|
||||||
|
z1 = z1 + enhanced_z1
|
||||||
|
z2 = z2 + enhanced_z2
|
||||||
for layer in self.net:
|
for layer in self.net:
|
||||||
z1, z2 = layer(z1, z2)
|
z1, z2 = layer(z1, z2)
|
||||||
return torch.cat((z1, z2), dim=1)
|
return torch.cat((z1, z2), dim=1)
|
||||||
|
|
||||||
|
|
||||||
class DetailFeatureExtractionSAR(nn.Module):
|
|
||||||
def __init__(self, num_layers=3):
|
|
||||||
super(DetailFeatureExtractionSAR, self).__init__()
|
|
||||||
INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)]
|
|
||||||
self.net = nn.Sequential(*INNmodules)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
z1, z2 = x[:, :x.shape[1] // 2], x[:, x.shape[1] // 2:x.shape[1]]
|
|
||||||
for layer in self.net:
|
|
||||||
z1, z2 = layer(z1, z2)
|
|
||||||
return torch.cat((z1, z2), dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -490,23 +447,14 @@ class Restormer_Encoder(nn.Module):
|
|||||||
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
|
||||||
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
|
||||||
self.baseFeature = BaseFeatureExtraction(dim=dim)
|
self.baseFeature = BaseFeatureExtraction(dim=dim)
|
||||||
|
|
||||||
self.detailFeature = DetailFeatureExtraction()
|
self.detailFeature = DetailFeatureExtraction()
|
||||||
|
|
||||||
self.baseFeatureSar= BaseFeatureExtractionSAR(dim=dim)
|
def forward(self, inp_img):
|
||||||
self.detailFeatureSar = DetailFeatureExtractionSAR()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, inp_img, sar_img=False):
|
|
||||||
inp_enc_level1 = self.patch_embed(inp_img)
|
inp_enc_level1 = self.patch_embed(inp_img)
|
||||||
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
out_enc_level1 = self.encoder_level1(inp_enc_level1)
|
||||||
|
base_feature = self.baseFeature(out_enc_level1)
|
||||||
if sar_img:
|
detail_feature = self.detailFeature(out_enc_level1)
|
||||||
base_feature = self.baseFeature(out_enc_level1)
|
|
||||||
detail_feature = self.detailFeature(out_enc_level1)
|
|
||||||
else:
|
|
||||||
base_feature= self.baseFeature(out_enc_level1)
|
|
||||||
detail_feature = self.detailFeature(out_enc_level1)
|
|
||||||
return base_feature, detail_feature, out_enc_level1
|
return base_feature, detail_feature, out_enc_level1
|
||||||
|
|
||||||
|
|
||||||
@ -524,8 +472,7 @@ class Restormer_Decoder(nn.Module):
|
|||||||
|
|
||||||
super(Restormer_Decoder, self).__init__()
|
super(Restormer_Decoder, self).__init__()
|
||||||
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
|
self.reduce_channel = nn.Conv2d(int(dim * 2), int(dim), kernel_size=1, bias=bias)
|
||||||
self.encoder_level2 = nn.Sequential(
|
self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
||||||
*[TransformerBlock(dim=dim, num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
|
|
||||||
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
|
||||||
self.output = nn.Sequential(
|
self.output = nn.Sequential(
|
||||||
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
|
nn.Conv2d(int(dim), int(dim) // 2, kernel_size=3,
|
||||||
@ -552,3 +499,5 @@ if __name__ == '__main__':
|
|||||||
window_size = 8
|
window_size = 8
|
||||||
modelE = Restormer_Encoder().cuda()
|
modelE = Restormer_Encoder().cuda()
|
||||||
modelD = Restormer_Decoder().cuda()
|
modelD = Restormer_Decoder().cuda()
|
||||||
|
print(modelE)
|
||||||
|
print(modelD)
|
||||||
|
136
status.md
@ -1,136 +0,0 @@
|
|||||||
PFCFuse
|
|
||||||
|
|
||||||
```angular2html
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.14 46.48 13.18 2.22 1.76 0.79 0.56 1.02
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of RoadScene :
|
|
||||||
FLIR_07206.jpg
|
|
||||||
FLIR_08202.jpg
|
|
||||||
FLIR_05893.jpg
|
|
||||||
FLIR_06974.jpg
|
|
||||||
FLIR_04424.jpg
|
|
||||||
FLIR_08284.jpg
|
|
||||||
FLIR_07786.jpg
|
|
||||||
FLIR_08021.jpg
|
|
||||||
FLIR_07968.jpg
|
|
||||||
FLIR_01130.jpg
|
|
||||||
FLIR_06993.jpg
|
|
||||||
FLIR_07190.jpg
|
|
||||||
FLIR_06570.jpg
|
|
||||||
FLIR_07809.jpg
|
|
||||||
FLIR_06430.jpg
|
|
||||||
FLIR_08592.jpg
|
|
||||||
FLIR_00211.jpg
|
|
||||||
FLIR_08721.jpg
|
|
||||||
FLIR_05955.jpg
|
|
||||||
FLIR_04688.jpg
|
|
||||||
FLIR_07732.jpg
|
|
||||||
FLIR_06392.jpg
|
|
||||||
FLIR_00977.jpg
|
|
||||||
FLIR_05105.jpg
|
|
||||||
FLIR_04269.jpg
|
|
||||||
FLIR_07970.jpg
|
|
||||||
FLIR_05005.jpg
|
|
||||||
FLIR_07209.jpg
|
|
||||||
FLIR_07555.jpg
|
|
||||||
FLIR_06325.jpg
|
|
||||||
FLIR_04943.jpg
|
|
||||||
FLIR_video_02829.jpg
|
|
||||||
FLIR_08248.jpg
|
|
||||||
FLIR_04484.jpg
|
|
||||||
FLIR_08058.jpg
|
|
||||||
FLIR_06795.jpg
|
|
||||||
FLIR_06995.jpg
|
|
||||||
FLIR_05879.jpg
|
|
||||||
FLIR_04593.jpg
|
|
||||||
FLIR_08094.jpg
|
|
||||||
FLIR_08526.jpg
|
|
||||||
FLIR_08858.jpg
|
|
||||||
FLIR_09465.jpg
|
|
||||||
FLIR_05064.jpg
|
|
||||||
FLIR_05857.jpg
|
|
||||||
FLIR_05914.jpg
|
|
||||||
FLIR_04722.jpg
|
|
||||||
FLIR_06506.jpg
|
|
||||||
FLIR_06282.jpg
|
|
||||||
FLIR_04512.jpg
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.41 52.99 15.81 2.37 1.78 0.71 0.55 0.96
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
20241008
|
|
||||||
```
|
|
||||||
/home/star/anaconda3/envs/pfcfuse/bin/python /home/star/whaiDir/PFCFuse/test_IVF.py
|
|
||||||
|
|
||||||
|
|
||||||
================================================================================
|
|
||||||
The test result of TNO :
|
|
||||||
19.png
|
|
||||||
05.png
|
|
||||||
21.png
|
|
||||||
18.png
|
|
||||||
15.png
|
|
||||||
22.png
|
|
||||||
14.png
|
|
||||||
13.png
|
|
||||||
08.png
|
|
||||||
01.png
|
|
||||||
02.png
|
|
||||||
03.png
|
|
||||||
25.png
|
|
||||||
17.png
|
|
||||||
11.png
|
|
||||||
16.png
|
|
||||||
06.png
|
|
||||||
07.png
|
|
||||||
09.png
|
|
||||||
10.png
|
|
||||||
12.png
|
|
||||||
23.png
|
|
||||||
24.png
|
|
||||||
20.png
|
|
||||||
04.png
|
|
||||||
EN SD SF MI SCD VIF Qabf SSIM
|
|
||||||
PFCFuse 7.01 40.4 15.51 1.55 1.75 0.66 0.54 0.96
|
|
||||||
================================================================================
|
|
||||||
|
|
||||||
Process finished with exit code 0
|
|
||||||
|
|
||||||
```
|
|
17
test_IVF.py
@ -1,5 +1,3 @@
|
|||||||
import datetime
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
from net import Restormer_Encoder, Restormer_Decoder, BaseFeatureExtraction, DetailFeatureExtraction
|
||||||
import os
|
import os
|
||||||
@ -13,20 +11,16 @@ import logging
|
|||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
logging.basicConfig(level=logging.CRITICAL)
|
logging.basicConfig(level=logging.CRITICAL)
|
||||||
|
|
||||||
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
|
|
||||||
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
pth_path = "whaiFusion11-18-10-04"
|
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/PFCFusion10-05-20-46.pth"
|
||||||
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/"+pth_path+".pth"
|
|
||||||
print("path_pth:{}".format(ckpt_path))
|
|
||||||
|
|
||||||
for dataset_name in ["sar"]:
|
for dataset_name in ["TNO"]:
|
||||||
print("\n"*2+"="*80)
|
print("\n"*2+"="*80)
|
||||||
model_name=pth_path
|
model_name="PFCFuse "
|
||||||
print("The test result of "+dataset_name+' :')
|
print("The test result of "+dataset_name+' :')
|
||||||
test_folder = os.path.join('test_img', dataset_name)
|
test_folder=os.path.join('/home/star/whaiDir/CDDFuse/test_img/',dataset_name)
|
||||||
test_out_folder=os.path.join('test_result',pth_path,dataset_name)
|
test_out_folder=os.path.join('test_result',dataset_name)
|
||||||
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)
|
||||||
@ -45,6 +39,7 @@ for dataset_name in ["sar"]:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
for img_name in os.listdir(os.path.join(test_folder,"ir")):
|
||||||
|
print(img_name)
|
||||||
|
|
||||||
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
data_IR=image_read_cv2(os.path.join(test_folder,"ir",img_name),mode='GRAY')[np.newaxis,np.newaxis, ...]/255.0
|
||||||
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0
|
data_VIS = cv2.split(image_read_cv2(os.path.join(test_folder, "vi", img_name), mode='YCrCb'))[0][np.newaxis, np.newaxis, ...] / 255.0
|
||||||
|
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 39 KiB |
Before Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 34 KiB |
Before Width: | Height: | Size: 33 KiB |
Before Width: | Height: | Size: 38 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 44 KiB |
Before Width: | Height: | Size: 44 KiB |
Before Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 40 KiB |
Before Width: | Height: | Size: 40 KiB |
Before Width: | Height: | Size: 40 KiB |
Before Width: | Height: | Size: 39 KiB |
Before Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 59 KiB |
Before Width: | Height: | Size: 58 KiB |
Before Width: | Height: | Size: 60 KiB |
Before Width: | Height: | Size: 66 KiB |
Before Width: | Height: | Size: 63 KiB |
Before Width: | Height: | Size: 61 KiB |
Before Width: | Height: | Size: 60 KiB |
Before Width: | Height: | Size: 57 KiB |
Before Width: | Height: | Size: 54 KiB |
Before Width: | Height: | Size: 52 KiB |
Before Width: | Height: | Size: 51 KiB |
Before Width: | Height: | Size: 53 KiB |
Before Width: | Height: | Size: 54 KiB |
Before Width: | Height: | Size: 56 KiB |
Before Width: | Height: | Size: 55 KiB |
Before Width: | Height: | Size: 52 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 49 KiB |
Before Width: | Height: | Size: 48 KiB |
Before Width: | Height: | Size: 46 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 41 KiB |
Before Width: | Height: | Size: 39 KiB |
Before Width: | Height: | Size: 38 KiB |
Before Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 36 KiB |
Before Width: | Height: | Size: 34 KiB |
Before Width: | Height: | Size: 32 KiB |
Before Width: | Height: | Size: 29 KiB |
Before Width: | Height: | Size: 26 KiB |
Before Width: | Height: | Size: 21 KiB |
Before Width: | Height: | Size: 39 KiB |
Before Width: | Height: | Size: 40 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |
Before Width: | Height: | Size: 42 KiB |