Compare commits

..

2 Commits

Author SHA1 Message Date
zjut
b0030fe87f refactor(net): 修改 DetailNode 的使用方式
- 在 DetailFeatureFusion 类中,将 DetailNode 的 useBlock 参数从 1 修改为 2
- 在 DetailFeatureExtraction 类中,将 DetailNode 的 useBlock 参数从 0 修改为 2
2024-11-17 16:09:05 +08:00
zjut
260e3aa760 feat(net): 添加 WTConv2d 层并修改 DetailNode 使用- 在 net.py 中添加了 WTConv2d 层的导入- 修改了 DetailNode 类的构造函数,增加了 useBlock 参数
- 根据 useBlock 参数的值,选择使用 WTConv2d层或 InvertedResidualBlock- 更新了 DetailFeatureFusion 和 DetailFeatureExtraction 类,指定了 DetailNode 的 useBlock 参数
2024-11-17 15:57:41 +08:00

9
net.py
View File

@ -7,6 +7,7 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from einops import rearrange from einops import rearrange
from componets.SCSA import SCSA from componets.SCSA import SCSA
from componets.WTConvCV2 import WTConv2d
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):
@ -248,6 +249,10 @@ class DetailNode(nn.Module):
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
elif useBlock == 2:
self.theta_phi = WTConv2d(in_channels=32, out_channels=32)
self.theta_rho = WTConv2d(in_channels=32, out_channels=32)
self.theta_eta = WTConv2d(in_channels=32, out_channels=32)
else: else:
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2) self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
@ -269,7 +274,7 @@ class DetailNode(nn.Module):
class DetailFeatureFusion(nn.Module): class DetailFeatureFusion(nn.Module):
def __init__(self, num_layers=3): def __init__(self, num_layers=3):
super(DetailFeatureFusion, self).__init__() super(DetailFeatureFusion, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)] INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
def forward(self, x): def forward(self, x):
@ -281,7 +286,7 @@ class DetailFeatureFusion(nn.Module):
class DetailFeatureExtraction(nn.Module): class DetailFeatureExtraction(nn.Module):
def __init__(self, num_layers=3): def __init__(self, num_layers=3):
super(DetailFeatureExtraction, self).__init__() super(DetailFeatureExtraction, self).__init__()
INNmodules = [DetailNode() for _ in range(num_layers)] INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)]
self.net = nn.Sequential(*INNmodules) self.net = nn.Sequential(*INNmodules)
def forward(self, x): def forward(self, x):