Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

materialize-iceberg: serialize pyspark command inputs to a file #2499

Merged
merged 1 commit into from
Mar 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions materialize-iceberg/emr.go
Original file line number Diff line number Diff line change
@@ -84,10 +84,10 @@ func ensureEmrSecret(ctx context.Context, client *ssm.Client, parameterName, wan
return nil
}

func (t *transactor) runEmrJob(ctx context.Context, jobName string, input any, statusOutputPrefix, entryPointUri string) error {
func (t *transactor) runEmrJob(ctx context.Context, jobName string, input any, workingPrefix, entryPointUri string) error {
/***
Available arguments to the pyspark script:
| --input | Input for the program, as serialized JSON | Required |
| --input-uri | Input for the program, as an s3 URI, to be parsed by the script | Required |
| --status-output | Location where the final status object will be written. | Required |
| --catalog-url | The catalog URL | Required |
| --warehouse | REST Warehouse | Required |
@@ -97,7 +97,7 @@ func (t *transactor) runEmrJob(ctx context.Context, jobName string, input any, s
***/
getStatus := func() (*python.StatusOutput, error) {
var status python.StatusOutput
statusKey := path.Join(statusOutputPrefix, statusFile)
statusKey := path.Join(workingPrefix, statusFile)
if statusObj, err := t.s3Client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(t.cfg.Compute.Bucket),
Key: aws.String(statusKey),
@@ -109,14 +109,20 @@ func (t *transactor) runEmrJob(ctx context.Context, jobName string, input any, s
return &status, nil
}

encodedInput, err := encodeInput(input)
if err != nil {
inputKey := path.Join(workingPrefix, "input.json")
if inputBytes, err := encodeInput(input); err != nil {
return fmt.Errorf("encoding input: %w", err)
} else if _, err := t.s3Client.PutObject(ctx, &s3.PutObjectInput{
Bucket: aws.String(t.cfg.Compute.Bucket),
Key: aws.String(inputKey),
Body: bytes.NewReader(inputBytes),
}); err != nil {
return fmt.Errorf("putting input file object: %w", err)
}

args := []string{
"--input", encodedInput,
"--status-output", "s3://" + path.Join(t.cfg.Compute.Bucket, statusOutputPrefix, statusFile),
"--input-uri", "s3://" + path.Join(t.cfg.Compute.Bucket, inputKey),
"--status-output", "s3://" + path.Join(t.cfg.Compute.Bucket, workingPrefix, statusFile),
"--catalog-url", t.cfg.URL,
"--warehouse", t.cfg.Warehouse,
"--region", t.cfg.Compute.Region,
@@ -184,14 +190,14 @@ func (t *transactor) runEmrJob(ctx context.Context, jobName string, input any, s
}
}

func encodeInput(in any) (string, error) {
func encodeInput(in any) ([]byte, error) {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.SetAppendNewline(false)
enc.SetEscapeHTML(false)
if err := enc.Encode(in); err != nil {
return "", err
return nil, err
}

return buf.String(), nil
return buf.Bytes(), nil
}
33 changes: 22 additions & 11 deletions materialize-iceberg/python/common.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@
from pyspark.sql import SparkSession
from pyspark.sql.types import (
ArrayType,
BinaryType,
BooleanType,
DataType,
DateType,
@@ -68,7 +67,9 @@ def fields_to_struct(fields: list[NestedField]) -> StructType:
def common_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--input", required=True, help="Input for the program, as serialized JSON."
"--input-uri",
required=True,
help="Location of the program input, as serialized JSON.",
)
parser.add_argument(
"--status-output",
@@ -140,23 +141,33 @@ def get_spark_session(args: argparse.Namespace) -> SparkSession:
def run_with_status(
parsed_args: argparse.Namespace,
fn,
*args,
**kwargs,
):
parsed_url = urlparse(parsed_args.status_output)
bucket_name = parsed_url.netloc
file_path = parsed_url.path.lstrip("/")
input_uri = urlparse(parsed_args.input_uri)
input_bucket_name = input_uri.netloc
input_file_path = input_uri.path.lstrip("/")

output_uri = urlparse(parsed_args.status_output)
output_bucket_name = output_uri.netloc
output_file_path = output_uri.path.lstrip("/")

s3 = boto3.client("s3")

try:
fn(*args, **kwargs)
input = s3.get_object(Bucket=input_bucket_name, Key=input_file_path)
with input["Body"] as body:
input = json.loads(body.read().decode("utf-8"))
s3.delete_object(Bucket=input_bucket_name, Key=input_file_path)

fn(input)
s3.put_object(
Bucket=bucket_name, Key=file_path, Body=json.dumps({"success": True})
Bucket=output_bucket_name,
Key=output_file_path,
Body=json.dumps({"success": True}),
)
except Exception as e:
s3.put_object(
Bucket=bucket_name,
Key=file_path,
Bucket=output_bucket_name,
Key=output_file_path,
Body=json.dumps({"success": False, "error": str(e)}),
)
raise
12 changes: 4 additions & 8 deletions materialize-iceberg/python/load.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

from common import (
NestedField,
common_args,
@@ -11,14 +9,12 @@
args = common_args()
spark = get_spark_session(args)

input = json.loads(args.input)
query = input["query"]
bindings = input["bindings"]
output_location = input["output_location"]
status_output = args.status_output

def run(input):
query = input["query"]
bindings = input["bindings"]
output_location = input["output_location"]

def run():
for binding in bindings:
bindingIdx: int = binding["binding"]
keys: list[NestedField] = [NestedField(**key) for key in binding["keys"]]
7 changes: 2 additions & 5 deletions materialize-iceberg/python/merge.py
Original file line number Diff line number Diff line change
@@ -11,12 +11,9 @@
args = common_args()
spark = get_spark_session(args)

input = json.loads(args.input)
bindings = input["bindings"]


def run():
for binding in bindings:
def run(input):
for binding in input["bindings"]:
bindingIdx: int = binding["binding"]
query: str = binding["query"]
columns: list[NestedField] = [NestedField(**col) for col in binding["columns"]]
3 changes: 1 addition & 2 deletions materialize-iceberg/python/python.go
Original file line number Diff line number Diff line change
@@ -26,8 +26,7 @@ type MergeBinding struct {
}

type MergeInput struct {
Bindings []MergeBinding `json:"bindings"`
OutputLocation string `json:"output_location"`
Bindings []MergeBinding `json:"bindings"`
}

type StatusOutput struct {
8 changes: 5 additions & 3 deletions materialize-iceberg/transactor.go
Original file line number Diff line number Diff line change
@@ -237,7 +237,7 @@ func (t *transactor) Store(it *m.StoreIterator) (m.StartCommitFunc, error) {
func (t *transactor) Acknowledge(ctx context.Context) (*pf.ConnectorState, error) {
outputPrefix := path.Join(t.cfg.Compute.BucketPath, uuid.NewString())
checkpointClear := make(map[string]*python.MergeBinding)
mergeInput := python.MergeInput{OutputLocation: "s3://" + path.Join(t.cfg.Compute.Bucket, outputPrefix)}
var mergeInput python.MergeInput
var allFileUris []string

for _, b := range t.bindings {
@@ -281,7 +281,6 @@ func (t *transactor) Acknowledge(ctx context.Context) (*pf.ConnectorState, error
cleanupStatus := t.cleanPrefixOnceFn(ctx, t.cfg.Compute.Bucket, outputPrefix)
defer cleanupStatus()

stateUpdate = &pf.ConnectorState{MergePatch: true}
if err := t.runEmrJob(ctx, fmt.Sprintf("store for: %s", t.materializationName), mergeInput, outputPrefix, t.pyFiles.mergeURI); err != nil {
return nil, fmt.Errorf("store merge job failed: %w", err)
} else if err := cleanupStatus(); err != nil {
@@ -291,7 +290,10 @@ func (t *transactor) Acknowledge(ctx context.Context) (*pf.ConnectorState, error
} else if cpUpdate, err := json.Marshal(checkpointClear); err != nil {
return nil, fmt.Errorf("encoding checkpoint update: %w", err)
} else {
stateUpdate.UpdatedJson = cpUpdate
stateUpdate = &pf.ConnectorState{
UpdatedJson: cpUpdate,
MergePatch: true,
}
}
}