Skip to content

Commit 433783b

Browse files
committed
Dataset builder: add _to_athena_query method
1 parent 342fbbc commit 433783b

File tree

1 file changed

+64
-40
lines changed

1 file changed

+64
-40
lines changed

src/sagemaker/feature_store/dataset_builder.py

+64-40
Original file line numberDiff line numberDiff line change
@@ -438,53 +438,16 @@ def to_csv_file(self) -> Tuple[str, str]:
438438
os.remove(local_file_name)
439439
temp_table_name = f'dataframe_{temp_id.replace("-", "_")}'
440440
self._create_temp_table(temp_table_name, desired_s3_folder)
441-
base_features = list(self._base.columns)
442-
event_time_identifier_feature_dtype = self._base[
443-
self._event_time_identifier_feature_name
444-
].dtypes
445-
self._event_time_identifier_feature_type = (
446-
FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
447-
str(event_time_identifier_feature_dtype), None
448-
)
449-
)
450-
query_string = self._construct_query_string(
451-
FeatureGroupToBeMerged(
452-
base_features,
453-
self._included_feature_names if self._included_feature_names else base_features,
454-
self._included_feature_names if self._included_feature_names else base_features,
455-
_DEFAULT_CATALOG,
456-
_DEFAULT_DATABASE,
457-
temp_table_name,
458-
self._record_identifier_feature_name,
459-
FeatureDefinition(
460-
self._event_time_identifier_feature_name,
461-
self._event_time_identifier_feature_type,
462-
),
463-
None,
464-
TableType.DATA_FRAME,
465-
)
441+
query_result = self._run_query(
442+
**self._to_athena_query(temp_table_name=temp_table_name)
466443
)
467-
query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)
468444
# TODO: cleanup temp table, need more clarification, keep it for now
469445
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
470446
"OutputLocation", None
471447
), query_result.get("QueryExecution", {}).get("Query", None)
472448
if isinstance(self._base, FeatureGroup):
473-
base_feature_group = construct_feature_group_to_be_merged(
474-
self._base, self._included_feature_names
475-
)
476-
self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name
477-
self._event_time_identifier_feature_name = (
478-
base_feature_group.event_time_identifier_feature.feature_name
479-
)
480-
self._event_time_identifier_feature_type = (
481-
base_feature_group.event_time_identifier_feature.feature_type
482-
)
483-
query_string = self._construct_query_string(base_feature_group)
484449
query_result = self._run_query(
485-
query_string,
486-
base_feature_group.catalog,
487-
base_feature_group.database,
450+
**self._to_athena_query()
488451
)
489452
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
490453
"OutputLocation", None
@@ -1058,6 +1021,67 @@ def _construct_athena_table_column_string(self, column: str) -> str:
10581021
raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.")
10591022
return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}"
10601023

1024+
def _to_athena_query(self, temp_table_name: str = None) -> Tuple[str, str, str]:
1025+
"""Internal method for constructing an Athena query.
1026+
1027+
Args:
1028+
temp_table_name (str): The temporary Athena table name of the base pandas.DataFrame. Defaults to None.
1029+
1030+
Returns:
1031+
The query string.
1032+
The name of the catalog to be used in the query execution.
1033+
The database to be used in the query execution.
1034+
1035+
Raises:
1036+
ValueError: temp_table_name must be provided if the base is a pandas.DataFrame.
1037+
"""
1038+
if isinstance(self._base, pd.DataFrame):
1039+
if temp_table_name is None:
1040+
raise ValueError("temp_table_name must be provided for a pandas.DataFrame base.")
1041+
base_features = list(self._base.columns)
1042+
event_time_identifier_feature_dtype = self._base[
1043+
self._event_time_identifier_feature_name
1044+
].dtypes
1045+
self._event_time_identifier_feature_type = (
1046+
FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
1047+
str(event_time_identifier_feature_dtype), None
1048+
)
1049+
)
1050+
catalog = _DEFAULT_CATALOG
1051+
database = _DEFAULT_DATABASE
1052+
query_string = self._construct_query_string(
1053+
FeatureGroupToBeMerged(
1054+
base_features,
1055+
self._included_feature_names if self._included_feature_names else base_features,
1056+
self._included_feature_names if self._included_feature_names else base_features,
1057+
catalog,
1058+
database,
1059+
temp_table_name,
1060+
self._record_identifier_feature_name,
1061+
FeatureDefinition(
1062+
self._event_time_identifier_feature_name,
1063+
self._event_time_identifier_feature_type,
1064+
),
1065+
None,
1066+
TableType.DATA_FRAME,
1067+
)
1068+
)
1069+
if isinstance(self._base, FeatureGroup):
1070+
base_feature_group = construct_feature_group_to_be_merged(
1071+
self._base, self._included_feature_names
1072+
)
1073+
self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name
1074+
self._event_time_identifier_feature_name = (
1075+
base_feature_group.event_time_identifier_feature.feature_name
1076+
)
1077+
self._event_time_identifier_feature_type = (
1078+
base_feature_group.event_time_identifier_feature.feature_type
1079+
)
1080+
catalog = base_feature_group.catalog
1081+
database = base_feature_group.database
1082+
query_string = self._construct_query_string(base_feature_group)
1083+
return query_string, catalog, database
1084+
10611085
def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]:
10621086
"""Internal method for execute Athena query, wait for query finish and get query result.
10631087

0 commit comments

Comments
 (0)