Compare commits

...

2 Commits

Author SHA1 Message Date
zjut
ad0b45e198 feat(train): 添加模型输入功能- 在训练过程中增加了用户输入模型名称的功能
- 打印出用户输入的模型名称
2024-11-17 16:15:26 +08:00
zjut
231f42e924 refactor(net): 修改 DetailNode 类的默认参数
- 将 DetailNode 类的 useBlock 参数默认值从 0 改为 1
- 这个修改可能会影响网络结构的选择和初始化

两分支wtconv
2024-11-17 16:13:59 +08:00
2 changed files with 4 additions and 1 deletions

2
net.py
View File

@ -239,7 +239,7 @@ class DepthwiseSeparableConvBlock(nn.Module):
return x return x
class DetailNode(nn.Module): class DetailNode(nn.Module):
def __init__(self,useBlock=0): def __init__(self,useBlock=1):
super(DetailNode, self).__init__() super(DetailNode, self).__init__()
if useBlock == 0: if useBlock == 0:
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32) self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)

View File

@ -80,6 +80,9 @@ print(f"Clip gradient norm value: {clip_grad_norm_value}")
print(f"Optimization step: {optim_step}") print(f"Optimization step: {optim_step}")
print(f"Optimization gamma: {optim_gamma}") print(f"Optimization gamma: {optim_gamma}")
# 控制台输入
model_str = input("Model: ")
print(f"Model: {model_str}")
# Model # Model
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'