diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 67a0813959178..756370e075dbc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1100,3 +1100,10 @@ repos: language: python files: .*test.*\.py$ pass_filenames: true + - id: ktlint + name: Run ktlint format + description: "Use ktlint (via Gradle) to format Kotlin and Java files" + entry: ./java-sdk/gradlew -p ./java-sdk ktlintFormat + language: system + pass_filenames: false + files: ^java-sdk/.*$ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 021ae94b2f4fe..c744b684f9b65 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1129,6 +1129,7 @@ opsgenie Optimise optimise optimizationObjective +OptIn optionality ora oracledb diff --git a/java-sdk/.editorconfig b/java-sdk/.editorconfig new file mode 100644 index 0000000000000..1b89a6e999824 --- /dev/null +++ b/java-sdk/.editorconfig @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +root = true + +[*] +end_of_line = lf +insert_final_newline = true +indent_style = space +indent_size = 4 + +[*.java] +indent_size = 2 + +[*.kt] +indent_size = 2 diff --git a/java-sdk/.gitattributes b/java-sdk/.gitattributes new file mode 100644 index 0000000000000..a87d264c425cf --- /dev/null +++ b/java-sdk/.gitattributes @@ -0,0 +1,11 @@ +# +# https://help.github.com/articles/dealing-with-line-endings/ +# +# Linux start script should use lf +/gradlew text eol=lf + +# These are Windows script files and should use crlf +*.bat text eol=crlf + +# Binary files should be left untouched +*.jar binary diff --git a/java-sdk/.gitignore b/java-sdk/.gitignore new file mode 100644 index 0000000000000..bf1f44332ebc8 --- /dev/null +++ b/java-sdk/.gitignore @@ -0,0 +1,56 @@ +.gradle +build/ +!gradle/wrapper/gradle-wrapper.jar +!**/src/main/**/build/ +!**/src/test/**/build/ +.kotlin + +### IntelliJ IDEA ### +.idea/modules.xml +.idea/jarRepositories.xml +.idea/compiler.xml +.idea/libraries/ +*.iws +*.iml +*.ipr +out/ +!**/src/main/**/out/ +!**/src/test/**/out/ + +### Eclipse ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache +bin/ +!**/src/main/**/bin/ +!**/src/test/**/bin/ + +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ + +### VS Code ### +.vscode/ + +### Mac OS ### +.DS_Store +# Ignore Gradle build output directory +build + +### Artifacts of airflow standalone command ### +airflow.cfg +airflow.db +simple_auth_manager_passwords.json.generated +logs/dag_id=* +logs/dag_processor + +### Compatibility Test Results ### +validation/serialization/serialized_java.json +validation/serialization/serialized_python.json diff --git a/java-sdk/README.md b/java-sdk/README.md new file mode 100644 index 0000000000000..92ef9ebdc5dd8 --- /dev/null +++ b/java-sdk/README.md @@ -0,0 +1,74 @@ + + +# Airflow Java SDK + +A **JVM** SDK for Apache Airflow. You can use any JVM-compatible language to write +workflow bundles, and have Airflow consume the result. + +The SDK and execution-time logic is implemented in Kotlin. +An example is bundled showing how the SDK can be used in Java. + +## Building + +```bash +./gradlew build +``` + +## Technical Details + +The Java program is launched as a subprocess by the Airflow worker and communicates +through TCP sockets. The Java program accepts flags `--comm` and `--logs` from the +command line. + +The Java program "parses" DAGs on launch, and then connects to the specified TCP servers. +The rest is similar to the standard Airflow: + +* DAG-parsing: + 1. On connection, the parent immediately sends a DagParsingRequest through the socket. + 2. The Java program sends back a DagParsingResult to the parent. + 3. The Java program exits. +* Execution: + 1. On connection, the parent immediately sends a StartupDetails through the socket. + 2. The Java program uses the information to find the relevant task to execute. + 3. The task is run. + 4. The Java program tells the parent to update the task's terminal state. + 5. The Java program exits. + +Communication uses the same formats as the Python-based processes. + +## Serialization Validation + +Workflow: + +```bash +# 1. Generate Java output (runs as part of normal test suite) +# More specifically, the test `sdk/src/test/kotlin/org/apache/airflow/sdk/execution/SerializationCompatibilityTest.kt` generates the output file `validation/serialization/serialized_java.json`. +./gradlew sdk:test + +# 2. Generate Python output (requires Airflow env) +uv run validation/serialization/serialize_python.py \ + validation/serialization/test_dags.yaml \ + validation/serialization/serialized_python.json + +# 3. Compare +uv run validation/serialization/compare.py \ + validation/serialization/serialized_python.json \ + validation/serialization/serialized_java.json +``` diff --git a/java-sdk/adr/0001-java-sdk-airflow-integration.md b/java-sdk/adr/0001-java-sdk-airflow-integration.md new file mode 100644 index 0000000000000..f5735b0d7b55a --- /dev/null +++ b/java-sdk/adr/0001-java-sdk-airflow-integration.md @@ -0,0 +1,425 @@ + + +# ADR-0001: Java SDK Airflow Integration + +## Status + +Accepted + +## Context + +Airflow's current execution model is Python-only: DAGs are Python files, tasks are Python callables, and the task runner forks a Python process. To support DAGs and tasks authored in other languages (starting with Java), we need an architecture that: + +- Allows entire DAGs to be written in a non-Python language (pure Java DAG). +- Allows non-Python tasks to coexist with Python tasks in the same DAG (`@task.stub`). +- Reuses the existing task-runner two-layer design (task-runner process + forked child process) so Airflow extensions (XCom backends, connections, variables) stay in Python. +- Is extensible to other languages (Go, Rust, etc.) without per-language changes to Airflow Core. + +The existing task runner already uses a two-layer design. When an executor wants to run a task, it starts a task-runner process that talks to Airflow Core through the Execution API, and forks another process that talks to the task-runner through TCP to run the actual task code. All the Airflow extensions simply go into the task-runner process, keeping them in Python. + +The only thing missing is a way for the task-runner process to run tasks in another language. + +## Decision + +### Writing a Non-Python Task + +There is one way to write a non-Python task: implement the language SDK's task interface. For Java, this is the `Task` interface with a single `execute(Client client)` method. The `Client` provides access to Airflow services (connections, variables, XCom). + +### Two Ways to Integrate Non-Python Tasks into a DAG + +We provide two approaches for integrating non-Python tasks into a DAG: + +**a) Pure Java DAG** — define the entire DAG in Java, with no Python file at all. +The Java SDK provides `BundleBuilder`, `Dag`, and `Task` interfaces: + +```java +public class JavaExampleBuilder { + + public static class Extract implements Task { + public void execute(Client client) throws Exception { + var connection = client.getConnection("test_http"); + client.setXCom(new Date().getTime()); + } + } + + public static class Transform implements Task { + public void execute(Client client) { + var extract_xcom = client.getXCom("extract"); + client.setXCom(new Date().getTime()); + } + } + + @Override + public Dag build() { + var dag = new Dag("java_example", null, "@daily"); + dag.addTask("extract", Extract.class, List.of()); + dag.addTask("transform", Transform.class, List.of("extract")); + return dag; + } +} +``` + +**b) `@task.stub` in a Python DAG** — for mixed-language pipelines where Python and +Java tasks coexist in the same DAG. The `@task.stub` syntax is already supported for +the Go SDK; the same pattern applies to Java: + +```python +@task() +def python_task_1(ti): + ti.xcom_push(value="from-python", key="return_value") + + +@task.stub(queue="java") +def extract(): ... + + +@task.stub(queue="java") +def transform(): ... + + +@dag(dag_id="java_example") +def simple_dag(): + python_task_1() >> extract() >> transform() +``` + +Both approaches are supported in parallel. A pure Java DAG needs no Python at all for authoring. A `@task.stub` DAG requires a Python file but lets you mix Python operators and non-Python tasks in a single pipeline. + +> **Note:** The current `BundleBuilder` interface used in pure Java DAGs is subject to review before the SDK reaches 1.0. Subclassing `Dag` directly may be a more natural fit and is being considered for post-OSS-integration. + +### Public API Surface: `Client` and `Context` + +The Java task interface is `void execute(Client client)`. Two design choices warrant explanation. + +**Why `Client`, not `Context`?** The Java SDK exposes two objects, mirroring the Go SDK: + +| Object | Holds | Lifecycle | +|---|---|---| +| `Context` | Static run-time data (`ds`, `ti`, logical date, run-id, etc.) | Populated once from `StartupDetails`, read-only during execution | +| `Client` | Active accessors that perform Execution API calls (connections, variables, XCom) | Each method call is a synchronous request/response over the comm channel | + +In Python, magic objects on the context (e.g., `outlet_events`) can perform Execution API calls transparently because of the language's flexibility. Java is more rigid; making `Context` itself perform background API calls would require significantly more wiring without much user-visible benefit. Splitting the two surfaces makes the API call boundary explicit at the type level. + +**Why is `execute` `void`?** Returning a value from `execute` would imply an automatic XCom push. Java's static type system does not have a clean equivalent of Python's "return any object, get a default-keyed XCom" pattern, and explicit `client.setXCom(...)` calls keep the wire-level behavior obvious. This is a deliberate departure from Python's `@task` semantics, not an oversight. + +### Coordinator Interface: Subprocess-Based by Design + +`BaseCoordinator` exposes both **low-level** hooks (`dag_parsing_cmd`, `task_execution_cmd`) and **high-level** lifecycle methods (`run_dag_parsing`, `run_task_execution`). Subclasses normally implement only the `*_cmd` callbacks; the base class owns the TCP servers, the subprocess spawn, and the I/O bridge. + +This is deliberately tight coupling to a subprocess model. The reasoning: + +- **DAG files written in a programming language have side effects.** Airflow already isolates Python parsing and task execution into child processes; the coordinator interface preserves that invariant for any non-Python language. +- **`*_cmd` is the smallest possible contract for a new language.** A new SDK only needs to translate "you're being asked to parse this file / run this task" into an OS-level launch command. Everything else (TCP plumbing, framing, byte forwarding) is shared. +- **High-level overrides are still available.** A coordinator that wants to bypass the subprocess model entirely (in-process JVM via JNI, REST call to a remote DAG repository, etc.) can override `run_dag_parsing` / `run_task_execution` directly and ignore the `*_cmd` hooks. The two-tier interface is intentional. + +A complementary, **out-of-scope** future direction is parsing static (non-programming-language) DAG sources such as YAML (e.g., `dag-factory`). Those do not need a child process at all — but the decision to launch a child is currently made one layer above the coordinator (`DagFileProcessorManager` → `DagFileProcessorProcess`). Hooking in a YAML parser would need a separate extension point at the manager layer; it is not blocked by this design but is also not solved by it. A follow-up AIP is expected to formalize a general "any-source DAG parser" plugin model. + +### The Coordinator Layer + +We introduce a **Coordinator** layer. When a DAG bundle is loaded, it not only tells Airflow how to find the DAGs (and tasks in them), but also how to *run* each task. Current Python tasks use a Python code path that runs them by forking. A new **Java Coordinator** instructs the task runner how to run tasks in JAR files. + +The base interface (`BaseCoordinator`) lives in `airflow.sdk.execution_time`. Concrete coordinators ship as standalone distributions — **not** as Airflow providers — under the shared `airflow.sdk.coordinators` namespace package. The Java coordinator ships as `apache-airflow-coordinators-java` and resolves to `airflow.sdk.coordinators.java.JavaCoordinator`. New language coordinators follow the same pattern: `apache-airflow-coordinators-` → `airflow.sdk.coordinators..Coordinator`. + +Coordinators are instantiated from the `[sdk] coordinators` Airflow configuration (see [Coordinator Registration](#coordinator-registration) below). Both Airflow Core (DAG processor) and Task SDK (task runner) read that config and use `import_string()` to load the configured `classpath` — no provider plumbing is involved. Decoupling coordinators from the provider system is the direction agreed in [ADR-0005](0005-coordinator-packaging.md) and tracked in [apache/airflow#66451](https://github.com/apache/airflow/issues/66451), which also motivates the per-instance `kwargs` (multiple JDK versions, JVM flags, etc.) that a class-only registration could not express. + +### Architecture Overview + +``` + Airflow Backend Language Runtime Subprocess (Java in this example) + ─────────────── ────────────────────────────────────────────────── + + ┌──────────────────────────────┐ + │ DAG File (Python or JAR) │ + │ │ + │ @task.stub(queue="java") │ + │ def my_java_task(): │ + │ ... │ + └──────────────┬───────────────┘ + │ + ┌──────────────▼───────────────┐ ┌──────────────────────────────┐ + │ DAG File Processor │ │ Runtime Subprocess (Java) │ + │ │ can_handle_dag │ │ + │ For each file in bundle: │ _file() == True │ dag_parsing_cmd() │ + │ ┌ coordinator handles it? ──┼───────────────────►│ │ + │ │ Yes ──► delegate parse │ │ Java SDK parses JAR, builds │ + │ │ No ──► Python path │ SDK Serialized │ SDK-compatible Serialized │ + │ │ │◄─── DAG JSON ──────┤ DAG JSON (sdk, tasks, etc.) │ + │ └ │ │ │ + └──────────────┬───────────────┘ └──────────────────────────────┘ + │ + ┌──────────────▼───────────────┐ + │ Metadata DB │ + │ │ + │ serialized_dag: { │ Stored as-is from the language runtime's + │ "relative_fileloc": │ SDK Serialized DAG JSON + │ "path/to/example.jar" │ + │ } │ + │ task_instance.queue │ + └──────────────┬───────────────┘ + │ + ┌──────────────▼───────────────┐ + │ Scheduler │ + │ │ + │ Reads queue from TI │ + │ ──► ExecuteTask workload │ + │ (includes queue) │ + └──────────────┬───────────────┘ + │ + ┌──────────────▼───────────────┐ ┌──────────────────────────────┐ + │ Execution API │ │ Runtime Subprocess (Java) │ + │ │ │ │ + │ TI.queue ──► Startup │ │ task_execution_cmd() │ + │ Details │ │ Executes task in JVM │ + └──────────────┬───────────────┘ │ │ + │ └──────────────▲───────────────┘ + ┌──────────────▼───────────────┐ │ + │ Task Runner │ │ + │ │ │ + │ QueueToCoordinatorMapper │ │ + │ resolves queue via `[sdk] │ │ + │ queue_to_coordinator` ──────┼───────────────────────────────────┘ + │ to a coordinator instance │ + │ from `[sdk] coordinators` │ + └──────────────────────────────┘ +``` + +### The `BaseCoordinator` Interface + +This is the central abstraction that language SDKs implement. It lives in the Task SDK (`task-sdk/src/airflow/sdk/execution_time/coordinator.py`) and handles both DAG parsing and task execution for a specific language runtime. + +```python +class BaseCoordinator: + """ + Base coordinator for runtime-specific DAG file processing and task execution. + + Subclasses represent a specific language runtime (Java, Go, etc.) and are + instantiated by Airflow Core (DAG processor) and Task SDK (task runner) + from the ``[sdk] coordinators`` Airflow configuration. Each entry in that + config carries an instance ``name``, an importable ``classpath``, and + free-form ``kwargs`` that the subclass accepts in ``__init__`` — this is + how operators express runtime variants (multiple JDK versions, custom JVM + flags, etc.) without needing one subclass per variant. + + The base class owns the full bridge lifecycle: TCP servers, subprocess + management, selector-based I/O loop, and cleanup. + """ + + name: str # Instance name from [sdk] coordinators (e.g. "jdk-11", "jdk-17") + + def __init__(self, *, name: str, **kwargs) -> None: + """Accept the per-instance ``kwargs`` declared in ``[sdk] coordinators``.""" + ... + + # Discovery (called by DAG File Processor) + + def can_handle_dag_file(self, bundle_name: str, path: str | os.PathLike) -> bool: + """Return True if this coordinator should parse the file at *path*.""" + ... + + def get_code_from_file(self, fileloc: str) -> str: + """Return the actual DAG code (the content of JavaExampleBuilder.java in this case""" + ... + + # DAG Parsing (called in forked DagFileProcessor child process) + + def dag_parsing_cmd( + self, + *, + dag_file_path: str, # Absolute path to DAG file + bundle_name: str, # Name of the DAG bundle + bundle_path: str, # Root path of the bundle + comm_addr: str, # host:port for msgpack comm channel + logs_addr: str, # host:port for structured JSON log channel + ) -> list[str]: + """Return the subprocess command for DAG file parsing.""" + ... + + # Task Execution (called in forked worker child process) + + def task_execution_cmd( + self, + *, + what: TaskInstance, + dag_rel_path: str | os.PathLike, # Relative path to DAG file within bundle + bundle_info: BundleInfo, + comm_addr: str, + logs_addr: str, + ) -> list[str]: + """Return the subprocess command for task execution.""" + ... + + # Lifecycle (owned by base class, not overridden) + + def run_dag_parsing(self, *, path, bundle_name, bundle_path) -> None: ... + + def run_task_execution(self, *, what, dag_rel_path, bundle_info, startup_details) -> None: ... +``` + +### Coordinator Registration + +Coordinators are registered through Airflow configuration, not through `provider.yaml` or any provider-discovery mechanism. The Java coordinator ships as the standalone distribution `apache-airflow-coordinators-java`, which contributes the `airflow.sdk.coordinators.java` subpackage to the namespace package owned by the Task SDK. As long as the distribution is on `PYTHONPATH`, both Airflow Core and the Task SDK can resolve `airflow.sdk.coordinators.java.JavaCoordinator` via `import_string()`. + +Operators wire concrete instances in `airflow.cfg`: + +```ini +[sdk] +coordinators = [ + { + "name": "jdk-11", + "classpath": "airflow.sdk.coordinators.java.JavaCoordinator", + "kwargs": { + "java_executable": "/usr/lib/jvm/java-11-openjdk-amd64/bin/java", + "jvm_args": ["-Xmx512m"], + "jdk_home": "/usr/lib/jvm/java-11-openjdk-amd64" + } + }, + { + "name": "jdk-17", + "classpath": "airflow.sdk.coordinators.java.JavaCoordinator", + "kwargs": { + "java_executable": "/usr/lib/jvm/java-17-openjdk-amd64/bin/java", + "jvm_args": ["-Xmx1024m", "-Xms256m"], + "jdk_home": "/usr/lib/jvm/java-17-openjdk-amd64" + } + } +] + +queue_to_coordinator = {"legacy-java-queue": "jdk-11", "modern-java-queue": "jdk-17"} +``` + +The same `JavaCoordinator` class can back several instances with different runtime configuration; the routing key is the instance `name`, not the class. This shape is the resolution to the packaging and registration questions originally raised in [ADR-0005](0005-coordinator-packaging.md), motivated by [apache/airflow#66451](https://github.com/apache/airflow/issues/66451) (multi-JDK and JVM-flag support). + +### Implementation Language: Kotlin (with a Java-First Public API) + +The user-facing API surface (`Task`, `Client`, `Context`, `Dag`, `DagBundle`) is published as Java types and is the contract bundle authors program against. The SDK *implementation* — `CoordinatorComm`, `Serde`, `TaskSdkFrames`, `Server`, `Supervisor`, `TaskRunner`, `DagParser` — is written in Kotlin. + +Kotlin compiles to the same JVM bytecode as Java and is fully interoperable, so this choice is invisible to bundle authors at runtime. The practical reasons for using Kotlin internally: + +- **Null safety** is part of the type system, removing a large class of latent NPEs in the comm/serde paths. +- **Coroutines and structured I/O** simplify the synchronous-over-async pattern used by `Client.getVariable()` and friends. +- **Less boilerplate** in serialization and frame encoding code, which is the bulk of the SDK. + +Because the user-facing API is Java, "Java SDK" remains the accurate name from a DAG-author perspective. A future rename to "JVM SDK" has been floated but is not adopted here; it can be revisited if/when Scala or other JVM-language bindings are proposed. + +### Example: `JavaCoordinator` + +```python +# Shipped as ``apache-airflow-coordinators-java``; +# resolves to ``airflow.sdk.coordinators.java.JavaCoordinator``. +class JavaCoordinator(BaseCoordinator): + def __init__( + self, + *, + name: str, + java_executable: str = "java", + jvm_args: list[str] | None = None, + jdk_home: str | None = None, + ) -> None: + self.name = name + self.java_executable = java_executable + self.jvm_args = list(jvm_args or []) + self.jdk_home = jdk_home + + def can_handle_dag_file(self, bundle_name, path): + """True when path is a JAR with a Main-Class manifest entry.""" + ... + + def dag_parsing_cmd(self, *, dag_file_path, bundle_name, bundle_path, comm_addr, logs_addr): + main_class = find_main_class(Path(dag_file_path)) + return [ + self.java_executable, + *self.jvm_args, + "-classpath", + f"{bundle_path}/*", + main_class, + f"--comm={comm_addr}", + f"--logs={logs_addr}", + ] + + def task_execution_cmd(self, *, what, dag_rel_path, bundle_info, comm_addr, logs_addr): + jar_path = Path(dag_rel_path) + main_class = find_main_class(jar_path) + return [ + self.java_executable, + *self.jvm_args, + "-classpath", + f"{jar_path.parent}/*", + main_class, + f"--comm={comm_addr}", + f"--logs={logs_addr}", + ] +``` + +### Integration Points — Required Changes + +**1. Decorator — DAG Author Interface** + +DAG authors declare a non-Python task using `@task.stub` and specify a queue: + +```python +@task.stub(queue="java") +def my_java_task(): ... +``` + +**2. Serialization — Each Language SDK Produces SDK-Compatible Serialized DAG JSON** + +Serialization is the language runtime's responsibility, not Airflow Core's. Each language SDK implements its own serializer that understands the language-specific DAG and task structure and produces a Task SDK-compatible Serialized DAG JSON — the same schema that the Python SDK's `SerializedDAG` produces. + +The language runtime subprocess returns this JSON to the DAG File Processor through the msgpack comm channel. The DAG File Processor and Airflow Core treat it identically to Python-serialized DAGs — it is stored as-is in the metadata DB. + +We have already added compatibility validation between the Python SDK and Java SDK serialized DAG JSON formats to ensure both produce structurally equivalent output. + +**3. Execution API — Task Queues Routed to the Worker** + +A new pair of configurations registers coordinator instances and maps each task's `queue` to one of them: + +```ini +[sdk] +coordinators = [ + {"name": "jdk-17", "classpath": "airflow.sdk.coordinators.java.JavaCoordinator", "kwargs": {"java_executable": "java"}} +] +queue_to_coordinator = {"java": "jdk-17"} +``` + +Tasks scheduled to the `java` queue are routed to the coordinator instance named `jdk-17`. Multiple instances of the same class may coexist (e.g., `jdk-11` and `jdk-17`) and bind to different queues — see [Coordinator Registration](#coordinator-registration). + +## Consequences + +### New Interfaces + +| Component | New Interface | Change Type | +|-----------|--------------|-------------| +| `BaseCoordinator` | Abstract base defined in Task SDK | New class | +| `airflow.sdk.coordinators` | Namespace package contributed to by `apache-airflow-coordinators-` distributions | New namespace | +| `@task.stub` decorator | `queue: str \| None` parameter | Additive | +| `[sdk] coordinators` | Airflow configuration listing instances (`name`, `classpath`, `kwargs`) | New option | +| `[sdk] queue_to_coordinator` | Airflow configuration mapping queue → instance name | New option | +| `_resolve_runtime_entrypoint` | Route by `queue` → coordinator instance from `[sdk] coordinators` | Behavioral | + +### What Becomes Easier + +- Adding a new language runtime requires only a `BaseCoordinator` subclass shipped as `apache-airflow-coordinators-` and a corresponding entry in `[sdk] coordinators` — no changes to Airflow Core and no provider plumbing. +- DAG authors can mix Python and non-Python tasks in the same pipeline. +- The existing task-runner two-layer design is preserved, keeping all Airflow extensions in Python. + +### What Becomes Harder + +- Each language SDK must independently produce compatible serialized DAG JSON, which requires cross-language validation infrastructure. +- The coordinator subprocess bridge adds a TCP hop and process management overhead per non-Python task. +- Debugging non-Python tasks requires understanding the bridge layer between the task runner and the language runtime. diff --git a/java-sdk/adr/0002-dag-parsing.md b/java-sdk/adr/0002-dag-parsing.md new file mode 100644 index 0000000000000..5e69e4bb4a496 --- /dev/null +++ b/java-sdk/adr/0002-dag-parsing.md @@ -0,0 +1,412 @@ + + +# ADR-0002: DAG Parsing — Language-Specific DAG File Processing + +## Status + +Accepted + +## Context + +Airflow's standard DAG file processor only understands Python files. To support DAGs defined in other languages (Java, Go, Rust, etc.), the pipeline needs an extension point where a language-specific processor can intercept the parsing request, delegate to an external runtime, and return a result in the same format the Airflow scheduler expects. + +This ADR details the DAG parsing side of the coordinator architecture described in [ADR-0001](0001-java-sdk-airflow-integration.md). It starts with the generic model — the abstract contracts and expected behavior that any language must implement — then walks through Java as a concrete example. + +## Decision + +### Extension Point: `BaseCoordinator` + +A single abstract base class — `BaseCoordinator` — handles both DAG parsing and task execution. Concrete subclasses ship as standalone distributions (`apache-airflow-coordinators-`) under the shared `airflow.sdk.coordinators` namespace package; they are **not** Airflow providers and are not registered through `provider.yaml`. For DAG parsing, a subclass must implement two methods: + +| Method | Signature | Responsibility | +|---|---|---| +| `can_handle_dag_file` | `(bundle_name, path) -> bool` | Return `True` if this coordinator should handle the given file. Default returns `False`; subclasses add language-specific checks (e.g., "is this a JAR with a Main-Class?"). | +| `dag_parsing_cmd` | `(dag_file_path, bundle_name, bundle_path, comm_addr, logs_addr) -> list[str]` | Return the full command to launch the language runtime. `comm_addr` and `logs_addr` are `host:port` strings the process must connect to. | + +### Registration + +Coordinators are configured in `airflow.cfg` (see [ADR-0001 — Coordinator Registration](0001-java-sdk-airflow-integration.md#coordinator-registration)). Each entry names a coordinator instance, points at an importable class via `classpath`, and supplies per-instance `kwargs`: + +```ini +[sdk] +coordinators = [ + { + "name": "jdk-17", + "classpath": "airflow.sdk.coordinators.java.JavaCoordinator", + "kwargs": {"java_executable": "/usr/lib/jvm/java-17/bin/java"} + } +] +``` + +A single instance entry covers both DAG parsing and task execution — there are no separate registries for the two roles. + +**Per-host opt-in.** A coordinator becomes available on a given DAG processor host only when its distribution is installed there *and* its instance appears in the host's `[sdk] coordinators`. A deployment can run a Python-only DAG processor pool and a separate Java-capable DAG processor pool by simply *not* installing `apache-airflow-coordinators-java` (or omitting the instance from config) on the Python-only hosts. The same applies to workers ([ADR-0003](0003-workload-execution.md)). There is no requirement that every parser carry a JDK; coordinators are opt-in per host by package install plus config entry. + +### Discovery: `_resolve_processor_target()` + +When `DagFileProcessorProcess.start()` needs to parse a file: + +``` +_resolve_processor_target(path, bundle_name, bundle_path) + for entry in conf.get("sdk", "coordinators"): + coordinator_cls = import_string(entry["classpath"]) + coordinator = coordinator_cls(name=entry["name"], **entry.get("kwargs", {})) + if coordinator.can_handle_dag_file(bundle_name, path): + return functools.partial(coordinator.run_dag_parsing, path=..., bundle_name=..., bundle_path=...) + return None # fall back to default Python parser +``` + +The first coordinator instance whose `can_handle_dag_file()` returns `True` wins. If none match, the default Python `_parse_file_entrypoint` runs. Instances are constructed lazily from `[sdk] coordinators` and cached for the lifetime of the host process. + +### Transport: Why msgpack over TCP Loopback + +A natural reviewer question is "why a custom-looking framed-msgpack protocol over `127.0.0.1:`, and not Unix sockets / gRPC / HTTP REST?" Two clarifications are important: + +1. **The protocol is not new for the Java SDK.** Length-prefixed msgpack frames are the existing transport between the Airflow supervisor and the Python task runner (see `task-sdk/src/airflow/sdk/execution_time/supervisor.py` and `comms.py`). The coordinator bridge wires the language-runtime sockets onto that same byte stream — it does not define a new wire format. Migrating it would be a separate, pan-SDK change. +2. **Forward-compat for IPC messages is treated as a contract**, not as a transport choice. The decoder rules that all SDKs must follow are stated in [ADR-0003 — IPC Forward-Compatibility Contract](0003-workload-execution.md#ipc-forward-compatibility-contract). + +#### Alternatives considered + +| Option | Why not (today) | +|---|---| +| **Unix domain sockets** instead of TCP loopback | Avoids the IPv6/dual-stack concern with `127.0.0.1`, and matches conventions like Docker's `/var/run/docker.sock`. Worth revisiting once a formal IPC AIP lands; not adopted now because it would diverge from the existing Python supervisor transport, which is also TCP loopback. | +| **gRPC / Protocol Buffers** | Would require defining an intermediate IDL for `DagFileParseRequest`, `StartupDetails`, etc. The internal serialization that the language runtime returns (DagSerialization v3) is *not* expressible as a flat ProtoBuf without losing information — see "Cross-SDK serialization compatibility" below. gRPC would replace one custom-looking layer with two: ProtoBuf for transport plus a separate JSON-shaped DAG payload nested inside it. | +| **HTTP REST** | Adds an HTTP server in every language runtime and an HTTP client in the supervisor for a strictly local, single-peer connection. None of HTTP's value (intermediaries, caching, content negotiation) applies. The Java SDK's `Supervisor.kt` already does HTTP for the *Execution API* (Edge-worker path); the comm channel between supervisor and language runtime is intentionally lower-level. | +| **Keep msgpack-over-TCP** (chosen) | Reuses the existing supervisor transport unchanged; the bridge is a pure byte forwarder. New language SDKs only need a length-prefixed-msgpack codec, which exists in every target language. | + +A formal AIP for the supervisor-to-runtime comm protocol is expected as a follow-up once two or more language SDKs (Java, Go) are in tree; that AIP is the natural place to revisit transport and framing. + +### Cross-SDK Serialization Compatibility + +The `DagFileParsingResult` payload that a language runtime returns is the *Airflow internal* serialized DAG format, not an SDK-defined schema. The authoritative reference is `airflow-core/src/airflow/serialization/schema.json`, which describes `LazyDeserializedDAG` (see `airflow-core/src/airflow/dag_processing/processor.py` and `airflow-core/src/airflow/serialization/serialized_objects.py`). The scheduler reads this format directly into its internal model — any divergence is a parsing failure. + +**Why a per-language reimplementation rather than codegen?** The first attempt was to generate POJOs from `schema.json` (similar to how Pydantic models are generated from OpenAPI specs). That approach was abandoned because the generated types miss the wrapping/unwrapping rules that distinguish "decorated" fields (kept as `{"__type", "__var"}`) from "non-decorated" fields (unwrapped to the bare value), as well as the timetable/task encoding rules listed below. Wiring an extra translation layer on top of generated types added more code than implementing the serializer directly per language. + +**Compatibility strategy.** Each language SDK ships its own serializer plus a cross-SDK validator: + +- A shared `test_dags.yaml` defines logical fixtures. +- Python emits `serialized_python.json` via `DagSerialization.serialize_dag()`. +- Each language SDK emits `serialized_.json` via its own serializer. +- `compare.py` does a field-by-field comparison and fails on divergence. + +This validator is planned to run as a CI gate (PR #65959). A complementary direction (suggested by reviewers, deferred): publish JSON schemas for the IPC envelope types themselves (`DagFileParsingResult`, `StartupDetails`, `TaskInstance`), which are currently undocumented because they were Python-to-Python only. That work is out of scope for the Java SDK PR but is a sensible next step once a second language SDK is in tree. + +### What the Base Class Handles Automatically + +The matched coordinator's `run_dag_parsing()` (a concrete method on `BaseCoordinator`) delegates to `_runtime_subprocess_entrypoint()`, which handles all the TCP/process plumbing: + +1. Creates two TCP servers on `127.0.0.1` with random ports (comm + logs) +2. Creates a stderr socketpair +3. Calls `dag_parsing_cmd()` to get the command +4. Spawns the subprocess with `stdin=DEVNULL` (does NOT inherit fd 0) +5. Accepts TCP connections from the subprocess +6. Wraps fd 0 as `supervisor_comm` via `os.dup(0)` +7. Runs `_bridge()` — a raw byte forwarder between fd 0 and the TCP comm socket + +### Expected E2E Flow + +``` +Airflow Dag-Processor + │ + ▼ +DagFileProcessorProcess.start(path, bundle_name, bundle_path) + │ + ├─ _resolve_processor_target() + │ └─ iterates instances from [sdk] coordinators (airflow.cfg) + │ └─ first can_handle_dag_file() == True wins + │ + ▼ +WatchedSubprocess.start(target=coordinator.run_dag_parsing) + │ + [fork — child process gets fd 0 as Unix domain socket to supervisor] + │ + ▼ (in child) +Coordinator.run_dag_parsing(path, bundle_name, bundle_path) + │ + ▼ +BaseCoordinator._runtime_subprocess_entrypoint(DagParsingInfo) + │ + ├─ 1. Create TCP comm_server + logs_server on 127.0.0.1:random + ├─ 2. Create stderr socketpair + ├─ 3. Call dag_parsing_cmd() → get launch command + ├─ 4. Popen(cmd, stdin=DEVNULL, stderr=child_stderr) + ├─ 5. Accept TCP connections from the language runtime + ├─ 6. supervisor_comm = socket(fileno=os.dup(0)) + └─ 7. _bridge() — raw byte forwarding until process exits +``` + +### Expected Message Sequence + +Once the bridge is running, the Airflow supervisor and the language runtime communicate directly through the bridge (raw bytes, no re-encoding): + +``` +Airflow Supervisor Bridge Language Runtime + │ │ │ + ├── DagFileParseRequest ──────────┼──────────────────────►│ + │ [4-byte len][msgpack frame] │ raw byte forward │ + │ │ │ + │ │ ├── parse DAGs from + │ │ │ bundle/file + │ │ │ + │◄── DagFileParsingResult ────────┼───────────────────────┤ + │ [4-byte len][msgpack frame] │ raw byte forward │ + │ │ │ + │ │ └── exit(0) + │ │ + │ └── drain remaining bytes (5s deadline) + │ close all sockets +``` + +### DagFileParsingResult Format + +The language runtime must produce a `DagFileParsingResult` that matches Python Airflow's DagSerialization format exactly. The Airflow scheduler deserializes this into its internal model — any divergence causes parsing failures. + +**Envelope:** + +``` +{ + "type": "DagFileParsingResult", + "fileloc": "", + "serialized_dags": [ + { + "data": { + "__version": 3, + "dag": { } + } + }, + ... + ] +} +``` + +**Serialized DAG structure** (version 3): + +| Field | Type | Required | Description | +|---|---|---|---| +| `dag_id` | string | yes | Unique identifier | +| `fileloc` | string | yes | Source file path (can be empty) | +| `relative_fileloc` | string | yes | Relative source path (can be empty) | +| `timezone` | string | yes | Always `"UTC"` | +| `timetable` | `{__type, __var}` | yes | Schedule timetable (see below) | +| `tasks` | list | yes | Serialized task list | +| `dag_dependencies` | list | yes | Empty list for non-Python DAGs | +| `task_group` | map | yes | Flat root task group | +| `edge_info` | map | yes | Empty map | +| `params` | list | yes | DAG-level parameters | +| `description` | string | if set | | +| `start_date` | float (epoch) | if set | Unwrapped from `__type`/`__var` | +| `end_date` | float (epoch) | if set | Unwrapped from `__type`/`__var` | +| `tags` | list | if non-empty | Unwrapped from `__type`/`__var` | +| `catchup` | bool | if `true` | | +| `max_active_tasks` | int | if non-default | | +| `max_active_runs` | int | if non-default | | + +**Timetable encoding:** + +| Schedule | `__type` | `__var` | +|---|---|---| +| `null` | `airflow.timetables.simple.NullTimetable` | `{}` | +| `@once` | `airflow.timetables.simple.OnceTimetable` | `{}` | +| `@continuous` | `airflow.timetables.simple.ContinuousTimetable` | `{}` | +| cron expr | `airflow.timetables.trigger.CronTriggerTimetable` | `{expression, timezone, interval, run_immediately}` | + +**Task encoding:** + +``` +{ + "__type": "operator", + "__var": { + "task_id": "", + "task_type": "", + "_task_module": "", + "downstream_task_ids": [""] // only if non-empty + } +} +``` + +**Value type encoding** (for complex fields): + +| Type | Encoding | +|---|---| +| datetime | `{"__type": "datetime", "__var": }` | +| timedelta | `{"__type": "timedelta", "__var": }` | +| dict | `{"__type": "dict", "__var": {k: serialize(v), ...}}` | +| set | `{"__type": "set", "__var": [sorted_items]}` | +| list | `[serialize(item), ...]` (no wrapper) | +| primitives | pass through unchanged | + +**Non-decorated vs decorated fields:** Some fields (like `start_date`, `end_date`, `tags`) are "non-decorated" — they are serialized with `__type`/`__var` wrapping but then unwrapped to just the `__var` value before inclusion in the DAG dict. Other fields (like `default_args`, `access_control`) are "decorated" — they keep the `__type`/`__var` wrapper. This matches Python's `serialize_to_json` behavior. + +### What a Language Provider Must Implement + +For DAG parsing, a new language provider needs: + +1. **A `BaseCoordinator` subclass** with: + - `can_handle_dag_file()` — language-specific file detection (e.g., "is this a JAR?", "is this a .go file?") + - `dag_parsing_cmd()` — returns the command to launch the runtime + +2. **A runtime process** that: + - Accepts `--comm=host:port` and `--logs=host:port` CLI arguments + - Connects to both TCP addresses + - Reads a `DagFileParseRequest` msgpack frame from the comm channel + - Parses the DAGs from the bundle + - Serializes the result to DagSerialization v3 format + - Sends back a `DagFileParsingResult` msgpack frame + - Exits + +3. **Registration** as an entry in `[sdk] coordinators` in `airflow.cfg`, pointing `classpath` at the importable subclass under `airflow.sdk.coordinators.` + +### Java as a Concrete Example + +**JavaCoordinator:** + +The Java SDK implements all DAG-parsing contracts in a single `BaseCoordinator` subclass shipped as `apache-airflow-coordinators-java`: + +```python +# Distribution: apache-airflow-coordinators-java +# Module: airflow.sdk.coordinators.java.coordinator +class JavaCoordinator(BaseCoordinator): + def __init__(self, *, name, java_executable="java", jvm_args=None, jdk_home=None): + self.name = name + self.java_executable = java_executable + self.jvm_args = list(jvm_args or []) + self.jdk_home = jdk_home + + def can_handle_dag_file(self, bundle_name, path) -> bool: + # Returns True when path is a JAR with a Main-Class manifest entry + with contextlib.suppress(FileNotFoundError): + return find_main_class(Path(path)) is not None + return False + + def dag_parsing_cmd(self, *, dag_file_path, bundle_name, bundle_path, comm_addr, logs_addr): + main_class = find_main_class(Path(dag_file_path)) + return [ + self.java_executable, + *self.jvm_args, + "-classpath", + f"{bundle_path}/*", + main_class, + f"--comm={comm_addr}", + f"--logs={logs_addr}", + ] +``` + +`can_handle_dag_file()` checks that the file is a JAR with a `Main-Class` in its manifest. This ensures the coordinator only claims files it can actually handle. + +The classpath is `/*` — a wildcard that includes all JARs in the directory (the application JAR plus its dependencies). The `java_executable` and `jvm_args` come from the per-instance `kwargs` declared in `[sdk] coordinators`, so multiple instances (e.g., `jdk-11`, `jdk-17`) can launch different JVMs with different flags from the same class. + +No separate `JavaDagFileProcessor` class is needed — `BaseCoordinator` consolidates file detection, DAG parsing, and task execution into a single extension point. + +**Java SDK Bundle Process:** + +The Java bundle process (`Server.kt`) starts, connects to both TCP servers, and enters `CoordinatorComm.startProcessing()`. When it receives a `DagFileParseRequest`: + +``` +CoordinatorComm.handleIncoming(frame) + │ + ├── frame.body is DagFileParseRequest + │ file: String ← the path from the request + │ + ▼ +DagParser(request.file).parse(bundle) + │ + ├── Returns DagParsingResult(fileloc=file, dags=bundle.dags) + │ The DAGs were already loaded into the Bundle at startup + │ via BundleBuilder.getDags() + │ + ▼ +sendMessage(frame.id, result) + │ + ├── CoordinatorComm.encode(OutgoingFrame(id, result)) + │ ├── detects DagParsingResult type + │ └── calls result.serialize() ← Serde.kt + │ + ├── DagParsingResult.serialize() + │ ├── Wraps each DAG: {"data": {"__version": 3, "dag": dag.serialize(id)}} + │ ├── Dag.serialize() produces the full v3 format: + │ │ timetable, tasks, task_group, params, optional fields... + │ ├── Task.serialize() wraps as {"__type": "operator", "__var": {...}} + │ └── serializeValue() handles datetime/timedelta/dict/set encoding + │ + ├── TaskSdkFrames.encodeRequest(id, serializedMap) + │ ├── Converts map to msgpack: [id, body] + │ └── Returns byte array + │ + └── Writes [4-byte length prefix][msgpack payload] to comm channel + +shutDownRequested = true ← one-shot, process will exit +``` + +**Java SDK BundleBuilder Interface:** + +Bundle authors implement `BundleBuilder` to define their DAGs: + +```java +public class ExampleBundleBuilder implements BundleBuilder { + @Override + public List getDags() { + var dag = new Dag("java_example", null, "@daily"); + dag.addTask("extract", Extract.class, List.of()); + dag.addTask("transform", Transform.class, List.of("extract")); + dag.addTask("load", Load.class, List.of("transform")); + return List.of(dag); + } + + public static void main(String[] args) { + var bundle = new ExampleBundleBuilder().build(); + Server.create(args).serve(bundle); + } +} +``` + +The `Dag` class provides a fluent API: + +- `dagId`, `description`, `schedule` (cron or preset), `startDate`, `endDate`, and all standard Airflow DAG parameters +- `addTask(id, taskClass, dependsOn)` — registers a task and its upstream dependencies +- Dependencies are stored as a `dependants` map (parent → set of children), serialized as `downstream_task_ids` + +**Java SDK Serialization Compatibility:** + +The serialization in `Serde.kt` is validated against Python's output: + +```bash +# 1. Java generates serialized output +./gradlew sdk:test +# → writes validation/serialization/serialized_java.json + +# 2. Python generates the same DAGs +uv run validation/serialization/serialize_python.py \ + validation/serialization/test_dags.yaml \ + validation/serialization/serialized_python.json + +# 3. Field-by-field comparison +uv run validation/serialization/compare.py \ + validation/serialization/serialized_python.json \ + validation/serialization/serialized_java.json +``` + +Both share test cases defined in `test_dags.yaml`, ensuring the Java SDK produces byte-identical output to Python's `DagSerialization.serialize_dag()` for the same inputs. + +## Consequences + +- The DAG file processor can be extended to any language without modifying Airflow Core — only a `BaseCoordinator` subclass distributed as `apache-airflow-coordinators-` plus an entry in `[sdk] coordinators` is needed. +- The language runtime must produce exact DagSerialization v3 JSON, requiring cross-language validation infrastructure (e.g., `test_dags.yaml` + `compare.py`). +- The base class absorbs all TCP/process plumbing, so language providers only implement two methods for DAG parsing. +- The subprocess bridge adds latency and a process boundary; DAG parsing for non-Python files is inherently slower than in-process Python parsing. diff --git a/java-sdk/adr/0003-workload-execution.md b/java-sdk/adr/0003-workload-execution.md new file mode 100644 index 0000000000000..d1c698cd44af2 --- /dev/null +++ b/java-sdk/adr/0003-workload-execution.md @@ -0,0 +1,519 @@ + + +# ADR-0003: Workload Execution — Language-Specific Task Execution + +## Status + +Accepted + +## Context + +Airflow's standard task runner executes Python callables. To support tasks written in other languages, the pipeline needs an extension point where a language-specific coordinator can intercept the execution, delegate to an external runtime process, and bridge the Task SDK protocol so the external process can access Airflow services (connections, variables, XCom) during execution. + +This ADR details the task execution side of the coordinator architecture described in [ADR-0001](0001-java-sdk-airflow-integration.md). It starts with the generic model — the abstract contracts and expected behavior that any language must implement — then walks through Java as a concrete example. + +## Decision + +### Extension Point: `BaseCoordinator` + +The same `BaseCoordinator` base class that handles DAG parsing also handles task execution. Concrete subclasses ship as standalone distributions (`apache-airflow-coordinators-`, contributing to the `airflow.sdk.coordinators` namespace package) and are activated through `[sdk] coordinators` in `airflow.cfg` — there is no `provider.yaml` involvement. For task execution, a subclass must implement: + +| Method | Signature | Responsibility | +|---|---|---| +| `task_execution_cmd` | `(what, dag_rel_path, bundle_info, comm_addr, logs_addr) -> list[str]` | Return the full command to launch the language runtime for task execution. `comm_addr` and `logs_addr` are `host:port` strings the process must connect to. | + +The base class provides `run_task_execution()` as a concrete method that handles all TCP/process plumbing automatically (same pattern as `run_dag_parsing()` for the DAG parsing side). + +**Parameters passed to `run_task_execution()`:** + +| Parameter | Type | Description | +|---|---|---| +| `what` | `TaskInstance` | The task instance to execute (id, dag_id, task_id, run_id, try_number, etc.) | +| `dag_rel_path` | `str \| PathLike` | Relative path to the DAG file / bundle within the bundle root | +| `bundle_info` | `BundleInfo` | Bundle name and version | +| `startup_details` | `StartupDetails` | Full startup context (task instance, DAG rel path, bundle info, run context, start date) — already consumed from fd 0 | + +### Registration + +The same `[sdk] coordinators` entry covers both DAG parsing and task execution — no separate registration needed (see [ADR-0001 — Coordinator Registration](0001-java-sdk-airflow-integration.md#coordinator-registration)): + +```ini +[sdk] +coordinators = [ + {"name": "jdk-17", "classpath": "airflow.sdk.coordinators.java.JavaCoordinator", "kwargs": {"java_executable": "/usr/lib/jvm/java-17/bin/java", "jvm_args": ["-Xmx1024m"]}} +] +queue_to_coordinator = {"java": "jdk-17"} +``` + +### Discovery: `_resolve_runtime_entrypoint()` + +When `task_runner.main()` starts, before any Python task execution: + +``` +task_runner.main() + → startup_details = get_startup_details() # reads from fd 0 + → _resolve_runtime_entrypoint(startup_details) + coord_name = conf.get("sdk", "queue_to_coordinator").get(startup_details.ti.queue) + if coord_name is None: + return None # fall back to default Python execution + entry = next(e for e in conf.get("sdk", "coordinators") if e["name"] == coord_name) + coordinator = import_string(entry["classpath"])(name=coord_name, **entry.get("kwargs", {})) + return functools.partial(coordinator.run_task_execution, + what=..., dag_rel_path=..., bundle_info=..., startup_details=...) + + → if runtime_entrypoint is not None: + runtime_entrypoint() # language-specific execution + return # short-circuit — skip Python execution entirely +``` + +> **Note:** `QueueToCoordinatorMapper` resolves the task's `queue` against `[sdk] queue_to_coordinator` to pick the coordinator instance name, then looks that name up in `[sdk] coordinators` and instantiates the `classpath` with the entry's `kwargs`. Two queues mapped to two different instances of the same class (e.g., `jdk-11` and `jdk-17`) execute on different JVMs with different flags. + +### Expected E2E Flow + +``` +Airflow Executor (dispatches task) + │ + ▼ +WatchedSubprocess.start(target=task_runner.main) + │ + [fork — child process gets fd 0 as Unix domain socket to supervisor] + │ + ▼ (in child) +task_runner.main() + │ + ├─ get_startup_details() ← reads StartupDetails from fd 0 + │ + ├─ _resolve_runtime_entrypoint() + │ └─ resolves queue → instance name via [sdk] queue_to_coordinator + │ └─ instantiates the matching entry from [sdk] coordinators + │ + ▼ +Coordinator.run_task_execution(what, dag_rel_path, bundle_info, startup_details) + │ + ▼ +BaseCoordinator._runtime_subprocess_entrypoint(TaskExecutionInfo) + │ + ├─ 1. Create TCP comm_server + logs_server on 127.0.0.1:random + ├─ 2. Create stderr socketpair + ├─ 3. Call task_execution_cmd() → get launch command + ├─ 4. Popen(cmd, stdin=DEVNULL, stderr=child_stderr) + ├─ 5. Accept TCP connections from the language runtime + ├─ 6. _send_startup_details(runtime_comm, startup_details) + │ └─ re-serializes with model_dump(mode="json") to avoid + │ msgpack extension types non-Python decoders can't handle + ├─ 7. supervisor_comm = socket(fileno=os.dup(0)) + └─ 8. _bridge() — raw byte forwarding until process exits +``` + +Key difference from DAG parsing: In task execution, `task_runner.main()` has already consumed `StartupDetails` from fd 0. The bridge must re-send `StartupDetails` to the language runtime over TCP before starting the byte-forwarding bridge. This is done via `_send_startup_details()`, which re-serializes using JSON mode to avoid msgpack extension types (like `Timestamp`) that non-Python decoders may not support. + +### Expected Message Sequence + +Task execution is a multi-round conversation, unlike DAG parsing's single request/response: + +``` +Airflow Supervisor Bridge Language Runtime + │ │ │ + │ [StartupDetails sent by bridge directly] │ + │ ├── StartupDetails ────►│ + │ │ │ + │ │ ├── Look up task + │ │ │ from bundle + │ │ │ + │ │ ┌───────────────────┤ + │ │ │ Task code runs │ + │ │ │ and may request: │ + │ │ │ │ + │◄── GetConnection(conn_id) ──────┼───┤ │ + │ │ │ │ + ├── ConnectionResult ─────────────┼──►│ │ + │ │ │ │ + │◄── GetVariable(key) ────────────┼───┤ │ + │ │ │ │ + ├── VariableResult ───────────────┼──►│ │ + │ │ │ │ + │◄── GetXCom(key, dag_id, ...) ───┼───┤ │ + │ │ │ │ + ├── XComResult ───────────────────┼──►│ │ + │ │ │ │ + │◄── SetXCom(key, value, ...) ────┼───┤ │ + │ │ │ │ + ├── (empty response) ─────────────┼──►│ │ + │ │ │ │ + │ │ └───────────────────┤ + │ │ │ + │◄── SucceedTask / TaskState ─────┼───────────────────────┤ + │ (terminal — no response) │ │ + │ │ └── exit(0) + │ │ + │ └── drain, close sockets +``` + +### Task SDK Protocol Messages + +The language runtime exchanges these message types with the Airflow supervisor: + +**Runtime → Supervisor (requests):** + +| Message | Fields | Purpose | +|---|---|---| +| `GetConnection` | `conn_id` | Fetch an Airflow connection by ID | +| `GetVariable` | `key` | Fetch an Airflow variable by key | +| `GetXCom` | `key`, `dag_id`, `task_id`, `run_id`, `map_index?`, `include_prior_dates?` | Fetch an XCom value | +| `SetXCom` | `key`, `value`, `dag_id`, `task_id`, `run_id`, `map_index`, `mapped_length?` | Store an XCom value | +| `SucceedTask` | `end_date`, `task_outlets?`, `outlet_events?` | Terminal: task succeeded | +| `TaskState` | `state` (`"failed"`, `"removed"`, `"skipped"`), `end_date` | Terminal: task ended non-successfully | + +**Supervisor → Runtime (responses):** + +| Message | Fields | In response to | +|---|---|---| +| `ConnectionResult` | `conn_id`, `conn_type`, `host`, `schema`, `login`, `password`, `port`, `extra` | `GetConnection` | +| `VariableResult` | `key`, `value` | `GetVariable` | +| `XComResult` | `key`, `value` | `GetXCom` | +| (empty) | | `SetXCom` | +| `ErrorResponse` | `error`, `detail` | Any request that failed server-side | + +**Framing:** Every message is a length-prefixed msgpack frame. Requests are `[id, body]` (2-element array); responses are `[id, body, error]` (3-element array). The `id` field correlates request/response pairs. + +### Request/Response Semantics + +The task execution follows a synchronous request/response pattern from the runtime's perspective: + +1. The runtime sends a request frame (e.g., `GetVariable`) with an incrementing `id` +2. The supervisor reads the frame, fulfills the request (e.g., calls the Execution API), and sends back a response with the same `id` +3. The runtime blocks until it receives the response +4. This repeats for each Airflow service call the task code makes +5. When the task finishes, the runtime sends a terminal message (`SucceedTask` or `TaskState`) — no response is expected, and the process exits + +### IPC Forward-Compatibility Contract + +The supervisor-to-runtime IPC schema (the messages enumerated above plus `StartupDetails` and `DagFileParseRequest` from [ADR-0002](0002-dag-parsing.md)) is shared between Airflow Core (Python) and every language SDK. A formal AIP for this protocol is expected as follow-up work; until then, this section pins down the rules that the Java SDK assumes and that any future SDK (Go, Rust, …) must follow. + +**Codec rule (load-bearing).** Every SDK MUST configure its decoder to ignore unknown fields: + +- Python side: `msgspec` / Pydantic models are forward-compatible by default. +- Java side: `TaskSdkFrames.kt` configures the Jackson `ObjectMapper` with `FAIL_ON_UNKNOWN_PROPERTIES = false`. A short comment at that call site documents that this is contract, not preference — flipping it back to the Jackson default would break forward compatibility with Core. +- Any new SDK: pick a codec configuration that mirrors this (silent drop of unknown fields). + +This rule is what makes additive Core changes safe to ship without bumping a version on every SDK. The analogous trap — generated clients that emit their *own* allowlist check before the configured mapper sees the bytes — has bitten downstream Java consumers in unrelated systems; flagging the contract here makes it visible to future SDK authors. + +**Change classification.** + +| Change to a message | Status | Required action | +|---|---|---| +| Add a new optional field | **Non-breaking.** Decoders ignore it; old SDKs unaffected. | None. Just ship it. | +| Add a new required field | Breaking. | Deprecation cycle: ship as optional first, populate from Core, wait for SDKs to consume it, then tighten. | +| Rename a field | Breaking. | Deprecation cycle: emit both names from Core during transition. | +| Change a field's type | Breaking. | Deprecation cycle, typically via a new field name + parallel emission. | +| Remove a required field | Breaking. **Especially dangerous in Java**: `lateinit var` properties on `StartupDetails` deserialize silently and only throw `UninitializedPropertyAccessException` on first access, so the failure surfaces inside user task code rather than at the protocol boundary. | Deprecation cycle. Prefer making the field optional first, then remove after a release in which all SDKs have absorbed the change. | + +**Recommended testing.** A small contract test on the SDK side should feed the decoder synthetic frames that exercise the rules above — an unknown field, a missing optional field, a `null` in an optional position — so that a future codec-config regression is caught before it reaches users. `SerializationCompatibilityTest` already covers DAG-payload divergence (see [ADR-0002 — Cross-SDK Serialization Compatibility](0002-dag-parsing.md#cross-sdk-serialization-compatibility)); the IPC-envelope tests are complementary and currently in the follow-up bucket. + +### Runtime Lifecycle and Worker Capability + +The language runtime is **ephemeral and one-process-per-task**: + +- Each task instance launches its own `java -classpath /* --comm=… --logs=…` (or the equivalent for another language). The lifetime of that process is the lifetime of the task. There is no pooling or warm-pool reuse. +- Parallelism on a single worker therefore equals the number of concurrently running task processes. Five concurrent Java tasks on one worker means five JVMs. +- DAG parsing has the same shape: each `DagFileProcessorProcess` child handles one parse request and exits. The language runtime spawned underneath it inherits that ephemerality. + +**Worker capability is opt-in.** A worker can run a non-Python task only if the corresponding `apache-airflow-coordinators-` distribution is installed, the matching coordinator instance is declared in `[sdk] coordinators`, and the language toolchain (e.g., a JRE) is on the host. There is no requirement that every worker support every language. Routing relies on: + +| Layer | Mechanism | +|---|---| +| Author intent | Operator / `@task.stub` declares `queue="java"` (or any custom queue) | +| Worker selection | The executor (Celery, Kubernetes, etc.) routes the task to a worker that consumes that queue, exactly as it does for Python tasks today | +| Runtime selection | Inside the task runner, `[sdk] queue_to_coordinator` maps the queue name to a coordinator instance name; that name is resolved against `[sdk] coordinators` to instantiate the configured class with its `kwargs`; `_resolve_runtime_entrypoint` then dispatches into `.run_task_execution` | + +The deployment model is the same one that already applies to Python providers: install what your DAGs need, on the hosts they run on. Multi-language workers are possible (install both providers and both toolchains) but not required. + +**JAR / artifact version compatibility.** The Java SDK embeds its version in the bundle JAR via the `Airflow-Java-SDK-Version` manifest attribute (see [ADR-0004](0004-pure-java-dags.md)). Validating that a bundle's SDK version matches the installed `JavaCoordinator` version at execution time is planned but not yet wired in; this is a follow-up to add before promoting the SDK out of preview. + +### StartupDetails + +The first message the runtime receives is `StartupDetails`, which provides full context for the task: + +| Field | Type | Description | +|---|---|---| +| `ti` | `TaskInstance` | id, task_id, dag_id, run_id, try_number, dag_version_id, map_index, context_carrier | +| `dag_rel_path` | string | Relative path to the DAG file / bundle | +| `bundle_info` | `BundleInfo` | name, version | +| `start_date` | datetime | When this task attempt started | +| `ti_context` | `TIRunContext` | DAG run context (logical date, data interval, etc.) | +| `sentry_integration` | string | Sentry DSN for error reporting (optional) | + +### What a Language SDK Must Implement + +For task execution, a new language SDK needs: + +1. **A `BaseCoordinator` subclass** with: + - An `__init__` that accepts the kwargs the operator will declare in `[sdk] coordinators` (e.g., interpreter path, language-specific runtime flags) + - `task_execution_cmd()` — returns the command to launch the runtime, typically using attributes set in `__init__` + - (This is the same subclass that implements `can_handle_dag_file()` and `dag_parsing_cmd()` for DAG parsing — one class covers both) + +2. **A runtime process** that: + - Accepts `--comm=host:port` and `--logs=host:port` CLI arguments + - Connects to both TCP addresses + - Reads a `StartupDetails` msgpack frame from the comm channel + - Looks up the task to execute from its bundle using `ti.dag_id` and `ti.task_id` + - Executes the task, making `GetConnection`/`GetVariable`/`GetXCom`/`SetXCom` requests as needed + - Sends `SucceedTask` on success or `TaskState("failed")` on failure + - Exits + +3. **A task interface** that user code implements (analogous to Python's `@task` decorator or `BaseOperator`) + +4. **A client API** that wraps the socket protocol behind a simple interface (get_connection, get_variable, get_xcom, set_xcom) so task authors don't deal with framing + +5. **Distribution** as `apache-airflow-coordinators-`, contributing the subclass under `airflow.sdk.coordinators.` (same module path as the DAG-parsing entry — one class, one import path) + +### Java as a Concrete Example + +**JavaCoordinator (Python side):** + +The same `JavaCoordinator` that handles DAG parsing also handles task execution — no separate `JavaTaskCoordinator` class is needed: + +```python +# Distribution: apache-airflow-coordinators-java +# Module: airflow.sdk.coordinators.java.coordinator +class JavaCoordinator(BaseCoordinator): + def __init__(self, *, name, java_executable="java", jvm_args=None, jdk_home=None): + self.name = name + self.java_executable = java_executable + self.jvm_args = list(jvm_args or []) + self.jdk_home = jdk_home + + def can_handle_dag_file(self, bundle_name, path) -> bool: + with contextlib.suppress(FileNotFoundError): + return find_main_class(Path(path)) is not None + return False + + def dag_parsing_cmd(self, *, dag_file_path, bundle_name, bundle_path, comm_addr, logs_addr): + main_class = find_main_class(Path(dag_file_path)) + return [ + self.java_executable, + *self.jvm_args, + "-classpath", + f"{bundle_path}/*", + main_class, + f"--comm={comm_addr}", + f"--logs={logs_addr}", + ] + + def task_execution_cmd(self, *, what, dag_rel_path, bundle_info, comm_addr, logs_addr): + jar_path = Path(dag_rel_path) + main_class = find_main_class(jar_path) + return [ + self.java_executable, + *self.jvm_args, + "-classpath", + f"{jar_path.parent}/*", + main_class, + f"--comm={comm_addr}", + f"--logs={logs_addr}", + ] +``` + +One class, one importable `classpath`, covers both DAG parsing and task execution. Operators register it once per JVM variant in `[sdk] coordinators` and route queues to those instances via `[sdk] queue_to_coordinator`. + +**Java SDK Task Interface:** + +User task code implements a single-method interface: + +```java +// sdk: org.apache.airflow.sdk.Task +public interface Task { + void execute(Client client) throws Exception; +} +``` + +The `Client` provides access to Airflow services: + +```java +// sdk: org.apache.airflow.sdk.Client +public class Client { + // Access task metadata + public StartupDetails getDetails(); + + // Airflow services + public Connection getConnection(String id); + public Object getVariable(String key); + public Object getXCom(String key, String dagId, String taskId, String runId, ...); + public void setXCom(String key, Object value); // defaults: key="return_value", dagId/taskId/runId from current task +} +``` + +**Java SDK Task Execution Flow:** + +When the bundle process receives `StartupDetails`: + +``` +CoordinatorComm.handleIncoming(frame) + │ + ├── frame.body is StartupDetails + │ ti: TaskInstance (id, dagId, taskId, runId, tryNumber, ...) + │ dagRelPath, bundleInfo, startDate, tiContext + │ + ▼ +TaskRunner.run(bundle, request, comm) + │ + ├── Create Client(request, CoordinatorClient(comm)) + │ CoordinatorClient wraps the comm channel behind the Client interface + │ + ├── Look up task class: + │ bundle.dags[request.ti.dagId]?.tasks[request.ti.taskId] + │ └── if not found → return TaskState("removed") + │ + ├── Instantiate task: + │ task.getDeclaredConstructor().newInstance() + │ + ├── Execute: + │ try { + │ instance.execute(client) ← USER TASK CODE RUNS HERE + │ return SucceedTask() + │ } catch (Exception e) { + │ return TaskState("failed") + │ } + │ + ▼ +sendMessage(frame.id, result) ← sends SucceedTask or TaskState back +shutDownRequested = true ← one-shot, process will exit +``` + +**Java SDK Airflow Service Access:** + +When user task code calls `client.getVariable("my_key")`, the call chain is: + +``` +client.getVariable("my_key") // Client.kt (public SDK) + │ + └── impl.getVariable("my_key") // CoordinatorClient (execution) + │ + └── runBlocking { // blocks the calling thread + comm.communicate( // CoordinatorComm + GetVariable(key = "my_key") + ) + } + │ + ├── sendMessage(nextId++, GetVariable) // encode + write to comm socket + │ ├── encode: [id, {"type": "GetVariable", "key": "my_key"}] + │ └── write: [4-byte len][msgpack] + │ + ├── processOnce(::handle) // block until response arrives + │ ├── read 4-byte length prefix + │ ├── read payload + │ └── decode: [id, {"type": "VariableResult", ...}, null] + │ + └── return response.value // unwrap VariableResponse +``` + +This is fully synchronous from the task code's perspective — `getVariable()` blocks until the supervisor responds. + +**Java SDK Example Task Implementation:** + +```java +public static class Extract implements Task { + public void execute(Client client) throws Exception { + // Read XCom from a Python task in the same DAG + var pythonXcom = client.getXCom("python_task_1"); + + // Access Airflow connections + var connection = client.getConnection("test_http"); + + // Do work... + Thread.sleep(6000); + + // Push XCom for downstream tasks (Java or Python) + client.setXCom(new Date().getTime()); + } +} + +public static class Transform implements Task { + public void execute(Client client) { + // Read XCom from upstream Java task + var extractXcom = client.getXCom("extract"); + + // Access Airflow variables + var variable = client.getVariable("my_variable"); + + // Push XCom (readable by downstream Python tasks) + client.setXCom(new Date().getTime()); + } +} + +public static class Load implements Task { + public void execute(Client client) { + var xcom = client.getXCom("transform"); + throw new RuntimeException("I failed"); + // Exception → TaskRunner catches → sends TaskState("failed") + } +} +``` + +**Java SDK Complete Bundle Entry Point:** + +```java +public class ExampleBundleBuilder implements BundleBuilder { + @Override + public List getDags() { + var dag = JavaExampleBuilder.build(); + return List.of(dag); + } + + public static void main(String[] args) { + var bundle = new ExampleBundleBuilder().build(); + Server.create(args).serve(bundle); // parses --comm/--logs, connects, enters message loop + } +} +``` + +The same `main()` entry point handles both DAG parsing and task execution — the first message received (`DagFileParseRequest` or `StartupDetails`) determines the mode. + +**Java SDK Java-side Supervisor (Alternative Execution Path):** + +The Java SDK also provides `Supervisor.kt` for execution contexts where there is no Python process (e.g., the Edge Worker). In this path, the Supervisor terminates the protocol directly instead of bridging: + +``` +Supervisor.run(request) + │ + ├── Create TCP comm + logs servers + ├── Spawn Java bundle process with --comm/--logs + ├── Accept connections + ├── HTTP PATCH task → running state + ├── Send StartupDetails to bundle via comm socket + │ + └── serveTaskSdkRequests() loop: + Read frame from bundle + ├── GetConnection → HTTP GET /connections/{id} → send response + ├── GetVariable → HTTP GET /variables/{key} → send response + ├── GetXCom → HTTP GET /xcom/... → send response + ├── SetXCom → HTTP POST /xcom/... → send response + └── SucceedTask/TaskState → HTTP PATCH terminal state → exit loop +``` + +The bundle process behaves identically in both paths — it is unaware of whether its comm channel leads to a Python bridge or a Java Supervisor. This is the core design invariant of the Java SDK. + +## Consequences + +- Task execution for any language reuses the same coordinator + bridge pattern as DAG parsing, keeping the extension surface small. +- The multi-round protocol (GetConnection, GetVariable, etc.) means the language runtime has full access to Airflow services without reimplementing them — they stay in Python. +- The synchronous request/response model is simple for language SDK authors but adds a round-trip per service call. +- The Java-side Supervisor (`Supervisor.kt`) provides an alternative execution path for environments without Python, but requires the Java SDK to implement HTTP calls to the Execution API directly. +- Task authors interact with a simple `Client` interface, completely abstracted from the underlying socket protocol. diff --git a/java-sdk/adr/0004-pure-java-dags.md b/java-sdk/adr/0004-pure-java-dags.md new file mode 100644 index 0000000000000..c55bf5505630b --- /dev/null +++ b/java-sdk/adr/0004-pure-java-dags.md @@ -0,0 +1,249 @@ + + +# ADR-0004: Pure Java DAGs — Build-Time Packaging and Code Visibility + +## Status + +Accepted + +## Context + +[ADR-0001](0001-java-sdk-airflow-integration.md) introduces two ways to integrate non-Python tasks: `@task.stub` (mixed Python+Java DAGs) and pure Java DAGs (entire DAG in Java via `BundleBuilder`). [ADR-0002](0002-dag-parsing.md) and [ADR-0003](0003-workload-execution.md) describe the coordinator infrastructure for DAG parsing and task execution respectively. + +This ADR focuses on the Java-SDK-specific concerns that make pure Java DAGs work end-to-end — build-time metadata generation, source code packaging for UI visibility, and JAR manifest conventions — rather than the shared coordinator infrastructure already covered in those ADRs. + +The central challenge is that Airflow Core expects to read DAG metadata and source code from files on disk or from the metadata DB. A JAR is an opaque binary — Airflow cannot `open()` it and read Python source. The Java SDK must bridge this gap at build time by embedding machine-readable metadata and human-readable source into the JAR itself. + +## Decision + +### JAR Manifest Conventions + +The JAR manifest (`META-INF/MANIFEST.MF`) carries three SDK-specific attributes that Airflow and the Java SDK use to bootstrap a bundle: + +| Attribute | Example Value | Purpose | +|---|---------------------------------------------------|---| +| `Main-Class` | `org.apache.airflow.example.ExampleBundleBuilder` | Standard Java attribute; the coordinator uses it to launch the JVM | +| `Airflow-Java-SDK-Metadata` | `airflow-metadata.yaml` | Points to the embedded metadata file (dag IDs, task IDs) | +| `Airflow-Java-SDK-Dag-Code` | `JavaExampleBuilder.java` | Points to the embedded source file for Airflow UI display | + +These attributes are set in the Gradle build (see [Build-Time Packaging](#build-time-packaging-gradle) below). The Python-side coordinator reads `Main-Class` to construct the launch command; `BundleScanner` reads `Airflow-Java-SDK-Metadata` to discover DAG IDs without launching the JVM. + +### Build-Time Metadata: `airflow-metadata.yaml` + +At build time, the SDK runs `BundleInspector` — a build-time utility that reflectively instantiates the user's `BundleBuilder` class, calls `getDags()`, and writes a YAML file listing every DAG ID and its task IDs: + +```yaml +dags: + java_example: + tasks: + - extract + - transform + - load +``` + +This file is embedded in the JAR root and referenced by the `Airflow-Java-SDK-Metadata` manifest attribute. + +**Why build-time, not runtime?** The metadata must be available before the JVM starts. `BundleScanner` reads it from the JAR to discover which DAG IDs a bundle contains — this is used for `@task.stub` routing (mapping a `dag_id` to the correct bundle's classpath) without paying JVM startup cost. For pure Java DAGs, the coordinator already knows the bundle path, but the metadata is still useful for validation and tooling. + +**`BundleInspector`:** + +```kotlin +object BundleInspector { + @JvmStatic + fun main(args: Array) { + val className = args[0] + val outputPath = args[1] + val clazz = Class.forName(className) + val instance = clazz.getDeclaredConstructor().newInstance() as? BundleBuilder + ?: error("$className does not implement BundleBuilder") + val dags = instance.getDags() + File(outputPath).apply { parentFile.mkdirs() }.writeText(toYaml(dags)) + } + + internal fun toYaml(dags: List): String = buildString { + appendLine("dags:") + for (dag in dags) { + appendLine(" ${dag.dagId}:") + appendLine(" tasks:") + for (taskId in dag.tasks.keys) { + appendLine(" - $taskId") + } + } + } +} +``` + +### Source Code Packaging for UI Visibility + +Airflow stores DAG source code in the `dag_code` table and displays it in the web UI. For Python DAGs this is trivial — `DagCode.write_code()` reads the `.py` file from disk. For a JAR, the raw bytecode is not human-readable. + +The solution: pack the original `.java` source file into the JAR at build time. The `Airflow-Java-SDK-Dag-Code` manifest attribute tells the coordinator which file to extract. + +On the Python side, `get_code_from_file()` on the coordinator: + +1. Opens the JAR as a ZIP +2. Reads the `Airflow-Java-SDK-Dag-Code` attribute from the manifest +3. Extracts and returns the raw `.java` source + +This lets Airflow's existing `DagCode` infrastructure store and display Java source code with no changes to Airflow Core. + +### Build-Time Packaging (Gradle) + +The `example/build.gradle.kts` shows the complete packaging pattern: + +```kotlin +val bundleMainClass = application.mainClass.get() +val metadataFileName = "airflow-metadata.yaml" +val metadataOutputDir = layout.buildDirectory.dir("airflow-metadata") +val dagCodeSourcePath = bundleMainClass.replace('.', '/') + ".java" +val dagCodeFileName = bundleMainClass.substringAfterLast('.') + ".java" + +// 1. Run BundleInspector at compile time to generate metadata +val inspectBundle = tasks.register("inspectBundle") { + dependsOn("classes") + classpath = sourceSets.main.get().runtimeClasspath + mainClass.set("org.apache.airflow.sdk.BundleInspector") + args = listOf(bundleMainClass, metadataOutputDir.get().file(metadataFileName).asFile.absolutePath) +} + +// 2. Pack metadata + source into the JAR +tasks.withType { + dependsOn(inspectBundle) + from(metadataOutputDir) // airflow-metadata.yaml + from("src/java/$dagCodeSourcePath") // raw .java source file + manifest { + attributes( + "Main-Class" to bundleMainClass, + "Airflow-Java-SDK-Version" to project.version, + "Airflow-Java-SDK-Metadata" to metadataFileName, + "Airflow-Java-SDK-Dag-Code" to dagCodeFileName, + ) + } +} +``` + +The resulting JAR contains: + +``` +example.jar +├── META-INF/MANIFEST.MF (Main-Class, SDK attributes) +├── airflow-metadata.yaml (dag IDs + task IDs) +├── JavaExampleBuilder.java (raw source for UI display) +├── org/apache/airflow/example/ +│ ├── JavaExampleBuildser.class (compiled bundle entry point) +│ ├── JavaExampleBuilder$Extract.class +│ ├── JavaExampleBuilder$Transform.class +│ └── JavaExampleBuilder$Load.class +└── ... (SDK + dependency classes) +``` + +### `BundleScanner` — Runtime Bundle Discovery + +`BundleScanner` reads JAR manifests at runtime to discover bundles without launching the JVM. This is used by the `@task.stub` path to resolve which bundle contains a given `dag_id`. + +```kotlin +data class ResolvedBundle( + val mainClass: String, // From Main-Class manifest attribute + val classpath: String, // All JARs in bundle directory, colon-separated +) + +fun scanBundles(bundlesDir: Path): Map +``` + +It supports two directory layouts: + +- **Nested**: each subdirectory of `bundlesDir` is a bundle home (e.g., `bundles/my-app/lib/*.jar`) +- **Flat**: `bundlesDir` itself contains the JARs (e.g., `bundles/*.jar`) + +For each JAR, it reads the `Airflow-Java-SDK-Metadata` manifest attribute, extracts the referenced YAML, parses DAG IDs, and returns a mapping from `dag_id` to `ResolvedBundle`. + +### The BundleBuilder Authoring API + +Bundle authors implement builder classes to define their DAGs: + +```java +public class JavaExampleBuilder { + + public static class Extract implements Task { + public void execute(Client client) throws Exception { + var connection = client.getConnection("test_http"); + client.setXCom(new Date().getTime()); + } + } + + public static class Transform implements Task { + public void execute(Client client) { + var extract_xcom = client.getXCom("extract"); + client.setXCom(new Date().getTime()); + } + } + + public static Dag build() { + var dag = new Dag("java_example", null, "@daily"); + dag.addTask("extract", Extract.class, List.of()); + dag.addTask("transform", Transform.class, List.of("extract")); + return dag; + } +} +``` + +and then collect DAGs with a BundleBuilder: + +```java +public class ExampleBundleBuilder implements BundleBuilder { + public Iterable getDags() { + return List.of(JavaExampleBuilder.build()) + } + + public static void main(String[] args) { + var bundle = new ExampleBundleBuilder().build(); + Server.create(args).serve(bundle); + } +} +``` + +The `main()` method is the JVM entry point that the coordinator launches. It wires the `BundleBuilder` to the SDK's TCP communication layer (`Server` → `CoordinatorComm`), which handles DAG parsing requests and task execution commands as described in [ADR-0002](0002-dag-parsing.md) and [ADR-0003](0003-workload-execution.md). + +> **Note:** The current `BundleBuilder` interface is subject to review before the SDK reaches 1.0. Subclassing `Dag` directly may be a more natural fit and is being considered for post-OSS-integration. + +### Deployment and Updates + +A reasonable concern about JAR-based DAGs is whether updating a bundle requires draining or restarting the DAG processor / workers — Python source files are flexible because everything is read fresh on each parse, but a long-lived JVM holding a JAR open could pin an old version. + +The design avoids this by leaning on the same ephemerality that Python uses: + +- **DAG processor.** `DagFileProcessorManager` is long-lived, but each `DagFileProcessorProcess` child is one-shot and exits after returning a `DagFileParseRequest`. The Java runtime spawned underneath it (`java -classpath /* …`) shares that lifetime — it loads the JAR fresh on every parse, then exits. Replacing the JAR on disk takes effect on the next scheduled parse with no manager restart. +- **Workers.** Each task instance launches its own JVM ([ADR-0003 — Runtime Lifecycle and Worker Capability](0003-workload-execution.md#runtime-lifecycle-and-worker-capability)). The classloader is process-scoped; a swapped JAR is picked up the next time a task starts. There is no warm JVM pool to invalidate. + +In practice, "updating a Java DAG bundle" is the same shape as "updating a Python DAG file": drop the new file (or directory of JARs) into the bundle location and let normal scheduling pick it up. The version that runs a given task instance is determined at task start, not at worker start. + +Two operational details worth flagging: + +- **Atomic swap.** Writing a JAR in place while a task happens to be loading it can yield a corrupted read. Operators should prefer the standard "write to a temp name, rename into place" pattern, which the file system handles atomically on POSIX. This is the same guidance that already applies to Python file-system bundles. +- **Mid-run version skew.** Because the version is resolved per task launch, a long-running DAG run can in principle observe one bundle version for an upstream task and a different version for a downstream task if a swap happens between them. Bundle-version validation against `Airflow-Java-SDK-Bundle-Version` (planned — distinct from `Airflow-Java-SDK-Version`, which identifies the SDK toolkit; see [ADR-0003](0003-workload-execution.md#runtime-lifecycle-and-worker-capability)) gives operators a way to detect skew if it matters; the data-plane consequences (XCom shape changes, etc.) are the bundle author's responsibility, exactly as with Python. + +## Consequences + +- JAR bundles are self-contained: metadata, source, and compiled code are all in one artifact, simplifying deployment (copy one directory of JARs). +- Build-time metadata generation means DAG IDs can be discovered without JVM startup — important for `BundleScanner` and tooling. +- Source code packaging enables Airflow UI display with no changes to Airflow Core's `DagCode` infrastructure. +- The manifest convention (`Airflow-Java-SDK-*` attributes) is extensible — future attributes can carry additional metadata without breaking existing tooling. +- The build-time `BundleInspector` step adds a compile-time dependency on the SDK and requires the `BundleBuilder` class to be instantiable without side effects (no I/O, no connections in the constructor). +- Bundle authors must follow the Gradle packaging pattern (or replicate it in Maven/other build tools) — this is SDK-specific boilerplate that doesn't exist for Python DAGs. diff --git a/java-sdk/adr/0005-coordinator-packaging.md b/java-sdk/adr/0005-coordinator-packaging.md new file mode 100644 index 0000000000000..5015d5ce4fa1f --- /dev/null +++ b/java-sdk/adr/0005-coordinator-packaging.md @@ -0,0 +1,125 @@ + + +# ADR-0005: Coordinator Packaging, Module Layout, and Registration + +## Status + +Accepted — coordinators are a new distribution type, **not** Airflow providers, and are activated through Airflow configuration rather than `provider.yaml`. Tracked operationally in [apache/airflow#66451](https://github.com/apache/airflow/issues/66451). + +## Context + +[ADR-0001](0001-java-sdk-airflow-integration.md) introduces a coordinator extension point. Reviewers on PR #65958 raised three related but separable questions: + +1. **PyPI package name.** Should the Java coordinator ship as `apache-airflow-providers-sdk-java` (consistent with every other provider) or as `apache-airflow-coordinators-java` (recognizing that "language coordinator" is a structurally new kind of distribution that does not behave like operators/hooks/sensors)? +2. **Source-tree module layout.** Should it live under `providers/sdk/java/` alongside other providers, or as a new top-level peer to `providers/`, `airflow-core/`, and `task-sdk/`? +3. **Discovery / registration mechanism.** Should coordinator classes be discovered through the existing `ProvidersManager` (and its task-runtime equivalent `ProvidersManagerTaskRuntime`), or through some other mechanism? + +A second concern, raised separately, is **runtime configuration**: a single `JavaCoordinator` class is not enough to express "use JDK 11 for the legacy queue and JDK 17 for the modern queue, with different `-Xmx` values." Class-only registration forces operators to subclass for every variant or hardcode environment lookups, which the issue calls out explicitly: + +> How can I use different JDK version? How can I use different JVM arguments? We hardcoded the subprocess cmd … so users have to subclass another Coordinator to override the Java config. +> — [apache/airflow#66451](https://github.com/apache/airflow/issues/66451) + +The existing `[sdk] queue_to_sdk` config (introduced in [ADR-0001](0001-java-sdk-airflow-integration.md)) maps a queue to a *language*, not to a *runtime variant*, and is therefore insufficient for this need. + +## Decision + +### A. Distribution name: `apache-airflow-coordinators-` + +Coordinators are not Airflow providers; they are a separate distribution type. The Java coordinator ships as **`apache-airflow-coordinators-java`**. New language coordinators follow the same pattern (`apache-airflow-coordinators-go`, `apache-airflow-coordinators-rust`, …). + +A coordinator distribution exposes: + +- A `BaseCoordinator` subclass under `airflow.sdk.coordinators.`. +- No operators, hooks, sensors, triggers, or `provider.yaml`. + +### B. Module layout: namespace package under `airflow.sdk.coordinators` + +Each coordinator distribution contributes a subpackage to the **namespace package** `airflow.sdk.coordinators`. The Task SDK owns the namespace; concrete coordinator distributions add `airflow.sdk.coordinators.`. + +The Java coordinator therefore resolves as: + +```python +from airflow.utils.module_loading import import_string + +JavaCoordinator = import_string("airflow.sdk.coordinators.java.JavaCoordinator") +``` + +Both Airflow Core (DAG processor) and the Task SDK (task runner) import coordinators by this path. As long as `apache-airflow-coordinators-java` is installed on a host, that `import_string` call resolves correctly without any registry lookup. + +### C. Discovery via `[sdk] coordinators` (Airflow configuration) + +Coordinators are **not** discovered through `ProvidersManager` / `ProvidersManagerTaskRuntime`, and there is no `coordinators` key in `provider.yaml`. They are registered as instance entries in `airflow.cfg`: + +```ini +[sdk] +coordinators = [ + { + "name": "jdk-11", + "classpath": "airflow.sdk.coordinators.java.JavaCoordinator", + "kwargs": { + "java_executable": "/usr/lib/jvm/java-11-openjdk-amd64/bin/java", + "jvm_args": ["-Xmx512m"], + "jdk_home": "/usr/lib/jvm/java-11-openjdk-amd64" + } + }, + { + "name": "jdk-17", + "classpath": "airflow.sdk.coordinators.java.JavaCoordinator", + "kwargs": { + "java_executable": "/usr/lib/jvm/java-17-openjdk-amd64/bin/java", + "jvm_args": ["-Xmx1024m", "-Xms256m"], + "jdk_home": "/usr/lib/jvm/java-17-openjdk-amd64" + } + } +] + +queue_to_coordinator = {"legacy-java-queue": "jdk-11", "modern-java-queue": "jdk-17"} +``` + +The shape is intentionally similar to `AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST`: a list of self-describing entries with `name`, `classpath`, and free-form `kwargs`. + +**Renames vs ADR-0001's earlier draft:** + +| Old (`[sdk] queue_to_sdk`) | New (`[sdk] queue_to_coordinator`) | +|---|---| +| Maps queue → language tag (e.g., `"java"`) | Maps queue → coordinator instance name (e.g., `"jdk-17"`) | +| One coordinator per language | Many coordinator instances per language, distinguished by `kwargs` | + +`queue_to_coordinator` replaces `queue_to_sdk` everywhere. + +### Why not `provider.yaml` / `ProvidersManager`? + +Coordinators are not providers in the Airflow sense: + +- They expose no operators / hooks / sensors / triggers. +- They are consumed by both Airflow Core (in the DAG processor) **and** the Task SDK (in the task runner). The provider system is not designed to be loaded from inside a worker subprocess that intentionally has no Airflow-Core import. +- They need **per-instance** runtime configuration (interpreter path, JVM flags, …). `provider.yaml` registers classes, not instances, and bolting kwargs onto provider entries would distort the provider data model. +- A coordinator is the only thing in this distribution; there is no benefit to sharing the provider's discoverability surface (registry listings, `airflow providers list`, etc.). On the contrary, listing `apache-airflow-providers-sdk-java` next to AWS/GCP providers is misleading for users. + +Putting the registry in `airflow.cfg` keeps the data model honest (instances, with their kwargs) and makes the per-host opt-in (install + config-edit) explicit rather than implicit (install-implies-active). + +## Consequences + +- **`apache-airflow-coordinators-java`** ships as a new distribution type with its own release docs and constraints handling, distinct from providers. +- **`airflow.sdk.coordinators`** is a namespace package owned by the Task SDK; concrete coordinator distributions contribute subpackages to it. Multiple coordinator distributions can be installed side by side without colliding. +- **`[sdk] coordinators`** carries instance-level configuration; **`[sdk] queue_to_coordinator`** carries queue → instance routing. `[sdk] queue_to_sdk` is removed. +- Operators can register multiple instances of the same coordinator class (e.g., `jdk-11` and `jdk-17`) and bind different queues to them — solving the multi-JDK and JVM-flag use cases raised in [apache/airflow#66451](https://github.com/apache/airflow/issues/66451) without subclassing. +- The provider registry no longer shows coordinators, removing the "Java appears, Go does not" asymmetry that earlier drafts of this ADR flagged as a transitional UX wart. +- Future static-source DAG parsers (e.g., YAML / `dag-factory`) that fit the same coordinator shape can use the same `[sdk] coordinators` registry without inventing a new extension point. diff --git a/java-sdk/build.gradle.kts b/java-sdk/build.gradle.kts new file mode 100644 index 0000000000000..b2b367fc4fd91 --- /dev/null +++ b/java-sdk/build.gradle.kts @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import com.diffplug.gradle.spotless.SpotlessExtension +import org.jetbrains.kotlin.gradle.dsl.JvmTarget + +plugins { + kotlin("jvm") version "2.3.0" + id("com.diffplug.spotless") version "7.2.1" // Last version supporting JDK 11. + id("org.jlleitschuh.gradle.ktlint") version "14.0.1" +} + +allprojects { + apply(plugin = "com.diffplug.spotless") + apply(plugin = "org.jetbrains.kotlin.jvm") + apply(plugin = "org.jlleitschuh.gradle.ktlint") + + repositories { mavenCentral() } + + java { + toolchain { + languageVersion.set(JavaLanguageVersion.of(11)) + } + sourceCompatibility = JavaVersion.VERSION_11 + } + kotlin { compilerOptions { jvmTarget = JvmTarget.JVM_11 } } + + configure { + java { + target("**/*.java") + googleJavaFormat().formatJavadoc(false) + trimTrailingWhitespace() + endWithNewline() + } + } +} diff --git a/java-sdk/dags/stub_dag.py b/java-sdk/dags/stub_dag.py new file mode 100644 index 0000000000000..bd606e58e5976 --- /dev/null +++ b/java-sdk/dags/stub_dag.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.sdk import dag, task + + +@task() +def python_task_1(ti): + print("python_task_1") + print("Push Python Task 'python_task_1' XCom:") + ti.xcom_push(value="value-pushed-from-python_task_1", key="return_value") + + +@task.stub(sdk="java") +def extract(): ... + + +@task.stub(sdk="java") +def transform(): ... + + +@task() +def python_task_2(ti): + print("python_task_2") + print("Pull Java Task 'transform' XCom:") + print(ti.xcom_pull(task_ids="transform")) + + +@dag(dag_id="java_example") +def simple_dag(): + + python_task_1() >> extract() >> transform() >> python_task_2() + + +simple_dag() diff --git a/java-sdk/example/build.gradle.kts b/java-sdk/example/build.gradle.kts new file mode 100644 index 0000000000000..d1565a53bae58 --- /dev/null +++ b/java-sdk/example/build.gradle.kts @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +plugins { + application +} + +dependencies { + annotationProcessor(project(":sdk")) + implementation(project(":sdk")) + implementation("org.slf4j:slf4j-simple:2.0.17") +} + +sourceSets { + main { + java.srcDir("src/java") + } +} + +application { + mainClass = "org.apache.airflow.example.ExampleBundleBuilder" +} + +val bundleMainClass = application.mainClass.get() +val metadataFileName = "airflow-metadata.yaml" +val metadataOutputDir = layout.buildDirectory.dir("airflow-metadata") +val dagCodeSourcePath = bundleMainClass.replace('.', '/') + ".java" +val dagCodeFileName = bundleMainClass.substringAfterLast('.') + ".java" + +val inspectBundle = + tasks.register("inspectBundle") { + description = "Collect Dag structures by inspecting the Dag bundle" + dependsOn("classes") + classpath = sourceSets.main.get().runtimeClasspath + mainClass.set("org.apache.airflow.sdk.BundleInspector") + args = + listOf( + bundleMainClass, + metadataOutputDir + .get() + .file(metadataFileName) + .asFile.absolutePath, + ) + } + +tasks.withType { + dependsOn(inspectBundle) + from(metadataOutputDir) + from("src/java/$dagCodeSourcePath") + manifest { + attributes( + "Main-Class" to bundleMainClass, + "Airflow-Java-SDK-Version" to project.version, + "Airflow-Java-SDK-Metadata" to metadataFileName, + "Airflow-Java-SDK-Dag-Code" to dagCodeFileName, + "Implementation-Title" to "Example Java bundle", + "Implementation-Version" to "1", + ) + } +} diff --git a/java-sdk/example/src/java/org/apache/airflow/example/AnnotationExample.java b/java-sdk/example/src/java/org/apache/airflow/example/AnnotationExample.java new file mode 100644 index 0000000000000..4068f36ef0253 --- /dev/null +++ b/java-sdk/example/src/java/org/apache/airflow/example/AnnotationExample.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.example; + +import org.apache.airflow.sdk.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Date; + +@Builder.Dag(id = "java_annotation_example") +public class AnnotationExample { + private static final Logger logger = LoggerFactory.getLogger(AnnotationExample.class); + + @Builder.Task(id = "extract") + public long extractValue(Client client) throws InterruptedException { + logger.info("Hello from task"); + + var pythonXcom = client.getXCom("python_task_1"); + logger.info("Got XCom from Python Task 'python_task_1' {}", pythonXcom); + + var connection = client.getConnection("test_http"); + logger.info("Got con {}", connection); + + for (var i = 0; i < 3; i++) { + logger.info("Beep {}, next time will be {}", i, new Date()); + Thread.sleep(2 * 1000); + } + + logger.info("Goodbye from task"); + return new Date().getTime(); + } + + @Builder.Task(id = "transform", depends = {"extract"}) + public long transformValue(Client client, @Builder.XCom(task = "extract") long extracted) { + logger.info("Got XCom from 'extract' {}", extracted); + + var variable = client.getVariable("my_variable"); + logger.info("Got variable {}", variable); + + logger.info("Push XCom to python task 2"); + return new Date().getTime(); + } + + @Builder.Task(depends = {"transform"}) + public void load(@Builder.XCom(task = "transform") long transformed) { + logger.info("Got XCom from 'transform' {}", transformed); + throw new RuntimeException("I failed"); + } +} diff --git a/java-sdk/example/src/java/org/apache/airflow/example/ExampleBundleBuilder.java b/java-sdk/example/src/java/org/apache/airflow/example/ExampleBundleBuilder.java new file mode 100644 index 0000000000000..0aa729d00030e --- /dev/null +++ b/java-sdk/example/src/java/org/apache/airflow/example/ExampleBundleBuilder.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.example; + +import org.apache.airflow.sdk.*; +import org.jetbrains.annotations.NotNull; +import java.util.List; + +public class ExampleBundleBuilder implements BundleBuilder { + @NotNull + @Override + public Iterable getDags() { + return List.of(InterfaceExampleBuilder.build(), AnnotationExampleBuilder.build()); + } + + public static void main(String[] args) { + var bundle = new ExampleBundleBuilder().build(); + Server.create(args).serve(bundle); + } +} diff --git a/java-sdk/example/src/java/org/apache/airflow/example/InterfaceExampleBuilder.java b/java-sdk/example/src/java/org/apache/airflow/example/InterfaceExampleBuilder.java new file mode 100644 index 0000000000000..2599d90cc4317 --- /dev/null +++ b/java-sdk/example/src/java/org/apache/airflow/example/InterfaceExampleBuilder.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.example; + +import java.util.Date; +import java.util.List; +import org.apache.airflow.sdk.*; +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class InterfaceExampleBuilder { + private static final Logger logger = LoggerFactory.getLogger(InterfaceExampleBuilder.class); + + public static class Extract implements Task { + public void execute(@NotNull Context context, Client client) throws Exception { + logger.info("Hello from task"); + + var pythonInput = client.getXCom("python_task_1"); + logger.info("Got XCom from Python Task 'python_task_1' {}", pythonInput); + + var connection = client.getConnection("test_http"); + logger.info("Got con {}", connection); + + for (var i = 0; i < 3; i++) { + logger.info("Beep {}, next time will be {}", i, new Date()); + Thread.sleep(2 * 1000); + } + + client.setXCom(new Date().getTime()); + logger.info("Goodbye from task"); + } + } + + public static class Transform implements Task { + public void execute(@NotNull Context context, Client client) { + var extracted = client.getXCom("extract"); + logger.info("Got XCom from 'extract' {}", extracted); + + var variable = client.getVariable("my_variable"); + logger.info("Got variable {}", variable); + + logger.info("Push XCom to python task 2"); + client.setXCom(new Date().getTime()); + } + } + + public static class Load implements Task { + public void execute(@NotNull Context context, Client client) { + var transformed = client.getXCom("transform"); + logger.info("Got XCom from 'transform' {}", transformed); + throw new RuntimeException("I failed"); + } + } + + public static Dag build() { + var dag = new Dag("java_example"); + dag.addTask("extract", Extract.class, List.of()); + dag.addTask("transform", Transform.class, List.of("extract")); + dag.addTask("load", Load.class, List.of("transform")); + return dag; + } +} diff --git a/java-sdk/gradle.properties b/java-sdk/gradle.properties new file mode 100644 index 0000000000000..7ec9aa8974afa --- /dev/null +++ b/java-sdk/gradle.properties @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was generated by the Gradle 'init' task. +# https://docs.gradle.org/current/userguide/build_environment.html#sec:gradle_configuration_properties + +org.gradle.configuration-cache=true + +airflowExecApiVersion=2025-11-05 diff --git a/java-sdk/gradle/libs.versions.toml b/java-sdk/gradle/libs.versions.toml new file mode 100644 index 0000000000000..a0ac505d527fe --- /dev/null +++ b/java-sdk/gradle/libs.versions.toml @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file was generated by the Gradle 'init' task. +# https://docs.gradle.org/current/userguide/platforms.html#sub::toml-dependencies-format diff --git a/java-sdk/gradle/wrapper/gradle-wrapper.jar b/java-sdk/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000..f8e1ee3125fe0 Binary files /dev/null and b/java-sdk/gradle/wrapper/gradle-wrapper.jar differ diff --git a/java-sdk/gradle/wrapper/gradle-wrapper.properties b/java-sdk/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000000..bc660c8e3d572 --- /dev/null +++ b/java-sdk/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +#Mon Jan 19 21:06:09 CST 2026 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/java-sdk/gradlew b/java-sdk/gradlew new file mode 100755 index 0000000000000..adff685a0348c --- /dev/null +++ b/java-sdk/gradlew @@ -0,0 +1,248 @@ +#!/bin/sh + +# +# Copyright © 2015 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s\n' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -jar "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/java-sdk/gradlew.bat b/java-sdk/gradlew.bat new file mode 100644 index 0000000000000..c4bdd3ab8e3cc --- /dev/null +++ b/java-sdk/gradlew.bat @@ -0,0 +1,93 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/java-sdk/sdk/build.gradle.kts b/java-sdk/sdk/build.gradle.kts new file mode 100644 index 0000000000000..b494a72ce1b22 --- /dev/null +++ b/java-sdk/sdk/build.gradle.kts @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import org.jetbrains.kotlin.gradle.tasks.KotlinCompile +import org.jlleitschuh.gradle.ktlint.tasks.KtLintCheckTask +import java.time.ZonedDateTime + +buildscript { + repositories { + mavenCentral() + } +} + +val airflowExecApiVersion: String by project + +plugins { + kotlin("plugin.serialization") version "2.3.0" + id("org.openapi.generator") version "7.19.0" +} + +val constantsDir = layout.buildDirectory.dir("generate-constants/main/src/main/kotlin") + +dependencies { + compileOnly("com.github.spotbugs:spotbugs-annotations:4.9.8") + compileOnly("javax.annotation:javax.annotation-api:1.3.2") + compileOnly("org.apache.oltu.oauth2:org.apache.oltu.oauth2.client:1.0.1") + + implementation("com.fasterxml.jackson.core:jackson-annotations:2.21") + implementation("com.fasterxml.jackson.core:jackson-core:2.21.1") + implementation("com.fasterxml.jackson.core:jackson-databind:2.21.0") + implementation("com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.21.0") + implementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.21.0") + implementation("com.squareup:javapoet:1.13.0") + implementation("com.squareup.retrofit2:converter-jackson:3.0.0") + implementation("com.squareup.retrofit2:converter-scalars:3.0.0") + implementation("com.squareup.retrofit2:retrofit:3.0.0") + implementation("com.xenomachina:kotlin-argparser:2.0.7") + implementation("io.ktor:ktor-network:3.3.3") + implementation("javax.ws.rs:javax.ws.rs-api:2.0") + implementation("org.jetbrains.kotlinx:kotlinx-coroutines-core:1.10.2") + implementation("org.jetbrains.kotlinx:kotlinx-datetime:0.7.1") + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.10.0") + implementation("org.msgpack:msgpack-core:0.9.11") + implementation("org.msgpack:jackson-dataformat-msgpack:0.9.11") + + testImplementation(kotlin("test")) + testImplementation("com.google.testing.compile:compile-testing:0.23.0") + testImplementation("com.squareup.okhttp3:mockwebserver:4.12.0") +} + +openApiGenerate { + generatorName = "java" + library = "retrofit2" + + remoteInputSpec = "https://airflow.apache.org/schemas/execution-api/$airflowExecApiVersion.json" + apiPackage = "org.apache.airflow.sdk.execution.api.route" + modelPackage = "org.apache.airflow.sdk.execution.api.model" + invokerPackage = "org.apache.airflow.sdk.execution.api.client" + + generateApiDocumentation = false + generateApiTests = false + generateModelDocumentation = false + generateModelTests = false + + // The spec on arbitrary mapping (e.g. 'extra') causes the OpenAPI generator to output JsonValue. + // We should probably fix the spec instead, but this should work before that. + // Suggested fix: + // type: object + // additionalProperties: true + schemaMappings.put("JsonValue", "java.lang.Object") + + additionalProperties = + mapOf( + "dateLibrary" to "java8", + "openApiNullable" to false, + "serializationLibrary" to "jackson", + "withXml" to false, + ) +} + +sourceSets { + main { + java.srcDir(layout.buildDirectory.dir("generate-resources/main/src/main/java")) + kotlin.srcDir(constantsDir) + } +} + +abstract class GenerateConstantsTask : DefaultTask() { + @get:Input + abstract val airflowExecApiVersionProp: Property + + @get:OutputDirectory + abstract val outputDirProp: DirectoryProperty + + @TaskAction + fun generate() { + val dir = outputDirProp.get().asFile.resolve("org/apache/airflow/sdk/execution") + dir.mkdirs() + dir.resolve("BuildConstants.kt").writeText( + """ + // File generated at ${ZonedDateTime.now()} + package org.apache.airflow.sdk.execution + + const val AIRFLOW_EXEC_API_VERSION = "${airflowExecApiVersionProp.get()}" + """.trimIndent() + "\n", + ) + } +} + +tasks.register("generateConstants") { + description = "Generate constants to use in code from build configurations." + airflowExecApiVersionProp = airflowExecApiVersion + outputDirProp = constantsDir +} + +tasks.named("compileJava") { + dependsOn("openApiGenerate") +} + +tasks.named("compileKotlin") { + dependsOn("openApiGenerate") + dependsOn("generateConstants") +} + +tasks.named("runKtlintCheckOverMainSourceSet") { + dependsOn("openApiGenerate") + dependsOn("generateConstants") +} + +tasks.named("test") { + useJUnitPlatform() +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Builder.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Builder.kt new file mode 100644 index 0000000000000..ff278e9ff569c --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Builder.kt @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +@file:Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") + +package org.apache.airflow.sdk + +import com.squareup.javapoet.ClassName +import com.squareup.javapoet.JavaFile +import com.squareup.javapoet.MethodSpec +import com.squareup.javapoet.TypeName +import com.squareup.javapoet.TypeSpec +import javax.annotation.processing.AbstractProcessor +import javax.annotation.processing.ProcessingEnvironment +import javax.annotation.processing.RoundEnvironment +import javax.annotation.processing.SupportedAnnotationTypes +import javax.annotation.processing.SupportedSourceVersion +import javax.lang.model.SourceVersion +import javax.lang.model.element.ExecutableElement +import javax.lang.model.element.Modifier +import javax.lang.model.element.TypeElement +import javax.lang.model.type.TypeKind +import javax.lang.model.type.TypeMirror +import javax.tools.Diagnostic + +class Builder internal constructor() { + /** + * Annotation to automate a Dag-builder pattern. + * + * When applied on a class Foo, this generates a FooBuilder class with a static build method + * to create the Dag structure automatically. + * + * @param id Override the Dag ID. If empty or not provided, the annotated class's name is used by default. + * @param to Name of the Dag-builder class. If empty or not provided, use the annotated class name + "Builder". + */ + @Target(AnnotationTarget.CLASS) + @MustBeDocumented + annotation class Dag( + val id: String = "", + val to: String = "", + ) + + /** + * Annotation to automate task definition in a Dag-builder pattern. + * + * @param id Override the task ID. If empty or not provided, the annotated function's name is used by default. + * @param depends List of task IDs this task depends on. + */ + @Target(AnnotationTarget.FUNCTION) + @MustBeDocumented + annotation class Task( + val id: String = "", + val depends: Array = [], + ) + + /** + * Annotation to mark a task definition's method parameter as an XCom input. + * + * @param task The task ID to pull. If empty or not given, the annotated parameter's name is used by default. + * @param key The XCom key to pull. Defaults to the task's return value. + */ + @Target(AnnotationTarget.VALUE_PARAMETER) + @MustBeDocumented + annotation class XCom( + val task: String = "", + val key: String = Client.XCOM_RETURN_KEY, + ) +} + +@SupportedAnnotationTypes("org.apache.airflow.sdk.Builder.Dag") +@SupportedSourceVersion(SourceVersion.RELEASE_11) +class BuilderProcessor : AbstractProcessor() { + override fun process( + annotations: Set, + roundEnv: RoundEnvironment, + ): Boolean { + if (annotations.isEmpty()) return false + roundEnv.getElementsAnnotatedWith(Builder.Dag::class.java).filterIsInstance().forEach { el -> + with(processingEnv) { + runCatching { + JavaFile + .builder( + elementUtils.getPackageOf(el).qualifiedName.toString(), + buildDag(el), + ).build() + .writeTo(filer) + }.onFailure { e -> + messager.printMessage( + Diagnostic.Kind.ERROR, + e.message ?: "Unknown error", + el, + ) + } + } + } + return true + } + + private fun buildDag(el: TypeElement): TypeSpec { + val ann = el.getAnnotation(Builder.Dag::class.java)!! + + val builderClass = + TypeSpec + .classBuilder(ann.to.ifBlank { "${el.simpleName}Builder" }) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + + val buildMethod = + MethodSpec + .methodBuilder("build") + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .returns(ClassName.get(Dag::class.java)) + .addStatement($$"var dag = new $T($S)", ClassName.get(Dag::class.java), ann.id.ifBlank { el.simpleName }) + + for (inner in el.enclosedElements) { + if (inner !is ExecutableElement) continue + if (inner.isVarArgs) throw IllegalArgumentException("Cannot create task from vararg function ${inner.simpleName}") + + val ann = inner.getAnnotation(Builder.Task::class.java) ?: continue + val innerName = inner.simpleName.toString().replaceFirstChar(Char::uppercase) + + val task = buildTask(innerName, inner, el) + builderClass.addType(task.spec) + + val depends = + task.required + .map { it.taskId } + .plus(ann.depends) + .toTypedArray() + buildMethod.addStatement( + if (depends.isEmpty()) { + $$"dag.addTask($S, $L.class)" + } else { + $$"dag.addTask($S, $L.class, new String[]{$${depends.joinToString { $$"$S" }}})" + }, + ann.id.ifBlank { inner.simpleName }, + innerName, + *depends, + ) + } + + buildMethod.addStatement("return dag") + builderClass.addMethod(buildMethod.build()) + return builderClass.build() + } + + private fun buildTask( + name: String, + inner: ExecutableElement, + parent: TypeElement, + ): BuildTaskResult { + val clientType = ClassName.get(Client::class.java) + val contextType = ClassName.get(Context::class.java) + + val executeSpec = + MethodSpec + .methodBuilder("execute") + .addAnnotation(Override::class.java) + .addModifiers(Modifier.PUBLIC) + .returns(TypeName.VOID) + .addParameter(contextType, "context") + .addParameter(clientType, "client") + .addException(Exception::class.java) + + val required = mutableListOf() + val innerArgs = + with(processingEnv) { + inner.parameters.joinToString { param -> + val anno = param.getAnnotation(Builder.XCom::class.java) + val type = param.asType() + when { + anno != null -> + param.simpleName.toString().also { + required += RequiredXCom(type, it, anno.task.ifBlank { it }) + } + isType(type, clientType) -> "client" + isType(type, contextType) -> "context" + else -> throw IllegalArgumentException("Unsupported task parameter '${param.simpleName}' with type: $type") + } + } + } + required.forEach { + executeSpec.addStatement( + $$"var $L = ($T) client.getXCom($S)", + it.paramName, + with(TypeName.get(it.paramType)) { if (isPrimitive) box() else this }, + it.taskId, + ) + } + if (inner.returnType.kind == TypeKind.VOID) { + $$"new $T().$L($L)" + } else { + $$"client.setXCom(new $T().$L($L))" + }.also { + executeSpec.addStatement( + it, + ClassName.get(parent), + inner.simpleName, + innerArgs, + ) + } + + val spec = + TypeSpec + .classBuilder(name) + .addSuperinterface(Task::class.java) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL, Modifier.STATIC) + .addMethod(executeSpec.build()) + .build() + return BuildTaskResult(spec, required) + } +} + +private fun ProcessingEnvironment.isType( + t: TypeMirror, + c: ClassName, +): Boolean = typeUtils.isSameType(t, elementUtils.getTypeElement(c.canonicalName()).asType()) + +private data class RequiredXCom( + val paramType: TypeMirror, + val paramName: String, + val taskId: String, +) + +private data class BuildTaskResult( + val spec: TypeSpec, + val required: List, +) diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Bundle.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Bundle.kt new file mode 100644 index 0000000000000..edf962fd14578 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Bundle.kt @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +class Bundle( + val version: String, + dags: Iterable, +) { + val dags: Map = dags.associateByDagId() +} + +private fun Iterable.associateByDagId(): Map { + val dagMap = linkedMapOf() + for (dag in this) { + require(dagMap.putIfAbsent(dag.id, dag) == null) { + "Dags in bundle have duplicate ID: ${dag.id}" + } + } + return dagMap +} + +/** + * Interface for declaring DAGs in a bundle. + * + *

Implement this interface in the class specified as {@code Main-Class} in your JAR manifest. + * The build system instantiates this class at compile time to extract dag_ids and task_ids + * into the JAR manifest, enabling inspection of bundled DAGs without running the full process. + */ +interface BundleBuilder { + fun getDags(): Iterable + + fun build(): Bundle = Bundle(this::class.java.`package`.implementationVersion ?: "0", getDags()) +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/BundleInspector.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/BundleInspector.kt new file mode 100644 index 0000000000000..ae180c0d91e91 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/BundleInspector.kt @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import java.io.File + +/** + * Build-time utility that inspects a [BundleBuilder] implementation and writes + * dag_ids and task_ids to a YAML metadata file for inclusion in the JAR. + * + * Usage: {@code java -cp org.apache.airflow.sdk.BundleInspector } + */ +object BundleInspector { + @JvmStatic + fun main(args: Array) { + require(args.size == 2) { "Usage: BundleInspector " } + val className = args[0] + val outputPath = args[1] + + val clazz = Class.forName(className) + val instance = + clazz.getDeclaredConstructor().newInstance() as? BundleBuilder + ?: error("$className does not implement ${BundleBuilder::class.qualifiedName}") + val dags = instance.getDags() + + val outputFile = File(outputPath) + outputFile.parentFile.mkdirs() + outputFile.writeText(toYaml(dags)) + } + + private fun toYaml(dags: Iterable): String = + buildString { + appendLine("dags:") + for (dag in dags) { + appendLine(" ${dag.id}:") + appendLine(" tasks:") + for (id in dag.tasks.keys) { + appendLine(" - $id") + } + } + } +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/BundleScanner.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/BundleScanner.kt new file mode 100644 index 0000000000000..e73698bd9f6cf --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/BundleScanner.kt @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory +import org.apache.airflow.sdk.execution.containsJars +import org.apache.airflow.sdk.execution.isJarFile +import org.apache.airflow.sdk.execution.jarFiles +import java.io.File +import java.nio.file.Files +import java.nio.file.Path +import java.util.jar.JarFile + +const val METADATA_MANIFEST_KEY = "Airflow-Java-SDK-Metadata" + +private val yamlMapper = ObjectMapper(YAMLFactory()) + +/** + * A fully resolved bundle: everything needed to start the bundle process. + */ +data class ResolvedBundle( + val mainClass: String, + val classpath: String, +) + +/** + * Scans [bundlesDir] for Java DAG bundles by checking JAR manifests for the + * [METADATA_MANIFEST_KEY] attribute and reading the referenced YAML metadata. + * + * Supports two layouts: + * - **Nested**: each immediate subdirectory of [bundlesDir] is a bundle home. + * - **Flat**: [bundlesDir] itself contains the bundle JARs. + * + * Returns a mapping from dag_id to a [ResolvedBundle] with mainClass and classpath. + */ +fun scanBundles(bundlesDir: Path): Map { + if (!Files.isDirectory(bundlesDir)) return emptyMap() + val result = mutableMapOf() + + // Check each immediate subdirectory as a potential bundle home. + Files.list(bundlesDir).use { paths -> + paths.filter { Files.isDirectory(it) }.forEach { candidate -> + collectBundleDags(candidate, result) + } + } + + // Also check bundlesDir itself (flat layout). + collectBundleDags(bundlesDir, result) + + return result +} + +private fun collectBundleDags( + candidate: Path, + result: MutableMap, +) { + val bundleHome = normalizeBundleHome(candidate) + val resolved = resolveBundle(bundleHome) ?: return + for (dagId in resolved.first) { + result.putIfAbsent(dagId, resolved.second) + } +} + +/** + * Inspects JARs in [bundleHome] for [METADATA_MANIFEST_KEY] and Main-Class. + * Returns (dagIds, ResolvedBundle) or null if no JAR carries the metadata attribute. + */ +private fun resolveBundle(bundleHome: Path): Pair, ResolvedBundle>? { + val jars = jarFiles(bundleHome) + if (jars.isEmpty()) return null + + for (jarPath in jars) { + JarFile(jarPath.toFile()).use { jar -> + val attrs = jar.manifest?.mainAttributes ?: return@use + val metadataFile = attrs.getValue(METADATA_MANIFEST_KEY) ?: return@use + val mainClass = attrs.getValue("Main-Class") ?: return@use + val entry = jar.getJarEntry(metadataFile) ?: return@use + val content = jar.getInputStream(entry).bufferedReader().readText() + val dagIds = parseDagIdsFromYaml(content) + if (dagIds.isEmpty()) return@use + + val classpath = + jars + .map { it.toAbsolutePath().normalize().toString() } + .joinToString(File.pathSeparator) + + return dagIds to ResolvedBundle(mainClass, classpath) + } + } + return null +} + +fun readBundleDagIds(bundleHome: Path): Set { + for (jarPath in jarFiles(bundleHome)) { + JarFile(jarPath.toFile()).use { jar -> + val metadataFile = jar.manifest?.mainAttributes?.getValue(METADATA_MANIFEST_KEY) ?: return@use + val entry = jar.getJarEntry(metadataFile) ?: return@use + val content = jar.getInputStream(entry).bufferedReader().readText() + return parseDagIdsFromYaml(content) + } + } + return emptySet() +} + +fun parseDagIdsFromYaml(yaml: String): Set { + val root = yamlMapper.readTree(yaml) + val dagsNode = root.get("dags") ?: return emptySet() + val dagIds = mutableSetOf() + dagsNode.fieldNames().forEachRemaining { dagIds.add(it) } + return dagIds +} + +private fun normalizeBundleHome(path: Path): Path { + val normalized = path.toAbsolutePath().normalize() + if (normalized.isJarFile()) return normalized.parent + val lib = normalized.resolve("lib") + return if (containsJars(lib)) lib else normalized +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Client.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Client.kt new file mode 100644 index 0000000000000..d6507b8345d55 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Client.kt @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import org.apache.airflow.sdk.execution.Client +import org.apache.airflow.sdk.execution.StartupDetails + +class Client( + val details: StartupDetails, + val impl: Client, +) { + companion object { + const val XCOM_RETURN_KEY = "return_value" + } + + fun getConnection(id: String): Connection = + with(impl.getConnection(id)) { + Connection( + id = connId, + type = connType, + host = host, + schema = schema, + login = login, + password = password, + port = port, + extra = extra, + ) + } + + fun getVariable(key: String): Any? = impl.getVariable(key).value + + @JvmOverloads fun getXCom( + key: String = XCOM_RETURN_KEY, + dagId: String = details.ti.dagId, + taskId: String, + runId: String = details.ti.runId, + mapIndex: Int? = null, + includePriorDates: Boolean = false, + ): Any? = + impl + .getXCom( + key = key, + dagId = dagId, + taskId = taskId, + runId = runId, + mapIndex = mapIndex, + includePriorDates = includePriorDates, + ).value + + @JvmOverloads fun setXCom( + key: String = XCOM_RETURN_KEY, + value: Any, + ) = impl.setXCom( + key = key, + value = value, + dagId = details.ti.dagId, + taskId = details.ti.taskId, + runId = details.ti.runId, + mapIndex = details.ti.mapIndex ?: -1, + ) +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Config.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Config.kt new file mode 100644 index 0000000000000..ad2c48dfd41b6 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Config.kt @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory +import java.nio.file.Files +import java.nio.file.Path +import java.nio.file.Paths + +private const val CONFIG_FILE_NAME = "java-sdk.yaml" + +open class WorkerError( + message: String, +) : IllegalStateException(message) + +class NoBody : WorkerError("No body") + +/** + * SDK configuration resolved from environment variables and an optional YAML config file. + * + * Resolution order (highest priority first): + * 1. Environment variable `AIRFLOW__

__` (uppercase, double-underscore delimited) + * 2. YAML config file value at `
.` (lowercase) + * 3. Default value (where applicable) + * + * Only the canonical `AIRFLOW__
__` env var form is recognised. + * Single-underscore variants (`AIRFLOW_SECTION_KEY`) are **not** supported — use the + * YAML file for a more readable alternative. + * + * The YAML file is loaded from `$AIRFLOW_HOME/java-sdk.yaml` when present. + * + * ```yaml + * core: + * execution_api_server_url: "http://localhost:8080/execution/" + * + * sdk: + * bundles_dir: "./bin" + * + * api_auth: + * jwt_secret: "your-secret-key" + * jwt_issuer: "airflow" + * jwt_expiration_time: 30 + * ``` + * + * Each YAML key corresponds directly to the env-var option name: + * `core.execution_api_server_url` ↔ `AIRFLOW__CORE__EXECUTION_API_SERVER_URL`. + */ +class SdkConfig( + private val env: Map = System.getenv(), + yamlOverride: Path? = null, +) { + @Suppress("UNCHECKED_CAST") + private val yaml: Map> = + run { + val path = yamlOverride ?: resolveConfigPath(env) + if (path != null && Files.isRegularFile(path)) { + val raw = ObjectMapper(YAMLFactory()).readValue(path.toFile(), Map::class.java) as? Map ?: emptyMap() + raw.entries.associate { (k, v) -> + k to ((v as? Map<*, *>)?.entries?.associate { (ik, iv) -> ik.toString() to iv } ?: emptyMap()) + } + } else { + emptyMap() + } + } + + /** + * Look up a config value by section and key. + * Checks `AIRFLOW__
__` env var first, then YAML `
.`. + */ + fun get( + section: String, + key: String, + ): String? { + val envKey = "AIRFLOW__${section.uppercase()}__${key.uppercase()}" + env[envKey]?.takeIf { it.isNotBlank() }?.let { return it } + return yaml[section]?.get(key)?.toString()?.takeIf { it.isNotBlank() } + } + + /** Like [get] but throws [WorkerError] when the value is missing. */ + fun require( + section: String, + key: String, + ): String = + get(section, key) + ?: throw WorkerError( + "$section.$key must be configured " + + "(AIRFLOW__${section.uppercase()}__${key.uppercase()} or $CONFIG_FILE_NAME)", + ) + + /** Resolve a positive long, falling back to [default]. */ + fun getPositiveLong( + section: String, + key: String, + default: Long, + ): Long { + val raw = get(section, key) ?: return default + val parsed = + raw.toLongOrNull() + ?: throw WorkerError("$section.$key must be an integer") + if (parsed <= 0) throw WorkerError("$section.$key must be greater than 0") + return parsed + } + + // -- Execution API -- + + val executionApiUrl: String + get() { + val url = + get("core", "execution_api_server_url") + ?: get("execution", "api_url") + return url?.ensureTrailingSlash() + ?: throw WorkerError( + "core.execution_api_server_url must be configured " + + "(AIRFLOW__CORE__EXECUTION_API_SERVER_URL or $CONFIG_FILE_NAME)", + ) + } + + // -- JWT -- + + val jwtSecret: String get() = require("api_auth", "jwt_secret") + val jwtIssuer: String? get() = get("api_auth", "jwt_issuer") + val jwtExpirationTime: Long get() = getPositiveLong("api_auth", "jwt_expiration_time", 30) + + // -- Bundle resolution -- + + val bundlesDir: Path? + get() = get("sdk", "bundles_dir")?.let(Paths::get) + + companion object { + private fun resolveConfigPath(env: Map): Path? { + val home = env["AIRFLOW_HOME"]?.takeIf { it.isNotBlank() } ?: return null + return Paths.get(home, CONFIG_FILE_NAME) + } + } +} + +internal fun String.ensureTrailingSlash() = if (endsWith('/')) this else "$this/" diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Connection.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Connection.kt new file mode 100644 index 0000000000000..16f3b72d0e198 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Connection.kt @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +data class Connection( + val id: String, + val type: String?, + val host: String?, + val schema: String?, + val login: String?, + val password: String?, + val port: Int?, + val extra: String?, +) diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Context.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Context.kt new file mode 100644 index 0000000000000..0ab48f09b86c2 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Context.kt @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import org.apache.airflow.sdk.execution.StartupDetails + +data class DagRun( + @JvmField val dagId: String, + @JvmField val runId: String, +) + +data class TaskInstance( + @JvmField val dagId: String, + @JvmField val runId: String, + @JvmField val taskId: String, + @JvmField val mapIndex: Int?, + @JvmField val tryNumber: Int, +) + +data class Context( + @JvmField val dagRun: DagRun, + @JvmField val ti: TaskInstance, +) { + internal companion object { + fun from(request: StartupDetails): Context = + Context( + dagRun = with(request.tiContext.dagRun) { DagRun(dagId, runId) }, + ti = with(request.ti) { TaskInstance(dagId, runId, taskId, mapIndex, tryNumber) }, + ) + } +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Dag.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Dag.kt new file mode 100644 index 0000000000000..dd1c878f13281 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Dag.kt @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +/** + * Collection of tasks with directional dependencies. + * + * @param id The Dag's id. This must consist exclusively of alphanumeric characters, + * dashes, dots and underscores (all ASCII). + */ +class Dag( + // TODO: charset check? + val id: String, +) { + internal var tasks = mutableMapOf>() + internal var dependants = mutableMapOf>() + + @JvmOverloads + fun addTask( + id: String, + definition: Class, + dependsOn: Iterable = emptyList(), + ) { + // TODO: Check duplicate key. + tasks[id] = definition + for (parent in dependsOn) { + dependants.getOrPut(parent) { mutableSetOf() }.add(id) + } + } + + fun addTask( + id: String, + definition: Class, + dependsOn: Array, + ) = addTask(id, definition, dependsOn.toSet()) +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Server.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Server.kt new file mode 100644 index 0000000000000..741442a7f1a4c --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Server.kt @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import com.xenomachina.argparser.ArgParser +import io.ktor.network.selector.SelectorManager +import io.ktor.network.sockets.InetSocketAddress +import io.ktor.network.sockets.aSocket +import io.ktor.network.sockets.openReadChannel +import io.ktor.network.sockets.openWriteChannel +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.apache.airflow.sdk.execution.CoordinatorComm +import org.apache.airflow.sdk.execution.LogSender +import org.apache.airflow.sdk.execution.Logger +import kotlin.text.substringAfterLast +import kotlin.text.substringBeforeLast + +private class Args( + parser: ArgParser, +) { + private fun parseAddress(address: String): InetSocketAddress = + InetSocketAddress( + address.substringBeforeLast(':'), + address.substringAfterLast(':').toInt(), + ) + + val comm by parser.storing("--comm", help = "Address (host:port) to communicate with parent") { + parseAddress(this) + } + val logs by parser.storing("--logs", help = "Address (host:port) to send Airflow logs to") { + parseAddress(this) + } +} + +class ApiError( + message: String, +) : IllegalStateException(message) + +class Server( + private val comm: InetSocketAddress, + private val logs: InetSocketAddress, +) { + companion object { + @JvmStatic + fun create(args: Array): Server { + val args = ArgParser(args).parseInto(::Args) + return Server(args.comm, args.logs) + } + } + + private val logger = Logger(Server::class) + + fun serve(bundle: Bundle) { + runBlocking { + launch { + awaitAll( + async { + aSocket(SelectorManager(Dispatchers.IO)).tcp().connect(comm).use { socket -> + logger.debug("Connected comm", mapOf("addr" to comm)) + CoordinatorComm( + bundle, + socket.openReadChannel(), + socket.openWriteChannel(autoFlush = true), + ).startProcessing() + } + }, + async { + aSocket(SelectorManager(Dispatchers.IO)).tcp().connect(logs).use { socket -> + logger.debug("Connected logs", mapOf("addr" to logs)) + LogSender.configure(socket.openWriteChannel(autoFlush = true)) + } + }, + ) + } + } + } +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Task.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Task.kt new file mode 100644 index 0000000000000..b24ad75668eab --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/Task.kt @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import kotlin.Throws + +interface Task { + @Throws(Exception::class) + fun execute( + context: Context, + client: Client, + ) +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Client.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Client.kt new file mode 100644 index 0000000000000..f22af53868890 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Client.kt @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import kotlinx.coroutines.runBlocking +import org.apache.airflow.sdk.execution.api.client.ApiClient +import org.apache.airflow.sdk.execution.api.model.ConnectionResponse +import org.apache.airflow.sdk.execution.api.model.VariableResponse +import org.apache.airflow.sdk.execution.api.model.XComResponse +import org.apache.airflow.sdk.execution.api.route.ConnectionsApi +import org.apache.airflow.sdk.execution.api.route.VariablesApi +import org.apache.airflow.sdk.execution.api.route.XComsApi +import java.time.LocalDate + +interface Client { + fun getConnection(id: String): ConnectionResponse + + fun getVariable(key: String): VariableResponse + + fun getXCom( + key: String, + dagId: String, + taskId: String, + runId: String, + mapIndex: Int? = null, + includePriorDates: Boolean = false, + ): XComResponse + + fun setXCom( + key: String, + value: Any, + dagId: String, + taskId: String, + runId: String, + mapIndex: Int, + ) +} + +class CoordinatorClient( + val exec: CoordinatorComm, +) : Client { + override fun getConnection(id: String) = runBlocking { exec.communicate(GetConnection(id)) } + + override fun getVariable(key: String) = runBlocking { exec.communicate(GetVariable(key)) } + + override fun setXCom( + key: String, + value: Any, + dagId: String, + taskId: String, + runId: String, + mapIndex: Int, + ) { + val message = + SetXCom( + key = key, + value = value, + dagId = dagId, + taskId = taskId, + runId = runId, + mapIndex = mapIndex, + ) + runBlocking { exec.communicate(message) } + } + + override fun getXCom( + key: String, + dagId: String, + taskId: String, + runId: String, + mapIndex: Int?, + includePriorDates: Boolean, + ): XComResponse { + val message = + GetXCom( + key = key, + dagId = dagId, + taskId = taskId, + runId = runId, + mapIndex = mapIndex, + includePriorDates = includePriorDates, + ) + return runBlocking { exec.communicate(message) } + } +} + +class HttpExecApiClient( + val http: ApiClient, +) : Client { + companion object { + val version: LocalDate = LocalDate.parse(AIRFLOW_EXEC_API_VERSION) + } + + override fun getConnection(id: String) = + http.communicate { + getConnection(id, version) + } + + override fun getVariable(key: String) = + http.communicate { + getVariable(key, version) + } + + override fun getXCom( + key: String, + dagId: String, + taskId: String, + runId: String, + mapIndex: Int?, + includePriorDates: Boolean, + ) = http.communicate { + getXcom( + dagId, + runId, + taskId, + key, + mapIndex, + includePriorDates, + 0, + version, + ) + } + + override fun setXCom( + key: String, + value: Any, + dagId: String, + taskId: String, + runId: String, + mapIndex: Int, + ) { + http.communicate { + setXcom( + dagId, + runId, + taskId, + key, + mapIndex, + null, + version, + value, + ) + } + } +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comms.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comms.kt new file mode 100644 index 0000000000000..0b5e50c4523c6 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Comms.kt @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import com.fasterxml.jackson.annotation.JsonProperty +import com.fasterxml.jackson.annotation.JsonPropertyOrder +import io.ktor.utils.io.ByteReadChannel +import io.ktor.utils.io.ByteWriteChannel +import io.ktor.utils.io.readByteArray +import io.ktor.utils.io.writeByteArray +import org.apache.airflow.sdk.ApiError +import org.apache.airflow.sdk.Bundle +import org.apache.airflow.sdk.execution.api.client.ApiClient +import org.apache.airflow.sdk.execution.api.model.AssetProfile +import org.apache.airflow.sdk.execution.api.model.BundleInfo +import org.apache.airflow.sdk.execution.api.model.TIRunContext +import org.apache.airflow.sdk.execution.api.model.TISuccessStatePayload +import org.apache.airflow.sdk.execution.api.model.TaskInstance +import retrofit2.Call +import java.time.OffsetDateTime +import kotlin.concurrent.atomics.AtomicInt +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.system.exitProcess + +data class IncomingFrame( + val id: Int, + val body: Any?, +) + +data class OutgoingFrame( + val id: Int, + val body: Any, +) + +class ErrorResponse { + @JsonProperty("error") + var error: String = "" // TODO: Use enum. + + @JsonProperty("detail") + var detail: Any? = null +} + +class DagFileParseRequest { + var file: String = "" + + @JsonProperty("bundle_path") + var bundlePath: String = "" +} + +class StartupDetails { + @JsonProperty("ti") + lateinit var ti: TaskInstance + + @JsonProperty("dag_rel_path") + var dagRelPath: String = "" + + @JsonProperty("bundle_info") + lateinit var bundleInfo: BundleInfo + + @JsonProperty("start_date") + lateinit var startDate: OffsetDateTime + + @JsonProperty("ti_context") + lateinit var tiContext: TIRunContext + + @JsonProperty("sentry_integration") + var sentryIntegration: String = "" +} + +class SucceedTask : TISuccessStatePayload { + constructor( + endDate: OffsetDateTime = OffsetDateTime.now(), + taskOutlets: List = emptyList(), + outletEvents: List> = emptyList(), + ) { + endDate(endDate) + taskOutlets(taskOutlets) + outletEvents(outletEvents) + } + + val type = "SucceedTask" +} + +@JsonPropertyOrder(value = ["state", "end_date", "type"]) +data class TaskState( + val state: String, // TODO: Use enum (failed, removed, skipped) and custom serialization. + @get:JsonProperty("end_date") val endDate: OffsetDateTime = OffsetDateTime.now(), +) { + val type = "TaskState" +} + +data class GetConnection( + @get:JsonProperty("conn_id") val id: String, +) { + val type = "GetConnection" +} + +data class GetVariable( + val key: String, +) { + val type = "GetVariable" +} + +data class GetXCom( + val key: String, + @get:JsonProperty("dag_id") val dagId: String, + @get:JsonProperty("task_id") val taskId: String, + @get:JsonProperty("run_id") val runId: String, + @get:JsonProperty("map_index") val mapIndex: Int? = null, + @get:JsonProperty("include_prior_dates") val includePriorDates: Boolean = false, +) { + val type = "GetXCom" +} + +data class SetXCom( + val key: String, + val value: Any, + @get:JsonProperty("dag_id") val dagId: String, + @get:JsonProperty("task_id") val taskId: String, + @get:JsonProperty("run_id") val runId: String, + @get:JsonProperty("map_index") val mapIndex: Int, + @get:JsonProperty("mapped_length") val mappedLength: Int? = null, +) { + val type = "SetXCom" +} + +@OptIn(ExperimentalAtomicApi::class) +class CoordinatorComm( + private val bundle: Bundle, + private val reader: ByteReadChannel, + private val writer: ByteWriteChannel, +) { + internal companion object { + private val logger = Logger(CoordinatorComm::class) + + fun encode(outgoing: OutgoingFrame): ByteArray { + val body = + when (val message = outgoing.body) { + is DagParsingResult -> message.serialize() + else -> message + } + return TaskSdkFrames.encodeRequest(outgoing.id, body) + } + + fun decode(bytes: ByteArray): IncomingFrame = TaskSdkFrames.decode(bytes, TaskSdkFrames.toBundleProcessTypes) + } + + private val nextId = AtomicInt(0) + private var shutDownRequested = false + + suspend fun startProcessing() { + while (!shutDownRequested) { + processOnce(::handleIncoming) + } + logger.debug("Goodbye") + } + + private suspend fun processOnce(handle: suspend (IncomingFrame) -> Unit) { + val prefix = reader.readByteArray(4) // First 4 bytes as length. + if (prefix.size != 4) { // Something is terribly wrong. Let's bail. + logger.error("Need 4 prefix bytes", mapOf("actual" to prefix.size)) + shutDownRequested = true + return + } + + val payloadLength = TaskSdkFrames.parseLengthPrefix(prefix) + val payload = reader.readByteArray(payloadLength) + if (payload.size != payloadLength) { // Something is terribly wrong. Let's bail. + logger.error( + "Payload length not right", + mapOf("expect" to payloadLength, "receive" to payload.size), + ) + shutDownRequested = true + return + } + val frame = decode(payload) + logger.debug("Handling", mapOf("id" to frame.id)) + handle(frame) + } + + private suspend fun sendMessage( + id: Int, + body: Any, + ) { + val data = encode(OutgoingFrame(id, body)) + logger.debug("Sending", mapOf("id" to id, "body" to body)) + writer.writeByteArray(TaskSdkFrames.lengthPrefix(data.size)) + writer.writeByteArray(data) + } + + suspend fun handleIncoming(frame: IncomingFrame) { + when (val request = frame.body) { + null -> {} + is ErrorResponse -> { + println("Error!! id=${frame.id} [${request.error}] ${request.detail}") // TODO: Handle error. + exitProcess(1) + } + is DagFileParseRequest -> { + val body = DagParser(request.file, request.bundlePath).parse(bundle) + sendMessage(frame.id, body) + shutDownRequested = true + } + is StartupDetails -> { + sendMessage(frame.id, TaskRunner.run(bundle, request, this)) + shutDownRequested = true + } + } + } + + @Throws(ApiError::class) + suspend fun communicateImpl(body: Any): Any { + var frame: IncomingFrame? = null + + suspend fun handle(f: IncomingFrame) { + frame = f + } + sendMessage(nextId.fetchAndAdd(1), body) + processOnce(::handle) + if (frame == null) { + throw ApiError("No response received") + } + return frame.body ?: Unit + } + + @Throws(ApiError::class) + suspend inline fun communicate(request: Any): T { + when (val response = communicateImpl(request)) { + is ErrorResponse -> throw ApiError("[${response.error}] ${response.detail}") + is T -> return response + else -> throw ApiError("Unexpected response type ${response::class.java}") + } + } +} + +internal inline fun ApiClient.communicate(block: S.() -> Call): R { + val service = createService(S::class.java) + val response = block(service).execute() + if (!response.isSuccessful) { + throw ApiError("[${response.message()}] $response (from $service") + } + return response.body() ?: throw ApiError("No body") +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/DagParser.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/DagParser.kt new file mode 100644 index 0000000000000..a8f6eb2713a2e --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/DagParser.kt @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import org.apache.airflow.sdk.Bundle +import org.apache.airflow.sdk.Dag + +data class DagParsingResult( + val fileloc: String, + val bundlePath: String, + val dags: Map, +) + +class DagParser( + val file: String, + val bundlePath: String, +) { + fun parse(bundle: Bundle): DagParsingResult = DagParsingResult(file, bundlePath, bundle.dags) +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/JarUtils.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/JarUtils.kt new file mode 100644 index 0000000000000..7536f0d49f7f8 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/JarUtils.kt @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import java.nio.file.Files +import java.nio.file.Path + +/** True when [this] points to a regular file whose name ends with `.jar`. */ +fun Path.isJarFile(): Boolean = Files.isRegularFile(this) && fileName.toString().endsWith(".jar") + +/** Lists JAR files in [directory], sorted by path name. */ +fun jarFiles(directory: Path): List { + if (!Files.isDirectory(directory)) return emptyList() + val jars = mutableListOf() + Files.list(directory).use { paths -> + paths + .filter { it.isJarFile() } + .sorted() + .forEach { jars.add(it) } + } + return jars +} + +/** True when [directory] contains at least one JAR file. */ +fun containsJars(directory: Path): Boolean = + Files.isDirectory(directory) && + Files.list(directory).use { paths -> paths.anyMatch { it.isJarFile() } } diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Logger.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Logger.kt new file mode 100644 index 0000000000000..f769ed3325324 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Logger.kt @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import io.ktor.utils.io.ByteWriteChannel +import io.ktor.utils.io.writeString +import kotlinx.coroutines.runBlocking +import kotlinx.datetime.LocalDateTime +import kotlinx.datetime.TimeZone +import kotlinx.datetime.toLocalDateTime +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject +import kotlin.reflect.KClass +import kotlin.time.Clock + +enum class Level { ERROR, DEBUG, } + +internal data class LogMessage( + val event: String, + val arguments: Map, + val logger: Logger, + val level: Level, + val timestamp: LocalDateTime = Clock.System.now().toLocalDateTime(TimeZone.currentSystemDefault()), +) + +internal class Logger( + cls: KClass<*>, +) { + val name: String? = cls.java.typeName + + // TODO: Actually implement level filtering. + @Suppress("UNUSED_PARAMETER") + fun isEnabledForLevel(level: Level): Boolean = true + + fun debug( + message: String, + arguments: Map = emptyMap(), + ) { + log(Level.DEBUG, message, arguments) + } + + fun error( + message: String, + arguments: Map = emptyMap(), + ) { + log(Level.ERROR, message, arguments) + } + + private fun log( + level: Level, + event: String, + arguments: Map, + ) { + if (!isEnabledForLevel(level)) return + LogSender.send(LogMessage(event, arguments, this, level)) + } +} + +internal object LogSender { + private var writer: ByteWriteChannel? = null + val messages: MutableList = mutableListOf() + + fun configure(channel: ByteWriteChannel) { + writer = channel + if (!channel.isClosedForWrite) { + while (messages.isNotEmpty()) { + sendTo(channel, messages.removeAt(0)) + } + } + } + + fun send(message: LogMessage) { + val channel = writer + if (channel == null || channel.isClosedForWrite) { + messages.add(message) + } else { + sendTo(channel, message) + } + } + + private fun sendTo( + writer: ByteWriteChannel, + message: LogMessage, + ) { + val map = message.arguments.toMutableMap() + map["event"] = message.event + map["level"] = message.level.name.lowercase() + map["logger"] = message.logger.name ?: "(java)" + map["timestamp"] = message.timestamp + // TODO: Can this be done asynchronously instead? + runBlocking { writer.writeString("${map.toJsonElement()}\n") } + } +} + +private fun Any?.toJsonElement(): JsonElement = + when (this) { + is JsonElement -> this + is Map<*, *> -> + buildJsonObject { + forEach { (k, v) -> put(k.toString(), v.toJsonElement()) } + } + is Iterable<*> -> buildJsonArray { forEach { add(it.toJsonElement()) } } + is Number -> JsonPrimitive(this) + is String -> JsonPrimitive(this) + null -> JsonNull + else -> JsonPrimitive(toString()) // Also correctly handles Kotlinx DateTime. + } diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/MsgPack.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/MsgPack.kt new file mode 100644 index 0000000000000..fb9a28542fcdc --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/MsgPack.kt @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import com.fasterxml.jackson.core.JsonParser +import com.fasterxml.jackson.core.JsonToken +import com.fasterxml.jackson.core.util.JacksonFeatureSet +import com.fasterxml.jackson.databind.DeserializationContext +import com.fasterxml.jackson.databind.deser.std.StdDeserializer +import com.fasterxml.jackson.databind.module.SimpleModule +import com.fasterxml.jackson.datatype.jsr310.JavaTimeFeature +import com.fasterxml.jackson.datatype.jsr310.deser.InstantDeserializer +import org.msgpack.core.ExtensionTypeHeader +import org.msgpack.core.MessagePack +import org.msgpack.core.MessagePacker +import org.msgpack.core.MessageUnpacker +import org.msgpack.jackson.dataformat.MessagePackExtensionType +import org.msgpack.value.ArrayValue +import org.msgpack.value.MapValue +import org.msgpack.value.Value +import org.msgpack.value.ValueType +import java.math.BigInteger +import java.time.OffsetDateTime +import java.time.ZoneOffset + +private fun MessagePacker.packByteArray(data: ByteArray) { + packBinaryHeader(data.size) + data.forEach { packByte(it) } +} + +private fun MessagePacker.packMap(data: Map<*, *>) { + packMapHeader(data.size) + data.forEach { (k, v) -> + check(k is String) + packString(k) + packAny(v) + } +} + +private fun MessagePacker.packCollection(data: Collection<*>) { + packArrayHeader(data.size) + data.forEach { packAny(it) } +} + +fun MessagePacker.packAny(data: Any?) { + when (data) { + null -> packNil() + is Boolean -> packBoolean(data) + is Byte -> packByte(data) + is Short -> packShort(data) + is Int -> packInt(data) + is Long -> packLong(data) + is BigInteger -> packBigInteger(data) + is Float -> packFloat(data) + is Double -> packDouble(data) + is ByteArray -> packByteArray(data) + is String -> packString(data) + is Map<*, *> -> packMap(data) + is Collection<*> -> packCollection(data) + else -> throw IllegalArgumentException("Unsupported data type: $data") + } +} + +private fun ArrayValue.decodeArray(): List<*> = + mutableListOf().also { + iterator().forEach { v -> it.add(v.decode()) } + } + +private fun MapValue.decodeMap(): Map<*, *> = + mutableMapOf().also { + entrySet().forEach { (k, v) -> it[k.asStringValue().asString()] = v.decode() } + } + +private fun Value.decode(): Any? = + when (valueType) { + ValueType.NIL -> null + ValueType.BOOLEAN -> asBooleanValue().boolean + ValueType.INTEGER -> + with(asIntegerValue()) { + if (isInLongRange) asLong() else asBigInteger() + } + ValueType.FLOAT -> asFloatValue().toDouble() + ValueType.STRING -> asStringValue().asString() + ValueType.BINARY -> asBinaryValue().asByteArray() + ValueType.ARRAY -> asArrayValue().decodeArray() + ValueType.MAP -> asMapValue().decodeMap() + else -> throw IllegalArgumentException("Unsupported data type: $this") + } + +fun MessageUnpacker.unpackAny(): Any? = unpackValue().decode() + +class TimestampToJavaOffsetDateTimeModule : SimpleModule() { + companion object { + const val EXT_TYPE: Byte = -1 + } + + class OffsetDateTimeDeserializer : StdDeserializer(OffsetDateTime::class.java) { + val instantDeserializer = + InstantDeserializer.OFFSET_DATE_TIME.withFeatures( + JacksonFeatureSet.fromDefaults(JavaTimeFeature.entries.toTypedArray()), + ) + + override fun deserialize( + p: JsonParser, + ctxt: DeserializationContext, + ): OffsetDateTime { + if (p.currentToken == JsonToken.VALUE_EMBEDDED_OBJECT) { + deserializeMsgPackTimestamp(p)?.let { return it } + } + return instantDeserializer.deserialize(p, ctxt) + } + + private fun deserializeMsgPackTimestamp(p: JsonParser): OffsetDateTime? { + val ext = p.readValueAs(MessagePackExtensionType::class.java) + if (ext.type != EXT_TYPE) { + return null + } + val unpacker = MessagePack.newDefaultUnpacker(ext.data) + val instant = unpacker.unpackTimestamp(ExtensionTypeHeader(EXT_TYPE, ext.data.size)) + return instant.atOffset(ZoneOffset.UTC) + } + } + + init { + addDeserializer(OffsetDateTime::class.java, OffsetDateTimeDeserializer()) + } +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Serde.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Serde.kt new file mode 100644 index 0000000000000..9962da20fc9af --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Serde.kt @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import org.apache.airflow.sdk.Dag +import org.apache.airflow.sdk.Task +import java.nio.file.Path +import java.time.Duration +import java.time.Instant + +/** + * Serialization logic decoupled from user-facing SDK classes. + * + * Produces output compatible with Python Airflow's DagSerialization format (version 3). + */ +typealias Serialized = Map + +private object SerdeScope + +private val logger = Logger(SerdeScope::class) + +// --------------------------------------------------------------------------- +// Value encoding — matches Python's BaseSerialization.serialize +// --------------------------------------------------------------------------- + +/** + * Recursively serialize a value with Airflow's type/var encoding. + * + * Primitives pass through; complex types are wrapped in {"__type": ..., "__var": ...}. + * This matches the Python BaseSerialization.serialize() output exactly: + * - Map -> {"__type": "dict", "__var": {k: serialize(v), ...}} + * - Set -> {"__type": "set", "__var": [sorted items]} + * - Datetime -> {"__type": "datetime", "__var": epoch_seconds} + * - Timedelta -> {"__type": "timedelta", "__var": total_seconds} + */ +internal fun serializeValue(value: Any?): Any? = + when (value) { + null -> null + is String, is Boolean, is Int, is Long, is Float, is Double -> value + is Instant -> + mapOf( + "__type" to "datetime", + "__var" to (value.epochSecond.toDouble() + value.nano.toDouble() / 1_000_000_000.0), + ) + is Duration -> + mapOf( + "__type" to "timedelta", + "__var" to (value.toMillis().toDouble() / 1000.0), + ) + is Map<*, *> -> + mapOf( + "__type" to "dict", + "__var" to value.entries.associate { (k, v) -> k.toString() to serializeValue(v) }, + ) + is Set<*> -> { + val items = value.map { serializeValue(it) } + mapOf( + "__type" to "set", + "__var" to + try { + items.sortedBy { it?.toString() ?: "" } + } catch (_: Exception) { + items + }, + ) + } + is List<*> -> value.map { serializeValue(it) } + else -> value.toString() + } + +// --------------------------------------------------------------------------- +// Timetable serialization +// --------------------------------------------------------------------------- + +private fun serializeTimetable(): Serialized = + mapOf( + "__type" to "airflow.timetables.simple.NullTimetable", + "__var" to emptyMap(), + ) + +// --------------------------------------------------------------------------- +// Task serialization +// --------------------------------------------------------------------------- + +private fun Class.serialize( + id: String, + dependants: Collection?, +): Serialized { + val data = + mutableMapOf( + "task_id" to id, + "task_type" to simpleName, + "_task_module" to name.substringBeforeLast('.'), + "sdk" to "java", + ) + if (!dependants.isNullOrEmpty()) { + data["downstream_task_ids"] = dependants.sorted() + } + return mapOf("__type" to "operator", "__var" to data) +} + +// --------------------------------------------------------------------------- +// Task group serialization (flat root group from task list) +// --------------------------------------------------------------------------- + +private fun serializeTaskGroup(taskIds: Collection): Serialized = + mapOf( + "_group_id" to null, + "group_display_name" to "", + "prefix_group_id" to true, + "tooltip" to "", + "ui_color" to "CornflowerBlue", + "ui_fgcolor" to "#000", + "children" to taskIds.associateWith { listOf("operator", it) }, + "upstream_group_ids" to emptyList(), + "downstream_group_ids" to emptyList(), + "upstream_task_ids" to emptyList(), + "downstream_task_ids" to emptyList(), + ) + +// --------------------------------------------------------------------------- +// Params serialization +// --------------------------------------------------------------------------- + +private fun serializeParams(params: Map): List> = + params.entries.map { (k, v) -> + listOf( + k, + mapOf( + "__class" to "airflow.sdk.definitions.param.Param", + "default" to serializeValue(v), + "description" to null, + "schema" to serializeValue(emptyMap()), + "source" to null, + ), + ) + } + +// --------------------------------------------------------------------------- +// DAG serialization — matches Python's DagSerialization.serialize_dag +// --------------------------------------------------------------------------- + +private fun Dag.serialize( + id: String, + fileloc: String, + relativeFileloc: String, +): Serialized = + mutableMapOf( + // Required fields (always present) + "dag_id" to id, + "fileloc" to fileloc, + "relative_fileloc" to relativeFileloc, + // Always serialized + "timezone" to "UTC", + "timetable" to serializeTimetable(), + "tasks" to tasks.entries.map { (taskId, task) -> task.serialize(taskId, dependants[taskId]) }, + "dag_dependencies" to emptyList(), + "task_group" to serializeTaskGroup(tasks.keys), + "edge_info" to emptyMap(), + "params" to serializeParams(emptyMap()), + "deadline" to null, + "allowed_run_types" to null, + ) + +/** Serialize a single DAG to a dict. Exposed for cross-language validation testing. */ +internal fun serializeDag(dag: Dag): Serialized = dag.serialize(dag.id, "", "") + +// --------------------------------------------------------------------------- +// Top-level envelope — matches Python's DagSerialization.to_dict +// --------------------------------------------------------------------------- + +private fun computeRelativeFileloc( + fileloc: String, + bundlePath: String, +): String { + if (fileloc.isEmpty()) return "" + if (bundlePath.isEmpty()) return "." + val rel = Path.of(bundlePath).relativize(Path.of(fileloc)).toString() + return rel.ifEmpty { "." } +} + +internal fun DagParsingResult.serialize(): Serialized { + val relativeFileloc = computeRelativeFileloc(fileloc, bundlePath) + val result = + mapOf( + "type" to "DagFileParsingResult", + "fileloc" to fileloc, + "serialized_dags" to + dags.entries.map { (id, d) -> + mapOf("data" to mapOf("__version" to 3, "dag" to d.serialize(id, fileloc, relativeFileloc))) + }, + ) + + logger.debug("Serialized DAG parsing result", mapOf("result" to result)) + + return result +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Supervisor.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Supervisor.kt new file mode 100644 index 0000000000000..cad988eab59d2 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/Supervisor.kt @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import org.apache.airflow.sdk.ensureTrailingSlash +import org.apache.airflow.sdk.execution.api.client.ApiClient +import org.apache.airflow.sdk.execution.api.model.BundleInfo +import org.apache.airflow.sdk.execution.api.model.TIEnterRunningPayload +import org.apache.airflow.sdk.execution.api.model.TIRunContext +import org.apache.airflow.sdk.execution.api.model.TISuccessStatePayload +import org.apache.airflow.sdk.execution.api.model.TITerminalStatePayload +import org.apache.airflow.sdk.execution.api.model.TaskInstance +import org.apache.airflow.sdk.execution.api.model.TaskInstanceState +import org.apache.airflow.sdk.execution.api.model.TerminalStateNonSuccess +import org.apache.airflow.sdk.execution.api.route.TaskInstancesApi +import org.apache.airflow.sdk.execution.api.route.XComsApi +import retrofit2.Call +import java.io.InputStream +import java.io.OutputStream +import java.net.InetAddress +import java.net.ServerSocket +import java.net.Socket +import java.time.LocalDate +import java.time.OffsetDateTime +import java.util.UUID + +data class SupervisorTaskInstance( + val id: UUID, + val taskId: String, + val dagId: String, + val runId: String, + val tryNumber: Int, + val dagVersionId: UUID, + val mapIndex: Int?, + val contextCarrier: Map? = null, +) + +data class SupervisorBundleInfo( + val name: String, + val version: String?, +) + +data class SupervisorRequest( + val mainClass: String, + val classpath: String, + val executionApiBaseUrl: String, + val token: String, + val workerName: String, + val userName: String, + val dagRelPath: String, + val bundleInfo: SupervisorBundleInfo, + val taskInstance: SupervisorTaskInstance, + val sentryIntegration: String = "", + val onLogLine: suspend (String) -> Unit = {}, +) + +data class SupervisorResult( + val finalState: TaskInstanceState, + val exitCode: Int, +) + +/** + * Retrofit interface for reporting task instance terminal state to the Execution API. + * + * Mirrors the Python SDK's `TaskInstanceOperations.succeed()` and `.finish()` methods + * (see `airflow/sdk/api/client.py`), both of which call `PATCH /task-instances/{id}/state` + * with [TISuccessStatePayload] or [TITerminalStatePayload] respectively. + * + * Why not use the generated [TaskInstancesApi.tiUpdateState]? + * The OpenAPI code generator flattens the endpoint's `oneOf` discriminated union into a single + * class [org.apache.airflow.sdk.execution.api.model.TiPatchPayload] whose `StateEnum` only + * contains `UP_FOR_RETRY`. It cannot represent `"success"`, `"failed"`, or `"skipped"`, and its + * method signature does not accept [TISuccessStatePayload] or [TITerminalStatePayload]. + * This interface works around that limitation by binding the same endpoint with the correct + * payload types that the generator *did* produce correctly as standalone classes. + */ +private interface TaskInstanceStateApi { + @retrofit2.http.Headers("Content-Type:application/json") + @retrofit2.http.PATCH("task-instances/{task_instance_id}/state") + fun succeed( + @retrofit2.http.Path("task_instance_id") id: UUID, + @retrofit2.http.Body payload: TISuccessStatePayload, + @retrofit2.http.Header("Airflow-API-Version") version: LocalDate?, + ): Call + + @retrofit2.http.Headers("Content-Type:application/json") + @retrofit2.http.PATCH("task-instances/{task_instance_id}/state") + fun finish( + @retrofit2.http.Path("task_instance_id") id: UUID, + @retrofit2.http.Body payload: TITerminalStatePayload, + @retrofit2.http.Header("Airflow-API-Version") version: LocalDate?, + ): Call +} + +object Supervisor { + private const val CONNECT_TIMEOUT_MS = 15_000 + private val loopback: InetAddress = InetAddress.getByName("127.0.0.1") + + suspend fun run(request: SupervisorRequest): SupervisorResult { + val execApi = executionApiClient(request.executionApiBaseUrl, request.token) + val execClient = HttpExecApiClient(execApi) + val startDate = OffsetDateTime.now() + + return withContext(Dispatchers.IO) { + coroutineScope { + ServerSocket(0, 1, loopback).use { commServer -> + ServerSocket(0, 1, loopback).use { logsServer -> + commServer.soTimeout = CONNECT_TIMEOUT_MS + logsServer.soTimeout = CONNECT_TIMEOUT_MS + + val process = startBundleProcess(request.classpath, request.mainClass, commServer.localPort, logsServer.localPort) + val stdoutPump = + launch(Dispatchers.IO) { + streamLines(process.inputStream, request.onLogLine) + } + val stderrPump = + launch(Dispatchers.IO) { + streamLines(process.errorStream, request.onLogLine) + } + try { + val (commSocket, logsSocket) = acceptConnections(commServer, logsServer) + + commSocket.use { comm -> + logsSocket.use { logs -> + val logPump = + launch(Dispatchers.IO) { + streamLines(logs.getInputStream(), request.onLogLine) + } + + val taskInstance = request.taskInstance.toExecutionTaskInstance(request.workerName) + val tiContext = startTask(execApi, taskInstance, startDate, process, request.workerName, request.userName) + + TaskSdkFrames.writeRequest( + comm.getOutputStream(), + 0, + request.toStartupDetails(taskInstance, tiContext, startDate), + ) + + val finalState = serveTaskSdkRequests(comm, execApi, execClient, taskInstance.id) + val exitCode = process.waitFor() + logPump.join() + stdoutPump.join() + stderrPump.join() + + SupervisorResult( + finalState = if (exitCode == 0) finalState else TaskInstanceState.FAILED, + exitCode = exitCode, + ) + } + } + } catch (e: Exception) { + process.destroy() + throw e + } + } + } + } + } + } + + internal suspend fun streamLines( + input: InputStream, + onLogLine: suspend (String) -> Unit, + ) { + withContext(Dispatchers.IO) { + input.bufferedReader().useLines { lines -> + for (line in lines) { + onLogLine(line) + } + } + } + } + + private fun serveTaskSdkRequests( + comm: Socket, + execApi: ApiClient, + execClient: HttpExecApiClient, + taskInstanceId: UUID, + ): TaskInstanceState { + val input = comm.getInputStream() + val output = comm.getOutputStream() + + while (true) { + val frame = TaskSdkFrames.readFrame(input, TaskSdkFrames.toSupervisorTypes) + when (val message = frame.body ?: return TaskInstanceState.FAILED) { + is GetConnection -> + reply(frame.id, output) { + execClient.getConnection(message.id) + } + is GetVariable -> + reply(frame.id, output) { + execClient.getVariable(message.key) + } + is GetXCom -> + reply(frame.id, output) { + execClient.getXCom( + key = message.key, + dagId = message.dagId, + taskId = message.taskId, + runId = message.runId, + mapIndex = message.mapIndex, + includePriorDates = message.includePriorDates, + ) + } + is SetXCom -> + reply(frame.id, output) { + setXCom(execApi, message) + null + } + is SucceedTask -> { + succeed(execApi, taskInstanceId, message) + return TaskInstanceState.SUCCESS + } + is TaskState -> { + finish(execApi, taskInstanceId, message) + return TaskInstanceState.fromValue(message.state) + } + is ErrorResponse -> throw IllegalStateException("[${message.error}] ${message.detail}") + else -> throw IllegalStateException("Unsupported Task SDK message type ${message::class.java.name}") + } + } + } + + private fun succeed( + execApi: ApiClient, + taskInstanceId: UUID, + message: SucceedTask, + ) { + execApi.send { + succeed( + taskInstanceId, + TISuccessStatePayload() + .endDate(message.endDate) + .taskOutlets(message.taskOutlets) + .outletEvents(message.outletEvents), + HttpExecApiClient.version, + ) + } + } + + private fun finish( + execApi: ApiClient, + taskInstanceId: UUID, + message: TaskState, + ) { + execApi.send { + finish( + taskInstanceId, + TITerminalStatePayload() + .state(TerminalStateNonSuccess.fromValue(message.state)) + .endDate(message.endDate), + HttpExecApiClient.version, + ) + } + } + + private fun reply( + requestId: Int, + output: OutputStream, + block: () -> Any?, + ) { + try { + TaskSdkFrames.writeResponse(output, requestId, body = block()) + } catch (e: Exception) { + TaskSdkFrames.writeResponse( + output, + requestId, + error = + ErrorResponse().also { + it.error = "generic_error" + it.detail = mapOf("message" to (e.message ?: e::class.java.name)) + }, + ) + } + } + + private suspend fun acceptConnections( + commServer: ServerSocket, + logsServer: ServerSocket, + ): Pair = + coroutineScope { + val comm = async(Dispatchers.IO) { commServer.accept() } + val logs = async(Dispatchers.IO) { logsServer.accept() } + comm.await() to logs.await() + } + + private fun startBundleProcess( + classpath: String, + mainClass: String, + commPort: Int, + logsPort: Int, + ): Process { + val command = + listOf( + "java", + "-classpath", + classpath, + mainClass, + "--comm=${loopback.hostAddress}:$commPort", + "--logs=${loopback.hostAddress}:$logsPort", + ) + return ProcessBuilder(command) + .redirectOutput(ProcessBuilder.Redirect.PIPE) + .redirectError(ProcessBuilder.Redirect.PIPE) + .start() + } + + private fun executionApiClient( + baseUrl: String, + token: String, + ) = ApiClient("JWTBearer").apply { + setBearerToken(token) + adapterBuilder.baseUrl(baseUrl.ensureTrailingSlash()) + } + + private fun setXCom( + execApi: ApiClient, + request: SetXCom, + ) { + execApi.send { + setXcom( + request.dagId, + request.runId, + request.taskId, + request.key, + request.mapIndex, + null, + HttpExecApiClient.version, + request.value, + ) + } + } + + private fun startTask( + api: ApiClient, + taskInstance: TaskInstance, + startDate: OffsetDateTime, + process: Process, + workerName: String, + userName: String, + ): TIRunContext = + api.communicate { + tiRun( + taskInstance.id, + TIEnterRunningPayload() + .hostname(workerName) + .unixname(userName) + .pid(process.pid().toInt()) + .startDate(startDate), + HttpExecApiClient.version, + ) + } + + private fun SupervisorTaskInstance.toExecutionTaskInstance(workerName: String) = + TaskInstance().also { + it.id = id + it.taskId = taskId + it.dagId = dagId + it.runId = runId + it.tryNumber = tryNumber + it.dagVersionId = dagVersionId + it.mapIndex = mapIndex + it.hostname = workerName + it.contextCarrier = contextCarrier + } + + private fun SupervisorRequest.toStartupDetails( + taskInstance: TaskInstance, + tiContext: TIRunContext, + startDate: OffsetDateTime, + ) = StartupDetails().also { + it.ti = taskInstance + it.dagRelPath = dagRelPath + it.bundleInfo = + BundleInfo().also { info -> + info.name = bundleInfo.name + info.version = bundleInfo.version + } + it.tiContext = tiContext + it.startDate = startDate + it.sentryIntegration = sentryIntegration + } +} + +private inline fun ApiClient.send(block: Q.() -> Call<*>) { + val service = createService(Q::class.java) + val response = block(service).execute() + if (!response.isSuccessful) { + throw IllegalStateException("[${response.message()}] $response (from $service)") + } +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/TaskRunner.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/TaskRunner.kt new file mode 100644 index 0000000000000..c88620540072a --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/TaskRunner.kt @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import org.apache.airflow.sdk.Bundle +import org.apache.airflow.sdk.Client +import org.apache.airflow.sdk.Context + +object TaskRunner { + fun run( + bundle: Bundle, + request: StartupDetails, + comm: CoordinatorComm, + ): Any = run(bundle, request, Client(request, CoordinatorClient(comm))) + + internal fun run( + bundle: Bundle, + request: StartupDetails, + client: Client, + ): Any { + val task = bundle.dags[request.ti.dagId]?.tasks[request.ti.taskId] ?: return TaskState("removed") + val instance = task.getDeclaredConstructor().newInstance() + return try { + instance.execute(Context.from(request), client) + SucceedTask() + } catch (e: Exception) { + e.printStackTrace() + TaskState("failed") + } + } +} diff --git a/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/TaskSdkFrames.kt b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/TaskSdkFrames.kt new file mode 100644 index 0000000000000..010c29a40f1d6 --- /dev/null +++ b/java-sdk/sdk/src/main/kotlin/org/apache/airflow/sdk/execution/TaskSdkFrames.kt @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import com.fasterxml.jackson.databind.DeserializationFeature +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.SerializationFeature +import com.fasterxml.jackson.databind.util.StdDateFormat +import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule +import org.apache.airflow.sdk.execution.api.model.ConnectionResponse +import org.apache.airflow.sdk.execution.api.model.VariableResponse +import org.apache.airflow.sdk.execution.api.model.XComResponse +import org.msgpack.core.MessagePack +import java.io.ByteArrayOutputStream +import java.io.EOFException +import java.io.InputStream +import java.io.OutputStream + +typealias TaskSdkMessageDecoder = (Map<*, *>) -> Any + +object TaskSdkFrames { + private val mapper = + ObjectMapper().apply { + configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS) + registerModule(JavaTimeModule()) + registerModule(TimestampToJavaOffsetDateTimeModule()) + setDateFormat(StdDateFormat().withColonInTimeZone(true)) + } + + private val inferredTypes = + mapOf( + ConnectionResponse::class to "ConnectionResult", + ErrorResponse::class to "ErrorResponse", + StartupDetails::class to "StartupDetails", + VariableResponse::class to "VariableResult", + XComResponse::class to "XComResult", + ) + + private val toBundleClientTypes: Map = + mapOf( + "ConnectionResult" to mapperDecoder(ConnectionResponse::class.java), + "ErrorResponse" to mapperDecoder(ErrorResponse::class.java), + "VariableResult" to mapperDecoder(VariableResponse::class.java), + "XComResult" to mapperDecoder(XComResponse::class.java), + ) + + val toDagProcessorTypes: Map = + toBundleClientTypes + + mapOf( + "DagFileParseRequest" to mapperDecoder(DagFileParseRequest::class.java), + ) + + val toTaskTypes: Map = + toBundleClientTypes + + mapOf( + "StartupDetails" to mapperDecoder(StartupDetails::class.java), + ) + + // The Java bundle process can act as either Python's DagProcessor or Task runtime, so + // its inbound decoder is the union of both message sets. + val toBundleProcessTypes: Map = toDagProcessorTypes + toTaskTypes + + val toSupervisorTypes: Map = + mapOf( + "ErrorResponse" to mapperDecoder(ErrorResponse::class.java), + "GetConnection" to { body -> GetConnection(id = body.string("conn_id")) }, + "GetVariable" to { body -> GetVariable(key = body.string("key")) }, + "GetXCom" to { + GetXCom( + key = it.string("key"), + dagId = it.string("dag_id"), + taskId = it.string("task_id"), + runId = it.string("run_id"), + mapIndex = it.intOrNull("map_index"), + includePriorDates = it.boolean("include_prior_dates", default = false), + ) + }, + "SetXCom" to { + SetXCom( + key = it.string("key"), + value = it["value"] ?: error("Missing 'value'"), + dagId = it.string("dag_id"), + taskId = it.string("task_id"), + runId = it.string("run_id"), + mapIndex = it.int("map_index"), + ) + }, + "SucceedTask" to { SucceedTask() }, + "TaskState" to { body -> TaskState(state = body.string("state")) }, + ) + + fun encodeRequest( + id: Int, + body: Any, + ): ByteArray = encodeFrame(id, body, error = null, isResponse = false) + + fun encodeResponse( + id: Int, + body: Any? = null, + error: ErrorResponse? = null, + ): ByteArray = encodeFrame(id, body, error = error, isResponse = true) + + fun writeRequest( + output: OutputStream, + id: Int, + body: Any, + ) = writeFrame(output, encodeRequest(id, body)) + + fun writeResponse( + output: OutputStream, + id: Int, + body: Any? = null, + error: ErrorResponse? = null, + ) = writeFrame(output, encodeResponse(id, body, error)) + + fun decode( + bytes: ByteArray, + bodyTypes: Map, + ): IncomingFrame { + val unpacker = MessagePack.newDefaultUnpacker(bytes) + val headerSize = unpacker.unpackArrayHeader() + check(headerSize >= 2) { "Unexpected Task SDK frame arity $headerSize" } + + val id = unpacker.unpackInt() + val rawBody = unpacker.unpackAny() + val rawError = if (headerSize >= 3) unpacker.unpackAny() else null + unpacker.close() + + val body = + decodeMessage(rawError, bodyTypes = mapOf("ErrorResponse" to mapperDecoder(ErrorResponse::class.java))) + ?: decodeMessage(rawBody, bodyTypes) + + return IncomingFrame(id, body) + } + + fun readFrame( + input: InputStream, + bodyTypes: Map, + ): IncomingFrame = decode(readBytes(input, readLengthPrefix(input)), bodyTypes) + + fun lengthPrefix(length: Int) = + byteArrayOf( + (length shr 24).toByte(), + (length shr 16).toByte(), + (length shr 8).toByte(), + length.toByte(), + ) + + fun readLengthPrefix(input: InputStream): Int = parseLengthPrefix(readBytes(input, 4)) + + fun parseLengthPrefix(prefix: ByteArray): Int { + check(prefix.size == 4) { "Need 4 prefix bytes" } + return prefix.fold(0) { acc, byte -> (acc shl 8) or (byte.toInt() and 0xff) } + } + + fun readBytes( + input: InputStream, + length: Int, + ): ByteArray { + val bytes = input.readNBytes(length) + if (bytes.size != length) { + throw EOFException("Expected $length bytes but only received ${bytes.size}") + } + return bytes + } + + private fun writeFrame( + output: OutputStream, + payload: ByteArray, + ) { + output.write(lengthPrefix(payload.size)) + output.write(payload) + output.flush() + } + + private fun encodeFrame( + id: Int, + body: Any?, + error: ErrorResponse?, + isResponse: Boolean, + ): ByteArray { + val payload = ByteArrayOutputStream() + val packer = MessagePack.newDefaultPacker(payload) + packer.packArrayHeader(if (isResponse) 3 else 2) + packer.packInt(id) + packer.packAny(body?.let(::toBody)) + if (isResponse) { + packer.packAny(error?.let(::toBody)) + } + packer.close() + return payload.toByteArray() + } + + private fun decodeMessage( + raw: Any?, + bodyTypes: Map, + ): Any? { + val body = raw as? Map<*, *> ?: return raw + val typeName = body["type"] as? String ?: return body + val decoder = bodyTypes[typeName] ?: error("Unsupported Task SDK message type $typeName") + return decoder(body) + } + + private fun mapperDecoder(targetType: Class<*>): TaskSdkMessageDecoder = { body -> mapper.convertValue(body, targetType) } + + @Suppress("UNCHECKED_CAST") + private fun toBody(value: Any): Map = + when (value) { + is Map<*, *> -> value as Map + else -> + (mapper.convertValue(value, MutableMap::class.java) as MutableMap).also { body -> + inferredTypes[value::class]?.let { typeName -> body.putIfAbsent("type", typeName) } + } + } + + private fun Map<*, *>.string(key: String): String = this[key] as? String ?: error("Missing '$key'") + + private fun Map<*, *>.int(key: String): Int = intOrNull(key) ?: error("Missing integer '$key'") + + private fun Map<*, *>.intOrNull(key: String): Int? = + when (val value = this[key]) { + null -> null + is Number -> value.toInt() + else -> error("Expected integer '$key', got ${value::class.java}") + } + + private fun Map<*, *>.boolean( + key: String, + default: Boolean, + ): Boolean = + when (val value = this[key]) { + null -> default + is Boolean -> value + else -> error("Expected boolean '$key', got ${value::class.java}") + } +} diff --git a/java-sdk/sdk/src/main/resources/META-INF/services/javax.annotation.processing.Processor b/java-sdk/sdk/src/main/resources/META-INF/services/javax.annotation.processing.Processor new file mode 100644 index 0000000000000..f9d6d12ad0cd7 --- /dev/null +++ b/java-sdk/sdk/src/main/resources/META-INF/services/javax.annotation.processing.Processor @@ -0,0 +1 @@ +org.apache.airflow.sdk.BuilderProcessor diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/BuilderTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/BuilderTest.kt new file mode 100644 index 0000000000000..681ba7d2eabf4 --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/BuilderTest.kt @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import com.google.testing.compile.CompilationSubject.assertThat +import com.google.testing.compile.Compiler +import com.google.testing.compile.JavaFileObjectSubject +import com.google.testing.compile.JavaFileObjects +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test + +private fun compile(source: String) = + Compiler.javac().withProcessors(BuilderProcessor()).compile( + JavaFileObjects.forSourceString("org.apache.airflow.example.TestExample", source), + ) + +private fun JavaFileObjectSubject.hasSourceEquivalentTo( + qual: String, + source: String, +) = hasSourceEquivalentTo( + JavaFileObjects.forSourceString(qual, source), +) + +class BuilderTest { + @Test + @DisplayName("generate builder for dag class") + fun generateBuilderForDagClass() { + val compilation = + compile( + """ + package org.apache.airflow.example; + + import org.apache.airflow.sdk.Builder; + import org.apache.airflow.sdk.Client; + import org.apache.airflow.sdk.Context; + + @Builder.Dag + public class TestExample { + @Builder.Task + public void t1() {} + + @Builder.Task(depends = {"t1"}) + public int t2(Client client) { + return (Integer) client.getXCom("t0"); + } + + @Builder.Task(depends = {"t1", "t2"}) + public void t3(Context ctx, @Builder.XCom(task = "t2") int value) { + System.out.println(String.format("%s %s", ctx.ti, value)); + } + } + """, + ) + + assertThat(compilation) + .generatedSourceFile("org.apache.airflow.example.TestExampleBuilder") + .hasSourceEquivalentTo( + "org.apache.airflow.example.TestExampleBuilder", + """ + package org.apache.airflow.example; + + import java.lang.Exception; + import java.lang.Integer; + import java.lang.Override; + import org.apache.airflow.sdk.Client; + import org.apache.airflow.sdk.Context; + import org.apache.airflow.sdk.Dag; + import org.apache.airflow.sdk.Task; + + public final class TestExampleBuilder { + public static Dag build() { + var dag = new Dag("TestExample"); + dag.addTask("t1", T1.class); + dag.addTask("t2", T2.class, new String[]{"t1"}); + dag.addTask("t3", T3.class, new String[]{"t2", "t1", "t2"}); + return dag; + } + public static final class T1 implements Task { + @Override + public void execute(Context context, Client client) throws Exception { + new TestExample().t1(); + } + } + public static final class T2 implements Task { + @Override + public void execute(Context context, Client client) throws Exception { + client.setXCom(new TestExample().t2(client)); + } + } + public static final class T3 implements Task { + @Override + public void execute(Context context, Client client) throws Exception { + var value = (Integer) client.getXCom("t2"); + new TestExample().t3(context, value); + } + } + } + """, + ) + } + + @Test + @DisplayName("generate builder for dag class with custom dag id") + fun generateBuilderWithCustomDagId() { + val compilation = + compile( + """ + package org.apache.airflow.example; + import org.apache.airflow.sdk.Builder; + @Builder.Dag(id = "foo") public class TestExample {} + """, + ) + assertThat(compilation) + .generatedSourceFile("org.apache.airflow.example.TestExampleBuilder") + .hasSourceEquivalentTo( + "org.apache.airflow.example.TestExampleBuilder", + """ + package org.apache.airflow.example; + import org.apache.airflow.sdk.Dag; + public final class TestExampleBuilder { public static Dag build() { var dag = new Dag("foo"); return dag; } } + """, + ) + } + + @Test + @DisplayName("generate builder for dag class with custom class name") + fun generateBuilderWithCustomClassName() { + val compilation = + compile( + """ + package org.apache.airflow.example; + import org.apache.airflow.sdk.Builder; + @Builder.Dag(to = "Foo") public class TestExample {} + """, + ) + assertThat(compilation) + .generatedSourceFile("org.apache.airflow.example.Foo") + .hasSourceEquivalentTo( + "org.apache.airflow.example.Foo", + """ + package org.apache.airflow.example; + import org.apache.airflow.sdk.Dag; + public final class Foo { public static Dag build() { var dag = new Dag("TestExample"); return dag; } } + """, + ) + } + + @Test + @DisplayName("generate builder for dag class with custom task name") + fun generateBuilderForDagClassWithCustomTaskName() { + val compilation = + compile( + """ + package org.apache.airflow.example; + import org.apache.airflow.sdk.Builder; + @Builder.Dag + public class TestExample { @Builder.Task(id = "foo") public void t1() {} } + """, + ) + + assertThat(compilation) + .generatedSourceFile("org.apache.airflow.example.TestExampleBuilder") + .hasSourceEquivalentTo( + "org.apache.airflow.example.TestExampleBuilder", + """ + package org.apache.airflow.example; + import java.lang.Exception; + import java.lang.Override; + import org.apache.airflow.sdk.Client; + import org.apache.airflow.sdk.Context; + import org.apache.airflow.sdk.Dag; + import org.apache.airflow.sdk.Task; + public final class TestExampleBuilder { + public static Dag build() { + var dag = new Dag("TestExample"); + dag.addTask("foo", T1.class); + return dag; + } + public static final class T1 implements Task { + @Override public void execute(Context context, Client client) throws Exception { new TestExample().t1(); } + } + } + """, + ) + } + + @Test + @DisplayName("generate builder for dag class with invalid task parameter") + fun generateBuilderForDagClassWithInvalidTaskParameter() { + val compilation = + compile( + """ + package org.apache.airflow.example; + import org.apache.airflow.sdk.Builder; + @Builder.Dag + public class TestExample { @Builder.Task(id = "foo") public void t1(String client) {} } + """, + ) + assertThat(compilation).failed() + assertThat(compilation).hadErrorContaining( + "Unsupported task parameter 'client' with type: java.lang.String", + ) + } + + @Test + @DisplayName("generate builder for dag class with varargs task parameter") + fun generateBuilderForDagClassWithVarArgsTaskParameter() { + val compilation = + compile( + """ + package org.apache.airflow.example; + import org.apache.airflow.sdk.Builder; + @Builder.Dag + public class TestExample { @Builder.Task(id = "foo") public void t1(String... client) {} } + """, + ) + assertThat(compilation).failed() + assertThat(compilation).hadErrorContaining( + "Cannot create task from vararg function t1", + ) + } +} diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/BundleScannerTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/BundleScannerTest.kt new file mode 100644 index 0000000000000..1d84487e6a04d --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/BundleScannerTest.kt @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir +import java.io.File +import java.nio.file.Files +import java.nio.file.Path +import java.util.jar.JarOutputStream +import java.util.jar.Manifest +import java.util.zip.ZipEntry + +private const val STUB_MAIN_CLASS = "com.example.Main" + +class BundleScannerTest { + @Test + @DisplayName("parseDagIdsFromYaml extracts dag ids from metadata YAML") + fun parseDagIdsFromYaml() { + val yaml = + """ + dags: + java_example: + tasks: + - extract + - transform + - load + another_dag: + tasks: + - task_a + """.trimIndent() + + assertEquals(setOf("java_example", "another_dag"), parseDagIdsFromYaml(yaml)) + } + + @Test + @DisplayName("parseDagIdsFromYaml returns empty set for missing dags key") + fun parseDagIdsFromYamlEmpty() { + assertEquals(emptySet(), parseDagIdsFromYaml("other_key: value")) + } + + @Test + @DisplayName("readBundleDagIds reads metadata from JAR with Airflow-Java-SDK-Metadata manifest") + fun readBundleDagIdsFromJar( + @TempDir tempDir: Path, + ) { + createBundleJar(tempDir, mapOf("my_dag" to listOf("t1", "t2"))) + + assertEquals(setOf("my_dag"), readBundleDagIds(tempDir)) + } + + @Test + @DisplayName("readBundleDagIds returns empty set when no JAR has metadata") + fun readBundleDagIdsNoMetadata( + @TempDir tempDir: Path, + ) { + val manifest = Manifest() + manifest.mainAttributes.putValue("Manifest-Version", "1.0") + JarOutputStream(Files.newOutputStream(tempDir.resolve("plain.jar")), manifest).use {} + + assertEquals(emptySet(), readBundleDagIds(tempDir)) + } + + @Test + @DisplayName("scanBundles discovers bundles in subdirectories") + fun scanBundlesNestedLayout( + @TempDir tempDir: Path, + ) { + val bundleA = Files.createDirectory(tempDir.resolve("bundle-a")) + createBundleJar(bundleA, mapOf("dag_a" to listOf("t1"))) + + val bundleB = Files.createDirectory(tempDir.resolve("bundle-b")) + createBundleJar(bundleB, mapOf("dag_b" to listOf("t2"), "dag_c" to listOf("t3"))) + + val result = scanBundles(tempDir) + + assertEquals(STUB_MAIN_CLASS, result["dag_a"]?.mainClass) + assertEquals(STUB_MAIN_CLASS, result["dag_b"]?.mainClass) + assertEquals(STUB_MAIN_CLASS, result["dag_c"]?.mainClass) + // dag_a classpath should point to bundle-a JARs, not bundle-b + assertTrue(result["dag_a"]!!.classpath.contains("bundle-a")) + assertTrue(result["dag_b"]!!.classpath.contains("bundle-b")) + } + + @Test + @DisplayName("scanBundles supports flat layout where bundlesDir itself contains JARs") + fun scanBundlesFlatLayout( + @TempDir tempDir: Path, + ) { + createBundleJar(tempDir, mapOf("flat_dag" to listOf("t1"))) + + val result = scanBundles(tempDir) + + assertNotNull(result["flat_dag"]) + assertEquals(STUB_MAIN_CLASS, result["flat_dag"]!!.mainClass) + } + + @Test + @DisplayName("scanBundles finds metadata JAR among many dependency JARs") + fun scanBundlesFlatWithDependencyJars( + @TempDir tempDir: Path, + ) { + // Simulate installDist layout: one bundle JAR with metadata among plain dependency JARs. + val plainManifest = Manifest() + plainManifest.mainAttributes.putValue("Manifest-Version", "1.0") + JarOutputStream(Files.newOutputStream(tempDir.resolve("aaa-dep.jar")), plainManifest).use {} + JarOutputStream(Files.newOutputStream(tempDir.resolve("zzz-dep.jar")), plainManifest).use {} + + // A JAR with no manifest at all. + JarOutputStream(Files.newOutputStream(tempDir.resolve("no-manifest.jar"))).use {} + + createBundleJar(tempDir, mapOf("my_dag" to listOf("t1"))) + + val result = scanBundles(tempDir) + + assertNotNull(result["my_dag"]) + assertEquals(STUB_MAIN_CLASS, result["my_dag"]!!.mainClass) + // All 4 JARs should be on the classpath. + val cpEntries = result["my_dag"]!!.classpath.split(File.pathSeparator) + assertEquals(4, cpEntries.size) + } + + @Test + @DisplayName("scanBundles resolves distZip layout where bundlesDir is the lib directory") + fun scanBundlesDistZipLibDir( + @TempDir tempDir: Path, + ) { + // Simulate: unzip example.zip → example/lib/*.jar, BUNDLES_DIR=.../example/lib + val libDir = Files.createDirectories(tempDir.resolve("example").resolve("lib")) + + // 30 plain dependency JARs + val plainManifest = Manifest() + plainManifest.mainAttributes.putValue("Manifest-Version", "1.0") + for (name in listOf( + "annotations-23.0.0", + "converter-jackson-3.0.0", + "jackson-core-2.21.1", + "jackson-databind-2.21.1", + "kotlin-stdlib-2.3.0", + "kotlinx-coroutines-core-jvm-1.10.2", + "msgpack-core-0.9.11", + "okhttp-4.12.0", + "retrofit-3.0.0", + "sdk", + )) { + JarOutputStream(Files.newOutputStream(libDir.resolve("$name.jar")), plainManifest).use {} + } + + // The bundle JAR with metadata, named "example.jar" (alphabetically after some deps) + createBundleJar(libDir, mapOf("java_example" to listOf("extract", "transform", "load")), "example.jar") + + // bundlesDir points directly at lib/ + val result = scanBundles(libDir) + + assertNotNull(result["java_example"], "java_example should be discovered in flat lib/ layout") + assertEquals(STUB_MAIN_CLASS, result["java_example"]!!.mainClass) + assertEquals(11, result["java_example"]!!.classpath.split(File.pathSeparator).size) + } + + @Test + @DisplayName("scanBundles returns empty map for nonexistent directory") + fun scanBundlesNonexistentDir() { + assertEquals(emptyMap(), scanBundles(Path.of("/nonexistent/dir"))) + } + + private fun createBundleJar( + dir: Path, + dags: Map>, + fileName: String = "bundle.jar", + ): Path { + val manifest = Manifest() + manifest.mainAttributes.putValue("Manifest-Version", "1.0") + manifest.mainAttributes.putValue("Main-Class", STUB_MAIN_CLASS) + manifest.mainAttributes.putValue(METADATA_MANIFEST_KEY, "airflow-metadata.yaml") + + val jarPath = dir.resolve(fileName) + JarOutputStream(Files.newOutputStream(jarPath), manifest).use { jos -> + jos.putNextEntry(ZipEntry("airflow-metadata.yaml")) + val yaml = + buildString { + appendLine("dags:") + for ((dagId, tasks) in dags) { + appendLine(" $dagId:") + appendLine(" tasks:") + for (task in tasks) { + appendLine(" - $task") + } + } + } + jos.write(yaml.toByteArray()) + jos.closeEntry() + } + return jarPath + } +} diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/BundleTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/BundleTest.kt new file mode 100644 index 0000000000000..48bae797cbe20 --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/BundleTest.kt @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test + +internal class BundleTest { + @Test + @DisplayName("Should index dags by dagId") + fun shouldIndexDagsByDagId() { + val dag = Dag("dag") + + val bundle = Bundle("0", listOf(dag)) + + Assertions.assertEquals(mapOf("dag" to dag), bundle.dags) + } + + @Test + @DisplayName("Should reject duplicate dag ids") + fun shouldRejectDuplicateDagIds() { + val error = + Assertions.assertThrows(IllegalArgumentException::class.java) { + Bundle("0", listOf(Dag("dag"), Dag("dag"))) + } + + Assertions.assertEquals("Dags in bundle have duplicate ID: dag", error.message) + } +} diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/ConfigTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/ConfigTest.kt new file mode 100644 index 0000000000000..5cd9c33d3835d --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/ConfigTest.kt @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.Assertions.assertThrows +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir +import java.nio.file.Files +import java.nio.file.Path + +class ConfigTest { + // -- SdkConfig: env var resolution -- + + @Test + @DisplayName("executionApiUrl uses AIRFLOW__CORE__EXECUTION_API_SERVER_URL env var") + fun executionApiUrlFromEnv() { + val config = SdkConfig(env = mapOf("AIRFLOW__CORE__EXECUTION_API_SERVER_URL" to "http://127.0.0.1:8080/execution")) + + assertEquals("http://127.0.0.1:8080/execution/", config.executionApiUrl) + } + + @Test + @DisplayName("executionApiUrl throws when missing") + fun executionApiUrlThrowsWhenMissing() { + val config = SdkConfig(env = emptyMap()) + + val error = assertThrows(WorkerError::class.java) { config.executionApiUrl } + assertTrue(error.message!!.contains("execution_api_server_url")) + } + + @Test + @DisplayName("executionApiUrl falls back to execution.api_url") + fun executionApiUrlFallback() { + val config = SdkConfig(env = mapOf("AIRFLOW__EXECUTION__API_URL" to "http://127.0.0.1:8080/execution")) + + assertEquals("http://127.0.0.1:8080/execution/", config.executionApiUrl) + } + + @Test + @DisplayName("jwtExpirationTime defaults to 30 seconds") + fun jwtExpirationTimeDefault() { + val config = SdkConfig(env = emptyMap()) + + assertEquals(30, config.jwtExpirationTime) + } + + // -- SdkConfig: YAML resolution -- + + @Test + @DisplayName("config values are loaded from YAML file") + fun yamlConfigLoading( + @TempDir tempDir: Path, + ) { + val yamlContent = + """ + core: + execution_api_server_url: "http://yaml-host:8080/execution/" + + sdk: + bundles_dir: "./bundles" + + api_auth: + jwt_secret: "yaml-secret" + jwt_issuer: "yaml-issuer" + jwt_expiration_time: 45 + """.trimIndent() + + val yamlPath = tempDir.resolve("java-sdk.yaml") + Files.writeString(yamlPath, yamlContent) + + val config = SdkConfig(env = emptyMap(), yamlOverride = yamlPath) + + assertEquals("http://yaml-host:8080/execution/", config.executionApiUrl) + assertEquals("yaml-secret", config.jwtSecret) + assertEquals("yaml-issuer", config.jwtIssuer) + assertEquals(45, config.jwtExpirationTime) + assertEquals(Path.of("./bundles"), config.bundlesDir) + } + + @Test + @DisplayName("env vars take precedence over YAML values") + fun envTakesPrecedenceOverYaml( + @TempDir tempDir: Path, + ) { + val yamlContent = + """ + core: + execution_api_server_url: "http://yaml-host:8080/execution/" + api_auth: + jwt_secret: "yaml-secret" + """.trimIndent() + + val yamlPath = tempDir.resolve("java-sdk.yaml") + Files.writeString(yamlPath, yamlContent) + + val config = + SdkConfig( + env = + mapOf( + "AIRFLOW__CORE__EXECUTION_API_SERVER_URL" to "http://env-host:9090/execution/", + "AIRFLOW__API_AUTH__JWT_SECRET" to "env-secret", + ), + yamlOverride = yamlPath, + ) + + assertEquals("http://env-host:9090/execution/", config.executionApiUrl) + assertEquals("env-secret", config.jwtSecret) + } + + @Test + @DisplayName("config works with no YAML file and no env vars for optional values") + fun noYamlFile() { + val config = SdkConfig(env = emptyMap()) + + assertEquals(30, config.jwtExpirationTime) + assertNull(config.bundlesDir) + } + + @Test + @DisplayName("YAML file is resolved from AIRFLOW_HOME") + fun yamlFromAirflowHome( + @TempDir tempDir: Path, + ) { + val yamlContent = + """ + api_auth: + jwt_secret: "home-secret" + """.trimIndent() + + Files.writeString(tempDir.resolve("java-sdk.yaml"), yamlContent) + + val config = SdkConfig(env = mapOf("AIRFLOW_HOME" to tempDir.toString())) + + assertEquals("home-secret", config.jwtSecret) + } +} diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/CoordinatorCommTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/CoordinatorCommTest.kt new file mode 100644 index 0000000000000..455de10cc68e0 --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/CoordinatorCommTest.kt @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk + +import io.ktor.utils.io.ByteChannel +import io.ktor.utils.io.availableForRead +import io.ktor.utils.io.readAvailable +import kotlinx.coroutines.runBlocking +import org.apache.airflow.sdk.execution.CoordinatorComm +import org.apache.airflow.sdk.execution.DagFileParseRequest +import org.apache.airflow.sdk.execution.IncomingFrame +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test +import kotlin.text.split + +fun byteArrayFromHexString(hexString: String): ByteArray = + hexString + .split(' ', '\r', '\n') + .filter { it.isNotEmpty() } + .map { it.toUByte(16).toByte() } + .toByteArray() + +@OptIn(ExperimentalUnsignedTypes::class) +internal class CoordinatorCommTest { + lateinit var comm: CoordinatorComm + lateinit var reader: ByteChannel + lateinit var writer: ByteChannel + + @BeforeEach + fun setUp() { + reader = ByteChannel(autoFlush = true) + writer = ByteChannel(autoFlush = true) + comm = CoordinatorComm(Bundle("0", listOf(Dag("dag"))), reader, writer) + } + + @Test + @DisplayName("handleIncoming should produce parse result") + fun handleIncomingShouldProduceParseResult() { + val frame = IncomingFrame(0, DagFileParseRequest().apply { file = ":memory:" }) + + // prefix + DagFileParsingResult payload for a minimal DAG. + + /* prefix + + [ + 0, + { + "type": "DagFileParsingResult", + "fileloc": ":memory:", + "serialized_dags": [ + { + "data": { + "__version": 3, + "dag": { + "dag_id": "dag", + "fileloc": ":memory:", + "relative_fileloc": ".", + "timezone": "UTC", + "timetable": { + "__type": "airflow.timetables.simple.NullTimetable", + "__var": {} + }, + "tasks": [] + } + } + } + ] + } + ] + */ + val expected = + byteArrayFromHexString( + """ + | 00 00 01 e6 + | 92 00 83 a4 74 79 70 65 b4 44 61 67 46 69 6c 65 50 61 72 73 69 6e 67 52 + | 65 73 75 6c 74 a7 66 69 6c 65 6c 6f 63 a8 3a 6d 65 6d 6f 72 79 3a af 73 + | 65 72 69 61 6c 69 7a 65 64 5f 64 61 67 73 91 81 a4 64 61 74 61 82 a9 5f + | 5f 76 65 72 73 69 6f 6e 03 a3 64 61 67 8c a6 64 61 67 5f 69 64 a3 64 61 + | 67 a7 66 69 6c 65 6c 6f 63 a8 3a 6d 65 6d 6f 72 79 3a b0 72 65 6c 61 74 + | 69 76 65 5f 66 69 6c 65 6c 6f 63 a1 2e a8 74 69 6d 65 7a 6f 6e 65 a3 55 + | 54 43 a9 74 69 6d 65 74 61 62 6c 65 82 a6 5f 5f 74 79 70 65 d9 27 61 69 + | 72 66 6c 6f 77 2e 74 69 6d 65 74 61 62 6c 65 73 2e 73 69 6d 70 6c 65 2e + | 4e 75 6c 6c 54 69 6d 65 74 61 62 6c 65 a5 5f 5f 76 61 72 80 a5 74 61 73 + | 6b 73 90 b0 64 61 67 5f 64 65 70 65 6e 64 65 6e 63 69 65 73 90 aa 74 61 + | 73 6b 5f 67 72 6f 75 70 8b a9 5f 67 72 6f 75 70 5f 69 64 c0 b2 67 72 6f + | 75 70 5f 64 69 73 70 6c 61 79 5f 6e 61 6d 65 a0 af 70 72 65 66 69 78 5f + | 67 72 6f 75 70 5f 69 64 c3 a7 74 6f 6f 6c 74 69 70 a0 a8 75 69 5f 63 6f + | 6c 6f 72 ae 43 6f 72 6e 66 6c 6f 77 65 72 42 6c 75 65 aa 75 69 5f 66 67 + | 63 6f 6c 6f 72 a4 23 30 30 30 a8 63 68 69 6c 64 72 65 6e 80 b2 75 70 73 + | 74 72 65 61 6d 5f 67 72 6f 75 70 5f 69 64 73 90 b4 64 6f 77 6e 73 74 72 + | 65 61 6d 5f 67 72 6f 75 70 5f 69 64 73 90 b1 75 70 73 74 72 65 61 6d 5f + | 74 61 73 6b 5f 69 64 73 90 b3 64 6f 77 6e 73 74 72 65 61 6d 5f 74 61 73 + | 6b 5f 69 64 73 90 a9 65 64 67 65 5f 69 6e 66 6f 80 a6 70 61 72 61 6d 73 + | 90 a8 64 65 61 64 6c 69 6e 65 c0 b1 61 6c 6c 6f 77 65 64 5f 72 75 6e 5f + | 74 79 70 65 73 c0 + """.trimMargin(), + ) + + val buffer = ByteArray(1024) { 0 } // Change ByteArray size if assertTrue below fails. + var count = 0 + runBlocking { + comm.handleIncoming(frame) + if (writer.availableForRead > 0) { + count = writer.readAvailable(buffer) + } + } + Assertions.assertTrue(count < buffer.size, "Please increase buffer size above") + + Assertions.assertEquals(expected.size, count) + + val received = buffer.sliceArray(0.until(count)) + Assertions.assertArrayEquals(expected, received) + } +} diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommsTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommsTest.kt new file mode 100644 index 0000000000000..3b91662751a5e --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/CommsTest.kt @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import org.apache.airflow.sdk.byteArrayFromHexString +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test +import java.time.OffsetDateTime +import java.time.ZoneOffset + +class CommsTest { + @Test + @DisplayName("Should decode startup details") + fun shouldDecodeStartupDetails() { + // [2, msg, null] with msg coming from + // https://github.com/astronomer/airflow/blob/f39c8da8/task-sdk/tests/task_sdk/execution_time/test_comms.py#L73-L108 + val data = + """ + 92 02 88 a4 74 79 70 65 ae 53 74 61 72 74 75 70 44 65 74 61 69 6c 73 a2 74 69 86 a2 69 64 d9 24 + 34 64 38 32 38 61 36 32 2d 61 34 31 37 2d 34 39 33 36 2d 61 37 61 36 2d 32 62 33 66 61 62 61 63 + 65 63 61 62 a7 74 61 73 6b 5f 69 64 a1 61 aa 74 72 79 5f 6e 75 6d 62 65 72 01 a6 72 75 6e 5f 69 + 64 a1 62 a6 64 61 67 5f 69 64 a1 63 ae 64 61 67 5f 76 65 72 73 69 6f 6e 5f 69 64 d9 24 34 64 38 + 32 38 61 36 32 2d 61 34 31 37 2d 34 39 33 36 2d 61 37 61 36 2d 32 62 33 66 61 62 61 63 65 63 61 + 62 aa 74 69 5f 63 6f 6e 74 65 78 74 85 a7 64 61 67 5f 72 75 6e 8c a6 64 61 67 5f 69 64 a1 63 a6 + 72 75 6e 5f 69 64 a1 62 ac 6c 6f 67 69 63 61 6c 5f 64 61 74 65 b4 32 30 32 34 2d 31 32 2d 30 31 + 54 30 31 3a 30 30 3a 30 30 5a b3 64 61 74 61 5f 69 6e 74 65 72 76 61 6c 5f 73 74 61 72 74 b4 32 + 30 32 34 2d 31 32 2d 30 31 54 30 30 3a 30 30 3a 30 30 5a b1 64 61 74 61 5f 69 6e 74 65 72 76 61 + 6c 5f 65 6e 64 b4 32 30 32 34 2d 31 32 2d 30 31 54 30 31 3a 30 30 3a 30 30 5a aa 73 74 61 72 74 + 5f 64 61 74 65 b4 32 30 32 34 2d 31 32 2d 30 31 54 30 31 3a 30 30 3a 30 30 5a a9 72 75 6e 5f 61 + 66 74 65 72 b4 32 30 32 34 2d 31 32 2d 30 31 54 30 31 3a 30 30 3a 30 30 5a a8 65 6e 64 5f 64 61 + 74 65 c0 a8 72 75 6e 5f 74 79 70 65 a6 6d 61 6e 75 61 6c a5 73 74 61 74 65 a7 73 75 63 63 65 73 + 73 a4 63 6f 6e 66 c0 b5 63 6f 6e 73 75 6d 65 64 5f 61 73 73 65 74 5f 65 76 65 6e 74 73 90 a9 6d + 61 78 5f 74 72 69 65 73 00 ac 73 68 6f 75 6c 64 5f 72 65 74 72 79 c2 a9 76 61 72 69 61 62 6c 65 + 73 c0 ab 63 6f 6e 6e 65 63 74 69 6f 6e 73 c0 a4 66 69 6c 65 a9 2f 64 65 76 2f 6e 75 6c 6c aa 73 + 74 61 72 74 5f 64 61 74 65 b4 32 30 32 34 2d 31 32 2d 30 31 54 30 31 3a 30 30 3a 30 30 5a ac 64 + 61 67 5f 72 65 6c 5f 70 61 74 68 a9 2f 64 65 76 2f 6e 75 6c 6c ab 62 75 6e 64 6c 65 5f 69 6e 66 + 6f 82 a4 6e 61 6d 65 a8 61 6e 79 2d 6e 61 6d 65 a7 76 65 72 73 69 6f 6e ab 61 6e 79 2d 76 65 72 + 73 69 6f 6e b2 73 65 6e 74 72 79 5f 69 6e 74 65 67 72 61 74 69 6f 6e a0 c0 + """.trimIndent() + val result = CoordinatorComm.decode(byteArrayFromHexString(data)) + Assertions.assertInstanceOf(IncomingFrame::class.java, result) + Assertions.assertInstanceOf(StartupDetails::class.java, result.body) + } + + @Test + @DisplayName("Should serialize all fields") + fun shouldEncodeSucceedTask() { + val endDate = OffsetDateTime.of(2024, 12, 1, 1, 0, 0, 0, ZoneOffset.UTC) + val bytes = CoordinatorComm.encode(OutgoingFrame(3, SucceedTask(endDate))) + val actual = bytes.toHexString(HexFormat { bytes { byteSeparator = " " } }) + + val expected = + """ + 92 03 86 a5 73 74 61 74 65 a7 73 75 63 63 65 73 73 a8 65 6e 64 5f 64 61 74 65 b4 32 30 32 34 2d + 31 32 2d 30 31 54 30 31 3a 30 30 3a 30 30 5a ac 74 61 73 6b 5f 6f 75 74 6c 65 74 73 90 ad 6f 75 + 74 6c 65 74 5f 65 76 65 6e 74 73 90 b2 72 65 6e 64 65 72 65 64 5f 6d 61 70 5f 69 6e 64 65 78 c0 + a4 74 79 70 65 ab 53 75 63 63 65 65 64 54 61 73 6b + """.trimIndent().replace('\n', ' ') + + Assertions.assertEquals(expected, actual) + } + + @Test + @DisplayName("Should decode requests to the supervisor") + fun shouldDecodeSupervisorRequest() { + val result = TaskSdkFrames.decode(TaskSdkFrames.encodeRequest(5, GetVariable("demo")), TaskSdkFrames.toSupervisorTypes) + + Assertions.assertEquals(5, result.id) + Assertions.assertEquals(GetVariable("demo"), result.body) + } + + @Test + @DisplayName("Should decode protocol errors from the response error slot") + fun shouldDecodeErrorResponseFromErrorSlot() { + val error = + ErrorResponse().also { + it.error = "generic_error" + it.detail = mapOf("message" to "boom") + } + + val result = CoordinatorComm.decode(TaskSdkFrames.encodeResponse(7, error = error)) + + Assertions.assertEquals(7, result.id) + Assertions.assertInstanceOf(ErrorResponse::class.java, result.body) + Assertions.assertEquals("generic_error", (result.body as ErrorResponse).error) + } +} diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/DagParserTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/DagParserTest.kt new file mode 100644 index 0000000000000..5e2edb549c45c --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/DagParserTest.kt @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import org.apache.airflow.sdk.Bundle +import org.apache.airflow.sdk.Dag +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test + +internal class DagParserTest { + lateinit var parser: DagParser + + @BeforeEach + fun setUp() { + parser = DagParser(":memory:", "") + } + + @Test + @DisplayName("Should produce serialized dag") + fun shouldProduceSerializedDag() { + val bundle = Bundle("0", listOf(Dag("dag"))) + val result = parser.parse(bundle) + Assertions.assertEquals( + DagParsingResult(":memory:", "", bundle.dags), + result, + ) + } +} diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/SerializationCompatibilityTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/SerializationCompatibilityTest.kt new file mode 100644 index 0000000000000..74307445ca6ee --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/SerializationCompatibilityTest.kt @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.SerializationFeature +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory +import org.apache.airflow.sdk.Dag +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.DynamicTest +import org.junit.jupiter.api.TestFactory +import java.io.File + +/** + * Reads test_dags.yaml, constructs Dags from the parameters, serializes each + * one with the Java SDK, and writes the result to serialized_java.json for + * cross-language comparison with the Python output. + * + * Each YAML test-case is turned into a JUnit 5 dynamic test, so failures are + * reported individually. + * + * After running: + * python validation/serialization/compare.py \ + * validation/serialization/serialized_python.json \ + * validation/serialization/serialized_java.json + */ +class SerializationCompatibilityTest { + companion object { + private val yamlMapper = ObjectMapper(YAMLFactory()) + private val jsonMapper = + ObjectMapper().apply { + enable(SerializationFeature.INDENT_OUTPUT) + configure(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS, true) + } + + /** Resolve a project-relative path that works from any Gradle working dir. */ + private fun projectFile(relative: String): File { + // Gradle may run from the repo root or from sdk/ + var dir = File(System.getProperty("user.dir")) + while (dir.parentFile != null) { + val candidate = File(dir, relative) + if (candidate.exists()) return candidate + dir = dir.parentFile + } + // Fallback: try relative to cwd + return File(relative) + } + } + + // ----------------------------------------------------------------------- + // YAML → Dag construction + // ----------------------------------------------------------------------- + + @Suppress("UNCHECKED_CAST") + private fun constructDag(params: Map): Dag = Dag(params["dag_id"] as String) + + // ----------------------------------------------------------------------- + // Dynamic test generation + // ----------------------------------------------------------------------- + + @Suppress("UNCHECKED_CAST") + @TestFactory + fun `serialise all YAML test cases`(): List { + val yamlFile = projectFile("validation/serialization/test_dags.yaml") + if (!yamlFile.exists()) { + return listOf( + DynamicTest.dynamicTest("test_dags.yaml not found — skipping") { + println("WARNING: ${yamlFile.absolutePath} not found, skipping serialisation tests") + }, + ) + } + + val root = yamlMapper.readValue(yamlFile, Map::class.java) as Map + val testCases = root["test_cases"] as List> + + // Accumulate results for JSON output + val allResults = mutableMapOf() + + val tests = + testCases.map { case -> + val name = case["name"] as String + val params = case["params"] as Map + + DynamicTest.dynamicTest(name) { + val dag = constructDag(params) + val serialized = serializeDag(dag) + + // Basic assertions + assertNotNull(serialized["dag_id"], "dag_id must be present") + assertNotNull(serialized["timetable"], "timetable must be present") + assertNotNull(serialized["tasks"], "tasks must be present") + assertFalse( + serialized.containsKey("__error"), + "serialisation must not produce an error entry", + ) + + allResults[name] = serialized + } + } + + // After all dynamic tests, write the combined JSON. + // We add a final "meta" test that writes the file. + val writeTest = + DynamicTest.dynamicTest("_write_serialized_java_json") { + val outputDir = projectFile("validation/serialization") + outputDir.mkdirs() + val outputFile = File(outputDir, "serialized_java.json") + jsonMapper.writeValue(outputFile, allResults.toSortedMap()) + println("Wrote ${allResults.size} serialised DAGs -> ${outputFile.absolutePath}") + } + + return tests + writeTest + } +} diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/SupervisorTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/SupervisorTest.kt new file mode 100644 index 0000000000000..6d0551b0e0b83 --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/SupervisorTest.kt @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.airflow.sdk.execution + +import kotlinx.coroutines.runBlocking +import okhttp3.mockwebserver.Dispatcher +import okhttp3.mockwebserver.MockResponse +import okhttp3.mockwebserver.MockWebServer +import okhttp3.mockwebserver.RecordedRequest +import org.apache.airflow.sdk.execution.api.model.TaskInstanceState +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test +import java.io.ByteArrayInputStream +import java.util.UUID +import java.util.concurrent.CopyOnWriteArrayList + +class SupervisorTest { + // streamLines() tests + + @Test + @DisplayName("streamLines: empty stream produces no callbacks") + fun streamLinesEmptyStream() = + runBlocking { + val lines = mutableListOf() + Supervisor.streamLines(ByteArrayInputStream(ByteArray(0))) { lines.add(it) } + assertTrue(lines.isEmpty()) + } + + @Test + @DisplayName("streamLines: single line") + fun streamLinesSingleLine() = + runBlocking { + val lines = mutableListOf() + Supervisor.streamLines(ByteArrayInputStream("hello\n".toByteArray())) { lines.add(it) } + assertEquals(listOf("hello"), lines) + } + + @Test + @DisplayName("streamLines: multiple lines") + fun streamLinesMultipleLines() = + runBlocking { + val input = "line1\nline2\nline3\n".toByteArray() + val lines = mutableListOf() + Supervisor.streamLines(ByteArrayInputStream(input)) { lines.add(it) } + assertEquals(listOf("line1", "line2", "line3"), lines) + } + + @Test + @DisplayName("streamLines: preserves blank lines between content") + fun streamLinesWithBlankLines() = + runBlocking { + val input = "first\n\nsecond\n".toByteArray() + val lines = mutableListOf() + Supervisor.streamLines(ByteArrayInputStream(input)) { lines.add(it) } + assertEquals(listOf("first", "", "second"), lines) + } + + @Test + @DisplayName("streamLines: handles line without trailing newline") + fun streamLinesNoTrailingNewline() = + runBlocking { + val lines = mutableListOf() + Supervisor.streamLines(ByteArrayInputStream("no-newline".toByteArray())) { lines.add(it) } + assertEquals(listOf("no-newline"), lines) + } + + @Test + @DisplayName("streamLines: handles large number of lines") + fun streamLinesManyLines() = + runBlocking { + val count = 10_000 + val input = (1..count).joinToString("\n") { "line-$it" }.toByteArray() + val lines = CopyOnWriteArrayList() + Supervisor.streamLines(ByteArrayInputStream(input)) { lines.add(it) } + assertEquals(count, lines.size) + assertEquals("line-1", lines.first()) + assertEquals("line-$count", lines.last()) + } + + // Data class tests + + @Test + @DisplayName("SupervisorTaskInstance: all fields populated") + fun supervisorTaskInstanceAllFields() { + val id = UUID.randomUUID() + val dagVersionId = UUID.randomUUID() + val carrier = mapOf("trace" to "abc") + val ti = + SupervisorTaskInstance( + id = id, + taskId = "my_task", + dagId = "my_dag", + runId = "run_1", + tryNumber = 2, + dagVersionId = dagVersionId, + mapIndex = 5, + contextCarrier = carrier, + ) + assertEquals(id, ti.id) + assertEquals("my_task", ti.taskId) + assertEquals("my_dag", ti.dagId) + assertEquals("run_1", ti.runId) + assertEquals(2, ti.tryNumber) + assertEquals(dagVersionId, ti.dagVersionId) + assertEquals(5, ti.mapIndex) + assertEquals(carrier, ti.contextCarrier) + } + + @Test + @DisplayName("SupervisorTaskInstance: null optional fields") + fun supervisorTaskInstanceNullOptionals() { + val ti = + SupervisorTaskInstance( + id = UUID.randomUUID(), + taskId = "t", + dagId = "d", + runId = "r", + tryNumber = 1, + dagVersionId = UUID.randomUUID(), + mapIndex = null, + ) + assertEquals(null, ti.mapIndex) + assertEquals(null, ti.contextCarrier) + } + + @Test + @DisplayName("SupervisorTaskInstance: data class equality") + fun supervisorTaskInstanceEquality() { + val id = UUID.randomUUID() + val dvId = UUID.randomUUID() + val a = SupervisorTaskInstance(id, "t", "d", "r", 1, dvId, null) + val b = SupervisorTaskInstance(id, "t", "d", "r", 1, dvId, null) + assertEquals(a, b) + assertEquals(a.hashCode(), b.hashCode()) + } + + @Test + @DisplayName("SupervisorBundleInfo: with and without version") + fun supervisorBundleInfo() { + val withVersion = SupervisorBundleInfo("my-bundle", "v2") + assertEquals("my-bundle", withVersion.name) + assertEquals("v2", withVersion.version) + + val withoutVersion = SupervisorBundleInfo("my-bundle", null) + assertEquals(null, withoutVersion.version) + } + + @Test + @DisplayName("SupervisorResult: success and failure states") + fun supervisorResult() { + val success = SupervisorResult(TaskInstanceState.SUCCESS, 0) + assertEquals(TaskInstanceState.SUCCESS, success.finalState) + assertEquals(0, success.exitCode) + + val failure = SupervisorResult(TaskInstanceState.FAILED, 1) + assertEquals(TaskInstanceState.FAILED, failure.finalState) + assertEquals(1, failure.exitCode) + } + + @Test + @DisplayName("SupervisorRequest: default values") + fun supervisorRequestDefaults() { + val request = + SupervisorRequest( + mainClass = "com.example.Main", + classpath = "/app/lib/*", + executionApiBaseUrl = "http://localhost:8080/execution/", + token = "test-token", + workerName = "worker-1", + userName = "airflow", + dagRelPath = "dags/my_dag.jar", + bundleInfo = SupervisorBundleInfo("bundle", "1"), + taskInstance = + SupervisorTaskInstance( + UUID.randomUUID(), + "task", + "dag", + "run", + 1, + UUID.randomUUID(), + null, + ), + ) + assertEquals("", request.sentryIntegration) + } + + // Integration tests: Supervisor.run() with real subprocess + MockWebServer + + private lateinit var mockServer: MockWebServer + + @BeforeEach + fun setUp() { + mockServer = MockWebServer() + } + + @AfterEach + fun tearDown() { + mockServer.shutdown() + } + + /** Minimal valid JSON for TIRunContext that Jackson can deserialize with unknown-props disabled. */ + private val tiRunContextJson = + """ + { + "dag_run": { + "dag_id": "test_dag", + "run_id": "run_1", + "logical_date": "2026-01-01T00:00:00Z", + "data_interval_start": "2026-01-01T00:00:00Z", + "data_interval_end": "2026-01-01T01:00:00Z", + "start_date": "2026-01-01T00:00:00Z", + "run_after": "2026-01-01T00:00:00Z", + "run_type": "manual" + }, + "max_tries": 0, + "should_retry": false + } + """.trimIndent() + + private fun request(mainClass: String = TestSucceedSubprocess::class.java.name): SupervisorRequest { + val classpath = System.getProperty("java.class.path") + return SupervisorRequest( + mainClass = mainClass, + classpath = classpath, + executionApiBaseUrl = mockServer.url("/execution/").toString(), + token = "test-jwt-token", + workerName = "test-worker", + userName = "testuser", + dagRelPath = "dags/test.jar", + bundleInfo = SupervisorBundleInfo("test-bundle", "1"), + taskInstance = + SupervisorTaskInstance( + id = UUID.randomUUID(), + taskId = "my_task", + dagId = "test_dag", + runId = "run_1", + tryNumber = 1, + dagVersionId = UUID.randomUUID(), + mapIndex = null, + ), + sentryIntegration = "", + onLogLine = {}, + ) + } + + /** + * A dispatcher that returns a TIRunContext for the /run endpoint and 200 OK for state updates. + * Also handles variable/connection/xcom API calls for the more complex test scenarios. + */ + private fun apiDispatcher(): Dispatcher = + object : Dispatcher() { + override fun dispatch(request: RecordedRequest): MockResponse { + val path = request.path ?: return MockResponse().setResponseCode(404) + return when { + // tiRun: PATCH .../task-instances/{id}/run + path.contains("/run") && request.method == "PATCH" -> + MockResponse() + .setResponseCode(200) + .setHeader("Content-Type", "application/json") + .setBody(tiRunContextJson) + + // succeed/finish: PATCH .../task-instances/{id}/state + path.contains("/state") && request.method == "PATCH" -> + MockResponse().setResponseCode(200) + + // getVariable: GET .../variables/{key} + path.contains("/variables/") && request.method == "GET" -> + MockResponse() + .setResponseCode(200) + .setHeader("Content-Type", "application/json") + .setBody("""{"key": "test_var", "value": "hello"}""") + + // getConnection: GET .../connections/{id} + path.contains("/connections/") && request.method == "GET" -> + MockResponse() + .setResponseCode(200) + .setHeader("Content-Type", "application/json") + .setBody("""{"conn_id": "test_conn", "conn_type": "http"}""") + + // setXcom: POST .../xcoms/... + path.contains("/xcoms/") && request.method == "POST" -> + MockResponse() + .setResponseCode(200) + .setHeader("Content-Type", "application/json") + .setBody("{}") + + else -> + MockResponse().setResponseCode(404).setBody("Not found: $path") + } + } + } + + @Test + @DisplayName("run: successful task execution returns SUCCESS with exit code 0") + fun runSuccessfulTask() = + runBlocking { + mockServer.dispatcher = apiDispatcher() + mockServer.start() + + val result = Supervisor.run(request()) + + assertEquals(TaskInstanceState.SUCCESS, result.finalState) + assertEquals(0, result.exitCode) + } + + @Test + @DisplayName("run: task reporting failed state returns FAILED") + fun runFailedTask() = + runBlocking { + mockServer.dispatcher = apiDispatcher() + mockServer.start() + + val result = Supervisor.run(request(mainClass = TestFailSubprocess::class.java.name)) + + assertEquals(TaskInstanceState.FAILED, result.finalState) + assertEquals(0, result.exitCode) // process exits cleanly, but reports failed state + } + + @Test + @DisplayName("run: task requesting a variable before succeeding") + fun runTaskWithGetVariable() = + runBlocking { + mockServer.dispatcher = apiDispatcher() + mockServer.start() + + val result = Supervisor.run(request(mainClass = TestGetVariableSubprocess::class.java.name)) + + assertEquals(TaskInstanceState.SUCCESS, result.finalState) + assertEquals(0, result.exitCode) + + // Verify the variable request was made to the mock server. + val requests = (1..mockServer.requestCount).map { mockServer.takeRequest() } + assertTrue(requests.any { it.path?.contains("/variables/") == true }) + } + + @Test + @DisplayName("run: task requesting a connection before succeeding") + fun runTaskWithGetConnection() = + runBlocking { + mockServer.dispatcher = apiDispatcher() + mockServer.start() + + val result = Supervisor.run(request(mainClass = TestGetConnectionSubprocess::class.java.name)) + + assertEquals(TaskInstanceState.SUCCESS, result.finalState) + assertEquals(0, result.exitCode) + + val requests = (1..mockServer.requestCount).map { mockServer.takeRequest() } + assertTrue(requests.any { it.path?.contains("/connections/") == true }) + } + + @Test + @DisplayName("run: reports task as running to execution API with correct payload") + fun runReportsRunningState() = + runBlocking { + mockServer.dispatcher = apiDispatcher() + mockServer.start() + + Supervisor.run(request()) + + val requests = (1..mockServer.requestCount).map { mockServer.takeRequest() } + val runRequest = requests.first { it.path?.contains("/run") == true && it.method == "PATCH" } + assertEquals("PATCH", runRequest.method) + val body = runRequest.body.readUtf8() + assertTrue(body.contains("test-worker"), "Should contain hostname") + assertTrue(body.contains("testuser"), "Should contain unix name") + } + + @Test + @DisplayName("run: reports terminal state to execution API") + fun runReportsTerminalState() = + runBlocking { + mockServer.dispatcher = apiDispatcher() + mockServer.start() + + Supervisor.run(request()) + + val requests = (1..mockServer.requestCount).map { mockServer.takeRequest() } + val stateRequest = requests.first { it.path?.contains("/state") == true } + assertEquals("PATCH", stateRequest.method) + } + + @Test + @DisplayName("run: sends bearer token in all API requests") + fun runSendsBearerToken() = + runBlocking { + mockServer.dispatcher = apiDispatcher() + mockServer.start() + + Supervisor.run(request()) + + val requests = (1..mockServer.requestCount).map { mockServer.takeRequest() } + for (req in requests) { + val auth = req.getHeader("Authorization") + assertNotNull(auth, "Authorization header should be present on ${req.path}") + assertTrue(auth!!.startsWith("Bearer "), "Should use Bearer auth on ${req.path}") + } + } + + @Test + @DisplayName("run: collects stdout and stderr from subprocess") + fun runCollectsLogLines() = + runBlocking { + mockServer.dispatcher = apiDispatcher() + mockServer.start() + + val logLines = CopyOnWriteArrayList() + val req = + request(mainClass = TestStdoutSubprocess::class.java.name).copy( + onLogLine = { logLines.add(it) }, + ) + + Supervisor.run(req) + + assertTrue(logLines.any { it == "stdout-line-1" }, "Should capture stdout: $logLines") + assertTrue(logLines.any { it == "stdout-line-2" }, "Should capture stdout: $logLines") + assertTrue(logLines.any { it == "stderr-line-1" }, "Should capture stderr: $logLines") + } + + @Test + @DisplayName("run: non-zero exit code overrides final state to FAILED") + fun runNonZeroExitCodeOverridesState() = + runBlocking { + mockServer.dispatcher = apiDispatcher() + mockServer.start() + + // TestSucceedThenCrashSubprocess sends SucceedTask (which would normally yield SUCCESS) + // but then exits with code 42. Supervisor should override the final state to FAILED. + val result = Supervisor.run(request(mainClass = TestSucceedThenCrashSubprocess::class.java.name)) + + assertEquals(TaskInstanceState.FAILED, result.finalState) + assertEquals(42, result.exitCode) + } +} diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/TaskRunnerTest.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/TaskRunnerTest.kt new file mode 100644 index 0000000000000..0f2bb53bc8d78 --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/TaskRunnerTest.kt @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.airflow.sdk.execution + +import org.apache.airflow.sdk.Bundle +import org.apache.airflow.sdk.Client +import org.apache.airflow.sdk.Context +import org.apache.airflow.sdk.Dag +import org.apache.airflow.sdk.Task +import org.apache.airflow.sdk.execution.api.model.BundleInfo +import org.apache.airflow.sdk.execution.api.model.TIRunContext +import org.apache.airflow.sdk.execution.api.model.TaskInstance +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.DisplayName +import org.junit.jupiter.api.Test +import java.time.OffsetDateTime +import java.util.UUID + +class TaskRunnerTest { + @Test + @DisplayName("Should execute task and return success") + fun shouldExecuteTaskAndReturnSuccess() { + val result = TaskRunner.run(bundleWith("success", SuccessTask::class.java), startupDetails(taskId = "success"), noOpClient()) + + Assertions.assertInstanceOf(SucceedTask::class.java, result) + } + + @Test + @DisplayName("Should return removed when task is missing") + fun shouldReturnRemovedWhenTaskIsMissing() { + val result = TaskRunner.run(bundleWith("other", SuccessTask::class.java), startupDetails(taskId = "missing"), noOpClient()) + + Assertions.assertInstanceOf(TaskState::class.java, result) + Assertions.assertEquals("removed", (result as TaskState).state) + } + + @Test + @DisplayName("Should return failed when task throws") + fun shouldReturnFailedWhenTaskThrows() { + val result = TaskRunner.run(bundleWith("failing", FailingTask::class.java), startupDetails(taskId = "failing"), noOpClient()) + + Assertions.assertInstanceOf(TaskState::class.java, result) + Assertions.assertEquals("failed", (result as TaskState).state) + } + + private fun bundleWith( + taskId: String, + taskClass: Class, + ): Bundle { + val dag = Dag("test_dag") + dag.addTask(taskId, taskClass) + return Bundle("1", listOf(dag)) + } + + private fun startupDetails(taskId: String): StartupDetails = + StartupDetails().also { + it.ti = + TaskInstance().also { taskInstance -> + taskInstance.id = UUID.randomUUID() + taskInstance.taskId = taskId + taskInstance.dagId = "test_dag" + taskInstance.runId = "manual__2026-03-31T00:00:00+00:00" + taskInstance.tryNumber = 1 + taskInstance.dagVersionId = UUID.randomUUID() + } + it.dagRelPath = "/dev/null" + it.bundleInfo = + BundleInfo().also { info -> + info.name = "bundle" + info.version = "1" + } + it.startDate = OffsetDateTime.parse("2026-03-31T00:00:00Z") + it.tiContext = TIRunContext() + it.sentryIntegration = "" + } + + private fun noOpClient() = + Client( + startupDetails(taskId = "unused"), + object : org.apache.airflow.sdk.execution.Client { + override fun getConnection(id: String) = throw UnsupportedOperationException("not used in test") + + override fun getVariable(key: String) = throw UnsupportedOperationException("not used in test") + + override fun getXCom( + key: String, + dagId: String, + taskId: String, + runId: String, + mapIndex: Int?, + includePriorDates: Boolean, + ) = throw UnsupportedOperationException("not used in test") + + override fun setXCom( + key: String, + value: Any, + dagId: String, + taskId: String, + runId: String, + mapIndex: Int, + ): Unit = throw UnsupportedOperationException("not used in test") + }, + ) + + class SuccessTask : Task { + override fun execute( + context: Context, + client: Client, + ) { + } + } + + class FailingTask : Task { + override fun execute( + context: Context, + client: Client, + ): Unit = throw IllegalStateException("boom") + } +} diff --git a/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/TestTaskSubprocess.kt b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/TestTaskSubprocess.kt new file mode 100644 index 0000000000000..0bd392279fc8c --- /dev/null +++ b/java-sdk/sdk/src/test/kotlin/org/apache/airflow/sdk/execution/TestTaskSubprocess.kt @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.airflow.sdk.execution + +import java.io.InputStream +import java.io.OutputStream +import java.net.Socket + +/** + * Family of minimal subprocesses that simulate a Java task bundle process for integration testing + * of [Supervisor]. Each object has a `main` method that can be spawned by [Supervisor.run]. + * + * Protocol: connect to comm + logs sockets (via `--comm` / `--logs` CLI args), read the + * [StartupDetails] frame from the supervisor, perform a behavior-specific action, then exit. + */ +private fun connectAndProcess( + args: Array, + onFrame: (InputStream, OutputStream, IncomingFrame) -> Unit, +) { + val commAddr = args.first { it.startsWith("--comm=") }.removePrefix("--comm=") + val logsAddr = args.first { it.startsWith("--logs=") }.removePrefix("--logs=") + val (commHost, commPort) = commAddr.split(":") + val (logsHost, logsPort) = logsAddr.split(":") + + val commSocket = Socket(commHost, commPort.toInt()) + val logsSocket = Socket(logsHost, logsPort.toInt()) + try { + val commIn = commSocket.getInputStream() + val commOut = commSocket.getOutputStream() + val frame = TaskSdkFrames.readFrame(commIn, TaskSdkFrames.toTaskTypes) + onFrame(commIn, commOut, frame) + } finally { + commSocket.close() + logsSocket.close() + } +} + +/** Reads StartupDetails and immediately sends [SucceedTask]. */ +object TestSucceedSubprocess { + @JvmStatic + fun main(args: Array) = + connectAndProcess(args) { _, output, frame -> + TaskSdkFrames.writeRequest(output, frame.id, SucceedTask()) + } +} + +/** Reads StartupDetails and sends [TaskState] with state=failed. */ +object TestFailSubprocess { + @JvmStatic + fun main(args: Array) = + connectAndProcess(args) { _, output, frame -> + TaskSdkFrames.writeRequest(output, frame.id, TaskState(state = "failed")) + } +} + +/** Sends a [GetVariable] request, reads the response, then sends [SucceedTask]. */ +object TestGetVariableSubprocess { + @JvmStatic + fun main(args: Array) = + connectAndProcess(args) { input, output, frame -> + TaskSdkFrames.writeRequest(output, 10, GetVariable("test_var")) + TaskSdkFrames.readFrame(input, TaskSdkFrames.toBundleProcessTypes) + TaskSdkFrames.writeRequest(output, frame.id, SucceedTask()) + } +} + +/** Sends a [GetConnection] request, reads the response, then sends [SucceedTask]. */ +object TestGetConnectionSubprocess { + @JvmStatic + fun main(args: Array) = + connectAndProcess(args) { input, output, frame -> + TaskSdkFrames.writeRequest(output, 10, GetConnection("test_conn")) + TaskSdkFrames.readFrame(input, TaskSdkFrames.toBundleProcessTypes) + TaskSdkFrames.writeRequest(output, frame.id, SucceedTask()) + } +} + +/** Writes a message to stdout before succeeding, to verify log collection. */ +object TestStdoutSubprocess { + @JvmStatic + fun main(args: Array) { + println("stdout-line-1") + println("stdout-line-2") + System.err.println("stderr-line-1") + connectAndProcess(args) { _, output, frame -> + TaskSdkFrames.writeRequest(output, frame.id, SucceedTask()) + } + } +} + +/** Sends [SucceedTask] but exits with non-zero code — tests exit-code override logic. */ +object TestSucceedThenCrashSubprocess { + @JvmStatic + fun main(args: Array) { + connectAndProcess(args) { _, output, frame -> + TaskSdkFrames.writeRequest(output, frame.id, SucceedTask()) + } + // Force non-zero exit after the protocol completes cleanly. + Runtime.getRuntime().halt(42) + } +} diff --git a/java-sdk/settings.gradle.kts b/java-sdk/settings.gradle.kts new file mode 100644 index 0000000000000..0892437c13d3e --- /dev/null +++ b/java-sdk/settings.gradle.kts @@ -0,0 +1,13 @@ +/* + * This file was generated by the Gradle 'init' task. + * + * The settings file is used to specify which projects to include in your build. + * For more detailed information on multi-project builds, please refer to https://docs.gradle.org/9.2.1/userguide/multi_project_builds.html in the Gradle documentation. + */ + +plugins { + id("org.gradle.toolchains.foojay-resolver-convention").version("0.10.0") +} + +rootProject.name = "airflow-java-sdk" +include("example", "sdk") diff --git a/java-sdk/validation/serialization/compare.py b/java-sdk/validation/serialization/compare.py new file mode 100644 index 0000000000000..464f247852c40 --- /dev/null +++ b/java-sdk/validation/serialization/compare.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Compare serialised DAG output from Python and Java SDKs. + +Usage: + python compare.py serialized_python.json serialized_java.json +""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path + +# Fields that are inherently environment-specific and should be ignored. +IGNORED_FIELDS = frozenset( + { + "fileloc", + "relative_fileloc", + "_processor_dags_folder", + } +) + +# Floating-point tolerance for timestamp / duration comparisons. +FLOAT_TOLERANCE = 1e-6 + + +# --------------------------------------------------------------------------- +# Normalisation +# --------------------------------------------------------------------------- + + +def _normalise(obj, *, _depth: int = 0): + """Recursively normalise an object for comparison.""" + if isinstance(obj, dict): + return { + k: _normalise(v, _depth=_depth + 1) for k, v in sorted(obj.items()) if k not in IGNORED_FIELDS + } + if isinstance(obj, list): + return [_normalise(item, _depth=_depth + 1) for item in obj] + if isinstance(obj, float): + return round(obj, 6) + return obj + + +# --------------------------------------------------------------------------- +# Deep diff +# --------------------------------------------------------------------------- + + +def _deep_diff(python_obj, java_obj, path: str = "") -> list[str]: + """Return a list of human-readable difference descriptions.""" + diffs: list[str] = [] + + if type(python_obj) is not type(java_obj): + # Allow int ↔ float (e.g. 0 vs 0.0) + if isinstance(python_obj, (int, float)) and isinstance(java_obj, (int, float)): + if abs(float(python_obj) - float(java_obj)) > FLOAT_TOLERANCE: + diffs.append(f"{path}: {python_obj!r} != {java_obj!r}") + return diffs + diffs.append( + f"{path}: type mismatch — Python {type(python_obj).__name__}" + f" vs Java {type(java_obj).__name__}" + f" (py={python_obj!r}, java={java_obj!r})" + ) + return diffs + + if isinstance(python_obj, dict): + all_keys = set(python_obj) | set(java_obj) + for key in sorted(all_keys): + child_path = f"{path}.{key}" if path else key + if key not in java_obj: + diffs.append(f"{child_path}: present in Python but missing in Java") + elif key not in python_obj: + diffs.append(f"{child_path}: present in Java but missing in Python") + else: + diffs.extend(_deep_diff(python_obj[key], java_obj[key], child_path)) + elif isinstance(python_obj, list): + if len(python_obj) != len(java_obj): + diffs.append(f"{path}: list length — Python {len(python_obj)} vs Java {len(java_obj)}") + for i, (p, j) in enumerate(zip(python_obj, java_obj)): + diffs.extend(_deep_diff(p, j, f"{path}[{i}]")) + elif isinstance(python_obj, float): + if abs(python_obj - java_obj) > FLOAT_TOLERANCE: + diffs.append(f"{path}: {python_obj!r} != {java_obj!r}") + elif python_obj != java_obj: + diffs.append(f"{path}: {python_obj!r} != {java_obj!r}") + + return diffs + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + + python_path = Path(sys.argv[1]) + java_path = Path(sys.argv[2]) + + with open(python_path) as fh: + python_data: dict = json.load(fh) + with open(java_path) as fh: + java_data: dict = json.load(fh) + + all_names = sorted(set(python_data) | set(java_data)) + total = len(all_names) + passed = 0 + failed = 0 + + for name in all_names: + if name not in python_data: + print(f"SKIP {name} (missing in Python output)") + continue + if name not in java_data: + print(f"SKIP {name} (missing in Java output)") + continue + + py_dag = python_data[name] + jv_dag = java_data[name] + + # Skip error entries + if isinstance(py_dag, dict) and "__error" in py_dag: + print(f"SKIP {name} (Python error: {py_dag['__error']})") + continue + if isinstance(jv_dag, dict) and "__error" in jv_dag: + print(f"SKIP {name} (Java error: {jv_dag['__error']})") + continue + + py_norm = _normalise(py_dag) + jv_norm = _normalise(jv_dag) + + diffs = _deep_diff(py_norm, jv_norm) + if diffs: + failed += 1 + print(f"FAIL {name}") + for d in diffs: + print(f" {d}") + else: + passed += 1 + print(f"PASS {name}") + + print(f"\n{'=' * 60}") + print(f"Total: {total} | Passed: {passed} | Failed: {failed}") + if failed: + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/java-sdk/validation/serialization/serialize_python.py b/java-sdk/validation/serialization/serialize_python.py new file mode 100644 index 0000000000000..7cc6ca00eb3d4 --- /dev/null +++ b/java-sdk/validation/serialization/serialize_python.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Serialize DAGs using the Python Airflow SDK for cross-language comparison. + +Prerequisites: + - Airflow core and task-sdk installed (pip install -e …) + - PyYAML installed (pip install pyyaml) + +Usage: + python serialize_python.py test_dags.yaml serialized_python.json +""" + +from __future__ import annotations + +import json +import sys +from datetime import datetime, timedelta +from pathlib import Path + +import yaml + +# --------------------------------------------------------------------------- +# YAML params → Python DAG constructor kwargs +# --------------------------------------------------------------------------- + + +def _yaml_params_to_dag_kwargs(params: dict) -> dict: + """Convert language-agnostic YAML params to Python DAG constructor kwargs.""" + kwargs: dict = {} + for key, value in params.items(): + if key in ("start_date", "end_date") and isinstance(value, str): + kwargs[key] = datetime.fromisoformat(value) + elif key == "dagrun_timeout_seconds": + kwargs["dagrun_timeout"] = timedelta(seconds=value) + elif key == "tags" and isinstance(value, list): + kwargs["tags"] = set(value) + elif key == "access_control" and isinstance(value, dict): + # Convert innermost lists → sets (permissions) + kwargs["access_control"] = { + role: { + resource: set(perms) if isinstance(perms, list) else perms + for resource, perms in resources.items() + } + for role, resources in value.items() + } + elif key == "params": + kwargs["params"] = value + else: + kwargs[key] = value + return kwargs + + +# --------------------------------------------------------------------------- +# JSON helper +# --------------------------------------------------------------------------- + + +def _make_json_safe(obj): + """Handle types that json.dumps cannot serialise natively.""" + if isinstance(obj, (set, frozenset)): + return sorted(obj) + if isinstance(obj, bytes): + return obj.decode("utf-8") + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + + yaml_path = Path(sys.argv[1]) + output_path = Path(sys.argv[2]) + + with open(yaml_path) as fh: + test_data = yaml.safe_load(fh) + + # Lazy-import Airflow so the script fails fast on missing args first. + from airflow.sdk import DAG + from airflow.serialization.serialized_objects import DagSerialization + + results: dict[str, dict] = {} + for case in test_data["test_cases"]: + name = case["name"] + kwargs = _yaml_params_to_dag_kwargs(case["params"]) + print(f" [{name}] ", end="") + try: + dag = DAG(**kwargs) + serialized = DagSerialization.serialize_dag(dag) + results[name] = serialized + print("OK") + except Exception as exc: + print(f"ERROR: {exc}") + results[name] = {"__error": str(exc)} + + with open(output_path, "w") as fh: + json.dump(results, fh, indent=2, sort_keys=True, default=_make_json_safe) + + print(f"\nWrote {len(results)} serialised DAGs → {output_path}") + + +if __name__ == "__main__": + main() diff --git a/java-sdk/validation/serialization/test_dags.yaml b/java-sdk/validation/serialization/test_dags.yaml new file mode 100644 index 0000000000000..99d3b6a8e1374 --- /dev/null +++ b/java-sdk/validation/serialization/test_dags.yaml @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Cross-language DAG serialization test cases. +# +# Each entry defines DAG constructor parameters in a language-agnostic way. +# Both the Python and Java serialization scripts read this file, construct +# DAGs from the params, and write the serialized output for comparison. +# +# Type conventions: +# start_date / end_date — ISO-8601 string, parsed to datetime / Instant +# dagrun_timeout_seconds — number of seconds, parsed to timedelta / Duration +# tags — list of strings, converted to set on both sides +# access_control — nested map; innermost lists become sets +--- +test_cases: + # ---- schedule variants ------------------------------------------------ + - name: "minimal_dag" + params: + dag_id: "example_dag" + + - name: "schedule_daily" + params: + dag_id: "example_dag" + schedule: "@daily" + + - name: "schedule_hourly" + params: + dag_id: "example_dag" + schedule: "@hourly" + + - name: "schedule_once" + params: + dag_id: "example_dag" + schedule: "@once" + + - name: "schedule_continuous" + params: + dag_id: "example_dag" + schedule: "@continuous" + max_active_runs: 1 + + - name: "schedule_cron" + params: + dag_id: "example_dag" + schedule: "0 0 * * *" + + - name: "schedule_cron_complex" + params: + dag_id: "example_dag" + schedule: "30 2 */3 * 1-5" + + # ---- simple scalar fields -------------------------------------------- + - name: "with_description" + params: + dag_id: "example_dag" + description: "This is an example DAG for testing serialization." + + - name: "with_doc_md" + params: + dag_id: "example_dag" + doc_md: "# Example DAG\n\nThis is **markdown** documentation." + + - name: "with_dag_display_name" + params: + dag_id: "example_dag" + dag_display_name: "My Example Pipeline" + + - name: "with_dag_display_name_same_as_id" + params: + dag_id: "example_dag" + dag_display_name: "example_dag" + + # ---- boolean / numeric fields ----------------------------------------- + - name: "with_catchup_and_start_date" + params: + dag_id: "example_dag" + schedule: "@daily" + start_date: "2024-01-01T00:00:00Z" + catchup: true + + - name: "with_fail_fast" + params: + dag_id: "example_dag" + fail_fast: true + + - name: "with_render_template_as_native_obj" + params: + dag_id: "example_dag" + render_template_as_native_obj: true + + - name: "with_is_paused_upon_creation_true" + params: + dag_id: "example_dag" + is_paused_upon_creation: true + + - name: "with_is_paused_upon_creation_false" + params: + dag_id: "example_dag" + is_paused_upon_creation: false + + - name: "with_max_active_tasks" + params: + dag_id: "example_dag" + max_active_tasks: 32 + max_active_runs: 8 + max_consecutive_failed_dag_runs: 5 + + - name: "with_dagrun_timeout" + params: + dag_id: "example_dag" + dagrun_timeout_seconds: 3600 + + - name: "with_start_and_end_date" + params: + dag_id: "example_dag" + start_date: "2024-01-01T00:00:00Z" + end_date: "2024-12-31T23:59:59Z" + + # ---- collection fields ------------------------------------------------ + - name: "with_tags" + params: + dag_id: "example_dag" + tags: ["alpha", "beta", "gamma"] + + - name: "with_owner_links" + params: + dag_id: "example_dag" + owner_links: + data_team: "https://example.com/data-team" + dev_team: "https://example.com/dev-team" + + # ---- decorated / typed fields ----------------------------------------- + - name: "with_default_args" + params: + dag_id: "example_dag" + default_args: + retries: 3 + owner: "test_owner" + + - name: "with_params" + params: + dag_id: "example_dag" + params: + param_string: "value1" + param_int: 42 + param_bool: true + + - name: "with_access_control" + params: + dag_id: "example_dag" + access_control: + viewer_role: + DAGs: + - "can_read" + editor_role: + DAGs: + - "can_read" + - "can_edit" + + # ---- complex combined ------------------------------------------------- + - name: "complex_dag" + params: + dag_id: "complex_dag" + schedule: "0 */6 * * *" + description: "A complex DAG with many parameters" + start_date: "2024-06-15T10:30:00Z" + catchup: true + max_active_tasks: 32 + max_active_runs: 8 + max_consecutive_failed_dag_runs: 3 + dagrun_timeout_seconds: 7200 + tags: ["complex", "production", "etl"] + owner_links: + data_team: "https://example.com/data-team" + fail_fast: true + dag_display_name: "Complex ETL Pipeline" + doc_md: "# Complex Pipeline\nHandles ETL processing." + default_args: + retries: 2 + owner: "data_engineering" + params: + env: "production" + batch_size: 1000