@@ -438,53 +438,16 @@ def to_csv_file(self) -> Tuple[str, str]:
438
438
os .remove (local_file_name )
439
439
temp_table_name = f'dataframe_{ temp_id .replace ("-" , "_" )} '
440
440
self ._create_temp_table (temp_table_name , desired_s3_folder )
441
- base_features = list (self ._base .columns )
442
- event_time_identifier_feature_dtype = self ._base [
443
- self ._event_time_identifier_feature_name
444
- ].dtypes
445
- self ._event_time_identifier_feature_type = (
446
- FeatureGroup .DTYPE_TO_FEATURE_DEFINITION_CLS_MAP .get (
447
- str (event_time_identifier_feature_dtype ), None
448
- )
449
- )
450
- query_string = self ._construct_query_string (
451
- FeatureGroupToBeMerged (
452
- base_features ,
453
- self ._included_feature_names if self ._included_feature_names else base_features ,
454
- self ._included_feature_names if self ._included_feature_names else base_features ,
455
- _DEFAULT_CATALOG ,
456
- _DEFAULT_DATABASE ,
457
- temp_table_name ,
458
- self ._record_identifier_feature_name ,
459
- FeatureDefinition (
460
- self ._event_time_identifier_feature_name ,
461
- self ._event_time_identifier_feature_type ,
462
- ),
463
- None ,
464
- TableType .DATA_FRAME ,
465
- )
441
+ query_result = self ._run_query (
442
+ ** self ._to_athena_query (temp_table_name = temp_table_name )
466
443
)
467
- query_result = self ._run_query (query_string , _DEFAULT_CATALOG , _DEFAULT_DATABASE )
468
444
# TODO: cleanup temp table, need more clarification, keep it for now
469
445
return query_result .get ("QueryExecution" , {}).get ("ResultConfiguration" , {}).get (
470
446
"OutputLocation" , None
471
447
), query_result .get ("QueryExecution" , {}).get ("Query" , None )
472
448
if isinstance (self ._base , FeatureGroup ):
473
- base_feature_group = construct_feature_group_to_be_merged (
474
- self ._base , self ._included_feature_names
475
- )
476
- self ._record_identifier_feature_name = base_feature_group .record_identifier_feature_name
477
- self ._event_time_identifier_feature_name = (
478
- base_feature_group .event_time_identifier_feature .feature_name
479
- )
480
- self ._event_time_identifier_feature_type = (
481
- base_feature_group .event_time_identifier_feature .feature_type
482
- )
483
- query_string = self ._construct_query_string (base_feature_group )
484
449
query_result = self ._run_query (
485
- query_string ,
486
- base_feature_group .catalog ,
487
- base_feature_group .database ,
450
+ ** self ._to_athena_query ()
488
451
)
489
452
return query_result .get ("QueryExecution" , {}).get ("ResultConfiguration" , {}).get (
490
453
"OutputLocation" , None
@@ -1058,6 +1021,67 @@ def _construct_athena_table_column_string(self, column: str) -> str:
1058
1021
raise RuntimeError (f"The dataframe type { dataframe_type } is not supported yet." )
1059
1022
return f"{ column } { self ._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP .get (str (dataframe_type ), None )} "
1060
1023
1024
+ def _to_athena_query (self , temp_table_name : str = None ) -> Tuple [str , str , str ]:
1025
+ """Internal method for constructing an Athena query.
1026
+
1027
+ Args:
1028
+ temp_table_name (str): The temporary Athena table name of the base pandas.DataFrame. Defaults to None.
1029
+
1030
+ Returns:
1031
+ The query string.
1032
+ The name of the catalog to be used in the query execution.
1033
+ The database to be used in the query execution.
1034
+
1035
+ Raises:
1036
+ ValueError: temp_table_name must be provided if the base is a pandas.DataFrame.
1037
+ """
1038
+ if isinstance (self ._base , pd .DataFrame ):
1039
+ if temp_table_name is None :
1040
+ raise ValueError ("temp_table_name must be provided for a pandas.DataFrame base." )
1041
+ base_features = list (self ._base .columns )
1042
+ event_time_identifier_feature_dtype = self ._base [
1043
+ self ._event_time_identifier_feature_name
1044
+ ].dtypes
1045
+ self ._event_time_identifier_feature_type = (
1046
+ FeatureGroup .DTYPE_TO_FEATURE_DEFINITION_CLS_MAP .get (
1047
+ str (event_time_identifier_feature_dtype ), None
1048
+ )
1049
+ )
1050
+ catalog = _DEFAULT_CATALOG
1051
+ database = _DEFAULT_DATABASE
1052
+ query_string = self ._construct_query_string (
1053
+ FeatureGroupToBeMerged (
1054
+ base_features ,
1055
+ self ._included_feature_names if self ._included_feature_names else base_features ,
1056
+ self ._included_feature_names if self ._included_feature_names else base_features ,
1057
+ catalog ,
1058
+ database ,
1059
+ temp_table_name ,
1060
+ self ._record_identifier_feature_name ,
1061
+ FeatureDefinition (
1062
+ self ._event_time_identifier_feature_name ,
1063
+ self ._event_time_identifier_feature_type ,
1064
+ ),
1065
+ None ,
1066
+ TableType .DATA_FRAME ,
1067
+ )
1068
+ )
1069
+ if isinstance (self ._base , FeatureGroup ):
1070
+ base_feature_group = construct_feature_group_to_be_merged (
1071
+ self ._base , self ._included_feature_names
1072
+ )
1073
+ self ._record_identifier_feature_name = base_feature_group .record_identifier_feature_name
1074
+ self ._event_time_identifier_feature_name = (
1075
+ base_feature_group .event_time_identifier_feature .feature_name
1076
+ )
1077
+ self ._event_time_identifier_feature_type = (
1078
+ base_feature_group .event_time_identifier_feature .feature_type
1079
+ )
1080
+ catalog = base_feature_group .catalog
1081
+ database = base_feature_group .database
1082
+ query_string = self ._construct_query_string (base_feature_group )
1083
+ return query_string , catalog , database
1084
+
1061
1085
def _run_query (self , query_string : str , catalog : str , database : str ) -> Dict [str , Any ]:
1062
1086
"""Internal method for execute Athena query, wait for query finish and get query result.
1063
1087
0 commit comments