diff --git a/.gitignore b/.gitignore index f0226c065..e9bd2b49f 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ var/ *.egg-info/ .installed.cfg *.egg +pip-wheel-metadata/ # PyInstaller # Usually these files are written by a python script from a template @@ -89,6 +90,7 @@ env.py # virtualenv venv/ ENV/ +.venv/ # Spyder project settings .spyderproject diff --git a/tableauserverclient/config.py b/tableauserverclient/config.py index 67a77f479..1a4a7dc37 100644 --- a/tableauserverclient/config.py +++ b/tableauserverclient/config.py @@ -7,7 +7,7 @@ # For when a datasource is over 64MB, break it into 5MB(standard chunk size) chunks CHUNK_SIZE_MB = 5 * 10 # 5MB felt too slow, upped it to 50 -DELAY_SLEEP_SECONDS = 10 +DELAY_SLEEP_SECONDS = 0.1 # The maximum size of a file that can be published in a single request is 64MB FILESIZE_LIMIT_MB = 64 diff --git a/tableauserverclient/helpers/headers.py b/tableauserverclient/helpers/headers.py new file mode 100644 index 000000000..2ed4a814d --- /dev/null +++ b/tableauserverclient/helpers/headers.py @@ -0,0 +1,17 @@ +from copy import deepcopy +from urllib.parse import unquote_plus + + +def fix_filename(params): + if "filename*" not in params: + return params + + params = deepcopy(params) + filename = params["filename*"] + prefix = "UTF-8''" + if filename.startswith(prefix): + filename = filename[len(prefix) :] + + params["filename"] = unquote_plus(filename) + del params["filename*"] + return params diff --git a/tableauserverclient/helpers/strings.py b/tableauserverclient/helpers/strings.py index e51a6611a..75534103b 100644 --- a/tableauserverclient/helpers/strings.py +++ b/tableauserverclient/helpers/strings.py @@ -9,8 +9,6 @@ T = TypeVar("T", str, bytes) -# usage: _redact_any_type("") -# -> b" def _redact_any_type(xml: T, sensitive_word: T, replacement: T, encoding=None) -> T: try: root = fromstring(xml) diff --git a/tableauserverclient/models/interval_item.py b/tableauserverclient/models/interval_item.py index 02b57591b..f2f159625 100644 --- a/tableauserverclient/models/interval_item.py +++ b/tableauserverclient/models/interval_item.py @@ -29,7 +29,12 @@ class HourlyInterval(object): def __init__(self, start_time, end_time, interval_value): self.start_time = start_time self.end_time = end_time - self.interval = interval_value + + # interval should be a tuple, if it is not, assign as a tuple with single value + if isinstance(interval_value, tuple): + self.interval = interval_value + else: + self.interval = (interval_value,) def __repr__(self): return f"<{self.__class__.__name__} start={self.start_time} end={self.end_time} interval={self.interval}>" @@ -63,25 +68,44 @@ def interval(self): return self._interval @interval.setter - def interval(self, interval): + def interval(self, intervals): VALID_INTERVALS = {0.25, 0.5, 1, 2, 4, 6, 8, 12} - if float(interval) not in VALID_INTERVALS: - error = "Invalid interval {} not in {}".format(interval, str(VALID_INTERVALS)) - raise ValueError(error) + for interval in intervals: + # if an hourly interval is a string, then it is a weekDay interval + if isinstance(interval, str) and not interval.isnumeric() and not hasattr(IntervalItem.Day, interval): + error = "Invalid weekDay interval {}".format(interval) + raise ValueError(error) + + # if an hourly interval is a number, it is an hours or minutes interval + if isinstance(interval, (int, float)) and float(interval) not in VALID_INTERVALS: + error = "Invalid interval {} not in {}".format(interval, str(VALID_INTERVALS)) + raise ValueError(error) - self._interval = interval + self._interval = intervals def _interval_type_pairs(self): - # We use fractional hours for the two minute-based intervals. - # Need to convert to minutes from hours here - if self.interval in {0.25, 0.5}: - calculated_interval = int(self.interval * 60) - interval_type = IntervalItem.Occurrence.Minutes - else: - calculated_interval = self.interval - interval_type = IntervalItem.Occurrence.Hours + interval_type_pairs = [] + for interval in self.interval: + # We use fractional hours for the two minute-based intervals. + # Need to convert to minutes from hours here + if interval in {0.25, 0.5}: + calculated_interval = int(interval * 60) + interval_type = IntervalItem.Occurrence.Minutes + + interval_type_pairs.append((interval_type, str(calculated_interval))) + else: + # if the interval is a non-numeric string, it will always be a weekDay + if isinstance(interval, str) and not interval.isnumeric(): + interval_type = IntervalItem.Occurrence.WeekDay + + interval_type_pairs.append((interval_type, str(interval))) + # otherwise the interval is hours + else: + interval_type = IntervalItem.Occurrence.Hours - return [(interval_type, str(calculated_interval))] + interval_type_pairs.append((interval_type, str(interval))) + + return interval_type_pairs class DailyInterval(object): @@ -111,8 +135,45 @@ def interval(self): return self._interval @interval.setter - def interval(self, interval): - self._interval = interval + def interval(self, intervals): + VALID_INTERVALS = {0.25, 0.5, 1, 2, 4, 6, 8, 12} + + for interval in intervals: + # if an hourly interval is a string, then it is a weekDay interval + if isinstance(interval, str) and not interval.isnumeric() and not hasattr(IntervalItem.Day, interval): + error = "Invalid weekDay interval {}".format(interval) + raise ValueError(error) + + # if an hourly interval is a number, it is an hours or minutes interval + if isinstance(interval, (int, float)) and float(interval) not in VALID_INTERVALS: + error = "Invalid interval {} not in {}".format(interval, str(VALID_INTERVALS)) + raise ValueError(error) + + self._interval = intervals + + def _interval_type_pairs(self): + interval_type_pairs = [] + for interval in self.interval: + # We use fractional hours for the two minute-based intervals. + # Need to convert to minutes from hours here + if interval in {0.25, 0.5}: + calculated_interval = int(interval * 60) + interval_type = IntervalItem.Occurrence.Minutes + + interval_type_pairs.append((interval_type, str(calculated_interval))) + else: + # if the interval is a non-numeric string, it will always be a weekDay + if isinstance(interval, str) and not interval.isnumeric(): + interval_type = IntervalItem.Occurrence.WeekDay + + interval_type_pairs.append((interval_type, str(interval))) + # otherwise the interval is hours + else: + interval_type = IntervalItem.Occurrence.Hours + + interval_type_pairs.append((interval_type, str(interval))) + + return interval_type_pairs class WeeklyInterval(object): @@ -155,7 +216,12 @@ def _interval_type_pairs(self): class MonthlyInterval(object): def __init__(self, start_time, interval_value): self.start_time = start_time - self.interval = str(interval_value) + + # interval should be a tuple, if it is not, assign as a tuple with single value + if isinstance(interval_value, tuple): + self.interval = interval_value + else: + self.interval = (interval_value,) def __repr__(self): return f"<{self.__class__.__name__} start={self.start_time} interval={self.interval}>" @@ -179,24 +245,24 @@ def interval(self): return self._interval @interval.setter - def interval(self, interval_value): - error = "Invalid interval value for a monthly frequency: {}.".format(interval_value) - + def interval(self, interval_values): # This is weird because the value could be a str or an int # The only valid str is 'LastDay' so we check that first. If that's not it # try to convert it to an int, if that fails because it's an incorrect string # like 'badstring' we catch and re-raise. Otherwise we convert to int and check # that it's in range 1-31 + for interval_value in interval_values: + error = "Invalid interval value for a monthly frequency: {}.".format(interval_value) - if interval_value != "LastDay": - try: - if not (1 <= int(interval_value) <= 31): - raise ValueError(error) - except ValueError: - if interval_value != "LastDay": - raise ValueError(error) + if interval_value != "LastDay": + try: + if not (1 <= int(interval_value) <= 31): + raise ValueError(error) + except ValueError: + if interval_value != "LastDay": + raise ValueError(error) - self._interval = str(interval_value) + self._interval = interval_values def _interval_type_pairs(self): return [(IntervalItem.Occurrence.MonthDay, self.interval)] diff --git a/tableauserverclient/models/project_item.py b/tableauserverclient/models/project_item.py index e7254ab5d..4918f1a14 100644 --- a/tableauserverclient/models/project_item.py +++ b/tableauserverclient/models/project_item.py @@ -163,9 +163,6 @@ def _set_default_permissions(self, permissions, content_type): attr, permissions, ) - fetch_call = getattr(self, attr) - logging.getLogger().info({"type": attr, "value": fetch_call()}) - return fetch_call() @classmethod def from_response(cls, resp, ns) -> List["ProjectItem"]: diff --git a/tableauserverclient/models/schedule_item.py b/tableauserverclient/models/schedule_item.py index dc0eca948..db187a5f9 100644 --- a/tableauserverclient/models/schedule_item.py +++ b/tableauserverclient/models/schedule_item.py @@ -254,25 +254,43 @@ def _parse_interval_item(parsed_response, frequency, ns): interval.extend(interval_elem.attrib.items()) if frequency == IntervalItem.Frequency.Daily: - return DailyInterval(start_time) + converted_intervals = [] + + for i in interval: + # We use fractional hours for the two minute-based intervals. + # Need to convert to hours from minutes here + if i[0] == IntervalItem.Occurrence.Minutes: + converted_intervals.append(float(i[1]) / 60) + elif i[0] == IntervalItem.Occurrence.Hours: + converted_intervals.append(float(i[1])) + else: + converted_intervals.append(i[1]) + + return DailyInterval(start_time, *converted_intervals) if frequency == IntervalItem.Frequency.Hourly: - interval_occurrence, interval_value = interval.pop() + converted_intervals = [] - # We use fractional hours for the two minute-based intervals. - # Need to convert to hours from minutes here - if interval_occurrence == IntervalItem.Occurrence.Minutes: - interval_value = float(interval_value) / 60 + for i in interval: + # We use fractional hours for the two minute-based intervals. + # Need to convert to hours from minutes here + if i[0] == IntervalItem.Occurrence.Minutes: + converted_intervals.append(float(i[1]) / 60) + elif i[0] == IntervalItem.Occurrence.Hours: + converted_intervals.append(i[1]) + else: + converted_intervals.append(i[1]) - return HourlyInterval(start_time, end_time, interval_value) + return HourlyInterval(start_time, end_time, tuple(converted_intervals)) if frequency == IntervalItem.Frequency.Weekly: interval_values = [i[1] for i in interval] return WeeklyInterval(start_time, *interval_values) if frequency == IntervalItem.Frequency.Monthly: - interval_occurrence, interval_value = interval.pop() - return MonthlyInterval(start_time, interval_value) + interval_values = [i[1] for i in interval] + + return MonthlyInterval(start_time, tuple(interval_values)) @staticmethod def _parse_element(schedule_xml, ns): diff --git a/tableauserverclient/models/task_item.py b/tableauserverclient/models/task_item.py index 159869b07..0ffc3bfab 100644 --- a/tableauserverclient/models/task_item.py +++ b/tableauserverclient/models/task_item.py @@ -1,8 +1,11 @@ +from datetime import datetime +from typing import List, Optional + from defusedxml.ElementTree import fromstring from tableauserverclient.datetime_helpers import parse_datetime -from .schedule_item import ScheduleItem -from .target import Target +from tableauserverclient.models.schedule_item import ScheduleItem +from tableauserverclient.models.target import Target class TaskItem(object): @@ -19,14 +22,14 @@ class Type: def __init__( self, - id_, - task_type, - priority, - consecutive_failed_count=0, - schedule_id=None, - schedule_item=None, - last_run_at=None, - target=None, + id_: str, + task_type: str, + priority: int, + consecutive_failed_count: int = 0, + schedule_id: Optional[str] = None, + schedule_item: Optional[ScheduleItem] = None, + last_run_at: Optional[datetime] = None, + target: Optional[Target] = None, ): self.id = id_ self.task_type = task_type @@ -37,14 +40,14 @@ def __init__( self.last_run_at = last_run_at self.target = target - def __repr__(self): + def __repr__(self) -> str: return ( "".format(**self.__dict__) ) @classmethod - def from_response(cls, xml, ns, task_type=Type.ExtractRefresh): + def from_response(cls, xml, ns, task_type=Type.ExtractRefresh) -> List["TaskItem"]: parsed_response = fromstring(xml) all_tasks_xml = parsed_response.findall(".//t:task/t:{}".format(task_type), namespaces=ns) @@ -62,8 +65,7 @@ def _parse_element(cls, element, ns): last_run_at_element = element.find(".//t:lastRunAt", namespaces=ns) schedule_item_list = ScheduleItem.from_element(element, ns) - if len(schedule_item_list) >= 1: - schedule_item = schedule_item_list[0] + schedule_item = next(iter(schedule_item_list), None) # according to the Tableau Server REST API documentation, # there should be only one of workbook or datasource @@ -87,14 +89,14 @@ def _parse_element(cls, element, ns): task_type, priority, consecutive_failed_count, - schedule_item.id, + schedule_item.id if schedule_item is not None else None, schedule_item, last_run_at, target, ) @staticmethod - def _translate_task_type(task_type): + def _translate_task_type(task_type: str) -> str: if task_type in TaskItem._TASK_TYPE_MAPPING: return TaskItem._TASK_TYPE_MAPPING[task_type] else: diff --git a/tableauserverclient/server/endpoint/datasources_endpoint.py b/tableauserverclient/server/endpoint/datasources_endpoint.py index c60f8f919..66ad9f710 100644 --- a/tableauserverclient/server/endpoint/datasources_endpoint.py +++ b/tableauserverclient/server/endpoint/datasources_endpoint.py @@ -8,6 +8,8 @@ from pathlib import Path from typing import List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from tableauserverclient.helpers.headers import fix_filename + if TYPE_CHECKING: from tableauserverclient.server import Server from tableauserverclient.models import PermissionsRule @@ -441,6 +443,7 @@ def download_revision( filepath.write(chunk) return_path = filepath else: + params = fix_filename(params) filename = to_filename(os.path.basename(params["filename"])) download_path = make_download_path(filepath, filename) with open(download_path, "wb") as f: diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index c11a3fb27..77a771288 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -2,7 +2,6 @@ from time import sleep from tableauserverclient import datetime_helpers as datetime -import requests from packaging.version import Version from functools import wraps from xml.etree.ElementTree import ParseError @@ -76,7 +75,7 @@ def set_user_agent(parameters): # return explicitly for testing only return parameters - def _blocking_request(self, method, url, parameters={}) -> Optional["Response"]: + def _blocking_request(self, method, url, parameters={}) -> Optional[Union["Response", Exception]]: self.async_response = None response = None logger.debug("[{}] Begin blocking request to {}".format(datetime.timestamp(), url)) @@ -95,39 +94,37 @@ def _blocking_request(self, method, url, parameters={}) -> Optional["Response"]: return self.async_response def send_request_while_show_progress_threaded( - self, method, url, parameters={}, request_timeout=0 - ) -> Optional["Response"]: + self, method, url, parameters={}, request_timeout=None + ) -> Optional[Union["Response", Exception]]: try: request_thread = Thread(target=self._blocking_request, args=(method, url, parameters)) - request_thread.async_response = -1 # type:ignore # this is an invented attribute for thread comms request_thread.start() except Exception as e: logger.debug("Error starting server request on separate thread: {}".format(e)) return None - seconds = 0 + seconds = 0.05 minutes = 0 - sleep(1) - if self.async_response != -1: + last_log_minute = 0 + sleep(seconds) + if self.async_response is not None: # a quick return for any immediate responses return self.async_response - while self.async_response == -1 and (request_timeout == 0 or seconds < request_timeout): - self.log_wait_time_then_sleep(minutes, seconds, url) + timed_out: bool = request_timeout is not None and seconds > request_timeout + while (self.async_response is None) and not timed_out: + sleep(DELAY_SLEEP_SECONDS) seconds = seconds + DELAY_SLEEP_SECONDS - if seconds >= 60: - seconds = 0 - minutes = minutes + 1 + minutes = int(seconds / 60) + last_log_minute = self.log_wait_time(minutes, last_log_minute, url) return self.async_response - def log_wait_time_then_sleep(self, minutes, seconds, url): + def log_wait_time(self, minutes, last_log_minute, url) -> int: logger.debug("{} Waiting....".format(datetime.timestamp())) - if seconds >= 60: # detailed log message ~every minute - if minutes % 5 == 0: - logger.info( - "[{}] Waiting ({} minutes so far) for request to {}".format(datetime.timestamp(), minutes, url) - ) - else: - logger.debug("[{}] Waiting for request to {}".format(datetime.timestamp(), url)) - sleep(DELAY_SLEEP_SECONDS) + if minutes > last_log_minute: # detailed log message ~every minute + logger.info("[{}] Waiting ({} minutes so far) for request to {}".format(datetime.timestamp(), minutes, url)) + last_log_minute = minutes + else: + logger.debug("[{}] Waiting for request to {}".format(datetime.timestamp(), url)) + return last_log_minute def _make_request( self, @@ -151,7 +148,7 @@ def _make_request( # a request can, for stuff like publishing, spin for ages waiting for a response. # we need some user-facing activity so they know it's not dead. request_timeout = self.parent_srv.http_options.get("timeout") or 0 - server_response: Optional["Response"] = self.send_request_while_show_progress_threaded( + server_response: Optional[Union["Response", Exception]] = self.send_request_while_show_progress_threaded( method, url, parameters, request_timeout ) logger.debug("[{}] Async request returned: received {}".format(datetime.timestamp(), server_response)) @@ -163,6 +160,8 @@ def _make_request( if server_response is None: logger.debug("[{}] Request failed".format(datetime.timestamp())) raise RuntimeError + if isinstance(server_response, Exception): + raise server_response self._check_status(server_response, url) loggable_response = self.log_response_safely(server_response) diff --git a/tableauserverclient/server/endpoint/flows_endpoint.py b/tableauserverclient/server/endpoint/flows_endpoint.py index ba8a152d7..21c16b1cc 100644 --- a/tableauserverclient/server/endpoint/flows_endpoint.py +++ b/tableauserverclient/server/endpoint/flows_endpoint.py @@ -7,6 +7,8 @@ from pathlib import Path from typing import Iterable, List, Optional, TYPE_CHECKING, Tuple, Union +from tableauserverclient.helpers.headers import fix_filename + from .dqw_endpoint import _DataQualityWarningEndpoint from .endpoint import QuerysetEndpoint, api from .exceptions import InternalServerError, MissingRequiredFieldError @@ -124,6 +126,7 @@ def download(self, flow_id: str, filepath: Optional[PathOrFileW] = None) -> Path filepath.write(chunk) return_path = filepath else: + params = fix_filename(params) filename = to_filename(os.path.basename(params["filename"])) download_path = make_download_path(filepath, filename) with open(download_path, "wb") as f: diff --git a/tableauserverclient/server/endpoint/tasks_endpoint.py b/tableauserverclient/server/endpoint/tasks_endpoint.py index 092597388..a727a515f 100644 --- a/tableauserverclient/server/endpoint/tasks_endpoint.py +++ b/tableauserverclient/server/endpoint/tasks_endpoint.py @@ -1,19 +1,23 @@ import logging +from typing import List, Optional, Tuple, TYPE_CHECKING -from .endpoint import Endpoint, api -from .exceptions import MissingRequiredFieldError +from tableauserverclient.server.endpoint.endpoint import Endpoint, api +from tableauserverclient.server.endpoint.exceptions import MissingRequiredFieldError from tableauserverclient.models import TaskItem, PaginationItem from tableauserverclient.server import RequestFactory from tableauserverclient.helpers.logging import logger +if TYPE_CHECKING: + from tableauserverclient.server.request_options import RequestOptions + class Tasks(Endpoint): @property - def baseurl(self): + def baseurl(self) -> str: return "{0}/sites/{1}/tasks".format(self.parent_srv.baseurl, self.parent_srv.site_id) - def __normalize_task_type(self, task_type): + def __normalize_task_type(self, task_type: str) -> str: """ The word for extract refresh used in API URL is "extractRefreshes". It is different than the tag "extractRefresh" used in the request body. @@ -24,11 +28,13 @@ def __normalize_task_type(self, task_type): return task_type @api(version="2.6") - def get(self, req_options=None, task_type=TaskItem.Type.ExtractRefresh): + def get( + self, req_options: Optional["RequestOptions"] = None, task_type: str = TaskItem.Type.ExtractRefresh + ) -> Tuple[List[TaskItem], PaginationItem]: if task_type == TaskItem.Type.DataAcceleration: self.parent_srv.assert_at_least_version("3.8", "Data Acceleration Tasks") - logger.info("Querying all {} tasks for the site".format(task_type)) + logger.info("Querying all %s tasks for the site", task_type) url = "{0}/{1}".format(self.baseurl, self.__normalize_task_type(task_type)) server_response = self.get_request(url, req_options) @@ -38,11 +44,11 @@ def get(self, req_options=None, task_type=TaskItem.Type.ExtractRefresh): return all_tasks, pagination_item @api(version="2.6") - def get_by_id(self, task_id): + def get_by_id(self, task_id: str) -> TaskItem: if not task_id: error = "No Task ID provided" raise ValueError(error) - logger.info("Querying a single task by id ({})".format(task_id)) + logger.info("Querying a single task by id %s", task_id) url = "{}/{}/{}".format( self.baseurl, self.__normalize_task_type(TaskItem.Type.ExtractRefresh), @@ -56,14 +62,14 @@ def create(self, extract_item: TaskItem) -> TaskItem: if not extract_item: error = "No extract refresh provided" raise ValueError(error) - logger.info("Creating an extract refresh ({})".format(extract_item)) + logger.info("Creating an extract refresh %s", extract_item) url = "{0}/{1}".format(self.baseurl, self.__normalize_task_type(TaskItem.Type.ExtractRefresh)) create_req = RequestFactory.Task.create_extract_req(extract_item) server_response = self.post_request(url, create_req) return server_response.content @api(version="2.6") - def run(self, task_item): + def run(self, task_item: TaskItem) -> bytes: if not task_item.id: error = "Task item missing ID." raise MissingRequiredFieldError(error) @@ -79,7 +85,7 @@ def run(self, task_item): # Delete 1 task by id @api(version="3.6") - def delete(self, task_id, task_type=TaskItem.Type.ExtractRefresh): + def delete(self, task_id: str, task_type: str = TaskItem.Type.ExtractRefresh) -> None: if task_type == TaskItem.Type.DataAcceleration: self.parent_srv.assert_at_least_version("3.8", "Data Acceleration Tasks") @@ -88,4 +94,4 @@ def delete(self, task_id, task_type=TaskItem.Type.ExtractRefresh): raise ValueError(error) url = "{0}/{1}/{2}".format(self.baseurl, self.__normalize_task_type(task_type), task_id) self.delete_request(url) - logger.info("Deleted single task (ID: {0})".format(task_id)) + logger.info("Deleted single task (ID: %s)", task_id) diff --git a/tableauserverclient/server/endpoint/workbooks_endpoint.py b/tableauserverclient/server/endpoint/workbooks_endpoint.py index a73b0f0d5..506fe02c2 100644 --- a/tableauserverclient/server/endpoint/workbooks_endpoint.py +++ b/tableauserverclient/server/endpoint/workbooks_endpoint.py @@ -6,6 +6,8 @@ from contextlib import closing from pathlib import Path +from tableauserverclient.helpers.headers import fix_filename + from .endpoint import QuerysetEndpoint, api, parameter_added_in from .exceptions import InternalServerError, MissingRequiredFieldError from .permissions_endpoint import _PermissionsEndpoint @@ -88,8 +90,8 @@ def get_by_id(self, workbook_id: str) -> WorkbookItem: return WorkbookItem.from_response(server_response.content, self.parent_srv.namespace)[0] @api(version="2.8") - def refresh(self, workbook_id: str) -> JobItem: - id_ = getattr(workbook_id, "id", workbook_id) + def refresh(self, workbook_item: Union[WorkbookItem, str]) -> JobItem: + id_ = getattr(workbook_item, "id", workbook_item) url = "{0}/{1}/refresh".format(self.baseurl, id_) empty_req = RequestFactory.Empty.empty_req() server_response = self.post_request(url, empty_req) @@ -455,7 +457,7 @@ def _get_workbook_revisions( def download_revision( self, workbook_id: str, - revision_number: str, + revision_number: Optional[str], filepath: Optional[PathOrFileW] = None, include_extract: bool = True, no_extract: Optional[bool] = None, @@ -487,6 +489,7 @@ def download_revision( filepath.write(chunk) return_path = filepath else: + params = fix_filename(params) filename = to_filename(os.path.basename(params["filename"])) download_path = make_download_path(filepath, filename) with open(download_path, "wb") as f: diff --git a/tableauserverclient/server/request_factory.py b/tableauserverclient/server/request_factory.py index 7fb9bf9ed..6316527ec 100644 --- a/tableauserverclient/server/request_factory.py +++ b/tableauserverclient/server/request_factory.py @@ -1032,6 +1032,16 @@ def run_req(self, xml_request, task_item): def create_extract_req(self, xml_request: ET.Element, extract_item: "TaskItem") -> bytes: extract_element = ET.SubElement(xml_request, "extractRefresh") + # Main attributes + extract_element.attrib["type"] = extract_item.task_type + + if extract_item.target is not None: + target_element = ET.SubElement(extract_element, extract_item.target.type) + target_element.attrib["id"] = extract_item.target.id + + if extract_item.schedule_item is None: + return ET.tostring(xml_request) + # Schedule attributes schedule_element = ET.SubElement(xml_request, "schedule") @@ -1043,17 +1053,11 @@ def create_extract_req(self, xml_request: ET.Element, extract_item: "TaskItem") frequency_element.attrib["end"] = str(interval_item.end_time) if hasattr(interval_item, "interval") and interval_item.interval: intervals_element = ET.SubElement(frequency_element, "intervals") - for interval in interval_item._interval_type_pairs(): + for interval in interval_item._interval_type_pairs(): # type: ignore expression, value = interval single_interval_element = ET.SubElement(intervals_element, "interval") single_interval_element.attrib[expression] = value - # Main attributes - extract_element.attrib["type"] = extract_item.task_type - - target_element = ET.SubElement(extract_element, extract_item.target.type) - target_element.attrib["id"] = extract_item.target.id - return ET.tostring(xml_request) diff --git a/tableauserverclient/server/request_options.py b/tableauserverclient/server/request_options.py index 796f8add3..95233f8fc 100644 --- a/tableauserverclient/server/request_options.py +++ b/tableauserverclient/server/request_options.py @@ -37,35 +37,75 @@ class Operator: class Field: Args = "args" + AuthenticationType = "authenticationType" + Caption = "caption" + Channel = "channel" CompletedAt = "completedAt" + ConnectedWorkbookType = "connectedWorkbookType" + ConnectionTo = "connectionTo" + ConnectionType = "connectionType" ContentUrl = "contentUrl" CreatedAt = "createdAt" + DatabaseName = "databaseName" + DatabaseUserName = "databaseUserName" + Description = "description" + DisplayTabs = "displayTabs" DomainName = "domainName" DomainNickname = "domainNickname" + FavoritesTotal = "favoritesTotal" + Fields = "fields" + FlowId = "flowId" + FriendlyName = "friendlyName" + HasAlert = "hasAlert" + HasAlerts = "hasAlerts" + HasEmbeddedPassword = "hasEmbeddedPassword" + HasExtracts = "hasExtracts" HitsTotal = "hitsTotal" + Id = "id" + IsCertified = "isCertified" + IsConnectable = "isConnectable" + IsDefaultPort = "isDefaultPort" + IsHierarchical = "isHierarchical" IsLocal = "isLocal" + IsPublished = "isPublished" JobType = "jobType" LastLogin = "lastLogin" + Luid = "luid" MinimumSiteRole = "minimumSiteRole" Name = "name" Notes = "notes" + NotificationType = "notificationType" OwnerDomain = "ownerDomain" OwnerEmail = "ownerEmail" OwnerName = "ownerName" ParentProjectId = "parentProjectId" + Priority = "priority" Progress = "progress" + ProjectId = "projectId" ProjectName = "projectName" PublishSamples = "publishSamples" + ServerName = "serverName" + ServerPort = "serverPort" + SheetCount = "sheetCount" + SheetNumber = "sheetNumber" + SheetType = "sheetType" SiteRole = "siteRole" + Size = "size" StartedAt = "startedAt" Status = "status" + SubscriptionsTotal = "subscriptionsTotal" Subtitle = "subtitle" + TableName = "tableName" Tags = "tags" Title = "title" TopLevelProject = "topLevelProject" Type = "type" UpdatedAt = "updatedAt" UserCount = "userCount" + UserId = "userId" + ViewUrlName = "viewUrlName" + WorkbookDescription = "workbookDescription" + WorkbookName = "workbookName" class Direction: Desc = "desc" diff --git a/test/assets/schedule_get_daily_id.xml b/test/assets/schedule_get_daily_id.xml new file mode 100644 index 000000000..99467a391 --- /dev/null +++ b/test/assets/schedule_get_daily_id.xml @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/test/assets/schedule_get_hourly_id.xml b/test/assets/schedule_get_hourly_id.xml new file mode 100644 index 000000000..27c374ccf --- /dev/null +++ b/test/assets/schedule_get_hourly_id.xml @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/test/assets/schedule_get_monthly_id.xml b/test/assets/schedule_get_monthly_id.xml new file mode 100644 index 000000000..3fc32cc57 --- /dev/null +++ b/test/assets/schedule_get_monthly_id.xml @@ -0,0 +1,11 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/test/assets/tasks_without_schedule.xml b/test/assets/tasks_without_schedule.xml new file mode 100644 index 000000000..e669bf67f --- /dev/null +++ b/test/assets/tasks_without_schedule.xml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/test/test_datasource.py b/test/test_datasource.py index e299e5291..f258fdc52 100644 --- a/test/test_datasource.py +++ b/test/test_datasource.py @@ -696,3 +696,14 @@ def test_download_revision(self) -> None: ) file_path = self.server.datasources.download_revision("9dbd2263-16b5-46e1-9c43-a76bb8ab65fb", "3", td) self.assertTrue(os.path.exists(file_path)) + + def test_bad_download_response(self) -> None: + with requests_mock.mock() as m, tempfile.TemporaryDirectory() as td: + m.get( + self.baseurl + "/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb/content", + headers={ + "Content-Disposition": '''name="tableau_datasource"; filename*=UTF-8''"Sample datasource.tds"''' + }, + ) + file_path = self.server.datasources.download("9dbd2263-16b5-46e1-9c43-a76bb8ab65fb", td) + self.assertTrue(os.path.exists(file_path)) diff --git a/test/test_flow.py b/test/test_flow.py index d10641809..a90b18171 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -1,5 +1,6 @@ import os import requests_mock +import tempfile import unittest from io import BytesIO @@ -203,3 +204,12 @@ def test_refresh(self): self.assertEqual(refresh_job.flow_run.id, "e0c3067f-2333-4eee-8028-e0a56ca496f6") self.assertEqual(refresh_job.flow_run.flow_id, "92967d2d-c7e2-46d0-8847-4802df58f484") self.assertEqual(format_datetime(refresh_job.flow_run.started_at), "2018-05-22T13:00:29Z") + + def test_bad_download_response(self) -> None: + with requests_mock.mock() as m, tempfile.TemporaryDirectory() as td: + m.get( + self.baseurl + "/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb/content", + headers={"Content-Disposition": '''name="tableau_flow"; filename*=UTF-8''"Sample flow.tfl"'''}, + ) + file_path = self.server.flows.download("9dbd2263-16b5-46e1-9c43-a76bb8ab65fb", td) + self.assertTrue(os.path.exists(file_path)) diff --git a/test/test_schedule.py b/test/test_schedule.py index 807467918..76c8720b9 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -11,6 +11,9 @@ GET_XML = os.path.join(TEST_ASSET_DIR, "schedule_get.xml") GET_BY_ID_XML = os.path.join(TEST_ASSET_DIR, "schedule_get_by_id.xml") +GET_HOURLY_ID_XML = os.path.join(TEST_ASSET_DIR, "schedule_get_hourly_id.xml") +GET_DAILY_ID_XML = os.path.join(TEST_ASSET_DIR, "schedule_get_daily_id.xml") +GET_MONTHLY_ID_XML = os.path.join(TEST_ASSET_DIR, "schedule_get_monthly_id.xml") GET_EMPTY_XML = os.path.join(TEST_ASSET_DIR, "schedule_get_empty.xml") CREATE_HOURLY_XML = os.path.join(TEST_ASSET_DIR, "schedule_create_hourly.xml") CREATE_DAILY_XML = os.path.join(TEST_ASSET_DIR, "schedule_create_daily.xml") @@ -100,6 +103,51 @@ def test_get_by_id(self) -> None: self.assertEqual("Weekday early mornings", schedule.name) self.assertEqual("Active", schedule.state) + def test_get_hourly_by_id(self) -> None: + self.server.version = "3.8" + with open(GET_HOURLY_ID_XML, "rb") as f: + response_xml = f.read().decode("utf-8") + with requests_mock.mock() as m: + schedule_id = "c9cff7f9-309c-4361-99ff-d4ba8c9f5467" + baseurl = "{}/schedules/{}".format(self.server.baseurl, schedule_id) + m.get(baseurl, text=response_xml) + schedule = self.server.schedules.get_by_id(schedule_id) + self.assertIsNotNone(schedule) + self.assertEqual(schedule_id, schedule.id) + self.assertEqual("Hourly schedule", schedule.name) + self.assertEqual("Active", schedule.state) + self.assertEqual(("Monday", 0.5), schedule.interval_item.interval) + + def test_get_daily_by_id(self) -> None: + self.server.version = "3.8" + with open(GET_DAILY_ID_XML, "rb") as f: + response_xml = f.read().decode("utf-8") + with requests_mock.mock() as m: + schedule_id = "c9cff7f9-309c-4361-99ff-d4ba8c9f5467" + baseurl = "{}/schedules/{}".format(self.server.baseurl, schedule_id) + m.get(baseurl, text=response_xml) + schedule = self.server.schedules.get_by_id(schedule_id) + self.assertIsNotNone(schedule) + self.assertEqual(schedule_id, schedule.id) + self.assertEqual("Daily schedule", schedule.name) + self.assertEqual("Active", schedule.state) + self.assertEqual(("Monday", 2.0), schedule.interval_item.interval) + + def test_get_monthly_by_id(self) -> None: + self.server.version = "3.8" + with open(GET_MONTHLY_ID_XML, "rb") as f: + response_xml = f.read().decode("utf-8") + with requests_mock.mock() as m: + schedule_id = "c9cff7f9-309c-4361-99ff-d4ba8c9f5467" + baseurl = "{}/schedules/{}".format(self.server.baseurl, schedule_id) + m.get(baseurl, text=response_xml) + schedule = self.server.schedules.get_by_id(schedule_id) + self.assertIsNotNone(schedule) + self.assertEqual(schedule_id, schedule.id) + self.assertEqual("Monthly multiple days", schedule.name) + self.assertEqual("Active", schedule.state) + self.assertEqual(("1", "2"), schedule.interval_item.interval) + def test_delete(self) -> None: with requests_mock.mock() as m: m.delete(self.baseurl + "/c9cff7f9-309c-4361-99ff-d4ba8c9f5467", status_code=204) @@ -131,7 +179,7 @@ def test_create_hourly(self) -> None: self.assertEqual(TSC.ScheduleItem.ExecutionOrder.Parallel, new_schedule.execution_order) self.assertEqual(time(2, 30), new_schedule.interval_item.start_time) self.assertEqual(time(23), new_schedule.interval_item.end_time) # type: ignore[union-attr] - self.assertEqual("8", new_schedule.interval_item.interval) # type: ignore[union-attr] + self.assertEqual(("8",), new_schedule.interval_item.interval) # type: ignore[union-attr] def test_create_daily(self) -> None: with open(CREATE_DAILY_XML, "rb") as f: @@ -216,7 +264,7 @@ def test_create_monthly(self) -> None: self.assertEqual("2016-10-12T14:00:00Z", format_datetime(new_schedule.next_run_at)) self.assertEqual(TSC.ScheduleItem.ExecutionOrder.Serial, new_schedule.execution_order) self.assertEqual(time(7), new_schedule.interval_item.start_time) - self.assertEqual("12", new_schedule.interval_item.interval) # type: ignore[union-attr] + self.assertEqual(("12",), new_schedule.interval_item.interval) # type: ignore[union-attr] def test_update(self) -> None: with open(UPDATE_XML, "rb") as f: diff --git a/test/test_task.py b/test/test_task.py index 4eb2c02e2..4e0157dfd 100644 --- a/test/test_task.py +++ b/test/test_task.py @@ -1,6 +1,7 @@ import os import unittest from datetime import time +from pathlib import Path import requests_mock @@ -8,7 +9,7 @@ from tableauserverclient.datetime_helpers import parse_datetime from tableauserverclient.models.task_item import TaskItem -TEST_ASSET_DIR = os.path.join(os.path.dirname(__file__), "assets") +TEST_ASSET_DIR = Path(__file__).parent / "assets" GET_XML_NO_WORKBOOK = os.path.join(TEST_ASSET_DIR, "tasks_no_workbook_or_datasource.xml") GET_XML_WITH_WORKBOOK = os.path.join(TEST_ASSET_DIR, "tasks_with_workbook.xml") @@ -17,6 +18,7 @@ GET_XML_DATAACCELERATION_TASK = os.path.join(TEST_ASSET_DIR, "tasks_with_dataacceleration_task.xml") GET_XML_RUN_NOW_RESPONSE = os.path.join(TEST_ASSET_DIR, "tasks_run_now_response.xml") GET_XML_CREATE_TASK_RESPONSE = os.path.join(TEST_ASSET_DIR, "tasks_create_extract_task.xml") +GET_XML_WITHOUT_SCHEDULE = TEST_ASSET_DIR / "tasks_without_schedule.xml" class TaskTests(unittest.TestCase): @@ -86,6 +88,15 @@ def test_get_task_with_schedule(self): self.assertEqual("workbook", task.target.type) self.assertEqual("b60b4efd-a6f7-4599-beb3-cb677e7abac1", task.schedule_id) + def test_get_task_without_schedule(self): + with requests_mock.mock() as m: + m.get(self.baseurl, text=GET_XML_WITHOUT_SCHEDULE.read_text()) + all_tasks, pagination_item = self.server.tasks.get() + + task = all_tasks[0] + self.assertEqual("c7a9327e-1cda-4504-b026-ddb43b976d1d", task.target.id) + self.assertEqual("datasource", task.target.type) + def test_delete(self): with requests_mock.mock() as m: m.delete(self.baseurl + "/c7a9327e-1cda-4504-b026-ddb43b976d1d", status_code=204) diff --git a/test/test_workbook.py b/test/test_workbook.py index 5114ce1b8..212d55a37 100644 --- a/test/test_workbook.py +++ b/test/test_workbook.py @@ -932,3 +932,12 @@ def test_download_revision(self) -> None: ) file_path = self.server.workbooks.download_revision("9dbd2263-16b5-46e1-9c43-a76bb8ab65fb", "3", td) self.assertTrue(os.path.exists(file_path)) + + def test_bad_download_response(self) -> None: + with requests_mock.mock() as m, tempfile.TemporaryDirectory() as td: + m.get( + self.baseurl + "/9dbd2263-16b5-46e1-9c43-a76bb8ab65fb/content", + headers={"Content-Disposition": '''name="tableau_workbook"; filename*=UTF-8''"Sample workbook.twb"'''}, + ) + file_path = self.server.workbooks.download("9dbd2263-16b5-46e1-9c43-a76bb8ab65fb", td) + self.assertTrue(os.path.exists(file_path))