Skip to content

Commit ab87df3

Browse files
committed
change from keras_core to keras 3.0
1 parent eab9358 commit ab87df3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+284
-191
lines changed

docs/environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- pip>=20.1 # pip is needed as dependency
1313
- pandoc
1414
- pip:
15-
- keras-core
15+
- keras>=3.0.0
1616
- jinja2==3.0.3
1717
- numpy
1818
- tensorflow-cpu

docs/requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
keras-core
1+
keras>=3.0.0
22
sphinx==3.4.3
33
sphinx_rtd_theme
44
jinja2==3.0.0
@@ -8,7 +8,7 @@ scipy
88
matplotlib
99
pandas
1010
scikit-learn
11-
rdkit-pypi
11+
rdkit
1212
pyyaml
1313
pymatgen
1414
ase

kgcnn/backend/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import tree
2-
from keras_core import KerasTensor
3-
from keras_core.backend import backend
2+
from keras import KerasTensor
3+
from keras.backend import backend
44

55

66
def any_symbolic_tensors(args=None, kwargs=None):

kgcnn/initializers/initializers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import keras_core as ks
2-
from keras_core import ops
1+
import keras as ks
2+
from keras import ops
33

44

55
def _compute_fans(shape):

kgcnn/io/loader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import keras_core as ks
1+
import keras as ks
22
import numpy as np
33
import tensorflow as tf
44

kgcnn/layers/activ.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
import keras_core as ks
2+
import keras as ks
33
# import keras_core.saving
44

55

kgcnn/layers/aggr.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import keras_core as ks
1+
import keras as ks
22
# import keras_core.saving
3-
from keras_core.layers import Layer
4-
from keras_core import ops
3+
from keras.layers import Layer
4+
from keras import ops
55
from kgcnn.ops.scatter import (
66
scatter_reduce_min, scatter_reduce_mean, scatter_reduce_max, scatter_reduce_sum, scatter_reduce_softmax)
77
from kgcnn import __indices_axis__ as global_axis_indices

kgcnn/layers/attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# import keras_core as ks
22
from kgcnn.layers.gather import GatherNodesIngoing, GatherNodesOutgoing
3-
from keras_core.layers import Dense, Concatenate, Activation, Average, Layer
3+
from keras.layers import Dense, Concatenate, Activation, Average, Layer
44
from kgcnn.layers.aggr import AggregateLocalEdgesAttention
5-
from keras_core import ops
5+
from keras import ops
66

77

88
class AttentionHeadGAT(Layer): # noqa

kgcnn/layers/casting.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import keras_core as ks
2-
from keras_core.layers import Layer
3-
from keras_core import ops
1+
import keras as ks
2+
from keras.layers import Layer
3+
from keras import ops
44
from kgcnn.ops.core import repeat_static_length, decompose_ragged_tensor
55
from kgcnn.ops.scatter import scatter_reduce_sum
66
from kgcnn import __indices_axis__ as global_axis_indices

kgcnn/layers/conv.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from keras_core.layers import Layer, Dense, Activation, Add, Multiply
1+
from keras.layers import Layer, Dense, Activation, Add, Multiply
22
from kgcnn.layers.aggr import AggregateWeightedLocalEdges, AggregateLocalEdges
33
from kgcnn.layers.gather import GatherNodesOutgoing
4-
from keras_core import ops
4+
from keras import ops
55
from kgcnn.ops.activ import shifted_softplus
66

77

kgcnn/layers/gather.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Union
2-
from keras_core.layers import Layer, Concatenate
3-
from keras_core import ops
2+
from keras.layers import Layer, Concatenate
3+
from keras import ops
44
from kgcnn import __indices_axis__ as global_axis_indices
55
from kgcnn import __index_send__ as global_index_send
66
from kgcnn import __index_receive__ as global_index_receive

kgcnn/layers/geom.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import math
22
import numpy as np
33
from typing import Union
4-
import keras_core as ks
5-
from keras_core import ops
6-
from keras_core.layers import Layer, Subtract, Multiply, Add, Subtract
4+
import keras as ks
5+
from keras import ops
6+
from keras.layers import Layer, Subtract, Multiply, Add, Subtract
77
from kgcnn.layers.gather import GatherNodes, GatherState
88
from kgcnn.ops.axis import get_positive_axis
99

@@ -991,7 +991,7 @@ def call(self, inputs, **kwargs):
991991
Returns:
992992
Tensor: Displacement vector for edges of shape `([M], 3)`.
993993
"""
994-
frac_coords, edge_indices, cell_translations = inputs[0], inputs[1], inputs[2]
994+
frac_coords, edge_indices, cell_translations = inputs
995995
# Gather sending and receiving coordinates.
996996
in_frac_coords, out_frac_coords = self.gather_node_positions([frac_coords, edge_indices], **kwargs)
997997
# Cell translation
@@ -1039,7 +1039,9 @@ def call(self, inputs, **kwargs):
10391039
"""
10401040
frac_coords, lattice_matrices, batch_id_edge = inputs
10411041
# lattice_matrices_ = ops.repeat(lattice_matrices, row_lengths, axis=0)
1042-
lattice_matrices_ = self.gather_state()([lattice_matrices, batch_id_edge])
1042+
lattice_matrices_ = self.gather_state([lattice_matrices, batch_id_edge])
1043+
# frac_to_real = ops.sum(
1044+
# ops.cast(lattice_matrices_, dtype=frac_coords.dtype) * ops.expand_dims(frac_coords, axis=-1), axis=1)
10431045
frac_to_real = ops.einsum('ij,ijk->ik', frac_coords, lattice_matrices_)
10441046
# frac_to_real_coords = ks.backend.batch_dot(frac_coords, lattice_matrices_)
10451047
return frac_to_real

kgcnn/layers/message.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import keras_core as ks
2-
from keras_core import ops
1+
import keras as ks
2+
from keras import ops
33
from kgcnn.layers.gather import GatherNodes
44
from kgcnn.layers.aggr import AggregateLocalEdges
55

kgcnn/layers/mlp.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import keras_core as ks
2-
from keras_core.layers import Dense, Layer, Activation, Dropout
3-
from keras_core.layers import LayerNormalization, GroupNormalization, BatchNormalization, UnitNormalization
1+
import keras as ks
2+
from keras.layers import Dense, Layer, Activation, Dropout
3+
from keras.layers import LayerNormalization, GroupNormalization, BatchNormalization, UnitNormalization
44
from kgcnn.layers.norm import (GraphNormalization, GraphInstanceNormalization,
55
GraphBatchNormalization, GraphLayerNormalization)
66
from kgcnn.layers.norm import global_normalization_args as global_normalization_args_graph

kgcnn/layers/modules.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import keras_core as ks
2-
from keras_core import ops
1+
import keras as ks
2+
from keras import ops
33

44

55
class Embedding(ks.layers.Layer):

kgcnn/layers/norm.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import keras_core as ks
2-
from keras_core.layers import Layer
3-
from keras_core import ops
4-
from keras_core import InputSpec
1+
import keras as ks
2+
from keras.layers import Layer
3+
from keras import ops
4+
from keras import InputSpec
55
from kgcnn.ops.scatter import scatter_reduce_sum
6-
from keras_core.layers import LayerNormalization as _LayerNormalization
7-
from keras_core.layers import BatchNormalization as _BatchNormalization
6+
from keras.layers import LayerNormalization as _LayerNormalization
7+
from keras.layers import BatchNormalization as _BatchNormalization
88

99
global_normalization_args = {
1010
"GraphNormalization": (

kgcnn/layers/pooling.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import keras_core as ks
2-
from keras_core.layers import Layer, Dense, Concatenate, GRUCell, Activation
1+
import keras as ks
2+
from keras.layers import Layer, Dense, Concatenate, GRUCell, Activation
33
from kgcnn.layers.gather import GatherState
4-
from keras_core import ops
4+
from keras import ops
55
from kgcnn.ops.scatter import scatter_reduce_softmax
66
from kgcnn.layers.aggr import Aggregate
77

kgcnn/layers/scale.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import keras_core as ks
1+
import keras as ks
22
from typing import Union
33
from kgcnn.layers.pooling import PoolingNodes
44
import numpy as np
5-
from keras_core import ops
5+
from keras import ops
66

77

88
class StandardLabelScaler(ks.layers.Layer): # noqa

kgcnn/layers/set2set.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import keras_core as ks
2-
from keras_core import ops
1+
import keras as ks
2+
from keras import ops
33
from kgcnn.ops.scatter import scatter_reduce_sum, scatter_reduce_max
44
from kgcnn.layers.aggr import Aggregate
55

kgcnn/layers/update.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import keras_core as ks
2-
from keras_core.layers import Dense, Add, Layer
1+
import keras as ks
2+
from keras.layers import Dense, Add, Layer
33

44

55
class GRUUpdate(Layer):

kgcnn/literature/AttentiveFP/_make.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import keras_core as ks
1+
import keras as ks
22
from kgcnn.layers.scale import get as get_scaler
33
from kgcnn.models.utils import update_model_kwargs
44
from kgcnn.models.casting import template_cast_output, template_cast_list_input
5-
from keras_core.backend import backend as backend_to_use
5+
from keras.backend import backend as backend_to_use
66
from kgcnn.layers.modules import Input
77
from ._model import model_disjoint
88

kgcnn/literature/AttentiveFP/_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import keras_core as ks
1+
import keras as ks
22
from kgcnn.ops.activ import *
33
from kgcnn.layers.attention import AttentiveHeadFP
44
from kgcnn.layers.mlp import MLP, GraphMLP

kgcnn/literature/CGCNN/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from ._make import make_crystal_model, model_crystal_default
2+
3+
__all__ = [
4+
"make_crystal_model",
5+
"model_crystal_default"
6+
]

kgcnn/literature/CGCNN/_layers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import keras_core as ks
1+
import keras as ks
22
from kgcnn.layers.message import MessagePassingBase
33
from kgcnn.layers.norm import GraphBatchNormalization
4-
from keras_core.layers import Activation, Multiply, Concatenate, Add, Dense
4+
from keras.layers import Activation, Multiply, Concatenate, Add, Dense
55

66

77
class CGCNNLayer(MessagePassingBase):
@@ -41,10 +41,11 @@ def __init__(self, units: int = 64,
4141
activity_regularizer=None,
4242
kernel_constraint=None,
4343
bias_constraint=None,
44+
pooling_method: str = "scatter_mean",
4445
kernel_initializer='glorot_uniform',
4546
bias_initializer='zeros',
4647
**kwargs):
47-
super(CGCNNLayer, self).__init__(use_id_tensors=4, **kwargs)
48+
super(CGCNNLayer, self).__init__(use_id_tensors=4, pooling_method=pooling_method, **kwargs)
4849
self.units = units
4950
self.use_bias = use_bias
5051
self.padded_disjoint = padded_disjoint
@@ -64,7 +65,7 @@ def __init__(self, units: int = 64,
6465
self.s = Dense(self.units, activation="linear", use_bias=use_bias, **kernel_args)
6566
self.lazy_mult = Multiply()
6667
self.lazy_add = Add()
67-
self.lazy_concat = Concatenate(axis=2)
68+
self.lazy_concat = Concatenate(axis=-1)
6869

6970
def message_function(self, inputs, **kwargs):
7071
r"""Prepare messages.
@@ -83,7 +84,6 @@ def message_function(self, inputs, **kwargs):
8384
Returns:
8485
Tensor: Messages for updates of shape `([M], units)`.
8586
"""
86-
8787
nodes_in = inputs[0] # shape: (batch_size, M, F)
8888
nodes_out = inputs[1] # shape: (batch_size, M, F)
8989
edge_features = inputs[2] # shape: (batch_size, M, E)

kgcnn/literature/CGCNN/_make.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import keras_core as ks
1+
import keras as ks
22
from kgcnn.layers.modules import Input
33
from kgcnn.models.utils import update_model_kwargs
4-
from keras_core.backend import backend as backend_to_use
4+
from keras.backend import backend as backend_to_use
55
from kgcnn.layers.scale import get as get_scaler
66
from kgcnn.models.casting import template_cast_output, template_cast_list_input
77
from ._model import model_disjoint_crystal
@@ -25,12 +25,12 @@
2525
'name': 'CGCNN',
2626
'inputs': [
2727
{'shape': (None,), 'name': 'node_number', 'dtype': 'int64'},
28-
# {'shape': (None, 1), 'name': 'multiplicities', 'dtype': 'float32'}, # For asu"
2928
{'shape': (None, 3), 'name': 'node_frac_coordinates', 'dtype': 'float64'},
30-
# {'shape': (None, 4, 4), 'name': 'symmops', 'dtype': 'float64'},
3129
{'shape': (None, 2), 'name': 'edge_indices', 'dtype': 'int64'},
3230
{'shape': (None, 3), 'name': 'cell_translations', 'dtype': 'float32'},
3331
{'shape': (3, 3), 'name': 'lattice_matrix', 'dtype': 'float64'},
32+
# {'shape': (None, 1), 'name': 'multiplicities', 'dtype': 'float32'}, # For asu"
33+
# {'shape': (None, 4, 4), 'name': 'symmops', 'dtype': 'float64'},
3434
{"shape": (), "name": "total_nodes", "dtype": "int64"},
3535
{"shape": (), "name": "total_edges", "dtype": "int64"}
3636
],
@@ -75,6 +75,7 @@ def make_crystal_model(inputs: list = None,
7575
verbose: int = None, # noqa
7676
gauss_args: dict = None,
7777
node_pooling_args: dict = None,
78+
output_to_tensor: dict = None, # noqa
7879
output_mlp: dict = None,
7980
output_embedding: str = None,
8081
output_scaling: dict = None,
@@ -88,14 +89,14 @@ def make_crystal_model(inputs: list = None,
8889
Model uses the list template of inputs and standard output template.
8990
Model supports :obj:`[node_attributes, node_frac_coordinates, bond_indices, lattice, cell_translations, ...]`
9091
if representation='unit'` and `make_distances=True` or
91-
:obj:`[node_attributes, symmops, node_frac_coords, multiplicities, bond_indices, lattice, cell_translations, ...]`
92+
:obj:`[node_attributes, node_frac_coords, bond_indices, lattice, cell_translations, multiplicities, symmops, ...]`
9293
if `representation='asu'` and `make_distances=True`
9394
or :obj:`[node_attributes, edge_distance, bond_indices, ...]`
9495
if `make_distances=False` .
9596
The optional tensor :obj:`multiplicities` is a node-like feature tensor with a single value that gives
9697
the multiplicity for each node.
9798
The optional tensor :obj:`symmops` is an edge-like feature tensor with a matrix of shape `(4, 4)` for each edge
98-
that defines the symmerty operation.
99+
that defines the symmetry operation.
99100
100101
%s
101102
@@ -125,6 +126,7 @@ def make_crystal_model(inputs: list = None,
125126
Defines number of model outputs and activation.
126127
output_scaling (dict): Dictionary of layer arguments unpacked in scaling layers. Default is None.
127128
output_tensor_type (str): Output type of graph tensors such as nodes or edges. Default is "padded".
129+
output_to_tensor (bool): Deprecated in favour of `output_tensor_type` .
128130
129131
Returns:
130132
:obj:`keras.models.Model`
@@ -136,13 +138,13 @@ def make_crystal_model(inputs: list = None,
136138
model_inputs,
137139
input_tensor_type=input_tensor_type,
138140
cast_disjoint_kwargs=cast_disjoint_kwargs,
139-
has_edges=int(not make_distances) + int(representation == "asu"),
140-
has_nodes=1 + int(make_distances) + int(representation == "asu"),
141-
has_crystal_input=2
141+
has_edges=int(not make_distances),
142+
has_nodes=1 + int(make_distances),
143+
has_crystal_input=2+2*int(representation == "asu")
142144
)
143145

144146
if representation == "asu":
145-
n, m, x, sym, djx, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = d_in
147+
n, x, djx, img, lattice, m, sym, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = d_in
146148
else:
147149
n, x, djx, img, lattice, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = d_in
148150
m, sym = None, None

kgcnn/literature/CGCNN/_model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import keras_core as ks
1+
import keras as ks
2+
from kgcnn.ops.activ import *
23
from kgcnn.layers.geom import (
34
DisplacementVectorsUnitCell,
45
DisplacementVectorsASU, NodePosition, FracToRealCoordinates,
@@ -39,7 +40,7 @@ def model_disjoint_crystal(
3940
x_in, x_out = NodePosition()([frac_coords, edge_indices])
4041
displacement_vectors = ks.layers.Subtract()([x_out, x_in])
4142

42-
displacement_vectors = FracToRealCoordinates()([displacement_vectors, lattice_matrix])
43+
displacement_vectors = FracToRealCoordinates()([displacement_vectors, lattice_matrix, batch_id_edge])
4344

4445
edge_distances = EuclideanNorm(axis=-1, keepdims=True)(displacement_vectors)
4546

@@ -51,7 +52,7 @@ def model_disjoint_crystal(
5152

5253
# embedding, if no feature dimension
5354
if use_node_embedding:
54-
n = Embedding(**input_node_embedding['node'])(atom_attributes)
55+
n = Embedding(**input_node_embedding)(atom_attributes)
5556
else:
5657
n = atom_attributes
5758

kgcnn/literature/DMPNN/_layers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import keras_core as ks
1+
import keras as ks
22
from kgcnn.layers.gather import GatherNodesOutgoing, GatherEdgesPairs
33
from kgcnn.layers.aggr import AggregateLocalEdges
4-
from keras_core.layers import Subtract
4+
from keras.layers import Subtract
55

66

77
class DMPNNPPoolingEdgesDirected(ks.layers.Layer): # noqa

0 commit comments

Comments
 (0)