Compare commits

...

11 Commits

Author SHA1 Message Date
zjut
f28bb255f4 build(.idea): 更新项目配置并调整测试脚本
- 更新 ProjectRootManager 配置,切换到本地 Python 3.8环境
- 修改模块配置,使用继承的 JDK
- 更新测试脚本,使用变量 pth_path 代替硬编码的模型路径- 优化测试结果输出目录结构
2024-11-18 09:47:16 +08:00
zjut
315c723399 Merge remote-tracking branch 'origin/base_vi(inn)+sar(wtconv)' into base_vi(inn)+sar(wtconv)
# Conflicts:
#	test_IVF.py
2024-11-18 09:46:16 +08:00
zjut
ef66a0321d refactor(net): 修改 DetailFeatureExtraction 和 DetailFeatureExtractionSAR 类中的 DetailNode 使用方式
- 将 DetailFeatureExtraction 类中的 DetailNode 使用方式从 useBlock=2 改为 useBlock=1
- 将 DetailFeatureExtractionSAR 类中的 DetailNode 使用方式从 useBlock=1 改为 useBlock=2
2024-11-18 09:45:51 +08:00
zjut
125a6bdf6f feat(net): 修改 DetailFeatureExtraction 和 DetailFeatureExtractionSAR 类
- 将 DetailFeatureExtraction 类中的 DetailNode 使用参数 useBlock=2
- 将 DetailFeatureExtractionSAR 类中的 DetailNode 使用参数 useBlock=1

vi wtconv
sar inn
2024-11-18 09:45:51 +08:00
zjut
be28d553fc feat(net): 添加 WTConv2d 层并修改 DetailNode 使用- 在 net.py 中添加了 WTConv2d 层的导入- 修改了 DetailNode 类的构造函数,增加了 useBlock 参数
- 根据 useBlock 参数的值,选择使用 WTConv2d层或 InvertedResidualBlock- 更新了 DetailFeatureFusion 和 DetailFeatureExtraction 类,指定了 DetailNode 的 useBlock 参数
2024-11-18 09:45:51 +08:00
zjut
ac4225c966 feat(net): 为 DetailNode模块添加可选卷积块
- 在 DetailNode 类中引入 useBlock 参数,用于选择不同的卷积块
- 新增 DepthwiseSeparableConvBlock 类,实现深度可分离卷积
- 根据 useBlock 的值,选择使用 DepthwiseSeparableConvBlock 或 InvertedResidualBlock
- 优化了网络结构,提供了更多的灵活性和选择性
2024-11-18 09:45:06 +08:00
zjut
8d99c2c4f8 feat(net): 替换 token_mixer 为 SCSA 模块
- 引入新的 SCSA(空间和通道协同注意力)模块
- 用 SCSA 替换原有的 Pooling层作为 token_mixer
- 删除了未使用的 SEBlock.py 文件- 移除了与当前项目无关的 TIAM(CV).py 文件
2024-11-18 09:45:01 +08:00
zjut
0cf1726eeb feat(net): 新增基础特征融合和细节特征融合模块
- 添加了 BaseFeatureFusion 和 DetailFeatureFusion 两个新类
- 更新了 train.py 中的导入和实例化语句
2024-11-18 09:45:01 +08:00
zjut
1bd418f0e4 feat(net): 增加 SAR 图像处理支持
- 新增 BaseFeatureExtractionSAR 和 DetailFeatureExtractionSAR 模块
- 修改 DIDF_Encoder 类,支持 SAR 图像输入
- 更新测试和训练脚本,增加 SAR 图像处理相关逻辑
2024-11-18 09:44:53 +08:00
zjut
f87a65e68e refactor(test_IVF): 重构测试代码以提高灵活性和可维护性
- 引入变量 pth_path 以动态构建模型权重路径
- 使用 pth_path替代直接使用时间戳创建输出文件夹
- 优化代码结构,提高可读性和可维护性
2024-11-18 09:31:58 +08:00
zjut
6738c9057d feat(train): 添加模型保存路径打印功能
- 优化模型保存逻辑,将保存路径存储在变量中
- 在保存模型后,打印模型的保存路径
- 这个改动可以帮助用户更容易地找到和管理训练好的模型文件
2024-11-18 09:29:10 +08:00
4 changed files with 10 additions and 6 deletions

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" jdkType="Python SDK" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">

View File

@ -12,5 +12,5 @@
</MavenGeneralSettings>
</option>
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.8.10 (sftp://star@192.168.50.108:22/home/star/anaconda3/envs/pfcfuse/bin/python)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pfcfuse)" project-jdk-type="Python SDK" />
</project>

View File

@ -17,14 +17,16 @@ current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/whaiFusion11-17-10-34.pth"
pth_path = "whaiFusion11-17-20-18"
ckpt_path= r"/home/star/whaiDir/PFCFuse/models/"+pth_path+".pth"
print("path_pth:{}".format(ckpt_path))
for dataset_name in ["sar"]:
print("\n"*2+"="*80)
model_name="PFCFuse Enhance 增加widthblock"
model_name=pth_path
print("The test result of "+dataset_name+' :')
test_folder = os.path.join('test_img', dataset_name)
test_out_folder=os.path.join('test_result',current_time,dataset_name)
test_out_folder=os.path.join('test_result',pth_path,dataset_name)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Encoder = nn.DataParallel(Restormer_Encoder()).to(device)

View File

@ -259,5 +259,7 @@ if True:
'BaseFuseLayer': BaseFuseLayer.state_dict(),
'DetailFuseLayer': DetailFuseLayer.state_dict(),
}
torch.save(checkpoint, os.path.join("models/whaiFusion"+timestamp+'.pth'))
savepth = os.path.join("models/whaiFusion" + timestamp + '.pth');
torch.save(checkpoint, savepth)
print("save model:{}".format(savepth))