Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def get_conn(self) -> Connection:
def _get_client(self) -> oss.Client:
config = oss.config.load_default()
config.region = self.region
config.endpoint = f"oss-{self.region}.aliyuncs.com"
# Prefer extra.endpoint (e.g. VPC internal endpoint) over default public endpoint
endpoint = self.oss_conn.extra_dejson.get("endpoint")
config.endpoint = endpoint or f"oss-{self.region}.aliyuncs.com"
config.credentials_provider = self.get_credential()
return oss.Client(config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ def close(self):
# Mark closed so we don't double write if close is called twice
self.closed = True

def oss_log_exists(self, remote_log_location: str) -> bool:
"""Check if remote_log_location exists in OSS, proxy to self.io."""
return self.io.oss_log_exists(remote_log_location)

def oss_read(self, remote_log_location: str, return_error: bool = False) -> str:
"""Read log content from OSS, proxy to self.io."""
return self.io.oss_read(remote_log_location, return_error=return_error)

def _read(self, ti, try_number, metadata=None):
"""
Read logs of given task instance and try_number from OSS remote storage.
Expand Down
26 changes: 26 additions & 0 deletions providers/alibaba/tests/unit/alibaba/cloud/hooks/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,32 @@ def test_get_credential(self, mock_provider):
self.hook.get_credential()
mock_provider.assert_called_once_with("mock_access_key_id", "mock_access_key_secret")

@mock.patch(OSS_STRING.format("oss.Client"))
@mock.patch(OSS_STRING.format("oss.config.load_default"))
@mock.patch(OSS_STRING.format("OSSHook.get_credential"))
def test_get_client_default_endpoint(self, mock_cred, mock_config, mock_client):
self.hook._get_client()
_, kwargs = mock_config.call_args
assert kwargs["endpoint"] == "oss-mock_region.aliyuncs.com"

@mock.patch(OSS_STRING.format("oss.Client"))
@mock.patch(OSS_STRING.format("oss.config.load_default"))
@mock.patch(OSS_STRING.format("OSSHook.get_credential"))
def test_get_client_custom_endpoint(self, mock_cred, mock_config, mock_client):
with mock.patch(
OSS_STRING.format("OSSHook.__init__"),
new=mock_oss_hook_default_project_id,
):
hook = OSSHook(
oss_conn_id=MOCK_OSS_CONN_ID,
region="mock_region",
)
# Simulate custom VPC internal endpoint via extra
hook.oss_conn.extra_dejson["endpoint"] = "oss-cn-hangzhou-internal.aliyuncs.com"
hook._get_client()
_, kwargs = mock_config.call_args
assert kwargs["endpoint"] == "oss-cn-hangzhou-internal.aliyuncs.com"

@mock.patch(OSS_STRING.format("OSSHook._get_client"))
def test_get_bucket(self, mock_get_client):
self.hook.get_bucket("mock_bucket_name")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,35 @@ def test_upload_passes_relative_path_to_oss_write(self, mock_oss_write, tmp_path
handler.io.upload(str(log_file), self.ti)
mock_oss_write.assert_called_once_with("test log content", relative_path)

@mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.oss_log_exists"))
@mock.patch(OSS_TASK_HANDLER_STRING.format("OSSTaskHandler.oss_read"))
def test_read_uses_oss_log_exists_and_oss_read(self, mock_oss_read, mock_oss_log_exists):
"""Test that _read calls oss_log_exists and oss_read on the handler itself."""
mock_oss_log_exists.return_value = True
mock_oss_read.return_value = "log content"

# _read should call self.oss_log_exists and self.oss_read (not self.io.*)
self.oss_task_handler._read(self.ti, try_number=1)

mock_oss_log_exists.assert_called_once()
mock_oss_read.assert_called_once()

@mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.oss_log_exists"))
def test_handler_oss_log_exists_proxies_to_io(self, mock_io_exists):
"""Test that OSSTaskHandler.oss_log_exists proxies to self.io.oss_log_exists."""
mock_io_exists.return_value = True
result = self.oss_task_handler.oss_log_exists("1.log")
mock_io_exists.assert_called_once_with("1.log")
assert result is True

@mock.patch(OSS_TASK_HANDLER_STRING.format("OSSRemoteLogIO.oss_read"))
def test_handler_oss_read_proxies_to_io(self, mock_io_read):
"""Test that OSSTaskHandler.oss_read proxies to self.io.oss_read."""
mock_io_read.return_value = "log content"
result = self.oss_task_handler.oss_read("1.log", return_error=True)
mock_io_read.assert_called_once_with("1.log", return_error=True)
assert result == "log content"

def test_filename_template_for_backward_compatibility(self):
# filename_template arg support for running the latest provider on airflow 2
OSSTaskHandler(self.base_log_folder, self.oss_log_folder, filename_template=None)
21 changes: 10 additions & 11 deletions providers/alibaba/tests/unit/alibaba/cloud/utils/oss_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@
OSS_PROJECT_ID_HOOK_UNIT_TEST = "example-project"


def mock_oss_hook_default_project_id(self, oss_conn_id="mock_oss_default", region="mock_region"):
def mock_oss_hook_default_project_id(self, oss_conn_id="mock_oss_default", region="mock_region", endpoint=None):
self.oss_conn_id = oss_conn_id
self.oss_conn = Connection(
extra=json.dumps(
{
"auth_type": "AK",
"access_key_id": "mock_access_key_id",
"access_key_secret": "mock_access_key_secret",
"region": "mock_region",
}
)
)
extra = {
"auth_type": "AK",
"access_key_id": "mock_access_key_id",
"access_key_secret": "mock_access_key_secret",
"region": region,
}
if endpoint is not None:
extra["endpoint"] = endpoint
self.oss_conn = Connection(extra=json.dumps(extra))
self.region = region