@@ -41,3 +41,34 @@ def get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import,
41
41
network .apply (network .initialize )
42
42
43
43
return network
44
+
45
+ if __name__ == "__main__" :
46
+ import torch
47
+
48
+ model = get_network_from_plans (
49
+ arch_class_name = "dynamic_network_architectures.architectures.unet.ResidualEncoderUNet" ,
50
+ arch_kwargs = {
51
+ "n_stages" : 7 ,
52
+ "features_per_stage" : [32 , 64 , 128 , 256 , 512 , 512 , 512 ],
53
+ "conv_op" : "torch.nn.modules.conv.Conv2d" ,
54
+ "kernel_sizes" : [[3 , 3 ], [3 , 3 ], [3 , 3 ], [3 , 3 ], [3 , 3 ], [3 , 3 ], [3 , 3 ]],
55
+ "strides" : [[1 , 1 ], [2 , 2 ], [2 , 2 ], [2 , 2 ], [2 , 2 ], [2 , 2 ], [2 , 2 ]],
56
+ "n_blocks_per_stage" : [1 , 3 , 4 , 6 , 6 , 6 , 6 ],
57
+ "n_conv_per_stage_decoder" : [1 , 1 , 1 , 1 , 1 , 1 ],
58
+ "conv_bias" : True ,
59
+ "norm_op" : "torch.nn.modules.instancenorm.InstanceNorm2d" ,
60
+ "norm_op_kwargs" : {"eps" : 1e-05 , "affine" : True },
61
+ "dropout_op" : None ,
62
+ "dropout_op_kwargs" : None ,
63
+ "nonlin" : "torch.nn.LeakyReLU" ,
64
+ "nonlin_kwargs" : {"inplace" : True },
65
+ },
66
+ arch_kwargs_req_import = ["conv_op" , "norm_op" , "dropout_op" , "nonlin" ],
67
+ input_channels = 1 ,
68
+ output_channels = 4 ,
69
+ allow_init = True ,
70
+ deep_supervision = True ,
71
+ )
72
+ data = torch .rand ((8 , 1 , 256 , 256 ))
73
+ target = torch .rand (size = (8 , 1 , 256 , 256 ))
74
+ outputs = model (data ) # this should be a list of torch.Tensor
0 commit comments