Compare commits
No commits in common. "b0030fe87fa2744fa7884a5072ebbfa2de8a4433" and "3b7b64c915f55f6d4790231fcc279c41f825ffee" have entirely different histories.
b0030fe87f
...
3b7b64c915
9
net.py
9
net.py
@ -7,7 +7,6 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from componets.SCSA import SCSA
|
from componets.SCSA import SCSA
|
||||||
from componets.WTConvCV2 import WTConv2d
|
|
||||||
|
|
||||||
|
|
||||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||||
@ -249,10 +248,6 @@ class DetailNode(nn.Module):
|
|||||||
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_eta = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
elif useBlock == 2:
|
|
||||||
self.theta_phi = WTConv2d(in_channels=32, out_channels=32)
|
|
||||||
self.theta_rho = WTConv2d(in_channels=32, out_channels=32)
|
|
||||||
self.theta_eta = WTConv2d(in_channels=32, out_channels=32)
|
|
||||||
else:
|
else:
|
||||||
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_phi = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
self.theta_rho = InvertedResidualBlock(inp=32, oup=32, expand_ratio=2)
|
||||||
@ -274,7 +269,7 @@ class DetailNode(nn.Module):
|
|||||||
class DetailFeatureFusion(nn.Module):
|
class DetailFeatureFusion(nn.Module):
|
||||||
def __init__(self, num_layers=3):
|
def __init__(self, num_layers=3):
|
||||||
super(DetailFeatureFusion, self).__init__()
|
super(DetailFeatureFusion, self).__init__()
|
||||||
INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)]
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
self.net = nn.Sequential(*INNmodules)
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -286,7 +281,7 @@ class DetailFeatureFusion(nn.Module):
|
|||||||
class DetailFeatureExtraction(nn.Module):
|
class DetailFeatureExtraction(nn.Module):
|
||||||
def __init__(self, num_layers=3):
|
def __init__(self, num_layers=3):
|
||||||
super(DetailFeatureExtraction, self).__init__()
|
super(DetailFeatureExtraction, self).__init__()
|
||||||
INNmodules = [DetailNode(useBlock=2) for _ in range(num_layers)]
|
INNmodules = [DetailNode() for _ in range(num_layers)]
|
||||||
self.net = nn.Sequential(*INNmodules)
|
self.net = nn.Sequential(*INNmodules)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
Loading…
Reference in New Issue
Block a user