@@ -346,13 +346,11 @@ async def run(
346
346
return graph
347
347
348
348
def validate_chunk (
349
- self ,
350
- chunk_graph : Neo4jGraph ,
351
- schema : SchemaConfig
349
+ self , chunk_graph : Neo4jGraph , schema : SchemaConfig
352
350
) -> Neo4jGraph :
353
351
"""
354
- Perform validation after entity and relation extraction:
355
- - Enforce schema if schema enforcement mode is on and schema is provided
352
+ Perform validation after entity and relation extraction:
353
+ - Enforce schema if schema enforcement mode is on and schema is provided
356
354
"""
357
355
if self .enforce_schema != SchemaEnforcementMode .NONE :
358
356
if not schema or not schema .entities : # schema is not provided
@@ -365,9 +363,9 @@ def validate_chunk(
365
363
return chunk_graph
366
364
367
365
def _clean_graph (
368
- self ,
369
- graph : Neo4jGraph ,
370
- schema : SchemaConfig ,
366
+ self ,
367
+ graph : Neo4jGraph ,
368
+ schema : SchemaConfig ,
371
369
) -> Neo4jGraph :
372
370
"""
373
371
Verify that the graph conforms to the provided schema.
@@ -389,17 +387,15 @@ def _clean_graph(
389
387
return Neo4jGraph (nodes = filtered_nodes , relationships = filtered_rels )
390
388
391
389
def _enforce_nodes (
392
- self ,
393
- extracted_nodes : List [Neo4jNode ],
394
- schema : SchemaConfig
390
+ self , extracted_nodes : List [Neo4jNode ], schema : SchemaConfig
395
391
) -> List [Neo4jNode ]:
396
392
"""
397
- Filter extracted nodes to be conformant to the schema.
393
+ Filter extracted nodes to be conformant to the schema.
398
394
399
- Keep only those whose label is in schema.
400
- For each valid node, filter out properties not present in the schema.
401
- Remove a node if it ends up with no valid properties.
402
- """
395
+ Keep only those whose label is in schema.
396
+ For each valid node, filter out properties not present in the schema.
397
+ Remove a node if it ends up with no valid properties.
398
+ """
403
399
if self .enforce_schema != SchemaEnforcementMode .STRICT :
404
400
return extracted_nodes
405
401
@@ -424,10 +420,10 @@ def _enforce_nodes(
424
420
return valid_nodes
425
421
426
422
def _enforce_relationships (
427
- self ,
428
- extracted_relationships : List [Neo4jRelationship ],
429
- filtered_nodes : List [Neo4jNode ],
430
- schema : SchemaConfig
423
+ self ,
424
+ extracted_relationships : List [Neo4jRelationship ],
425
+ filtered_nodes : List [Neo4jNode ],
426
+ schema : SchemaConfig ,
431
427
) -> List [Neo4jRelationship ]:
432
428
"""
433
429
Filter extracted nodes to be conformant to the schema.
@@ -451,8 +447,10 @@ def _enforce_relationships(
451
447
if not schema_relation :
452
448
continue
453
449
454
- if (rel .start_node_id not in valid_nodes or
455
- rel .end_node_id not in valid_nodes ):
450
+ if (
451
+ rel .start_node_id not in valid_nodes
452
+ or rel .end_node_id not in valid_nodes
453
+ ):
456
454
continue
457
455
458
456
start_label = valid_nodes [rel .start_node_id ]
@@ -461,8 +459,11 @@ def _enforce_relationships(
461
459
tuple_valid = True
462
460
if potential_schema :
463
461
tuple_valid = (start_label , rel .type , end_label ) in potential_schema
464
- reverse_tuple_valid = ((end_label , rel .type , start_label ) in
465
- potential_schema )
462
+ reverse_tuple_valid = (
463
+ end_label ,
464
+ rel .type ,
465
+ start_label ,
466
+ ) in potential_schema
466
467
467
468
if not tuple_valid and not reverse_tuple_valid :
468
469
continue
@@ -483,18 +484,13 @@ def _enforce_relationships(
483
484
return valid_rels
484
485
485
486
def _enforce_properties (
486
- self ,
487
- properties : Dict [str , Any ],
488
- valid_properties : List [Dict [str , Any ]]
487
+ self , properties : Dict [str , Any ], valid_properties : List [Dict [str , Any ]]
489
488
) -> Dict [str , Any ]:
490
489
"""
491
490
Filter properties.
492
491
Keep only those that exist in schema (i.e., valid properties).
493
492
"""
494
493
valid_prop_names = {prop ["name" ] for prop in valid_properties }
495
494
return {
496
- key : value
497
- for key , value in properties .items ()
498
- if key in valid_prop_names
495
+ key : value for key , value in properties .items () if key in valid_prop_names
499
496
}
500
-
0 commit comments