Compare commits
2 Commits
b0030fe87f
...
ad0b45e198
Author | SHA1 | Date | |
---|---|---|---|
|
ad0b45e198 | ||
|
231f42e924 |
2
net.py
2
net.py
@ -239,7 +239,7 @@ class DepthwiseSeparableConvBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class DetailNode(nn.Module):
|
class DetailNode(nn.Module):
|
||||||
def __init__(self,useBlock=0):
|
def __init__(self,useBlock=1):
|
||||||
super(DetailNode, self).__init__()
|
super(DetailNode, self).__init__()
|
||||||
if useBlock == 0:
|
if useBlock == 0:
|
||||||
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
self.theta_phi = DepthwiseSeparableConvBlock(inp=32, oup=32)
|
||||||
|
3
train.py
3
train.py
@ -80,6 +80,9 @@ print(f"Clip gradient norm value: {clip_grad_norm_value}")
|
|||||||
print(f"Optimization step: {optim_step}")
|
print(f"Optimization step: {optim_step}")
|
||||||
print(f"Optimization gamma: {optim_gamma}")
|
print(f"Optimization gamma: {optim_gamma}")
|
||||||
|
|
||||||
|
# 控制台输入
|
||||||
|
model_str = input("Model: ")
|
||||||
|
print(f"Model: {model_str}")
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
Loading…
Reference in New Issue
Block a user