diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py index f568d5dc8ea95..a8543d38a9f30 100755 --- a/airflow/bin/cli.py +++ b/airflow/bin/cli.py @@ -59,7 +59,6 @@ from sqlalchemy import func from sqlalchemy.orm import exc - api.load_auth() api_module = import_module(conf.get('cli', 'api_client')) api_client = api_module.Client(api_base_url=conf.get('cli', 'endpoint_url'), @@ -316,7 +315,7 @@ def run(args, dag=None): # Load custom airflow config if args.cfg_path: with open(args.cfg_path, 'r') as conf_file: - conf_dict = json.load(conf_file) + conf_dict = json.load(conf_file) if os.path.exists(args.cfg_path): os.remove(args.cfg_path) @@ -327,6 +326,21 @@ def run(args, dag=None): settings.configure_vars() settings.configure_orm() + if not args.pickle and not dag: + dag = get_dag(args) + elif not dag: + session = settings.Session() + logging.info('Loading pickle id {args.pickle}'.format(args=args)) + dag_pickle = session.query( + DagPickle).filter(DagPickle.id == args.pickle).first() + if not dag_pickle: + raise AirflowException("Who hid the pickle!? [missing pickle]") + dag = dag_pickle.pickle + + task = dag.get_task(task_id=args.task_id) + ti = TaskInstance(task, args.execution_date) + ti.refresh_from_db() + logging.root.handlers = [] if args.raw: # Output to STDOUT for the parent process to read and log @@ -350,19 +364,23 @@ def run(args, dag=None): # writable by both users, then it's possible that re-running a task # via the UI (or vice versa) results in a permission error as the task # tries to write to a log file created by the other user. + try_number = ti.try_number log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER')) - directory = log_base + "/{args.dag_id}/{args.task_id}".format(args=args) + log_relative_dir = logging_utils.get_log_directory(args.dag_id, args.task_id, + args.execution_date) + directory = os.path.join(log_base, log_relative_dir) # Create the log file and give it group writable permissions # TODO(aoen): Make log dirs and logs globally readable for now since the SubDag # operator is not compatible with impersonation (e.g. if a Celery executor is used # for a SubDag operator and the SubDag operator has a different owner than the # parent DAG) - if not os.path.exists(directory): + if not os.path.isdir(directory): # Create the directory as globally writable using custom mkdirs # as os.makedirs doesn't set mode properly. mkdirs(directory, 0o775) - iso = args.execution_date.isoformat() - filename = "{directory}/{iso}".format(**locals()) + log_relative = logging_utils.get_log_filename( + args.dag_id, args.task_id, args.execution_date, try_number) + filename = os.path.join(log_base, log_relative) if not os.path.exists(filename): open(filename, "a").close() @@ -376,21 +394,6 @@ def run(args, dag=None): hostname = socket.getfqdn() logging.info("Running on host {}".format(hostname)) - if not args.pickle and not dag: - dag = get_dag(args) - elif not dag: - session = settings.Session() - logging.info('Loading pickle id {args.pickle}'.format(**locals())) - dag_pickle = session.query( - DagPickle).filter(DagPickle.id == args.pickle).first() - if not dag_pickle: - raise AirflowException("Who hid the pickle!? [missing pickle]") - dag = dag_pickle.pickle - task = dag.get_task(task_id=args.task_id) - - ti = TaskInstance(task, args.execution_date) - ti.refresh_from_db() - if args.local: print("Logging into: " + filename) run_job = jobs.LocalTaskJob( @@ -424,8 +427,8 @@ def run(args, dag=None): session.commit() pickle_id = pickle.id print(( - 'Pickled dag {dag} ' - 'as pickle_id:{pickle_id}').format(**locals())) + 'Pickled dag {dag} ' + 'as pickle_id:{pickle_id}').format(**locals())) except Exception as e: print('Could not pickle the DAG') print(e) @@ -475,7 +478,8 @@ def run(args, dag=None): with open(filename, 'r') as logfile: log = logfile.read() - remote_log_location = filename.replace(log_base, remote_base) + remote_log_location = os.path.join(remote_base, log_relative) + logging.debug("Uploading to remote log location {}".format(remote_log_location)) # S3 if remote_base.startswith('s3:/'): logging_utils.S3Log().write(log, remote_log_location) @@ -669,10 +673,10 @@ def start_refresh(gunicorn_master_proc): gunicorn_master_proc.send_signal(signal.SIGTTIN) excess += 1 wait_until_true(lambda: num_workers_expected + excess == - get_num_workers_running(gunicorn_master_proc)) + get_num_workers_running(gunicorn_master_proc)) wait_until_true(lambda: num_workers_expected == - get_num_workers_running(gunicorn_master_proc)) + get_num_workers_running(gunicorn_master_proc)) while True: num_workers_running = get_num_workers_running(gunicorn_master_proc) @@ -695,7 +699,7 @@ def start_refresh(gunicorn_master_proc): gunicorn_master_proc.send_signal(signal.SIGTTOU) excess -= 1 wait_until_true(lambda: num_workers_expected + excess == - get_num_workers_running(gunicorn_master_proc)) + get_num_workers_running(gunicorn_master_proc)) # Start a new worker by asking gunicorn to increase number of workers elif num_workers_running == num_workers_expected: @@ -887,6 +891,7 @@ def serve_logs(filename): # noqa filename, mimetype="application/json", as_attachment=False) + WORKER_LOG_SERVER_PORT = \ int(conf.get('celery', 'WORKER_LOG_SERVER_PORT')) flask_app.run( @@ -947,8 +952,8 @@ def initdb(args): # noqa def resetdb(args): print("DB: " + repr(settings.engine.url)) if args.yes or input( - "This will drop existing tables if they exist. " - "Proceed? (y/n)").upper() == "Y": + "This will drop existing tables if they exist. " + "Proceed? (y/n)").upper() == "Y": logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT) db_utils.resetdb() @@ -966,7 +971,7 @@ def upgradedb(args): # noqa if not ds_rows: qry = ( session.query(DagRun.dag_id, DagRun.state, func.count('*')) - .group_by(DagRun.dag_id, DagRun.state) + .group_by(DagRun.dag_id, DagRun.state) ) for dag_id, state, count in qry: session.add(DagStat(dag_id=dag_id, state=state, count=count)) @@ -1065,8 +1070,8 @@ def connections(args): session = settings.Session() if not (session - .query(Connection) - .filter(Connection.conn_id == new_conn.conn_id).first()): + .query(Connection) + .filter(Connection.conn_id == new_conn.conn_id).first()): session.add(new_conn) session.commit() msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n' @@ -1168,16 +1173,16 @@ class CLIFactory(object): 'dry_run': Arg( ("-dr", "--dry_run"), "Perform a dry run", "store_true"), 'pid': Arg( - ("--pid", ), "PID file location", + ("--pid",), "PID file location", nargs='?'), 'daemon': Arg( ("-D", "--daemon"), "Daemonize instead of running " "in the foreground", "store_true"), 'stderr': Arg( - ("--stderr", ), "Redirect stderr to this file"), + ("--stderr",), "Redirect stderr to this file"), 'stdout': Arg( - ("--stdout", ), "Redirect stdout to this file"), + ("--stdout",), "Redirect stdout to this file"), 'log_file': Arg( ("-l", "--log-file"), "Location of the log file"), @@ -1333,7 +1338,7 @@ class CLIFactory(object): "Serialized pickle object of the entire dag (used internally)"), 'job_id': Arg(("-j", "--job_id"), argparse.SUPPRESS), 'cfg_path': Arg( - ("--cfg_path", ), "Path to config file to use instead of airflow.cfg"), + ("--cfg_path",), "Path to config file to use instead of airflow.cfg"), # webserver 'port': Arg( ("-p", "--port"), @@ -1341,11 +1346,11 @@ class CLIFactory(object): type=int, help="The port on which to run the server"), 'ssl_cert': Arg( - ("--ssl_cert", ), + ("--ssl_cert",), default=conf.get('webserver', 'WEB_SERVER_SSL_CERT'), help="Path to the SSL certificate for the webserver"), 'ssl_key': Arg( - ("--ssl_key", ), + ("--ssl_key",), default=conf.get('webserver', 'WEB_SERVER_SSL_KEY'), help="Path to the key to use with the SSL certificate"), 'workers': Arg( diff --git a/airflow/models.py b/airflow/models.py index 32ad144a22bb5..c1fd4a3e86225 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -108,7 +108,7 @@ def get_fernet(): _CONTEXT_MANAGER_DAG = None -def clear_task_instances(tis, session, activate_dag_runs=True): +def clear_task_instances(tis, session, activate_dag_runs=True, dag=None): """ Clears a set of task instances, but makes sure the running ones get killed. @@ -119,12 +119,20 @@ def clear_task_instances(tis, session, activate_dag_runs=True): if ti.job_id: ti.state = State.SHUTDOWN job_ids.append(ti.job_id) - # todo: this creates an issue with the webui tests - # elif ti.state != State.REMOVED: - # ti.state = State.NONE - # session.merge(ti) else: - session.delete(ti) + task_id = ti.task_id + if dag and dag.has_task(task_id): + task = dag.get_task(task_id) + task_retries = task.retries + ti.max_tries = ti.try_number + task_retries + else: + # Ignore errors when updating max_tries if dag is None or + # task not found in dag since database records could be + # outdated. We make max_tries the maximum value of its + # original max_tries or the current task try number. + ti.max_tries = max(ti.max_tries, ti.try_number) + ti.state = State.NONE + session.merge(ti) if job_ids: from airflow.jobs import BaseJob as BJ @@ -1316,8 +1324,8 @@ def run( # not 0-indexed lists (i.e. Attempt 1 instead of # Attempt 0 for the first attempt). msg = "Starting attempt {attempt} of {total}".format( - attempt=self.try_number % (task.retries + 1) + 1, - total=task.retries + 1) + attempt=self.try_number + 1, + total=self.max_tries + 1) self.start_date = datetime.now() dep_context = DepContext( @@ -1338,8 +1346,8 @@ def run( self.state = State.NONE msg = ("FIXME: Rescheduling due to concurrency limits reached at task " "runtime. Attempt {attempt} of {total}. State set to NONE.").format( - attempt=self.try_number % (task.retries + 1) + 1, - total=task.retries + 1) + attempt=self.try_number + 1, + total=self.max_tries + 1) logging.warning(hr + msg + hr) self.queued_dttm = datetime.now() @@ -1486,7 +1494,11 @@ def handle_failure(self, error, test_mode=False, context=None): # Let's go deeper try: - if task.retries and self.try_number % (task.retries + 1) != 0: + # try_number is incremented by 1 during task instance run. So the + # current task instance try_number is the try_number for the next + # task instance run. We only mark task instance as FAILED if the + # next task instance try_number exceeds the max_tries. + if task.retries and self.try_number <= self.max_tries: self.state = State.UP_FOR_RETRY logging.info('Marking task as UP_FOR_RETRY') if task.email_on_retry and task.email: @@ -1641,15 +1653,17 @@ def email_alert(self, exception, is_retry=False): task = self.task title = "Airflow alert: {self}".format(**locals()) exception = str(exception).replace('\n', '
') - try_ = task.retries + 1 + # For reporting purposes, we report based on 1-indexed, + # not 0-indexed lists (i.e. Try 1 instead of + # Try 0 for the first attempt). body = ( - "Try {self.try_number} out of {try_}
" + "Try {try_number} out of {max_tries}
" "Exception:
{exception}
" "Log: Link
" "Host: {self.hostname}
" "Log file: {self.log_filepath}
" "Mark success: Link
" - ).format(**locals()) + ).format(try_number=self.try_number + 1, max_tries=self.max_tries + 1, **locals()) send_email(task.email, title, body) def set_duration(self): @@ -2382,9 +2396,7 @@ def downstream_list(self): def downstream_task_ids(self): return self._downstream_task_ids - def clear( - self, start_date=None, end_date=None, - upstream=False, downstream=False): + def clear(self, start_date=None, end_date=None, upstream=False, downstream=False): """ Clears the state of task instances associated with the task, following the parameters specified. @@ -2413,7 +2425,7 @@ def clear( count = qry.count() - clear_task_instances(qry.all(), session) + clear_task_instances(qry.all(), session, dag=self.dag) session.commit() session.close() @@ -3244,7 +3256,7 @@ def clear( do_it = utils.helpers.ask_yesno(question) if do_it: - clear_task_instances(tis.all(), session) + clear_task_instances(tis.all(), session, dag=self) if reset_dag_runs: self.set_dag_runs_state(session=session) else: diff --git a/airflow/utils/logging.py b/airflow/utils/logging.py index 96767cb6ea5b4..b86d839261053 100644 --- a/airflow/utils/logging.py +++ b/airflow/utils/logging.py @@ -19,7 +19,9 @@ from builtins import object +import dateutil.parser import logging +import six from airflow import configuration from airflow.exceptions import AirflowException @@ -57,6 +59,19 @@ def __init__(self): 'Please make sure that airflow[s3] is installed and ' 'the S3 connection exists.'.format(remote_conn_id)) + def log_exists(self, remote_log_location): + """ + Check if remote_log_location exists in remote storage + :param remote_log_location: log's location in remote storage + :return: True if location exists else False + """ + if self.hook: + try: + return self.hook.get_key(remote_log_location) is not None + except Exception: + pass + return False + def read(self, remote_log_location, return_error=False): """ Returns the log found at the remote_log_location. Returns '' if no @@ -137,6 +152,20 @@ def __init__(self): '"{}". Please make sure that airflow[gcp_api] is installed ' 'and the GCS connection exists.'.format(remote_conn_id)) + def log_exists(self, remote_log_location): + """ + Check if remote_log_location exists in remote storage + :param remote_log_location: log's location in remote storage + :return: True if location exists else False + """ + if self.hook: + try: + bkt, blob = self.parse_gcs_url(remote_log_location) + return self.hook.exists(bkt, blob) + except Exception: + pass + return False + def read(self, remote_log_location, return_error=False): """ Returns the log found at the remote_log_location. @@ -211,3 +240,40 @@ def parse_gcs_url(self, gsurl): bucket = parsed_url.netloc blob = parsed_url.path.strip('/') return (bucket, blob) + + +# TODO: get_log_filename and get_log_directory are temporary helper +# functions to get airflow log filename. Logic of using FileHandler +# will be extract out and those two functions will be moved. +# For more details, please check issue AIRFLOW-1385. +def get_log_filename(dag_id, task_id, execution_date, try_number): + """ + Return relative log path. + :arg dag_id: id of the dag + :arg task_id: id of the task + :arg execution_date: execution date of the task instance + :arg try_number: try_number of current task instance + """ + relative_dir = get_log_directory(dag_id, task_id, execution_date) + # For reporting purposes and keeping logs consistent with web UI + # display, we report based on 1-indexed, not 0-indexed lists + filename = "{}/{}.log".format(relative_dir, try_number+1) + + return filename + + +def get_log_directory(dag_id, task_id, execution_date): + """ + Return log directory path: dag_id/task_id/execution_date + :arg dag_id: id of the dag + :arg task_id: id of the task + :arg execution_date: execution date of the task instance + """ + # execution_date could be parsed in as unicode character + # instead of datetime object. + if isinstance(execution_date, six.string_types): + execution_date = dateutil.parser.parse(execution_date) + iso = execution_date.isoformat() + relative_dir = '{}/{}/{}'.format(dag_id, task_id, iso) + + return relative_dir diff --git a/airflow/www/templates/airflow/ti_log.html b/airflow/www/templates/airflow/ti_log.html new file mode 100644 index 0000000000000..03c0ed3707f9a --- /dev/null +++ b/airflow/www/templates/airflow/ti_log.html @@ -0,0 +1,40 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +#} +{% extends "airflow/task_instance.html" %} +{% block title %}Airflow - DAGs{% endblock %} + +{% block body %} + {{ super() }} +

{{ title }}

+ +
+ {% for log in logs %} +
+
{{ log }}
+
+ {% endfor %} +
+{% endblock %} diff --git a/airflow/www/views.py b/airflow/www/views.py index 6c3946207205a..046c2e1e21cfe 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -67,7 +67,7 @@ from airflow.models import BaseOperator from airflow.operators.subdag_operator import SubDagOperator -from airflow.utils.logging import LoggingMixin +from airflow.utils.logging import LoggingMixin, get_log_filename from airflow.utils.json import json_ser from airflow.utils.state import State from airflow.utils.db import provide_session @@ -694,110 +694,112 @@ def rendered(self): form=form, title=title,) + def _get_log(self, ti, log_filename): + """ + Get log for a specific try number. + :param ti: current task instance + :param log_filename: relative filename to fetch the log + """ + # TODO: This is not the best practice. Log handler and + # reader should be configurable and separated from the + # frontend. The new airflow logging design is in progress. + # Please refer to #2422(https://github.com/apache/incubator-airflow/pull/2422). + log = '' + # Load remote log + remote_log_base = conf.get('core', 'REMOTE_BASE_LOG_FOLDER') + remote_log_loaded = False + if remote_log_base: + remote_log_path = os.path.join(remote_log_base, log_filename) + remote_log = "" + + # S3 + if remote_log_path.startswith('s3:/'): + s3_log = log_utils.S3Log() + if s3_log.log_exists(remote_log_path): + remote_log += s3_log.read(remote_log_path, return_error=True) + remote_log_loaded = True + # GCS + elif remote_log_path.startswith('gs:/'): + gcs_log = log_utils.GCSLog() + if gcs_log.log_exists(remote_log_path): + remote_log += gcs_log.read(remote_log_path, return_error=True) + remote_log_loaded = True + # unsupported + else: + remote_log += '*** Unsupported remote log location.' + + if remote_log: + log += ('*** Reading remote log from {}.\n{}\n'.format( + remote_log_path, remote_log)) + + # We only want to display local log if the remote log is not loaded. + if not remote_log_loaded: + # Load local log + local_log_base = os.path.expanduser(conf.get('core', 'BASE_LOG_FOLDER')) + local_log_path = os.path.join(local_log_base, log_filename) + if os.path.exists(local_log_path): + try: + f = open(local_log_path) + log += "*** Reading local log.\n" + "".join(f.readlines()) + f.close() + except: + log = "*** Failed to load local log file: {0}.\n".format(local_log_path) + else: + WORKER_LOG_SERVER_PORT = conf.get('celery', 'WORKER_LOG_SERVER_PORT') + url = os.path.join( + "http://{ti.hostname}:{WORKER_LOG_SERVER_PORT}/log", log_filename + ).format(**locals()) + log += "*** Log file isn't local.\n" + log += "*** Fetching here: {url}\n".format(**locals()) + try: + import requests + timeout = None # No timeout + try: + timeout = conf.getint('webserver', 'log_fetch_timeout_sec') + except (AirflowConfigException, ValueError): + pass + + response = requests.get(url, timeout=timeout) + response.raise_for_status() + log += '\n' + response.text + except: + log += "*** Failed to fetch log file from work r.\n".format( + **locals()) + + if PY2 and not isinstance(log, unicode): + log = log.decode('utf-8') + + return log + @expose('/log') @login_required @wwwutils.action_logging def log(self): - BASE_LOG_FOLDER = os.path.expanduser( - conf.get('core', 'BASE_LOG_FOLDER')) dag_id = request.args.get('dag_id') task_id = request.args.get('task_id') execution_date = request.args.get('execution_date') - dag = dagbag.get_dag(dag_id) - log_relative = "{dag_id}/{task_id}/{execution_date}".format( - **locals()) - loc = os.path.join(BASE_LOG_FOLDER, log_relative) - loc = loc.format(**locals()) - log = "" - TI = models.TaskInstance dttm = dateutil.parser.parse(execution_date) form = DateTimeForm(data={'execution_date': dttm}) + dag = dagbag.get_dag(dag_id) + TI = models.TaskInstance session = Session() ti = session.query(TI).filter( - TI.dag_id == dag_id, TI.task_id == task_id, + TI.dag_id == dag_id, + TI.task_id == task_id, TI.execution_date == dttm).first() - + logs = [] if ti is None: - log = "*** Task instance did not exist in the DB\n" + logs = ["*** Task instance did not exist in the DB\n"] else: - # load remote logs - remote_log_base = conf.get('core', 'REMOTE_BASE_LOG_FOLDER') - remote_log_loaded = False - if remote_log_base: - remote_log_path = os.path.join(remote_log_base, log_relative) - remote_log = "" - - # Only display errors reading the log if the task completed or ran at least - # once before (otherwise there won't be any remote log stored). - ti_execution_completed = ti.state in {State.SUCCESS, State.FAILED} - ti_ran_more_than_once = ti.try_number > 1 - surface_log_retrieval_errors = ( - ti_execution_completed or ti_ran_more_than_once) - - # S3 - if remote_log_path.startswith('s3:/'): - remote_log += log_utils.S3Log().read( - remote_log_path, return_error=surface_log_retrieval_errors) - remote_log_loaded = True - # GCS - elif remote_log_path.startswith('gs:/'): - remote_log += log_utils.GCSLog().read( - remote_log_path, return_error=surface_log_retrieval_errors) - remote_log_loaded = True - # unsupported - else: - remote_log += '*** Unsupported remote log location.' - - if remote_log: - log += ('*** Reading remote log from {}.\n{}\n'.format( - remote_log_path, remote_log)) - - # We only want to display the - # local logs while the task is running if a remote log configuration is set up - # since the logs will be transfered there after the run completes. - # TODO(aoen): One problem here is that if a task is running on a worker it - # already ran on, then duplicate logs will be printed for all of the previous - # runs of the task that already completed since they will have been printed as - # part of the remote log section above. This can be fixed either by streaming - # logs to the log servers as tasks are running, or by creating a proper - # abstraction for multiple task instance runs). - if not remote_log_loaded or ti.state == State.RUNNING: - if os.path.exists(loc): - try: - f = open(loc) - log += "*** Reading local log.\n" + "".join(f.readlines()) - f.close() - except: - log = "*** Failed to load local log file: {0}.\n".format(loc) - else: - WORKER_LOG_SERVER_PORT = \ - conf.get('celery', 'WORKER_LOG_SERVER_PORT') - url = os.path.join( - "http://{ti.hostname}:{WORKER_LOG_SERVER_PORT}/log", log_relative - ).format(**locals()) - log += "*** Log file isn't local.\n" - log += "*** Fetching here: {url}\n".format(**locals()) - try: - import requests - timeout = None # No timeout - try: - timeout = conf.getint('webserver', 'log_fetch_timeout_sec') - except (AirflowConfigException, ValueError): - pass - - response = requests.get(url, timeout=timeout) - response.raise_for_status() - log += '\n' + response.text - except: - log += "*** Failed to fetch log file from worker.\n".format( - **locals()) - - if PY2 and not isinstance(log, unicode): - log = log.decode('utf-8') + logs = [''] * ti.try_number + for try_number in range(ti.try_number): + log_filename = get_log_filename( + dag_id, task_id, execution_date, try_number) + logs[try_number] += self._get_log(ti, log_filename) return self.render( - 'airflow/ti_code.html', - code=log, dag=dag, title="Log", task_id=task_id, + 'airflow/ti_log.html', + logs=logs, dag=dag, title="Log by attempts", task_id=task_id, execution_date=execution_date, form=form) @expose('/task') diff --git a/dags/test_dag.py b/dags/test_dag.py index f2a9f6a27c00c..8dcde1594a99a 100644 --- a/dags/test_dag.py +++ b/dags/test_dag.py @@ -19,7 +19,7 @@ now = datetime.now() now_to_the_hour = (now - timedelta(0, 0, 0, 0, 0, 3)).replace(minute=0, second=0, microsecond=0) -START_DATE = now_to_the_hour +START_DATE = now_to_the_hour DAG_NAME = 'test_dag_v1' default_args = { @@ -34,5 +34,3 @@ run_this_2.set_upstream(run_this_1) run_this_3 = DummyOperator(task_id='run_this_3', dag=dag) run_this_3.set_upstream(run_this_2) - - diff --git a/docs/scheduler.rst b/docs/scheduler.rst index 4c5c6beed7f6e..8029eb05f69f2 100644 --- a/docs/scheduler.rst +++ b/docs/scheduler.rst @@ -147,8 +147,19 @@ To Keep in Mind Here are some of the ways you can **unblock tasks**: -* From the UI, you can **clear** (as in delete the status of) individual task instances from the task instances dialog, while defining whether you want to includes the past/future and the upstream/downstream dependencies. Note that a confirmation window comes next and allows you to see the set you are about to clear. -* The CLI command ``airflow clear -h`` has lots of options when it comes to clearing task instance states, including specifying date ranges, targeting task_ids by specifying a regular expression, flags for including upstream and downstream relatives, and targeting task instances in specific states (``failed``, or ``success``) -* Marking task instances as successful can be done through the UI. This is mostly to fix false negatives, or for instance when the fix has been applied outside of Airflow. -* The ``airflow backfill`` CLI subcommand has a flag to ``--mark_success`` and allows selecting subsections of the DAG as well as specifying date ranges. +* From the UI, you can **clear** (as in delete the status of) individual task instances + from the task instances dialog, while defining whether you want to includes the past/future + and the upstream/downstream dependencies. Note that a confirmation window comes next and + allows you to see the set you are about to clear. You can also clear all task instances + associated with the dag. +* The CLI command ``airflow clear -h`` has lots of options when it comes to clearing task instance + states, including specifying date ranges, targeting task_ids by specifying a regular expression, + flags for including upstream and downstream relatives, and targeting task instances in specific + states (``failed``, or ``success``) +* Clearing a task instance will no longer delete the task instance record. Instead it updates + max_tries and set the current task instance state to be None. +* Marking task instances as successful can be done through the UI. This is mostly to fix false negatives, + or for instance when the fix has been applied outside of Airflow. +* The ``airflow backfill`` CLI subcommand has a flag to ``--mark_success`` and allows selecting + subsections of the DAG as well as specifying date ranges. diff --git a/tests/models.py b/tests/models.py index 400c659a1ea89..cf2734b74622d 100644 --- a/tests/models.py +++ b/tests/models.py @@ -29,6 +29,7 @@ from airflow.models import DAG, TaskInstance as TI from airflow.models import State as ST from airflow.models import DagModel, DagStat +from airflow.models import clear_task_instances from airflow.operators.dummy_operator import DummyOperator from airflow.operators.bash_operator import BashOperator from airflow.operators.python_operator import PythonOperator @@ -912,7 +913,7 @@ def run_with_error(ti): # Clear the TI state since you can't run a task with a FAILED state without # clearing it first - ti.set_state(None, settings.Session()) + dag.clear() # third run -- up for retry run_with_error(ti) @@ -1154,3 +1155,137 @@ def post_execute(self, context, result): with self.assertRaises(TestError): ti.run() + + +class ClearTasksTest(unittest.TestCase): + def test_clear_task_instances(self): + dag = DAG('test_clear_task_instances', start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + task0 = DummyOperator(task_id='0', owner='test', dag=dag) + task1 = DummyOperator(task_id='1', owner='test', dag=dag, retries=2) + ti0 = TI(task=task0, execution_date=DEFAULT_DATE) + ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + + ti0.run() + ti1.run() + session = settings.Session() + qry = session.query(TI).filter( + TI.dag_id == dag.dag_id).all() + clear_task_instances(qry, session, dag=dag) + session.commit() + ti0.refresh_from_db() + ti1.refresh_from_db() + self.assertEqual(ti0.try_number, 1) + self.assertEqual(ti0.max_tries, 1) + self.assertEqual(ti1.try_number, 1) + self.assertEqual(ti1.max_tries, 3) + + def test_clear_task_instances_without_task(self): + dag = DAG('test_clear_task_instances_without_task', start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + task0 = DummyOperator(task_id='task0', owner='test', dag=dag) + task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2) + ti0 = TI(task=task0, execution_date=DEFAULT_DATE) + ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + ti0.run() + ti1.run() + + # Remove the task from dag. + dag.task_dict = {} + self.assertFalse(dag.has_task(task0.task_id)) + self.assertFalse(dag.has_task(task1.task_id)) + + session = settings.Session() + qry = session.query(TI).filter( + TI.dag_id == dag.dag_id).all() + clear_task_instances(qry, session) + session.commit() + # When dag is None, max_tries will be maximum of original max_tries or try_number. + ti0.refresh_from_db() + ti1.refresh_from_db() + self.assertEqual(ti0.try_number, 1) + self.assertEqual(ti0.max_tries, 1) + self.assertEqual(ti1.try_number, 1) + self.assertEqual(ti1.max_tries, 2) + + def test_clear_task_instances_without_dag(self): + dag = DAG('test_clear_task_instances_without_dag', start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + task0 = DummyOperator(task_id='task_0', owner='test', dag=dag) + task1 = DummyOperator(task_id='task_1', owner='test', dag=dag, retries=2) + ti0 = TI(task=task0, execution_date=DEFAULT_DATE) + ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + ti0.run() + ti1.run() + + session = settings.Session() + qry = session.query(TI).filter( + TI.dag_id == dag.dag_id).all() + clear_task_instances(qry, session) + session.commit() + # When dag is None, max_tries will be maximum of original max_tries or try_number. + ti0.refresh_from_db() + ti1.refresh_from_db() + self.assertEqual(ti0.try_number, 1) + self.assertEqual(ti0.max_tries, 1) + self.assertEqual(ti1.try_number, 1) + self.assertEqual(ti1.max_tries, 2) + + def test_dag_clear(self): + dag = DAG('test_dag_clear', start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag) + ti0 = TI(task=task0, execution_date=DEFAULT_DATE) + self.assertEqual(ti0.try_number, 0) + ti0.run() + self.assertEqual(ti0.try_number, 1) + dag.clear() + ti0.refresh_from_db() + self.assertEqual(ti0.try_number, 1) + self.assertEqual(ti0.state, State.NONE) + self.assertEqual(ti0.max_tries, 1) + + task1 = DummyOperator(task_id='test_dag_clear_task_1', owner='test', + dag=dag, retries=2) + ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + self.assertEqual(ti1.max_tries, 2) + ti1.try_number = 1 + ti1.run() + self.assertEqual(ti1.try_number, 2) + self.assertEqual(ti1.max_tries, 2) + + dag.clear() + ti0.refresh_from_db() + ti1.refresh_from_db() + # after clear dag, ti2 should show attempt 3 of 5 + self.assertEqual(ti1.max_tries, 4) + self.assertEqual(ti1.try_number, 2) + # after clear dag, ti1 should show attempt 2 of 2 + self.assertEqual(ti0.try_number, 1) + self.assertEqual(ti0.max_tries, 1) + + def test_operator_clear(self): + dag = DAG('test_operator_clear', start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + t1 = DummyOperator(task_id='bash_op', owner='test', dag=dag) + t2 = DummyOperator(task_id='dummy_op', owner='test', dag=dag, retries=1) + + t2.set_upstream(t1) + + ti1 = TI(task=t1, execution_date=DEFAULT_DATE) + ti2 = TI(task=t2, execution_date=DEFAULT_DATE) + ti2.run() + # Dependency not met + self.assertEqual(ti2.try_number, 0) + self.assertEqual(ti2.max_tries, 1) + + t2.clear(upstream=True) + ti1.run() + ti2.run() + self.assertEqual(ti1.try_number, 1) + # max_tries is 0 because there is no task instance in db for ti1 + # so clear won't change the max_tries. + self.assertEqual(ti1.max_tries, 0) + self.assertEqual(ti2.try_number, 1) + # try_number (0) + retries(1) + self.assertEqual(ti2.max_tries, 1) diff --git a/tests/operators/python_operator.py b/tests/operators/python_operator.py index 71432affd8199..74120fed0e3db 100644 --- a/tests/operators/python_operator.py +++ b/tests/operators/python_operator.py @@ -117,8 +117,8 @@ def test_without_dag_run(self): if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'branch_1': - # should not exist - raise + # should exist with state None + self.assertEquals(ti.state, State.NONE) elif ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: @@ -147,34 +147,31 @@ def test_with_dag_run(self): class ShortCircuitOperatorTest(unittest.TestCase): - def setUp(self): - self.dag = DAG('shortcircuit_operator_test', - default_args={ - 'owner': 'airflow', - 'start_date': DEFAULT_DATE}, - schedule_interval=INTERVAL) - self.short_op = ShortCircuitOperator(task_id='make_choice', - dag=self.dag, - python_callable=lambda: self.value) - - self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) - self.branch_1.set_upstream(self.short_op) - self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) - self.branch_2.set_upstream(self.branch_1) - self.upstream = DummyOperator(task_id='upstream', dag=self.dag) - self.upstream.set_downstream(self.short_op) - self.dag.clear() - - self.value = True - def test_without_dag_run(self): """This checks the defensive against non existent tasks in a dag run""" - self.value = False - self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + value = False + dag = DAG('shortcircuit_operator_test_without_dag_run', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }, + schedule_interval=INTERVAL) + short_op = ShortCircuitOperator(task_id='make_choice', + dag=dag, + python_callable=lambda: value) + branch_1 = DummyOperator(task_id='branch_1', dag=dag) + branch_1.set_upstream(short_op) + branch_2 = DummyOperator(task_id='branch_2', dag=dag) + branch_2.set_upstream(branch_1) + upstream = DummyOperator(task_id='upstream', dag=dag) + upstream.set_downstream(short_op) + dag.clear() + + short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) session = Session() tis = session.query(TI).filter( - TI.dag_id == self.dag.dag_id, + TI.dag_id == dag.dag_id, TI.execution_date == DEFAULT_DATE ) @@ -189,10 +186,10 @@ def test_without_dag_run(self): else: raise - self.value = True - self.dag.clear() + value = True + dag.clear() - self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) @@ -207,17 +204,34 @@ def test_without_dag_run(self): session.close() def test_with_dag_run(self): - self.value = False - logging.error("Tasks {}".format(self.dag.tasks)) - dr = self.dag.create_dagrun( + value = False + dag = DAG('shortcircuit_operator_test_with_dag_run', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }, + schedule_interval=INTERVAL) + short_op = ShortCircuitOperator(task_id='make_choice', + dag=dag, + python_callable=lambda: value) + branch_1 = DummyOperator(task_id='branch_1', dag=dag) + branch_1.set_upstream(short_op) + branch_2 = DummyOperator(task_id='branch_2', dag=dag) + branch_2.set_upstream(branch_1) + upstream = DummyOperator(task_id='upstream', dag=dag) + upstream.set_downstream(short_op) + dag.clear() + + logging.error("Tasks {}".format(dag.tasks)) + dr = dag.create_dagrun( run_id="manual__", start_date=datetime.datetime.now(), execution_date=DEFAULT_DATE, state=State.RUNNING ) - self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() self.assertEqual(len(tis), 4) @@ -231,11 +245,11 @@ def test_with_dag_run(self): else: raise - self.value = True - self.dag.clear() + value = True + dag.clear() dr.verify_integrity() - self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) - self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() self.assertEqual(len(tis), 4) diff --git a/tests/utils/test_dates.py b/tests/utils/test_dates.py index 56fae32b85241..13230349db837 100644 --- a/tests/utils/test_dates.py +++ b/tests/utils/test_dates.py @@ -40,7 +40,3 @@ def test_days_ago(self): self.assertTrue( dates.days_ago(0, microsecond=3) == today_midnight + timedelta(microseconds=3)) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py new file mode 100644 index 0000000000000..474430f3537dc --- /dev/null +++ b/tests/utils/test_logging.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from airflow.exceptions import AirflowException +from airflow.utils import logging as logging_utils +from datetime import datetime, timedelta + +class Logging(unittest.TestCase): + + def test_get_log_filename(self): + dag_id = 'dag_id' + task_id = 'task_id' + execution_date = datetime(2017, 1, 1, 0, 0, 0) + try_number = 0 + filename = logging_utils.get_log_filename(dag_id, task_id, execution_date, try_number) + self.assertEqual(filename, 'dag_id/task_id/2017-01-01T00:00:00/1.log')