@@ -33,8 +33,7 @@ def call(self, inputs):
33
33
34
34
35
35
def get_probabilistic_layer (
36
- output_size ,
37
- probabilistic_layer : Union [str , dict ]
36
+ output_size , probabilistic_layer : Union [str , dict ]
38
37
) -> Callable :
39
38
"""Get the probabilistic layer."""
40
39
@@ -47,14 +46,23 @@ def get_probabilistic_layer(
47
46
48
47
if hasattr (probabilistic_layers , probabilistic_layer_name ):
49
48
_LOGGER .info (f"Using custom probabilistic layer: { probabilistic_layer_name } " )
50
- probabilistic_layer_obj = getattr (probabilistic_layers , probabilistic_layer_name )
51
- n_params = getattr (probabilistic_layers , probabilistic_layer_name ).params_size (output_size )
49
+ probabilistic_layer_obj = getattr (
50
+ probabilistic_layers , probabilistic_layer_name
51
+ )
52
+ n_params = getattr (probabilistic_layers , probabilistic_layer_name ).params_size (
53
+ output_size
54
+ )
52
55
probabilistic_layer = (
53
- probabilistic_layer_obj (output_size , name = "output" , ** probabilistic_layer_options ) if isinstance (probabilistic_layer_obj , type )
56
+ probabilistic_layer_obj (
57
+ output_size , name = "output" , ** probabilistic_layer_options
58
+ )
59
+ if isinstance (probabilistic_layer_obj , type )
54
60
else probabilistic_layer_obj (output_size , name = "output" )
55
61
)
56
62
else :
57
- raise KeyError (f"The probabilistic layer { probabilistic_layer_name } is not available." )
63
+ raise KeyError (
64
+ f"The probabilistic layer { probabilistic_layer_name } is not available."
65
+ )
58
66
59
67
return probabilistic_layer , n_params
60
68
@@ -94,7 +102,9 @@ def _build_fcn_block(
94
102
def _build_fcn_output (x , output_size , probabilistic_layer , out_bias_init ):
95
103
# probabilistic prediction
96
104
if probabilistic_layer :
97
- probabilistic_layer , n_params = get_probabilistic_layer (output_size , probabilistic_layer )
105
+ probabilistic_layer , n_params = get_probabilistic_layer (
106
+ output_size , probabilistic_layer
107
+ )
98
108
if isinstance (out_bias_init , np .ndarray ):
99
109
out_bias_init = np .hstack (
100
110
[out_bias_init , [0.0 ] * (n_params - out_bias_init .shape [0 ])]
@@ -405,7 +415,9 @@ def deep_cross_network(
405
415
406
416
# probabilistic prediction
407
417
if probabilistic_layer :
408
- probabilistic_layer , n_params = get_probabilistic_layer (output_size , probabilistic_layer )
418
+ probabilistic_layer , n_params = get_probabilistic_layer (
419
+ output_size , probabilistic_layer
420
+ )
409
421
if isinstance (out_bias_init , np .ndarray ):
410
422
out_bias_init = np .hstack (
411
423
[out_bias_init , [0.0 ] * (n_params - out_bias_init .shape [0 ])]
0 commit comments