@@ -346,13 +346,11 @@ async def run(
346
346
return graph
347
347
348
348
def validate_chunk (
349
- self ,
350
- chunk_graph : Neo4jGraph ,
351
- schema : SchemaConfig
349
+ self , chunk_graph : Neo4jGraph , schema : SchemaConfig
352
350
) -> Neo4jGraph :
353
351
"""
354
- Perform validation after entity and relation extraction:
355
- - Enforce schema if schema enforcement mode is on and schema is provided
352
+ Perform validation after entity and relation extraction:
353
+ - Enforce schema if schema enforcement mode is on and schema is provided
356
354
"""
357
355
if self .enforce_schema != SchemaEnforcementMode .NONE :
358
356
if not schema or not schema .entities : # schema is not provided
@@ -365,9 +363,9 @@ def validate_chunk(
365
363
return chunk_graph
366
364
367
365
def _clean_graph (
368
- self ,
369
- graph : Neo4jGraph ,
370
- schema : SchemaConfig ,
366
+ self ,
367
+ graph : Neo4jGraph ,
368
+ schema : SchemaConfig ,
371
369
) -> Neo4jGraph :
372
370
"""
373
371
Verify that the graph conforms to the provided schema.
@@ -389,17 +387,15 @@ def _clean_graph(
389
387
return Neo4jGraph (nodes = filtered_nodes , relationships = filtered_rels )
390
388
391
389
def _enforce_nodes (
392
- self ,
393
- extracted_nodes : List [Neo4jNode ],
394
- schema : SchemaConfig
390
+ self , extracted_nodes : List [Neo4jNode ], schema : SchemaConfig
395
391
) -> List [Neo4jNode ]:
396
392
"""
397
- Filter extracted nodes to be conformant to the schema.
393
+ Filter extracted nodes to be conformant to the schema.
398
394
399
- Keep only those whose label is in schema.
400
- For each valid node, filter out properties not present in the schema.
401
- Remove a node if it ends up with no valid properties.
402
- """
395
+ Keep only those whose label is in schema.
396
+ For each valid node, filter out properties not present in the schema.
397
+ Remove a node if it ends up with no valid properties.
398
+ """
403
399
if self .enforce_schema != SchemaEnforcementMode .STRICT :
404
400
return extracted_nodes
405
401
@@ -424,10 +420,10 @@ def _enforce_nodes(
424
420
return valid_nodes
425
421
426
422
def _enforce_relationships (
427
- self ,
428
- extracted_relationships : List [Neo4jRelationship ],
429
- filtered_nodes : List [Neo4jNode ],
430
- schema : SchemaConfig
423
+ self ,
424
+ extracted_relationships : List [Neo4jRelationship ],
425
+ filtered_nodes : List [Neo4jNode ],
426
+ schema : SchemaConfig ,
431
427
) -> List [Neo4jRelationship ]:
432
428
"""
433
429
Filter extracted nodes to be conformant to the schema.
@@ -447,12 +443,16 @@ def _enforce_relationships(
447
443
potential_schema = schema .potential_schema
448
444
449
445
for rel in extracted_relationships :
450
- schema_relation = schema .relations .get (rel .type )
446
+ schema_relation = (
447
+ schema .relations .get (rel .type ) if schema .relations else None
448
+ )
451
449
if not schema_relation :
452
450
continue
453
451
454
- if (rel .start_node_id not in valid_nodes or
455
- rel .end_node_id not in valid_nodes ):
452
+ if (
453
+ rel .start_node_id not in valid_nodes
454
+ or rel .end_node_id not in valid_nodes
455
+ ):
456
456
continue
457
457
458
458
start_label = valid_nodes [rel .start_node_id ]
@@ -461,8 +461,11 @@ def _enforce_relationships(
461
461
tuple_valid = True
462
462
if potential_schema :
463
463
tuple_valid = (start_label , rel .type , end_label ) in potential_schema
464
- reverse_tuple_valid = ((end_label , rel .type , start_label ) in
465
- potential_schema )
464
+ reverse_tuple_valid = (
465
+ end_label ,
466
+ rel .type ,
467
+ start_label ,
468
+ ) in potential_schema
466
469
467
470
if not tuple_valid and not reverse_tuple_valid :
468
471
continue
@@ -483,18 +486,13 @@ def _enforce_relationships(
483
486
return valid_rels
484
487
485
488
def _enforce_properties (
486
- self ,
487
- properties : Dict [str , Any ],
488
- valid_properties : List [Dict [str , Any ]]
489
+ self , properties : Dict [str , Any ], valid_properties : List [Dict [str , Any ]]
489
490
) -> Dict [str , Any ]:
490
491
"""
491
492
Filter properties.
492
493
Keep only those that exist in schema (i.e., valid properties).
493
494
"""
494
495
valid_prop_names = {prop ["name" ] for prop in valid_properties }
495
496
return {
496
- key : value
497
- for key , value in properties .items ()
498
- if key in valid_prop_names
497
+ key : value for key , value in properties .items () if key in valid_prop_names
499
498
}
500
-
0 commit comments