1
- import keras_core as ks
1
+ import keras as ks
2
2
from kgcnn .layers .modules import Input
3
3
from kgcnn .models .utils import update_model_kwargs
4
- from keras_core .backend import backend as backend_to_use
4
+ from keras .backend import backend as backend_to_use
5
5
from kgcnn .layers .scale import get as get_scaler
6
6
from kgcnn .models .casting import template_cast_output , template_cast_list_input
7
7
from ._model import model_disjoint_crystal
25
25
'name' : 'CGCNN' ,
26
26
'inputs' : [
27
27
{'shape' : (None ,), 'name' : 'node_number' , 'dtype' : 'int64' },
28
- # {'shape': (None, 1), 'name': 'multiplicities', 'dtype': 'float32'}, # For asu"
29
28
{'shape' : (None , 3 ), 'name' : 'node_frac_coordinates' , 'dtype' : 'float64' },
30
- # {'shape': (None, 4, 4), 'name': 'symmops', 'dtype': 'float64'},
31
29
{'shape' : (None , 2 ), 'name' : 'edge_indices' , 'dtype' : 'int64' },
32
30
{'shape' : (None , 3 ), 'name' : 'cell_translations' , 'dtype' : 'float32' },
33
31
{'shape' : (3 , 3 ), 'name' : 'lattice_matrix' , 'dtype' : 'float64' },
32
+ # {'shape': (None, 1), 'name': 'multiplicities', 'dtype': 'float32'}, # For asu"
33
+ # {'shape': (None, 4, 4), 'name': 'symmops', 'dtype': 'float64'},
34
34
{"shape" : (), "name" : "total_nodes" , "dtype" : "int64" },
35
35
{"shape" : (), "name" : "total_edges" , "dtype" : "int64" }
36
36
],
@@ -75,6 +75,7 @@ def make_crystal_model(inputs: list = None,
75
75
verbose : int = None , # noqa
76
76
gauss_args : dict = None ,
77
77
node_pooling_args : dict = None ,
78
+ output_to_tensor : dict = None , # noqa
78
79
output_mlp : dict = None ,
79
80
output_embedding : str = None ,
80
81
output_scaling : dict = None ,
@@ -88,14 +89,14 @@ def make_crystal_model(inputs: list = None,
88
89
Model uses the list template of inputs and standard output template.
89
90
Model supports :obj:`[node_attributes, node_frac_coordinates, bond_indices, lattice, cell_translations, ...]`
90
91
if representation='unit'` and `make_distances=True` or
91
- :obj:`[node_attributes, symmops, node_frac_coords, multiplicities, bond_indices, lattice, cell_translations, ...]`
92
+ :obj:`[node_attributes, node_frac_coords, bond_indices, lattice, cell_translations, multiplicities, symmops , ...]`
92
93
if `representation='asu'` and `make_distances=True`
93
94
or :obj:`[node_attributes, edge_distance, bond_indices, ...]`
94
95
if `make_distances=False` .
95
96
The optional tensor :obj:`multiplicities` is a node-like feature tensor with a single value that gives
96
97
the multiplicity for each node.
97
98
The optional tensor :obj:`symmops` is an edge-like feature tensor with a matrix of shape `(4, 4)` for each edge
98
- that defines the symmerty operation.
99
+ that defines the symmetry operation.
99
100
100
101
%s
101
102
@@ -125,6 +126,7 @@ def make_crystal_model(inputs: list = None,
125
126
Defines number of model outputs and activation.
126
127
output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None.
127
128
output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded".
129
+ output_to_tensor (bool): Deprecated in favour of `output_tensor_type` .
128
130
129
131
Returns:
130
132
:obj:`keras.models.Model`
@@ -136,13 +138,13 @@ def make_crystal_model(inputs: list = None,
136
138
model_inputs ,
137
139
input_tensor_type = input_tensor_type ,
138
140
cast_disjoint_kwargs = cast_disjoint_kwargs ,
139
- has_edges = int (not make_distances ) + int ( representation == "asu" ) ,
140
- has_nodes = 1 + int (make_distances ) + int ( representation == "asu" ) ,
141
- has_crystal_input = 2
141
+ has_edges = int (not make_distances ),
142
+ has_nodes = 1 + int (make_distances ),
143
+ has_crystal_input = 2 + 2 * int ( representation == "asu" )
142
144
)
143
145
144
146
if representation == "asu" :
145
- n , m , x , sym , djx , img , lattice , batch_id_node , batch_id_edge , node_id , edge_id , count_nodes , count_edges = d_in
147
+ n , x , djx , img , lattice , m , sym , batch_id_node , batch_id_edge , node_id , edge_id , count_nodes , count_edges = d_in
146
148
else :
147
149
n , x , djx , img , lattice , batch_id_node , batch_id_edge , node_id , edge_id , count_nodes , count_edges = d_in
148
150
m , sym = None , None
0 commit comments