-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create custom Graph Extension Type #420
Comments
There is tf.experimental.BatchableExtensionType Seems to work pretty straightforward? I'm using this for graphs generated by my pipeline. |
Hey, do you have a self-contained example of using this type with Spektral? Or does it require to re-write the layers? Cheers |
Likely would just be another data mode. This is what I have, but integration into the library would probably look a little different class WrappedGCN(tf.keras.layers.Layer):
def __init__(self, features, *args, **kwargs):
super(WrappedGCN, self).__init__()
self.features = features
self.layer = GCNConv(features, *args, **kwargs)
def hook(self, graph):
features = graph.features.to_tensor()
features = tf.reshape(features, (1, -1, features.shape[1]))
adj = tf.cast(graph.adjacency.to_tensor(), tf.float32)
adj = tf.reshape(adj, (1, adj.shape[0], adj.shape[1]))
return tf.RaggedTensor.from_tensor(tf.squeeze(self.layer([features, adj])))
def __call__(self, graph):
if isinstance(graph, TensorGraph):
features = tf.map_fn(self.hook, graph, tf.RaggedTensorSpec(
shape=(None, self.features), dtype=tf.float32))
return TensorGraph(
features=features,
adjacency=graph.adjacency)
return self.layer(graph)
x0 = TensorGraph(adjacency=tf.ragged.stack(adjs),
features=tf.ragged.stack(features))
x1 = WrappedGCN(6)(x0)
x2 = WrappedGCN(6)(x1) where class TensorGraph(tf.experimental.BatchableExtensionType):
"""A collection of nodes with associated feature vectors."""
features: tf.RaggedTensor
adjacency: tf.RaggedTensor
# TODO: Validation functions etc... Could probably use Sparse instead of Ragged, but using Ragged here because I need the dense Adjs. |
A "Graph" Type is even defined in the example:
This lets the user carry around graph information as a Tensorflow object, and even allow for ops on the object level. I think this would also allow for a batch of graphs, which would be great.
This would be a little nicer than carrying around multiple arrays for adjacency etc...
The text was updated successfully, but these errors were encountered: