From 2c99eac2075dc992fc14b437f81bafdcf74da6f6 Mon Sep 17 00:00:00 2001 From: karpathy Date: Thu, 27 Nov 2014 16:21:08 -0800 Subject: [PATCH] first commit of code, phew --- .gitignore | 2 + Readme.md | 41 +++ cv/Readme.md | 1 + data/README.md | 1 + driver.py | 293 ++++++++++++++++ eval_sentence_predictions.py | 124 +++++++ imagernn/Readme.md | 9 + imagernn/__init__.py | 0 imagernn/data_provider.py | 116 +++++++ imagernn/generic_batch_generator.py | 155 +++++++++ imagernn/imagernn_utils.py | 40 +++ imagernn/lstm_generator.py | 298 ++++++++++++++++ imagernn/rnn_generator.py | 216 ++++++++++++ imagernn/solver.py | 150 ++++++++ imagernn/utils.py | 26 ++ matlab_features_reference/README.md | 5 + .../deploy_features.prototxt | 321 ++++++++++++++++++ matlab_features_reference/extract_features.m | 47 +++ .../prepare_images_batch.m | 28 ++ monitorcv.html | 165 +++++++++ status/Readme.md | 1 + vis_resources/d3.min.js | 5 + vis_resources/d3utils.css | 23 ++ vis_resources/d3utils.js | 48 +++ vis_resources/jquery-1.8.3.min.js | 2 + vis_resources/jsutils.js | 38 +++ vis_resources/underscore-min.js | 6 + vis_resources/underscore-min.map | 1 + visualize_result_struct.html | 183 ++++++++++ 29 files changed, 2345 insertions(+) create mode 100644 .gitignore create mode 100644 Readme.md create mode 100644 cv/Readme.md create mode 100644 data/README.md create mode 100644 driver.py create mode 100644 eval_sentence_predictions.py create mode 100644 imagernn/Readme.md create mode 100644 imagernn/__init__.py create mode 100644 imagernn/data_provider.py create mode 100644 imagernn/generic_batch_generator.py create mode 100644 imagernn/imagernn_utils.py create mode 100644 imagernn/lstm_generator.py create mode 100644 imagernn/rnn_generator.py create mode 100644 imagernn/solver.py create mode 100644 imagernn/utils.py create mode 100644 matlab_features_reference/README.md create mode 100644 matlab_features_reference/deploy_features.prototxt create mode 100644 matlab_features_reference/extract_features.m create mode 100644 matlab_features_reference/prepare_images_batch.m create mode 100644 monitorcv.html create mode 100644 status/Readme.md create mode 100644 vis_resources/d3.min.js create mode 100644 vis_resources/d3utils.css create mode 100644 vis_resources/d3utils.js create mode 100644 vis_resources/jquery-1.8.3.min.js create mode 100644 vis_resources/jsutils.js create mode 100644 vis_resources/underscore-min.js create mode 100644 vis_resources/underscore-min.map create mode 100644 visualize_result_struct.html diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fd9a7d1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.pyc +cv/*.p diff --git a/Readme.md b/Readme.md new file mode 100644 index 0000000..d4cbc47 --- /dev/null +++ b/Readme.md @@ -0,0 +1,41 @@ +#NeuralTalk + +This project contains *Python+numpy* source code for learning **Multimodal Recurrent Neural Networks** that describe images with sentences. + +This line of work was recently featured in a [New York Times article](http://www.nytimes.com/2014/11/18/science/researchers-announce-breakthrough-in-content-recognition-software.html) and has been the subject of multiple academic papers from the research community over the last few months. This code currently implements the models proposed by [Vinyals et al. from Google (CNN + LSTM)](http://arxiv.org/abs/1411.4555) and by [Karpathy and Fei-Fei from Stanford (CNN + RNN)](http://cs.stanford.edu/people/karpathy/deepimagesent/). Both models take an image and predict its sentence description with a Recurrent Neural Network (either an LSTM or an RNN). + +## Overview +The pipeline for the project looks as follows: + +- The **input** is a dataset of images and 5 sentence descriptions that were collected with Amazon Mechanical Turk. In particular, this code base is set up for [Flickr8K](http://nlp.cs.illinois.edu/HockenmaierGroup/Framing_Image_Description/KCCA.html), [Flickr30K](http://shannon.cs.illinois.edu/DenotationGraph/), and [MSCOCO](http://mscoco.org/) datasets. +- In the **training stage**, the images are fed as input to RNN and the RNN is asked to predict the words of the sentence, conditioned on the current word and previous context as mediated by the hidden layers of the neural network. In this stage, the parameters of the networks are trained with backpropagation. +- In the **prediction stage**, a witheld set of images is passed to RNN and the RNN generates the sentence one word at a time. The results are evaluated with **BLEU score** and with **ranking experiments** (coming soon). The code also includes utilities for visualizing the results in HTML. + +## Dependencies +**Python 2.7**, modern version of **numpy/scipy**, **nltk** (if you want to do BLEU score evaluation), **argparse** module. Most of these are okay to install with **pip**. + +I only tested this code with Ubuntu 12.04, but I tried to make it as generic as possible (e.g. use of **os** module for file system interactions etc. So it might work on Windows and Mac relatively easily.) + +*Protip*: you really want to link your numpy to use a BLAS implementation for its matrix operations. I use **virtualenv** and link numpy against a system installation of **OpenBLAS**. Doing this will make this code almost an order of time faster because it relies very heavily on large matrix multiplies. + +## Getting started + +1. **Get the code.** `$ git clone` the repo and install the Python dependencies +2. **Get the data.** I don't distribute the data in the Git repo, instead download the `data/` folder from [here](http://cs.stanford.edu/people/karpathy/deepimagesent/). Also, this download does not include the raw image files, so if you want to visualize the annotations on raw images, you have to obtain the images from Flickr8K / Flickr30K / COCO directly and dump them into the appropriate data folder. +3. **Train the model.** Run the training `$ python driver.py` (see many additional argument settings inside the file) and wait. You'll see that the learning code writes checkpoints into `cv/` and periodically reports its status in `status/` folder. +4. **Monitor the training.** The status can be inspected manually by reading the JSON and printing whatever you wish in a second process. In practice I run cross-validations on a cluster, so my `cv/` folder fills up with a lot of checkpoints that I further filter and inspect with other scripts. I am including my cluster training status visualization utility as well if you like. Run a local webserver (e.g. `$ python -m SimpleHTTPServer 8123`) and then open `monitorcv.html` in your browser on `http://localhost:8123/monitorcv.html`, or whatever the web server tells you the path is. You will have to edit the file to setup the paths properly and point it at the right json files. +5. **Evaluate model checkpoints.** To evaluate a checkpoint from `cv/`, run the `evaluate_sentence_predctions.py` script and pass it the path to a checkpoint. +6. **Visualize the predictions.** Use the included html file `visualize_result_struct.html` to visualize the JSON struct produced by the evaluation code. This will visualize the images and their predictions. Note that you'll have to download the raw images from the individual dataset pages and place them into the corresponding `data/` folder. + +Lastly, note that this is currently research code, so a lot of the documentation is inside individual Python files. If you wish to work with this code, you'll have to get familiar with it and be comfortable reading Python code. + +## Pretrained model + +Some pretrained models can be found [here](http://cs.stanford.edu/people/karpathy/deepimagesent/) (coming soon). The slightly hairy part is that if you wish to apply it to some arbitrary new image you have to first extract the VGG features with Caffe. I think there is opportunity for giving all of it as a single nice function that uses the Python wrapper to get the features and then runs the pretrained sentence model. I might add this in the future. + +## Using your own data + +The input to the system is the **data** folder, which contains the Flickr8K, Flickr30K and MSCOCO datasets. In particular, each folder (e.g. `data/flickr8k`) contains a `dataset.json` file that stores the image paths and sentences in the dataset (all images, sentences, raw preprocessed tokens, splits, and the mappings between images and sentences). Each folder additionally contains `vgg_feats.mat` , which is a `.mat` file that stores the CNN features from all images, one per row, using the VGG Net from ILSVRC 2014. Finally, there is the `imgs/` folder that holds the raw images. I also provide the Matlab script that I used to extract the features, which you may find helpful if you wish to use a different dataset. This is inside the `matlab_features_reference/` folder, and see the Readme file in that folder for more information. + +## License +BSD license. diff --git a/cv/Readme.md b/cv/Readme.md new file mode 100644 index 0000000..63162f0 --- /dev/null +++ b/cv/Readme.md @@ -0,0 +1 @@ +Checkpoints get written to this folder diff --git a/data/README.md b/data/README.md new file mode 100644 index 0000000..7cc82be --- /dev/null +++ b/data/README.md @@ -0,0 +1 @@ +This folder should contain the data folders, such as for example flickr8k, flickr30k, coco. diff --git a/driver.py b/driver.py new file mode 100644 index 0000000..3e45b77 --- /dev/null +++ b/driver.py @@ -0,0 +1,293 @@ +import argparse +import json +import time +import datetime +import numpy as np +import code +import socket +import os +import cPickle as pickle + +from imagernn.data_provider import getDataProvider +from imagernn.solver import Solver +from imagernn.imagernn_utils import decodeGenerator, eval_split + +def preProBuildWordVocab(sentence_iterator, word_count_threshold): + # count up all word counts so that we can threshold + # this shouldnt be too expensive of an operation + print 'preprocessing word counts and creating vocab based on word count threshold %d' % (word_count_threshold, ) + t0 = time.time() + word_counts = {} + nsents = 0 + for sent in sentence_iterator: + nsents += 1 + for w in sent['tokens']: + word_counts[w] = word_counts.get(w, 0) + 1 + vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold] + print 'filtered words from %d to %d in %.2fs' % (len(word_counts), len(vocab), time.time() - t0) + + # with K distinct words: + # - there are K+1 possible inputs (START token and all the words) + # - there are K+1 possible outputs (END token and all the words) + # we use ixtoword to take predicted indeces and map them to words for output visualization + # we use wordtoix to take raw words and get their index in word vector matrix + ixtoword = {} + ixtoword[0] = '.' # period at the end of the sentence. make first dimension be end token + wordtoix = {} + wordtoix['#START#'] = 0 # make first vector be the start token + ix = 1 + for w in vocab: + wordtoix[w] = ix + ixtoword[ix] = w + ix += 1 + + # compute bias vector, which is related to the log probability of the distribution + # of the labels (words) and how often they occur. We will use this vector to initialize + # the decoder weights, so that the loss function doesnt show a huge increase in performance + # very quickly (which is just the network learning this anyway, for the most part). This makes + # the visualizations of the cost function nicer because it doesn't look like a hockey stick. + # for example on Flickr8K, doing this brings down initial perplexity from ~2500 to ~170. + word_counts['.'] = nsents + bias_init_vector = np.array([1.0*word_counts[ixtoword[i]] for i in ixtoword]) + bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies + bias_init_vector = np.log(bias_init_vector) + bias_init_vector -= np.max(bias_init_vector) # shift to nice numeric range + return wordtoix, ixtoword, bias_init_vector + +def RNNGenCost(batch, model, params, misc): + """ cost function, returns cost and gradients for model """ + regc = params['regc'] # regularization cost + BatchGenerator = decodeGenerator(params) + wordtoix = misc['wordtoix'] + + # forward the RNN on each image sentence pair + # the generator returns a list of matrices that have word probabilities + # and a list of cache objects that will be needed for backprop + Ys, gen_caches = BatchGenerator.forward(batch, model, params, misc, predict_mode = False) + + # compute softmax costs for all generated sentences, and the gradients on top + loss_cost = 0.0 + dYs = [] + logppl = 0.0 + logppln = 0 + for i,pair in enumerate(batch): + img = pair['image'] + # ground truth indeces for this sentence we expect to see + gtix = [ wordtoix[w] for w in pair['sentence']['tokens'] if w in wordtoix ] + gtix.append(0) # don't forget END token must be predicted in the end! + # fetch the predicted probabilities, as rows + Y = Ys[i] + maxes = np.amax(Y, axis=1, keepdims=True) + e = np.exp(Y - maxes) # for numerical stability shift into good numerical range + P = e / np.sum(e, axis=1, keepdims=True) + loss_cost += - np.sum(np.log(1e-20 + P[range(len(gtix)),gtix])) # note: add smoothing to not get infs + logppl += - np.sum(np.log2(1e-20 + P[range(len(gtix)),gtix])) # also accumulate log2 perplexities + logppln += len(gtix) + + # lets be clever and optimize for speed here to derive the gradient in place quickly + for iy,y in enumerate(gtix): + P[iy,y] -= 1 # softmax derivatives are pretty simple + dYs.append(P) + + # backprop the RNN + grads = BatchGenerator.backward(dYs, gen_caches) + + # add L2 regularization cost and gradients + reg_cost = 0.0 + if regc > 0: + for p in misc['regularize']: + mat = model[p] + reg_cost += 0.5 * regc * np.sum(mat * mat) + grads[p] += regc * mat + + # normalize the cost and gradient by the batch size + batch_size = len(batch) + reg_cost /= batch_size + loss_cost /= batch_size + for k in grads: grads[k] /= batch_size + + # return output in json + out = {} + out['cost'] = {'reg_cost' : reg_cost, 'loss_cost' : loss_cost, 'total_cost' : loss_cost + reg_cost} + out['grad'] = grads + out['stats'] = { 'ppl2' : 2 ** (logppl / logppln)} + return out + +def main(params): + batch_size = params['batch_size'] + dataset = params['dataset'] + word_count_threshold = params['word_count_threshold'] + do_grad_check = params['do_grad_check'] + max_epochs = params['max_epochs'] + host = socket.gethostname() # get computer hostname + + # fetch the data provider + dp = getDataProvider(dataset) + + misc = {} # stores various misc items that need to be passed around the framework + + # go over all training sentences and find the vocabulary we want to use, i.e. the words that occur + # at least word_count_threshold number of times + misc['wordtoix'], misc['ixtoword'], bias_init_vector = preProBuildWordVocab(dp.iterSentences('train'), word_count_threshold) + + # delegate the initialization of the model to the Generator class + BatchGenerator = decodeGenerator(params) + init_struct = BatchGenerator.init(params, misc) + model, misc['update'], misc['regularize'] = (init_struct['model'], init_struct['update'], init_struct['regularize']) + + # force overwrite here. This is a bit of a hack, not happy about it + model['bd'] = bias_init_vector.reshape(1, bias_init_vector.size) + + print 'model init done.' + print 'model has keys: ' + ', '.join(model.keys()) + print 'updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['update']) + print 'updating: ' + ', '.join( '%s [%dx%d]' % (k, model[k].shape[0], model[k].shape[1]) for k in misc['regularize']) + print 'number of learnable parameters total: %d' % (sum(model[k].shape[0] * model[k].shape[1] for k in misc['update']), ) + + # initialize the Solver and the cost function + solver = Solver() + def costfun(batch, model): + # wrap the cost function to abstract some things away from the Solver + return RNNGenCost(batch, model, params, misc) + + # calculate how many iterations we need + num_sentences_total = dp.getSplitSize('train', ofwhat = 'sentences') + num_iters_one_epoch = num_sentences_total / batch_size + max_iters = max_epochs * num_iters_one_epoch + eval_period_in_epochs = params['eval_period'] + eval_period_in_iters = max(1, int(num_iters_one_epoch * eval_period_in_epochs)) + abort = False + top_val_ppl2 = -1 + smooth_train_ppl2 = len(misc['ixtoword']) # initially size of dictionary of confusion + val_ppl2 = len(misc['ixtoword']) + last_status_write_time = 0 # for writing worker job status reports + json_worker_status = {} + json_worker_status['params'] = params + json_worker_status['history'] = [] + for it in xrange(max_iters): + if abort: break + t0 = time.time() + # fetch a batch of data + batch = [dp.sampleImageSentencePair() for i in xrange(batch_size)] + # evaluate cost, gradient and perform parameter update + step_struct = solver.step(batch, model, costfun, **params) + cost = step_struct['cost'] + dt = time.time() - t0 + + # print training statistics + train_ppl2 = step_struct['stats']['ppl2'] + smooth_train_ppl2 = 0.99 * smooth_train_ppl2 + 0.01 * train_ppl2 # smooth exponentially decaying moving average + if it == 0: smooth_train_ppl2 = train_ppl2 # start out where we start out + epoch = it * 1.0 / num_iters_one_epoch + print '%d/%d batch done in %.3fs. at epoch %.2f. loss cost = %f, reg cost = %f, ppl2 = %.2f (smooth %.2f)' \ + % (it, max_iters, dt, epoch, cost['loss_cost'], cost['reg_cost'], \ + train_ppl2, smooth_train_ppl2) + + # perform gradient check if desired, with a bit of a burnin time (10 iterations) + if it == 10 and do_grad_check: + solver.gradCheck(batch, model, costfun) + print 'done gradcheck. continue?' + raw_input() + + # detect if loss is exploding and kill the job if so + total_cost = cost['total_cost'] + if it == 0: + total_cost0 = total_cost # store this initial cost + if total_cost > total_cost0 * 2: + print 'Aboring, cost seems to be exploding. Run gradcheck? Lower the learning rate?' + abort = True # set the abort flag, we'll break out + + # logging: write JSON files for visual inspection of the training + tnow = time.time() + if tnow > last_status_write_time + 60*1: # every now and then lets write a report + last_status_write_time = tnow + jstatus = {} + jstatus['time'] = datetime.datetime.now().isoformat() + jstatus['iter'] = (it, max_iters) + jstatus['epoch'] = (epoch, max_epochs) + jstatus['time_per_batch'] = dt + jstatus['smooth_train_ppl2'] = smooth_train_ppl2 + jstatus['val_ppl2'] = val_ppl2 # just write the last available one + jstatus['train_ppl2'] = train_ppl2 + json_worker_status['history'].append(jstatus) + status_file = os.path.join(params['worker_status_output_directory'], host + '_status.json') + try: + json.dump(json_worker_status, open(status_file, 'w')) + except Exception, e: # todo be more clever here + print 'tried to write worker status into %s but got error:' % (status_file, ) + print e + + # perform perplexity evaluation on the validation set and save a model checkpoint if it's good + is_last_iter = (it+1) == max_iters + if (((it+1) % eval_period_in_iters) == 0 and it < max_iters - 5) or is_last_iter: + val_ppl2 = eval_split('val', dp, model, params, misc) # perform the evaluation on VAL set + print 'validation perplexity = %f' % (val_ppl2, ) + write_checkpoint_ppl_threshold = params['write_checkpoint_ppl_threshold'] + if val_ppl2 < top_val_ppl2 or top_val_ppl2 < 0: + if val_ppl2 < write_checkpoint_ppl_threshold or write_checkpoint_ppl_threshold < 0: + # if we beat a previous record or if this is the first time + # AND we also beat the user-defined threshold or it doesnt exist + top_val_ppl2 = val_ppl2 + filename = 'model_checkpoint_%s_%s_%s_%.2f.p' % (dataset, host, params['fappend'], val_ppl2) + filepath = os.path.join(params['checkpoint_output_directory'], filename) + checkpoint = {} + checkpoint['it'] = it + checkpoint['epoch'] = epoch + checkpoint['model'] = model + checkpoint['params'] = params + checkpoint['perplexity'] = val_ppl2 + checkpoint['wordtoix'] = misc['wordtoix'] + checkpoint['ixtoword'] = misc['ixtoword'] + try: + pickle.dump(checkpoint, open(filepath, "wb")) + print 'saved checkpoint in %s' % (filepath, ) + except Exception, e: # todo be more clever here + print 'tried to write checkpoint into %s but got error: ' % (filepat, ) + print e + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + # global setup settings, and checkpoints + parser.add_argument('-d', '--dataset', dest='dataset', default='flickr8k', help='dataset: flickr8k/flickr30k') + parser.add_argument('-a', '--do_grad_check', dest='do_grad_check', type=int, default=0, help='perform gradcheck? program will block for visual inspection and will need manual user input') + parser.add_argument('--fappend', dest='fappend', type=str, default='baseline', help='append this string to checkpoint filenames') + parser.add_argument('-o', '--checkpoint_output_directory', dest='checkpoint_output_directory', type=str, default='cv/', help='output directory to write checkpoints to') + parser.add_argument('--worker_status_output_directory', dest='worker_status_output_directory', type=str, default='status/', help='directory to write worker status JSON blobs to') + parser.add_argument('--write_checkpoint_ppl_threshold', dest='write_checkpoint_ppl_threshold', type=float, default=-1, help='ppl threshold above which we dont bother writing a checkpoint to save space') + + # model parameters + parser.add_argument('--image_encoding_size', dest='image_encoding_size', type=int, default=256, help='size of the image encoding') + parser.add_argument('--word_encoding_size', dest='word_encoding_size', type=int, default=256, help='size of word encoding') + parser.add_argument('--hidden_size', dest='hidden_size', type=int, default=256, help='size of hidden layer in generator RNNs') + parser.add_argument('--generator', dest='generator', type=str, default='lstm', help='generator to use') + parser.add_argument('-c', '--regc', dest='regc', type=float, default=1e-8, help='regularization strength') + parser.add_argument('--tanhC_version', dest='tanhC_version', type=int, default=0, help='use tanh version of LSTM?') + + # optimization parameters + parser.add_argument('-m', '--max_epochs', dest='max_epochs', type=int, default=20, help='number of epochs to train for') + parser.add_argument('--solver', dest='solver', type=str, default='rmsprop', help='solver type: vanilla/adagrad/adadelta/rmsprop') + parser.add_argument('--momentum', dest='momentum', type=float, default=0.0, help='momentum for vanilla sgd') + parser.add_argument('--decay_rate', dest='decay_rate', type=float, default=0.999, help='decay rate for adadelta/rmsprop') + parser.add_argument('--smooth_eps', dest='smooth_eps', type=float, default=1e-8, help='epsilon smoothing for rmsprop/adagrad/adadelta') + parser.add_argument('-l', '--learning_rate', dest='learning_rate', type=float, default=1e-3, help='solver learning rate') + parser.add_argument('-b', '--batch_size', dest='batch_size', type=int, default=100, help='batch size') + parser.add_argument('--grad_clip', dest='grad_clip', type=float, default=5, help='clip gradients (normalized by batch size)? elementwise. if positive, at what threshold?') + parser.add_argument('--drop_prob_encoder', dest='drop_prob_encoder', type=float, default=0.5, help='what dropout to apply right after the encoder to an RNN/LSTM') + parser.add_argument('--drop_prob_decoder', dest='drop_prob_decoder', type=float, default=0.5, help='what dropout to apply right before the decoder in an RNN/LSTM') + + # data preprocessing parameters + parser.add_argument('--word_count_threshold', dest='word_count_threshold', type=int, default=5, help='if a word occurs less than this number of times in training data, it is discarded') + + # evaluation parameters + parser.add_argument('-p', '--eval_period', dest='eval_period', type=float, default=1.0, help='in units of epochs, how often do we evaluate on val set?') + parser.add_argument('--eval_batch_size', dest='eval_batch_size', type=int, default=100, help='for faster validation performance evaluation, what batch size to use on val img/sentences?') + parser.add_argument('--eval_max_images', dest='eval_max_images', type=int, default=-1, help='for efficiency we can use a smaller number of images to get validation error') + + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print 'parsed parameters:' + print json.dumps(params, indent = 2) + main(params) \ No newline at end of file diff --git a/eval_sentence_predictions.py b/eval_sentence_predictions.py new file mode 100644 index 0000000..9c84b3e --- /dev/null +++ b/eval_sentence_predictions.py @@ -0,0 +1,124 @@ +import argparse +import json +import time +import datetime +import numpy as np +import code +import socket +import os +import cPickle as pickle +import math + +from imagernn.data_provider import getDataProvider +from imagernn.solver import Solver +from imagernn.imagernn_utils import decodeGenerator, eval_split + +from nltk.align.bleu import BLEU + +# UTILS needed for BLEU score evaluation +def BLEUscore(candidate, references, weights): + p_ns = [BLEU.modified_precision(candidate, references, i) for i, _ in enumerate(weights, start=1)] + if all([x > 0 for x in p_ns]): + s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns)) + bp = BLEU.brevity_penalty(candidate, references) + return bp * math.exp(s) + else: # this is bad + return 0 + +def evalCandidate(candidate, references): + """ + candidate is a single list of words, references is a list of lists of words + written by humans. + """ + b1 = BLEUscore(candidate, references, [1.0]) + b2 = BLEUscore(candidate, references, [0.5, 0.5]) + b3 = BLEUscore(candidate, references, [1/3.0, 1/3.0, 1/3.0]) + return [b1,b2,b3] + +def main(params): + + # load the checkpoint + checkpoint_path = params['checkpoint_path'] + max_images = params['max_images'] + + print 'loading checkpoint %s' % (checkpoint_path, ) + checkpoint = pickle.load(open(checkpoint_path, 'rb')) + checkpoint_params = checkpoint['params'] + dataset = checkpoint_params['dataset'] + model = checkpoint['model'] + + # fetch the data provider + dp = getDataProvider(dataset) + + misc = {} + misc['wordtoix'] = checkpoint['wordtoix'] + ixtoword = checkpoint['ixtoword'] + + blob = {} # output blob which we will dump to JSON for visualizing the results + blob['params'] = params + blob['checkpoint_params'] = checkpoint_params + blob['imgblobs'] = [] + + # iterate over all images in test set and predict sentences + BatchGenerator = decodeGenerator(checkpoint_params) + all_bleu_scores = [] + n = 0 + #for img in dp.iterImages(split = 'test', shuffle = True, max_images = max_images): + for img in dp.iterImages(split = 'test', max_images = max_images): + n+=1 + print 'image %d/%d:' % (n, max_images) + references = [x['tokens'] for x in img['sentences']] # as list of lists of tokens + kwparams = { 'tanhC_version' : checkpoint_params.get('tanhC_version', 0) ,\ + 'beam_size' : params['beam_size'],\ + 'generator' : checkpoint_params['generator']} + Ys = BatchGenerator.predict([{'image':img}], model, **kwparams) + + img_blob = {} # we will build this up + img_blob['img_path'] = img['local_file_path'] + img_blob['imgid'] = img['imgid'] + + # encode the human-provided references + img_blob['references'] = [] + for gtwords in references: + print 'GT: ' + ' '.join(gtwords) + img_blob['references'].append({'text': ' '.join(gtwords)}) + + # now evaluate and encode the top prediction + top_predictions = Ys[0] # take predictions for the first (and only) image we passed in + top_prediction = top_predictions[0] # these are sorted with highest on top + candidate = [ixtoword[ix] for ix in top_prediction[1]] + print 'PRED: (%f) %s' % (top_prediction[0], ' '.join(candidate)) + bleu_scores = evalCandidate(candidate, references) + print 'BLEU: B-1: %f B-2: %f B-3: %f' % tuple(bleu_scores) + img_blob['candidate'] = {'text': ' '.join(candidate), 'logprob': top_prediction[0], 'bleu': bleu_scores} + + all_bleu_scores.append(bleu_scores) + blob['imgblobs'].append(img_blob) + + print 'final average bleu scores:' + bleu_averages = [sum(x[i] for x in all_bleu_scores)*1.0/len(all_bleu_scores) for i in xrange(3)] + blob['final_result'] = { 'bleu' : bleu_averages } + print 'FINAL BLEU: B-1: %f B-2: %f B-3: %f' % tuple(bleu_averages) + + # now also evaluate test split perplexity + gtppl = eval_split('test', dp, model, checkpoint_params, misc, eval_max_images = max_images) + print 'perplexity of ground truth words: %f' % (gtppl, ) + blob['gtppl'] = gtppl + + # dump result struct to file + print 'saving result struct to %s' % (params['result_struct_filename'], ) + json.dump(blob, open(params['result_struct_filename'], 'w')) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('checkpoint_path', type=str, help='the input checkpoint') + parser.add_argument('-b', '--beam_size', type=int, default=1, help='beam size in inference. 1 indicates greedy per-word max procedure. Good value is approx 20 or so, and more = better.') + parser.add_argument('--result_struct_filename', type=str, default='result_struct.json', help='filename of the result struct to save') + parser.add_argument('-m', '--max_images', type=int, default=-1, help='max images to use') + + args = parser.parse_args() + params = vars(args) # convert to ordinary dict + print 'parsed parameters:' + print json.dumps(params, indent = 2) + main(params) diff --git a/imagernn/Readme.md b/imagernn/Readme.md new file mode 100644 index 0000000..d2556ce --- /dev/null +++ b/imagernn/Readme.md @@ -0,0 +1,9 @@ +The code is organized as follows: + +- `data_provider.py` abstracts away the datasets and provides uniform API for the code. +- `utils.py` is what it sounds like it is :) +- `solver.py`: the solver class doesn't know anything about images or sentences, it gets a model and the gradients and performs a step update +- `generic_batch_generator.py` handles batching across a batch of image/sentences that need to be forwarded through the networks. It calls the +- `lstm_generator.py`, which is an implementation of the Google LSTM for generating images. +- `imagernn_utils.py` contains some image-rnn specific utilities, such as evaluation function etc. These come in handy when we want to use some functionality across different scripts (e.g. driver and evaluator) +- `rnn_generator.py` has a simple RNN implementation for now, an alternative to LSTM diff --git a/imagernn/__init__.py b/imagernn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imagernn/data_provider.py b/imagernn/data_provider.py new file mode 100644 index 0000000..1f6d8d9 --- /dev/null +++ b/imagernn/data_provider.py @@ -0,0 +1,116 @@ +import json +import os +import random +import scipy.io +import codecs +from collections import defaultdict + +class BasicDataProvider: + def __init__(self, dataset): + print 'Initializing data provider for dataset %s...' % (dataset, ) + + # !assumptions on folder structure + self.dataset_root = os.path.join('data', dataset) + self.image_root = os.path.join('data', dataset, 'imgs') + + # load the dataset into memory + dataset_path = os.path.join(self.dataset_root, 'dataset.json') + print 'BasicDataProvider: reading %s' % (dataset_path, ) + self.dataset = json.load(open(dataset_path, 'r')) + + # load the image features into memory + features_path = os.path.join(self.dataset_root, 'vgg_feats.mat') + print 'BasicDataProvider: reading %s' % (features_path, ) + features_struct = scipy.io.loadmat(features_path) + self.features = features_struct['feats'] + + # group images by their train/val/test split into a dictionary -> list structure + self.split = defaultdict(list) + for img in self.dataset['images']: + self.split[img['split']].append(img) + + # "PRIVATE" FUNCTIONS + # in future we may want to create copies here so that we don't touch the + # data provider class data, but for now lets do the simple thing and + # just return raw internal img sent structs. This also has the advantage + # that the driver could store various useful caching stuff in these structs + # and they will be returned in the future with the cache present + def _getImage(self, img): + """ create an image structure for the driver """ + + # lazily fill in some attributes + if not 'local_file_path' in img: img['local_file_path'] = os.path.join(self.image_root, img['filename']) + if not 'feat' in img: # also fill in the features + feature_index = img['imgid'] # NOTE: imgid is an integer, and it indexes into features + img['feat'] = self.features[:,feature_index] + return img + + def _getSentence(self, sent): + """ create a sentence structure for the driver """ + # NOOP for now + return sent + + # PUBLIC FUNCTIONS + + def getSplitSize(self, split, ofwhat = 'sentences'): + """ return size of a split, either number of sentences or number of images """ + if ofwhat == 'sentences': + return sum(len(img['sentences']) for img in self.split[split]) + else: # assume images + return len(self.split[split]) + + def sampleImageSentencePair(self, split = 'train'): + """ sample image sentence pair from a split """ + images = self.split[split] + + img = random.choice(images) + sent = random.choice(img['sentences']) + + out = {} + out['image'] = self._getImage(img) + out['sentence'] = self._getSentence(sent) + return out + + def iterImageSentencePair(self, split = 'train', max_images = -1): + for i,img in enumerate(self.split[split]): + if max_images >= 0 and i >= max_images: break + for sent in img['sentences']: + out = {} + out['image'] = self._getImage(img) + out['sentence'] = self._getSentence(sent) + yield out + + def iterImageSentencePairBatch(self, split = 'train', max_images = -1, max_batch_size = 100): + batch = [] + for i,img in enumerate(self.split[split]): + if max_images >= 0 and i >= max_images: break + for sent in img['sentences']: + out = {} + out['image'] = self._getImage(img) + out['sentence'] = self._getSentence(sent) + batch.append(out) + if len(batch) >= max_batch_size: + yield batch + batch = [] + if batch: + yield batch + + def iterSentences(self, split = 'train'): + for img in self.split[split]: + for sent in img['sentences']: + yield self._getSentence(sent) + + def iterImages(self, split = 'train', shuffle = False, max_images = -1): + imglist = self.split[split] + ix = range(len(imglist)) + if shuffle: + random.shuffle(ix) + if max_images > 0: + ix = ix[:min(len(ix),max_images)] # crop the list + for i in ix: + yield self._getImage(imglist[i]) + +def getDataProvider(dataset): + """ we could intercept a special dataset and return different data providers """ + assert dataset in ['flickr8k', 'flickr30k', 'coco'], 'dataset %s unknown' % (dataset, ) + return BasicDataProvider(dataset) diff --git a/imagernn/generic_batch_generator.py b/imagernn/generic_batch_generator.py new file mode 100644 index 0000000..58b9951 --- /dev/null +++ b/imagernn/generic_batch_generator.py @@ -0,0 +1,155 @@ +import numpy as np +import code +from imagernn.utils import merge_init_structs, initw, accumNpDicts +from imagernn.lstm_generator import LSTMGenerator +from imagernn.rnn_generator import RNNGenerator + +def decodeGenerator(generator): + if generator == 'lstm': + return LSTMGenerator + if generator == 'rnn': + return RNNGenerator + else: + raise Exception('generator %s is not yet supported' % (base_generator_str,)) + +class GenericBatchGenerator: + """ + Base batch generator class. + This class is aware of the fact that we are generating + sentences from images. + """ + + @staticmethod + def init(params, misc): + + # inputs + image_encoding_size = params.get('image_encoding_size', 128) + word_encoding_size = params.get('word_encoding_size', 128) + hidden_size = params.get('hidden_size', 128) + generator = params.get('generator', 'lstm') + vocabulary_size = len(misc['wordtoix']) + output_size = len(misc['ixtoword']) # these should match though + image_size = 4096 # size of CNN vectors hardcoded here + + assert image_encoding_size == word_encoding_size, 'for now these must match. later with other models these could be different.' + + # initialize the encoder models + model = {} + model['We'] = initw(image_size, image_encoding_size) # image encoder + model['be'] = np.zeros((1,image_encoding_size)) + model['Ws'] = initw(vocabulary_size, word_encoding_size) # word encoder + update = ['We', 'be', 'Ws'] + regularize = ['We', 'Ws'] + init_struct = { 'model' : model, 'update' : update, 'regularize' : regularize} + + # descend into the specific Generator and initialize it + Generator = decodeGenerator(generator) + generator_init_struct = Generator.init(word_encoding_size, hidden_size, output_size) + merge_init_structs(init_struct, generator_init_struct) + return init_struct + + @staticmethod + def forward(batch, model, params, misc, predict_mode = False): + """ iterates over items in the batch and calls generators on them """ + # we do the encoding here across all images/words in batch in single matrix + # multiplies to gain efficiency. The RNNs are then called individually + # in for loop on per-image-sentence pair and all they are concerned about is + # taking single matrix of vectors and doing the forward/backward pass without + # knowing anything about images, sentences or anything of that sort. + + # encode all images + # concatenate as rows. If N is number of image-sentence pairs, + # F will be N x image_size + F = np.row_stack(x['image']['feat'] for x in batch) + We = model['We'] + be = model['be'] + Xe = F.dot(We) + be # Xe becomes N x image_encoding_size + + # decode the generator we wish to use + generator_str = params.get('generator', 'lstm') + Generator = decodeGenerator(generator_str) + + # encode all words in all sentences (which exist in our vocab) + wordtoix = misc['wordtoix'] + Ws = model['Ws'] + gen_caches = [] + Ys = [] # outputs + for i,x in enumerate(batch): + # take all words in this sentence and pluck out their word vectors + # from Ws. Then arrange them in a single matrix Xs + # Note that we are setting the start token as first vector + # and then all the words afterwards. And start token is the first row of Ws + ix = [0] + [ wordtoix[w] for w in x['sentence']['tokens'] if w in wordtoix ] + Xs = np.row_stack( [Ws[j, :] for j in ix] ) + Xi = Xe[i,:] + + # forward prop through the RNN + kwparams = { 'drop_prob_encoder' : params.get('drop_prob_encoder',0.0), \ + 'drop_prob_decoder' : params.get('drop_prob_decoder',0.0), \ + 'tanhC_version' : params.get('tanhC_version', 0) } + gen_Y, gen_cache = Generator.forward(Xi, Xs, model, predict_mode = predict_mode, **kwparams) + gen_caches.append((ix, gen_cache)) + Ys.append(gen_Y) + + # back up information we need for efficient backprop + cache = {} + if not predict_mode: + # ok we need cache as well because we'll do backward pass + cache['gen_caches'] = gen_caches + cache['Xe'] = Xe + cache['Ws_shape'] = Ws.shape + cache['F'] = F + cache['generator_str'] = generator_str + + return Ys, cache + + @staticmethod + def backward(dY, cache): + Xe = cache['Xe'] + generator_str = cache['generator_str'] + dWs = np.zeros(cache['Ws_shape']) + gen_caches = cache['gen_caches'] + F = cache['F'] + dXe = np.zeros(Xe.shape) + + Generator = decodeGenerator(generator_str) + + # backprop each item in the batch + grads = {} + for i in xrange(len(gen_caches)): + ix, gen_cache = gen_caches[i] # unpack + local_grads = Generator.backward(dY[i], gen_cache) + dXs = local_grads['dXs'] # intercept the gradients wrt Xi and Xs + del local_grads['dXs'] + dXi = local_grads['dXi'] + del local_grads['dXi'] + accumNpDicts(grads, local_grads) # add up the gradients wrt model parameters + + # now backprop from dXs to the image vector and word vectors + dXe[i,:] += dXi # image vector + for n,j in enumerate(ix): # and now all the other words + dWs[j,:] += dXs[n,:] + + # finally backprop into the image encoder + dWe = F.transpose().dot(dXe) + dbe = np.sum(dXe, axis=0, keepdims = True) + + accumNpDicts(grads, { 'We':dWe, 'be':dbe, 'Ws':dWs }) + return grads + + @staticmethod + def predict(batch, model, **kwparams): + """ some code duplication here with forward pass, but I think we want the freedom in future """ + F = np.row_stack(x['image']['feat'] for x in batch) + We = model['We'] + be = model['be'] + Xe = F.dot(We) + be # Xe becomes N x image_encoding_size + generator_str = kwparams['generator'] + Generator = decodeGenerator(generator_str) + Ys = [] + for i,x in enumerate(batch): + gen_Y = Generator.predict(Xe[i, :], model, model['Ws'], **kwparams) + Ys.append(gen_Y) + return Ys + + diff --git a/imagernn/imagernn_utils.py b/imagernn/imagernn_utils.py new file mode 100644 index 0000000..7331a76 --- /dev/null +++ b/imagernn/imagernn_utils.py @@ -0,0 +1,40 @@ +from imagernn.generic_batch_generator import GenericBatchGenerator +import numpy as np + +def decodeGenerator(params): + """ + in the future we may want to have different classes + and options for them. For now there is this one generator + implemented and simply returned here. + """ + return GenericBatchGenerator + +def eval_split(split, dp, model, params, misc, **kwargs): + """ evaluate performance on a given split """ + # allow kwargs to override what is inside params + eval_batch_size = kwargs.get('eval_batch_size', params.get('eval_batch_size',100)) + eval_max_images = kwargs.get('eval_max_images', params.get('eval_max_images', -1)) + BatchGenerator = decodeGenerator(params) + wordtoix = misc['wordtoix'] + + print 'evaluating %s performance in batches of %d' % (split, eval_batch_size) + logppl = 0 + logppln = 0 + nsent = 0 + for batch in dp.iterImageSentencePairBatch(split = split, max_batch_size = eval_batch_size, max_images = eval_max_images): + Ys, gen_caches = BatchGenerator.forward(batch, model, params, misc, predict_mode = True) + + for i,pair in enumerate(batch): + gtix = [ wordtoix[w] for w in pair['sentence']['tokens'] if w in wordtoix ] + gtix.append(0) # we expect END token at the end + Y = Ys[i] + maxes = np.amax(Y, axis=1, keepdims=True) + e = np.exp(Y - maxes) # for numerical stability shift into good numerical range + P = e / np.sum(e, axis=1, keepdims=True) + logppl += - np.sum(np.log2(1e-20 + P[range(len(gtix)),gtix])) # also accumulate log2 perplexities + logppln += len(gtix) + nsent += 1 + + ppl2 = 2 ** (logppl / logppln) + print 'evaluated %d sentences and got perplexity = %f' % (nsent, ppl2) + return ppl2 # return the perplexity diff --git a/imagernn/lstm_generator.py b/imagernn/lstm_generator.py new file mode 100644 index 0000000..8ab7452 --- /dev/null +++ b/imagernn/lstm_generator.py @@ -0,0 +1,298 @@ +import numpy as np +import code + +from imagernn.utils import initw + +class LSTMGenerator: + """ + A multimodal long short-term memory (LSTM) generator + """ + + @staticmethod + def init(input_size, hidden_size, output_size): + + model = {} + # Recurrent weights: take x_t, h_{t-1}, and bias unit + # and produce the 3 gates and the input to cell signal + model['WLSTM'] = initw(input_size + hidden_size + 1, 4 * hidden_size) + # Decoder weights (e.g. mapping to vocabulary) + model['Wd'] = initw(hidden_size, output_size) # decoder + model['bd'] = np.zeros((1, output_size)) + + update = ['WLSTM', 'Wd', 'bd'] + regularize = ['WLSTM', 'Wd'] + return { 'model' : model, 'update' : update, 'regularize' : regularize } + + @staticmethod + def forward(Xi, Xs, model, **kwargs): + """ + Xi is 1-d array of size D (containing the image representation) + Xs is N x D (N time steps, rows are data containng word representations), and + it is assumed that the first row is already filled in as the start token. So a + sentence with 10 words will be of size 11xD in Xs. + """ + predict_mode = kwargs.get('predict_mode', False) + + # Google paper concatenates the image to the word vectors as the first word vector + X = np.row_stack([Xi, Xs]) + + # options + # use the version of LSTM with tanh? Otherwise dont use tanh (Google style) + # following http://arxiv.org/abs/1409.3215 + tanhC_version = kwargs.get('tanhC_version', 0) + drop_prob_encoder = kwargs.get('drop_prob_encoder', 0.0) + drop_prob_decoder = kwargs.get('drop_prob_decoder', 0.0) + + if drop_prob_encoder > 0: # if we want dropout on the encoder + # inverted version of dropout here. Suppose the drop_prob is 0.5, then during training + # we are going to drop half of the units. In this inverted version we also boost the activations + # of the remaining 50% by 2.0 (scale). The nice property of this is that during prediction time + # we don't have to do any scailing, since all 100% of units will be active, but at their base + # firing rate, giving 100% of the "energy". So the neurons later in the pipeline dont't change + # their expected firing rate magnitudes + if not predict_mode: # and we are in training mode + scale = 1.0 / (1.0 - drop_prob_encoder) + U = (np.random.rand(*(X.shape)) < (1 - drop_prob_encoder)) * scale # generate scaled mask + X *= U # drop! + + # follows http://arxiv.org/pdf/1409.2329.pdf + WLSTM = model['WLSTM'] + n = X.shape[0] + d = model['Wd'].shape[0] # size of hidden layer + Hin = np.zeros((n, WLSTM.shape[0])) # xt, ht-1, bias + Hout = np.zeros((n, d)) + IFOG = np.zeros((n, d * 4)) + IFOGf = np.zeros((n, d * 4)) # after nonlinearity + C = np.zeros((n, d)) + for t in xrange(n): + # set input + prev = np.zeros(d) if t == 0 else Hout[t-1] + Hin[t,0] = 1 + Hin[t,1:1+d] = X[t] + Hin[t,1+d:] = prev + + # compute all gate activations. dots: + IFOG[t] = Hin[t].dot(WLSTM) + + # non-linearities + IFOGf[t,:3*d] = 1.0/(1.0+np.exp(-IFOG[t,:3*d])) # sigmoids; these are the gates + IFOGf[t,3*d:] = np.tanh(IFOG[t, 3*d:]) # tanh + + # compute the cell activation + C[t] = IFOGf[t,:d] * IFOGf[t, 3*d:] + if t > 0: C[t] += IFOGf[t,d:2*d] * C[t-1] + if tanhC_version: + Hout[t] = IFOGf[t,2*d:3*d] * np.tanh(C[t]) + else: + Hout[t] = IFOGf[t,2*d:3*d] * C[t] + + if drop_prob_decoder > 0: # if we want dropout on the decoder + if not predict_mode: # and we are in training mode + scale2 = 1.0 / (1.0 - drop_prob_decoder) + U2 = (np.random.rand(*(Hout.shape)) < (1 - drop_prob_decoder)) * scale2 # generate scaled mask + Hout *= U2 # drop! + + # decoder at the end + Wd = model['Wd'] + bd = model['bd'] + # NOTE1: we are leaving out the first prediction, which was made for the image + # and is meaningless. + Y = Hout[1:, :].dot(Wd) + bd + + cache = {} + if not predict_mode: + # we can expect to do a backward pass + cache['WLSTM'] = WLSTM + cache['Hout'] = Hout + cache['Wd'] = Wd + cache['IFOGf'] = IFOGf + cache['IFOG'] = IFOG + cache['C'] = C + cache['X'] = X + cache['Hin'] = Hin + cache['tanhC_version'] = tanhC_version + cache['drop_prob_encoder'] = drop_prob_encoder + cache['drop_prob_decoder'] = drop_prob_decoder + if drop_prob_encoder > 0: cache['U'] = U # keep the dropout masks around for backprop + if drop_prob_decoder > 0: cache['U2'] = U2 + + return Y, cache + + @staticmethod + def backward(dY, cache): + + Wd = cache['Wd'] + Hout = cache['Hout'] + IFOG = cache['IFOG'] + IFOGf = cache['IFOGf'] + C = cache['C'] + Hin = cache['Hin'] + WLSTM = cache['WLSTM'] + X = cache['X'] + tanhC_version = cache['tanhC_version'] + drop_prob_encoder = cache['drop_prob_encoder'] + drop_prob_decoder = cache['drop_prob_decoder'] + n,d = Hout.shape + + # we have to add back a row of zeros, since in the forward pass + # this information was not used. See NOTE1 above. + dY = np.row_stack([np.zeros(dY.shape[1]), dY]) + + # backprop the decoder + dWd = Hout.transpose().dot(dY) + dbd = np.sum(dY, axis=0, keepdims = True) + dHout = dY.dot(Wd.transpose()) + + # backprop dropout, if it was applied + if drop_prob_decoder > 0: + dHout *= cache['U2'] + + # backprop the LSTM + dIFOG = np.zeros(IFOG.shape) + dIFOGf = np.zeros(IFOGf.shape) + dWLSTM = np.zeros(WLSTM.shape) + dHin = np.zeros(Hin.shape) + dC = np.zeros(C.shape) + dX = np.zeros(X.shape) + for t in reversed(xrange(n)): + + if tanhC_version: + tanhCt = np.tanh(C[t]) # recompute this here + dIFOGf[t,2*d:3*d] = tanhCt * dHout[t] + # backprop tanh non-linearity first then continue backprop + dC[t] += (1-tanhCt**2) * (IFOGf[t,2*d:3*d] * dHout[t]) + else: + dIFOGf[t,2*d:3*d] = C[t] * dHout[t] + dC[t] += IFOGf[t,2*d:3*d] * dHout[t] + + if t > 0: + dIFOGf[t,d:2*d] = C[t-1] * dC[t] + dC[t-1] += IFOGf[t,d:2*d] * dC[t] + dIFOGf[t,:d] = IFOGf[t, 3*d:] * dC[t] + dIFOGf[t, 3*d:] = IFOGf[t,:d] * dC[t] + + # backprop activation functions + dIFOG[t,3*d:] = (1 - IFOGf[t, 3*d:] ** 2) * dIFOGf[t,3*d:] + y = IFOGf[t,:3*d] + dIFOG[t,:3*d] = (y*(1.0-y)) * dIFOGf[t,:3*d] + + # backprop matrix multiply + dWLSTM += np.outer(Hin[t], dIFOG[t]) + dHin[t] = dIFOG[t].dot(WLSTM.transpose()) + + # backprop the identity transforms into Hin + dX[t] = dHin[t,1:1+d] + if t > 0: + dHout[t-1] += dHin[t,1+d:] + + if drop_prob_encoder > 0: # backprop encoder dropout + dX *= cache['U'] + + return { 'WLSTM': dWLSTM, 'Wd': dWd, 'bd': dbd, 'dXi': dX[0,:], 'dXs': dX[1:,:] } + + @staticmethod + def predict(Xi, model, Ws, **kwargs): + """ + Run in prediction mode with beam search. The input is the vector Xi, which + should be a 1-D array that contains the encoded image vector. We go from there. + Ws should be NxD array where N is size of vocabulary + 1. So there should be exactly + as many rows in Ws as there are outputs in the decoder Y. We are passing in Ws like + this because we may not want it to be exactly model['Ws']. For example it could be + fixed word vectors from somewhere else. + """ + tanhC_version = kwargs['tanhC_version'] + beam_size = kwargs.get('beam_size', 1) + + WLSTM = model['WLSTM'] + d = model['Wd'].shape[0] # size of hidden layer + Wd = model['Wd'] + bd = model['bd'] + + # lets define a helper function that does a single LSTM tick + def LSTMtick(x, h_prev, c_prev): + t = 0 + + # setup the input vector + Hin = np.zeros((1,WLSTM.shape[0])) # xt, ht-1, bias + Hin[t,0] = 1 + Hin[t,1:1+d] = x + Hin[t,1+d:] = h_prev + + # LSTM tick forward + IFOG = np.zeros((1, d * 4)) + IFOGf = np.zeros((1, d * 4)) + C = np.zeros((1, d)) + Hout = np.zeros((1, d)) + IFOG[t] = Hin[t].dot(WLSTM) + IFOGf[t,:3*d] = 1.0/(1.0+np.exp(-IFOG[t,:3*d])) + IFOGf[t,3*d:] = np.tanh(IFOG[t, 3*d:]) + C[t] = IFOGf[t,:d] * IFOGf[t, 3*d:] + IFOGf[t,d:2*d] * c_prev + if tanhC_version: + Hout[t] = IFOGf[t,2*d:3*d] * np.tanh(C[t]) + else: + Hout[t] = IFOGf[t,2*d:3*d] * C[t] + Y = Hout.dot(Wd) + bd + return (Y, Hout, C) # return output, new hidden, new cell + + # forward prop the image + (y0, h, c) = LSTMtick(Xi, np.zeros(d), np.zeros(d)) + + # perform BEAM search. NOTE: I am not very confident in this implementation since I don't have + # a lot of experience with these models. This implements my current understanding but I'm not + # sure how to handle beams that predict END tokens. TODO: research this more. + if beam_size > 1: + # log probability, indices of words predicted in this beam so far, and the hidden and cell states + beams = [(0.0, [], h, c)] + nsteps = 0 + while True: + beam_candidates = [] + for b in beams: + ixprev = b[1][-1] if b[1] else 0 # start off with the word where this beam left off + if ixprev == 0 and b[1]: + # this beam predicted end token. Keep in the candidates but don't expand it out any more + beam_candidates.append(b) + continue + (y1, h1, c1) = LSTMtick(Ws[ixprev], b[2], b[3]) + y1 = y1.ravel() # make into 1D vector + maxy1 = np.amax(y1) + e1 = np.exp(y1 - maxy1) # for numerical stability shift into good numerical range + p1 = e1 / np.sum(e1) + y1 = np.log(1e-20 + p1) # and back to log domain + top_indices = np.argsort(-y1) # we do -y because we want decreasing order + for i in xrange(beam_size): + wordix = top_indices[i] + beam_candidates.append((b[0] + y1[wordix], b[1] + [wordix], h1, c1)) + beam_candidates.sort(reverse = True) # decreasing order + beams = beam_candidates[:beam_size] # truncate to get new beams + nsteps += 1 + if nsteps >= 20: # bad things are probably happening, break out + break + # strip the intermediates + predictions = [(b[0], b[1]) for b in beams] + else: + # greedy inference. lets write it up independently, should be bit faster and simpler + ixprev = 0 + nsteps = 0 + predix = [] + predlogprob = 0.0 + while True: + (y1, h, c) = LSTMtick(Ws[ixprev], h, c) + ixprev, ixlogprob = ymax(y1) + predix.append(ixprev) + predlogprob += ixlogprob + nsteps += 1 + if ixprev == 0 or nsteps >= 20: + break + predictions = [(predlogprob, predix)] + + return predictions + +def ymax(y): + """ simple helper function here that takes unnormalized logprobs """ + y1 = y.ravel() # make sure 1d + maxy1 = np.amax(y1) + e1 = np.exp(y1 - maxy1) # for numerical stability shift into good numerical range + p1 = e1 / np.sum(e1) + y1 = np.log(1e-20 + p1) # guard against zero probabilities just in case + ix = np.argmax(y1) + return (ix, y1[ix]) diff --git a/imagernn/rnn_generator.py b/imagernn/rnn_generator.py new file mode 100644 index 0000000..4135d42 --- /dev/null +++ b/imagernn/rnn_generator.py @@ -0,0 +1,216 @@ +import numpy as np +import code + +from imagernn.utils import initw + +class RNNGenerator: + """ + An RNN generator. + This class is as stupid as possible. It gets some conditioning vector, + a sequence of input vectors, and produces a sequence of output vectors + """ + + @staticmethod + def init(input_size, hidden_size, output_size): + + model = {} + # connections to h_{t-1} + model['Whh'] = initw(hidden_size, hidden_size) + model['bhh'] = np.zeros((1, hidden_size)) + # Decoder weights (e.g. mapping to vocabulary) + model['Wd'] = initw(hidden_size, output_size) * 0.01 # decoder + model['bd'] = np.zeros((1, output_size)) + + update = ['Whh', 'bhh', 'Wd', 'bd'] + regularize = ['Whh', 'Wd'] + return { 'model' : model, 'update' : update, 'regularize' : regularize } + + @staticmethod + def forward(Xi, Xs, model, **kwargs): + """ + Xi is 1-d array of size D1 (containing the image representation) + Xs is N x D2 (N time steps, rows are data containng word representations), and + it is assumed that the first row is already filled in as the start token. So a + sentence with 10 words will be of size 11xD2 in Xs. + """ + predict_mode = kwargs.get('predict_mode', False) + + # options + drop_prob_encoder = kwargs.get('drop_prob_encoder', 0.0) + drop_prob_decoder = kwargs.get('drop_prob_decoder', 0.0) + + if drop_prob_encoder > 0: # if we want dropout on the encoder + # inverted version of dropout here. Suppose the drop_prob is 0.5, then during training + # we are going to drop half of the units. In this inverted version we also boost the activations + # of the remaining 50% by 2.0 (scale). The nice property of this is that during prediction time + # we don't have to do any scailing, since all 100% of units will be active, but at their base + # firing rate, giving 100% of the "energy". So the neurons later in the pipeline dont't change + # their expected firing rate magnitudes + if not predict_mode: # and we are in training mode + scale = 1.0 / (1.0 - drop_prob_encoder) + Us = (np.random.rand(*(Xs.shape)) < (1 - drop_prob_encoder)) * scale # generate scaled mask + Xs *= Us # drop! + Ui = (np.random.rand(*(Xi.shape)) < (1 - drop_prob_encoder)) * scale + Xi *= Ui # drop! + + # recurrence iteration for the Multimodal RNN similar to one described in Karpathy et al. + d = model['Wd'].shape[0] # size of hidden layer + n = Xs.shape[0] + H = np.zeros((n, d)) # hidden layer representation + Whh = model['Whh'] + bhh = model['bhh'] + for t in xrange(n): + + prev = np.zeros(d) if t == 0 else H[t-1] + ht = Xi + Xs[t] + prev.dot(Whh) + bhh + H[t] = np.maximum(ht, 0) # ReLU nonlinearity + + if drop_prob_decoder > 0: # if we want dropout on the decoder + if not predict_mode: # and we are in training mode + scale2 = 1.0 / (1.0 - drop_prob_decoder) + U2 = (np.random.rand(*(H.shape)) < (1 - drop_prob_decoder)) * scale2 # generate scaled mask + H *= U2 # drop! + + # decoder at the end + Wd = model['Wd'] + bd = model['bd'] + Y = H.dot(Wd) + bd + + cache = {} + if not predict_mode: + # we can expect to do a backward pass + cache['Whh'] = Whh + cache['H'] = H + cache['Wd'] = Wd + cache['Xs'] = Xs + cache['drop_prob_encoder'] = drop_prob_encoder + cache['drop_prob_decoder'] = drop_prob_decoder + if drop_prob_encoder > 0: + cache['Us'] = Us # keep the dropout masks around for backprop + cache['Ui'] = Ui + if drop_prob_decoder > 0: cache['U2'] = U2 + + return Y, cache + + @staticmethod + def backward(dY, cache): + + Wd = cache['Wd'] + H = cache['H'] + Xs = cache['Xs'] + Whh = cache['Whh'] + drop_prob_encoder = cache['drop_prob_encoder'] + drop_prob_decoder = cache['drop_prob_decoder'] + n,d = H.shape + + # backprop the decoder + dWd = H.transpose().dot(dY) + dbd = np.sum(dY, axis=0, keepdims = True) + dH = dY.dot(Wd.transpose()) + + # backprop dropout, if it was applied + if drop_prob_decoder > 0: + dH *= cache['U2'] + + # backprop the recurrent connections + dXs = np.zeros(Xs.shape) + dXi = np.zeros(d) + dWhh = np.zeros(Whh.shape) + dbhh = np.zeros((1,d)) + for t in reversed(xrange(n)): + dht = (H[t] > 0) * dH[t] # backprop ReLU + dXi += dht # backprop to Xi + dXs[t] += dht # backprop to word encodings + dbhh[0] += dht # backprop to bias + + if t > 0: + dH[t-1] += dht.dot(Whh.transpose()) + dWhh += np.outer(H[t-1], dht) + + if drop_prob_encoder > 0: # backprop encoder dropout + dXi *= cache['Ui'] + dXs *= cache['Us'] + + return { 'Whh': dWhh, 'bhh': dbhh, 'Wd': dWd, 'bd': dbd, 'dXs' : dXs, 'dXi': dXi } + + @staticmethod + def predict(Xi, model, Ws, **kwargs): + + beam_size = kwargs.get('beam_size', 1) + + d = model['Wd'].shape[0] # size of hidden layer + Whh = model['Whh'] + bhh = model['bhh'] + Wd = model['Wd'] + bd = model['bd'] + + if beam_size > 1: + # perform beam search + # NOTE: code duplication here with lstm_generator + # ideally the beam search would be abstracted away nicely and would take + # a TICK function or something, but for now lets save time & copy code around. Sorry ;\ + beams = [(0.0, [], np.zeros(d))] + nsteps = 0 + while True: + beam_candidates = [] + for b in beams: + ixprev = b[1][-1] if b[1] else 0 + if ixprev == 0 and b[1]: + # this beam predicted end token. Keep in the candidates but don't expand it out any more + beam_candidates.append(b) + continue + # tick the RNN for this beam + h1 = np.maximum(Xi + Ws[ixprev] + b[2].dot(Whh) + bhh, 0) + y1 = h1.dot(Wd) + bd + + # compute new candidates that expand out form this beam + y1 = y1.ravel() # make into 1D vector + maxy1 = np.amax(y1) + e1 = np.exp(y1 - maxy1) # for numerical stability shift into good numerical range + p1 = e1 / np.sum(e1) + y1 = np.log(1e-20 + p1) # and back to log domain + top_indices = np.argsort(-y1) # we do -y because we want decreasing order + for i in xrange(beam_size): + wordix = top_indices[i] + beam_candidates.append((b[0] + y1[wordix], b[1] + [wordix], h1)) + + beam_candidates.sort(reverse = True) # decreasing order + beams = beam_candidates[:beam_size] # truncate to get new beams + nsteps += 1 + if nsteps >= 20: # bad things are probably happening, break out + break + # strip the intermediates + predictions = [(b[0], b[1]) for b in beams] + + else: + ixprev = 0 # start out on start token + nsteps = 0 + predix = [] + predlogprob = 0.0 + hprev = np.zeros((1, d)) # hidden layer representation + xsprev = Ws[0] # start token + while True: + ht = np.maximum(Xi + Ws[ixprev] + hprev.dot(Whh) + bhh, 0) + Y = ht.dot(Wd) + bd + hprev = ht + + ixprev, ixlogprob = ymax(Y) + predix.append(ixprev) + predlogprob += ixlogprob + + nsteps += 1 + if ixprev == 0 or nsteps >= 20: + break + predictions = [(predlogprob, predix)] + return predictions + + +def ymax(y): + """ simple helper function here that takes unnormalized logprobs """ + y1 = y.ravel() # make sure 1d + maxy1 = np.amax(y1) + e1 = np.exp(y1 - maxy1) # for numerical stability shift into good numerical range + p1 = e1 / np.sum(e1) + y1 = np.log(1e-20 + p1) # guard against zero probabilities just in case + ix = np.argmax(y1) + return (ix, y1[ix]) diff --git a/imagernn/solver.py b/imagernn/solver.py new file mode 100644 index 0000000..eb4a425 --- /dev/null +++ b/imagernn/solver.py @@ -0,0 +1,150 @@ +import time +import numpy as np +from imagernn.utils import randi + +class Solver: + """ + solver worries about: + - different optimization methods, updates, weight decays + - it can also perform gradient check + """ + def __init__(self): + self.step_cache_ = {} # might need this + self.step_cache_2 = {} # might need this + + def step(self, batch, model, cost_function, **kwargs): + """ + perform a single batch update. Takes as input: + - batch of data (X) + - model (W) + - cost function which takes batch, model + """ + + learning_rate = kwargs.get('learning_rate', 0.0) + update = kwargs.get('update', model.keys()) + grad_clip = kwargs.get('grad_clip', -1) + solver = kwargs.get('solver', 'vanilla') + momentum = kwargs.get('momentum', 0) + smooth_eps = kwargs.get('smooth_eps', 1e-8) + decay_rate = kwargs.get('decay_rate', 0.999) + + if not (solver == 'vanilla' and momentum == 0): + # lazily make sure we initialize step cache if needed + for u in update: + if not u in self.step_cache_: + self.step_cache_[u] = np.zeros(model[u].shape) + if solver == 'adadelta': + self.step_cache2_[u] = np.zeros(model[u].shape) # adadelta needs one more cache + + # compute cost and gradient + cg = cost_function(batch, model) + cost = cg['cost'] + grads = cg['grad'] + stats = cg['stats'] + + # clip gradients if needed, simplest possible version + # todo later: maybe implement the gradient direction conserving version + if grad_clip > 0: + for p in update: + if p in grads: + grads[p] = np.minimum(grads[p], grad_clip) + grads[p] = np.maximum(grads[p], -grad_clip) + + # perform parameter update + for p in update: + if p in grads: + + if solver == 'vanilla': # vanilla sgd, optional with momentum + if momentum > 0: + dx = momentum * self.step_cache_[p] - learning_rate * grads[p] + self.step_cache_[p] = dx + else: + dx = - learning_rate * grads[p] + + elif solver == 'rmsprop': + self.step_cache_[p] = self.step_cache_[p] * decay_rate + (1.0 - decay_rate) * grads[p] ** 2 + dx = -(learning_rate * grads[p]) / np.sqrt(self.step_cache_[p] + smooth_eps) + + elif solver == 'adagrad': + self.step_cache_[p] += grads[p] ** 2 + dx = -(learning_rate * grads[p]) / np.sqrt(self.step_cache_[p] + smooth_eps) + + elif solver == 'adadelta': + self.step_cache_[p] = self.step_cache_[p] * decay_rate + (1.0 - decay_rate) * grads[p] ** 2 + dx = - np.sqrt( (self.step_cache2_[p] + smooth_eps) / (self.step_cache_[p] + smooth_eps) ) * grads[p] + self.step_cache2_[p] = self.step_cache2_[p] * decay_rate + (1.0 - decay_rate) * (dx ** 2) + + else: + raise Exception("solver %s not supported" % (solver, )) + + # perform the parameter update + model[p] += dx + + # create output dict and return + out = {} + out['cost'] = cost + out['stats'] = stats + return out + + def gradCheck(self, batch, model, cost_function, **kwargs): + """ + perform gradient check. + since gradcheck can be tricky (especially with relus involved) + this function prints to console for visual inspection + """ + + num_checks = kwargs.get('num_checks', 10) + delta = kwargs.get('delta', 1e-5) + rel_error_thr_warning = kwargs.get('rel_error_thr_warning', 1e-2) + rel_error_thr_error = kwargs.get('rel_error_thr_error', 1) + + cg = cost_function(batch, model) + + print 'running gradient check...' + for p in model.keys(): + print 'checking gradient on parameter %s of shape %s...' % (p, `model[p].shape`) + mat = model[p] + + s0 = cg['grad'][p].shape + s1 = mat.shape + assert s0 == s1, 'Error dims dont match: %s and %s.' % (`s0`, `s1`) + + for i in xrange(num_checks): + ri = randi(mat.size) + + # evluate cost at [x + delta] and [x - delta] + old_val = mat.flat[ri] + mat.flat[ri] = old_val + delta + cg0 = cost_function(batch, model) + mat.flat[ri] = old_val - delta + cg1 = cost_function(batch, model) + mat.flat[ri] = old_val # reset old value for this parameter + + # fetch both numerical and analytic gradient + grad_analytic = cg['grad'][p].flat[ri] + grad_numerical = (cg0['cost']['total_cost'] - cg1['cost']['total_cost']) / ( 2 * delta ) + + # compare them + if grad_numerical == 0 and grad_analytic == 0: + rel_error = 0 # both are zero, OK. + status = 'OK' + elif abs(grad_numerical) < 1e-7 and abs(grad_analytic) < 1e-7: + rel_error = 0 # not enough precision to check this + status = 'VAL SMALL WARNING' + else: + rel_error = abs(grad_analytic - grad_numerical) / abs(grad_numerical + grad_analytic) + status = 'OK' + if rel_error > rel_error_thr_warning: status = 'WARNING' + if rel_error > rel_error_thr_error: status = '!!!!! NOTOK' + + # print stats + print '%s checking param %s index %8d (val = %+8f), analytic = %+8f, numerical = %+8f, relative error = %+8f' \ + % (status, p, ri, old_val, grad_analytic, grad_numerical, rel_error) + + + + + + + + diff --git a/imagernn/utils.py b/imagernn/utils.py new file mode 100644 index 0000000..70f3cb8 --- /dev/null +++ b/imagernn/utils.py @@ -0,0 +1,26 @@ +from random import uniform +import numpy as np + +def randi(N): + """ get random integer in range [0, N) """ + return int(uniform(0, N)) + +def merge_init_structs(s0, s1): + """ merge struct s1 into s0 """ + for k in s1['model']: + assert (not k in s0['model']), 'Error: looks like parameter %s is trying to be initialized twice!' % (k, ) + s0['model'][k] = s1['model'][k] # copy over the pointer + s0['update'].extend(s1['update']) + s0['regularize'].extend(s1['regularize']) + +def initw(n,d): # initialize matrix of this size + magic_number = 0.1 + return (np.random.rand(n,d) * 2 - 1) * magic_number # U[-0.1, 0.1] + +def accumNpDicts(d0, d1): + """ forall k in d0, d0 += d1 . d's are dictionaries of key -> numpy array """ + for k in d1: + if k in d0: + d0[k] += d1[k] + else: + d0[k] = d1[k] \ No newline at end of file diff --git a/matlab_features_reference/README.md b/matlab_features_reference/README.md new file mode 100644 index 0000000..1c5acda --- /dev/null +++ b/matlab_features_reference/README.md @@ -0,0 +1,5 @@ +This directory contains the files I used to extract CNN features for my images. This is not meant to be a ready-to-use utility. I just wanted to provide this code in hope that it can be useful to someone as starter code to work on top of. + +- This code uses [Caffe](http://caffe.berkeleyvision.org/) and their Matlab wrapper. +- I use VGG Net which can be found in the [Model Zoo ](https://github.com/BVLC/caffe/wiki/Model-Zoo) under the title *Models used by the VGG team in ILSVRC-2014*. I use the [16-layer version](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md). +- Note that I provide my _features deploy network def as well, which is exactly what you see on that page but I chopped off the softmax to get the 4096-D codes below. diff --git a/matlab_features_reference/deploy_features.prototxt b/matlab_features_reference/deploy_features.prototxt new file mode 100644 index 0000000..ef137e2 --- /dev/null +++ b/matlab_features_reference/deploy_features.prototxt @@ -0,0 +1,321 @@ +name: "VGG_ILSVRC_16_layers" +input: "data" +input_dim: 10 +input_dim: 3 +input_dim: 224 +input_dim: 224 +layers { + bottom: "data" + top: "conv1_1" + name: "conv1_1" + type: CONVOLUTION + convolution_param { + num_output: 64 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv1_1" + top: "conv1_1" + name: "relu1_1" + type: RELU +} +layers { + bottom: "conv1_1" + top: "conv1_2" + name: "conv1_2" + type: CONVOLUTION + convolution_param { + num_output: 64 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv1_2" + top: "conv1_2" + name: "relu1_2" + type: RELU +} +layers { + bottom: "conv1_2" + top: "pool1" + name: "pool1" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 2 + stride: 2 + } +} +layers { + bottom: "pool1" + top: "conv2_1" + name: "conv2_1" + type: CONVOLUTION + convolution_param { + num_output: 128 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv2_1" + top: "conv2_1" + name: "relu2_1" + type: RELU +} +layers { + bottom: "conv2_1" + top: "conv2_2" + name: "conv2_2" + type: CONVOLUTION + convolution_param { + num_output: 128 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv2_2" + top: "conv2_2" + name: "relu2_2" + type: RELU +} +layers { + bottom: "conv2_2" + top: "pool2" + name: "pool2" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 2 + stride: 2 + } +} +layers { + bottom: "pool2" + top: "conv3_1" + name: "conv3_1" + type: CONVOLUTION + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv3_1" + top: "conv3_1" + name: "relu3_1" + type: RELU +} +layers { + bottom: "conv3_1" + top: "conv3_2" + name: "conv3_2" + type: CONVOLUTION + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv3_2" + top: "conv3_2" + name: "relu3_2" + type: RELU +} +layers { + bottom: "conv3_2" + top: "conv3_3" + name: "conv3_3" + type: CONVOLUTION + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv3_3" + top: "conv3_3" + name: "relu3_3" + type: RELU +} +layers { + bottom: "conv3_3" + top: "pool3" + name: "pool3" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 2 + stride: 2 + } +} +layers { + bottom: "pool3" + top: "conv4_1" + name: "conv4_1" + type: CONVOLUTION + convolution_param { + num_output: 512 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv4_1" + top: "conv4_1" + name: "relu4_1" + type: RELU +} +layers { + bottom: "conv4_1" + top: "conv4_2" + name: "conv4_2" + type: CONVOLUTION + convolution_param { + num_output: 512 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv4_2" + top: "conv4_2" + name: "relu4_2" + type: RELU +} +layers { + bottom: "conv4_2" + top: "conv4_3" + name: "conv4_3" + type: CONVOLUTION + convolution_param { + num_output: 512 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv4_3" + top: "conv4_3" + name: "relu4_3" + type: RELU +} +layers { + bottom: "conv4_3" + top: "pool4" + name: "pool4" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 2 + stride: 2 + } +} +layers { + bottom: "pool4" + top: "conv5_1" + name: "conv5_1" + type: CONVOLUTION + convolution_param { + num_output: 512 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv5_1" + top: "conv5_1" + name: "relu5_1" + type: RELU +} +layers { + bottom: "conv5_1" + top: "conv5_2" + name: "conv5_2" + type: CONVOLUTION + convolution_param { + num_output: 512 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv5_2" + top: "conv5_2" + name: "relu5_2" + type: RELU +} +layers { + bottom: "conv5_2" + top: "conv5_3" + name: "conv5_3" + type: CONVOLUTION + convolution_param { + num_output: 512 + pad: 1 + kernel_size: 3 + } +} +layers { + bottom: "conv5_3" + top: "conv5_3" + name: "relu5_3" + type: RELU +} +layers { + bottom: "conv5_3" + top: "pool5" + name: "pool5" + type: POOLING + pooling_param { + pool: MAX + kernel_size: 2 + stride: 2 + } +} +layers { + bottom: "pool5" + top: "fc6" + name: "fc6" + type: INNER_PRODUCT + inner_product_param { + num_output: 4096 + } +} +layers { + bottom: "fc6" + top: "fc6" + name: "relu6" + type: RELU +} +layers { + bottom: "fc6" + top: "fc6" + name: "drop6" + type: DROPOUT + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + bottom: "fc6" + top: "fc7" + name: "fc7" + type: INNER_PRODUCT + inner_product_param { + num_output: 4096 + } +} +layers { + bottom: "fc7" + top: "fc7" + name: "relu7" + type: RELU +} diff --git a/matlab_features_reference/extract_features.m b/matlab_features_reference/extract_features.m new file mode 100644 index 0000000..7acb05b --- /dev/null +++ b/matlab_features_reference/extract_features.m @@ -0,0 +1,47 @@ +%% vgg / caffe spec + +use_gpu = 1; +caffe('set_device', 1); +model_def_file = '/home/karpathy/caffe2/models/vgg_ilsvrc_16/deploy_features.prototxt'; +model_file = '/home/karpathy/caffe2/models/vgg_ilsvrc_16/VGG_ILSVRC_16_layers.caffemodel'; +batch_size = 10; + +matcaffe_init(use_gpu, model_def_file, model_file); + +%% input files spec + +root_path = '/data2/karpathy/flickr30k/'; +fs = textread([root_path 'all_imgs.txt'], '%s'); +N = length(fs); + +%% + +% iterate over the iamges in batches +feats = zeros(4096, N, 'single'); +for b=1:batch_size:N + + % enter images, and dont go out of bounds + Is = {}; + for i = b:min(N,b+batch_size-1) + I = imread([root_path fs{i}]); + if ndims(I) == 2 + I = cat(3, I, I, I); % handle grayscale edge case. Annoying! + end + Is{end+1} = I; + end + input_data = prepare_images_batch(Is); + + tic; + scores = caffe('forward', {input_data}); + scores = squeeze(scores{1}); + tt = toc; + + nb = length(Is); + feats(:, b:b+nb-1) = scores(:,1:nb); + fprintf('%d/%d = %.2f%% done in %.2fs\n', b, N, 100*(b-1)/N, tt); +end + +%% write to file + +save([root_path 'vgg_feats_hdf5.mat'], 'feats', '-v7.3'); +save([root_path 'vgg_feats.mat'], 'feats'); diff --git a/matlab_features_reference/prepare_images_batch.m b/matlab_features_reference/prepare_images_batch.m new file mode 100644 index 0000000..feb187b --- /dev/null +++ b/matlab_features_reference/prepare_images_batch.m @@ -0,0 +1,28 @@ +function images = prepare_images_batch(im_cell) +% takes a cell array of images in raw uint8 form + +mean_pix = [103.939, 116.779, 123.68]; % VGG mean assumed +batch_size = 10; + +CROPPED_DIM = 224; + +% insert images +images = zeros(CROPPED_DIM, CROPPED_DIM, 3, batch_size, 'single'); +N = length(im_cell); +for i=1:N + + % resize to fixed input + im = single(im_cell{i}); + im = imresize(im, [CROPPED_DIM CROPPED_DIM]); + + % RGB -> BGR + im = im(:, :, [3 2 1]); + + % flip xy + images(:, :, :, i) = permute(im, [2, 1, 3]); +end + +% mean BGR pixel subtraction +for c = 1:3 + images(:, :, c, :) = images(:, :, c, :) - mean_pix(c); +end diff --git a/monitorcv.html b/monitorcv.html new file mode 100644 index 0000000..31851ae --- /dev/null +++ b/monitorcv.html @@ -0,0 +1,165 @@ + + + + + + + + + + + + + + + +
+ + +
+
+
+ + + \ No newline at end of file diff --git a/status/Readme.md b/status/Readme.md new file mode 100644 index 0000000..ec61335 --- /dev/null +++ b/status/Readme.md @@ -0,0 +1 @@ +This contains status JSON files of running optimizations diff --git a/vis_resources/d3.min.js b/vis_resources/d3.min.js new file mode 100644 index 0000000..125058e --- /dev/null +++ b/vis_resources/d3.min.js @@ -0,0 +1,5 @@ +!function(){function n(n,t){return t>n?-1:n>t?1:n>=t?0:0/0}function t(n){return null!=n&&!isNaN(n)}function e(n){return{left:function(t,e,r,u){for(arguments.length<3&&(r=0),arguments.length<4&&(u=t.length);u>r;){var i=r+u>>>1;n(t[i],e)<0?r=i+1:u=i}return r},right:function(t,e,r,u){for(arguments.length<3&&(r=0),arguments.length<4&&(u=t.length);u>r;){var i=r+u>>>1;n(t[i],e)>0?u=i:r=i+1}return r}}}function r(n){return n.length}function u(n){for(var t=1;n*t%1;)t*=10;return t}function i(n,t){try{for(var e in t)Object.defineProperty(n.prototype,e,{value:t[e],enumerable:!1})}catch(r){n.prototype=t}}function o(){}function a(n){return ha+n in this}function c(n){return n=ha+n,n in this&&delete this[n]}function s(){var n=[];return this.forEach(function(t){n.push(t)}),n}function l(){var n=0;for(var t in this)t.charCodeAt(0)===ga&&++n;return n}function f(){for(var n in this)if(n.charCodeAt(0)===ga)return!1;return!0}function h(){}function g(n,t,e){return function(){var r=e.apply(t,arguments);return r===t?n:r}}function p(n,t){if(t in n)return t;t=t.charAt(0).toUpperCase()+t.substring(1);for(var e=0,r=pa.length;r>e;++e){var u=pa[e]+t;if(u in n)return u}}function v(){}function d(){}function m(n){function t(){for(var t,r=e,u=-1,i=r.length;++ue;e++)for(var u,i=n[e],o=0,a=i.length;a>o;o++)(u=i[o])&&t(u,o,e);return n}function U(n){return da(n,wa),n}function j(n){var t,e;return function(r,u,i){var o,a=n[i].update,c=a.length;for(i!=e&&(e=i,t=0),u>=t&&(t=u+1);!(o=a[t])&&++t0&&(n=n.substring(0,a));var s=ka.get(n);return s&&(n=s,c=I),a?t?u:r:t?v:i}function O(n,t){return function(e){var r=Go.event;Go.event=e,t[0]=this.__data__;try{n.apply(this,t)}finally{Go.event=r}}}function I(n,t){var e=O(n,t);return function(n){var t=this,r=n.relatedTarget;r&&(r===t||8&r.compareDocumentPosition(t))||e.call(t,n)}}function Y(){var n=".dragsuppress-"+ ++Aa,t="click"+n,e=Go.select(ea).on("touchmove"+n,y).on("dragstart"+n,y).on("selectstart"+n,y);if(Ea){var r=ta.style,u=r[Ea];r[Ea]="none"}return function(i){function o(){e.on(t,null)}e.on(n,null),Ea&&(r[Ea]=u),i&&(e.on(t,function(){y(),o()},!0),setTimeout(o,0))}}function Z(n,t){t.changedTouches&&(t=t.changedTouches[0]);var e=n.ownerSVGElement||n;if(e.createSVGPoint){var r=e.createSVGPoint();return r.x=t.clientX,r.y=t.clientY,r=r.matrixTransform(n.getScreenCTM().inverse()),[r.x,r.y]}var u=n.getBoundingClientRect();return[t.clientX-u.left-n.clientLeft,t.clientY-u.top-n.clientTop]}function V(){return Go.event.changedTouches[0].identifier}function $(){return Go.event.target}function X(){return ea}function B(n){return n>0?1:0>n?-1:0}function J(n,t,e){return(t[0]-n[0])*(e[1]-n[1])-(t[1]-n[1])*(e[0]-n[0])}function W(n){return n>1?0:-1>n?Ca:Math.acos(n)}function G(n){return n>1?La:-1>n?-La:Math.asin(n)}function K(n){return((n=Math.exp(n))-1/n)/2}function Q(n){return((n=Math.exp(n))+1/n)/2}function nt(n){return((n=Math.exp(2*n))-1)/(n+1)}function tt(n){return(n=Math.sin(n/2))*n}function et(){}function rt(n,t,e){return new ut(n,t,e)}function ut(n,t,e){this.h=n,this.s=t,this.l=e}function it(n,t,e){function r(n){return n>360?n-=360:0>n&&(n+=360),60>n?i+(o-i)*n/60:180>n?o:240>n?i+(o-i)*(240-n)/60:i}function u(n){return Math.round(255*r(n))}var i,o;return n=isNaN(n)?0:(n%=360)<0?n+360:n,t=isNaN(t)?0:0>t?0:t>1?1:t,e=0>e?0:e>1?1:e,o=.5>=e?e*(1+t):e+t-e*t,i=2*e-o,yt(u(n+120),u(n),u(n-120))}function ot(n,t,e){return new at(n,t,e)}function at(n,t,e){this.h=n,this.c=t,this.l=e}function ct(n,t,e){return isNaN(n)&&(n=0),isNaN(t)&&(t=0),st(e,Math.cos(n*=za)*t,Math.sin(n)*t)}function st(n,t,e){return new lt(n,t,e)}function lt(n,t,e){this.l=n,this.a=t,this.b=e}function ft(n,t,e){var r=(n+16)/116,u=r+t/500,i=r-e/200;return u=gt(u)*Za,r=gt(r)*Va,i=gt(i)*$a,yt(vt(3.2404542*u-1.5371385*r-.4985314*i),vt(-.969266*u+1.8760108*r+.041556*i),vt(.0556434*u-.2040259*r+1.0572252*i))}function ht(n,t,e){return n>0?ot(Math.atan2(e,t)*Ra,Math.sqrt(t*t+e*e),n):ot(0/0,0/0,n)}function gt(n){return n>.206893034?n*n*n:(n-4/29)/7.787037}function pt(n){return n>.008856?Math.pow(n,1/3):7.787037*n+4/29}function vt(n){return Math.round(255*(.00304>=n?12.92*n:1.055*Math.pow(n,1/2.4)-.055))}function dt(n){return yt(n>>16,255&n>>8,255&n)}function mt(n){return dt(n)+""}function yt(n,t,e){return new xt(n,t,e)}function xt(n,t,e){this.r=n,this.g=t,this.b=e}function Mt(n){return 16>n?"0"+Math.max(0,n).toString(16):Math.min(255,n).toString(16)}function _t(n,t,e){var r,u,i,o=0,a=0,c=0;if(r=/([a-z]+)\((.*)\)/i.exec(n))switch(u=r[2].split(","),r[1]){case"hsl":return e(parseFloat(u[0]),parseFloat(u[1])/100,parseFloat(u[2])/100);case"rgb":return t(kt(u[0]),kt(u[1]),kt(u[2]))}return(i=Ja.get(n))?t(i.r,i.g,i.b):(null==n||"#"!==n.charAt(0)||isNaN(i=parseInt(n.substring(1),16))||(4===n.length?(o=(3840&i)>>4,o=o>>4|o,a=240&i,a=a>>4|a,c=15&i,c=c<<4|c):7===n.length&&(o=(16711680&i)>>16,a=(65280&i)>>8,c=255&i)),t(o,a,c))}function bt(n,t,e){var r,u,i=Math.min(n/=255,t/=255,e/=255),o=Math.max(n,t,e),a=o-i,c=(o+i)/2;return a?(u=.5>c?a/(o+i):a/(2-o-i),r=n==o?(t-e)/a+(e>t?6:0):t==o?(e-n)/a+2:(n-t)/a+4,r*=60):(r=0/0,u=c>0&&1>c?0:r),rt(r,u,c)}function wt(n,t,e){n=St(n),t=St(t),e=St(e);var r=pt((.4124564*n+.3575761*t+.1804375*e)/Za),u=pt((.2126729*n+.7151522*t+.072175*e)/Va),i=pt((.0193339*n+.119192*t+.9503041*e)/$a);return st(116*u-16,500*(r-u),200*(u-i))}function St(n){return(n/=255)<=.04045?n/12.92:Math.pow((n+.055)/1.055,2.4)}function kt(n){var t=parseFloat(n);return"%"===n.charAt(n.length-1)?Math.round(2.55*t):t}function Et(n){return"function"==typeof n?n:function(){return n}}function At(n){return n}function Ct(n){return function(t,e,r){return 2===arguments.length&&"function"==typeof e&&(r=e,e=null),Nt(t,e,n,r)}}function Nt(n,t,e,r){function u(){var n,t=c.status;if(!t&&c.responseText||t>=200&&300>t||304===t){try{n=e.call(i,c)}catch(r){return o.error.call(i,r),void 0}o.load.call(i,n)}else o.error.call(i,c)}var i={},o=Go.dispatch("beforesend","progress","load","error"),a={},c=new XMLHttpRequest,s=null;return!ea.XDomainRequest||"withCredentials"in c||!/^(http(s)?:)?\/\//.test(n)||(c=new XDomainRequest),"onload"in c?c.onload=c.onerror=u:c.onreadystatechange=function(){c.readyState>3&&u()},c.onprogress=function(n){var t=Go.event;Go.event=n;try{o.progress.call(i,c)}finally{Go.event=t}},i.header=function(n,t){return n=(n+"").toLowerCase(),arguments.length<2?a[n]:(null==t?delete a[n]:a[n]=t+"",i)},i.mimeType=function(n){return arguments.length?(t=null==n?null:n+"",i):t},i.responseType=function(n){return arguments.length?(s=n,i):s},i.response=function(n){return e=n,i},["get","post"].forEach(function(n){i[n]=function(){return i.send.apply(i,[n].concat(Qo(arguments)))}}),i.send=function(e,r,u){if(2===arguments.length&&"function"==typeof r&&(u=r,r=null),c.open(e,n,!0),null==t||"accept"in a||(a.accept=t+",*/*"),c.setRequestHeader)for(var l in a)c.setRequestHeader(l,a[l]);return null!=t&&c.overrideMimeType&&c.overrideMimeType(t),null!=s&&(c.responseType=s),null!=u&&i.on("error",u).on("load",function(n){u(null,n)}),o.beforesend.call(i,c),c.send(null==r?null:r),i},i.abort=function(){return c.abort(),i},Go.rebind(i,o,"on"),null==r?i:i.get(Lt(r))}function Lt(n){return 1===n.length?function(t,e){n(null==t?e:null)}:n}function Tt(){var n=qt(),t=zt()-n;t>24?(isFinite(t)&&(clearTimeout(Qa),Qa=setTimeout(Tt,t)),Ka=0):(Ka=1,tc(Tt))}function qt(){var n=Date.now();for(nc=Wa;nc;)n>=nc.t&&(nc.f=nc.c(n-nc.t)),nc=nc.n;return n}function zt(){for(var n,t=Wa,e=1/0;t;)t.f?t=n?n.n=t.n:Wa=t.n:(t.t8?function(n){return n/e}:function(n){return n*e},symbol:n}}function Pt(n){var t=n.decimal,e=n.thousands,r=n.grouping,u=n.currency,i=r?function(n){for(var t=n.length,u=[],i=0,o=r[0];t>0&&o>0;)u.push(n.substring(t-=o,t+o)),o=r[i=(i+1)%r.length];return u.reverse().join(e)}:At;return function(n){var e=rc.exec(n),r=e[1]||" ",o=e[2]||">",a=e[3]||"",c=e[4]||"",s=e[5],l=+e[6],f=e[7],h=e[8],g=e[9],p=1,v="",d="",m=!1;switch(h&&(h=+h.substring(1)),(s||"0"===r&&"="===o)&&(s=r="0",o="=",f&&(l-=Math.floor((l-1)/4))),g){case"n":f=!0,g="g";break;case"%":p=100,d="%",g="f";break;case"p":p=100,d="%",g="r";break;case"b":case"o":case"x":case"X":"#"===c&&(v="0"+g.toLowerCase());case"c":case"d":m=!0,h=0;break;case"s":p=-1,g="r"}"$"===c&&(v=u[0],d=u[1]),"r"!=g||h||(g="g"),null!=h&&("g"==g?h=Math.max(1,Math.min(21,h)):("e"==g||"f"==g)&&(h=Math.max(0,Math.min(20,h)))),g=uc.get(g)||Ut;var y=s&&f;return function(n){var e=d;if(m&&n%1)return"";var u=0>n||0===n&&0>1/n?(n=-n,"-"):a;if(0>p){var c=Go.formatPrefix(n,h);n=c.scale(n),e=c.symbol+d}else n*=p;n=g(n,h);var x=n.lastIndexOf("."),M=0>x?n:n.substring(0,x),_=0>x?"":t+n.substring(x+1);!s&&f&&(M=i(M));var b=v.length+M.length+_.length+(y?0:u.length),w=l>b?new Array(b=l-b+1).join(r):"";return y&&(M=i(w+M)),u+=v,n=M+_,("<"===o?u+n+w:">"===o?w+u+n:"^"===o?w.substring(0,b>>=1)+u+n+w.substring(b):u+(y?n:w+n))+e}}}function Ut(n){return n+""}function jt(){this._=new Date(arguments.length>1?Date.UTC.apply(this,arguments):arguments[0])}function Ht(n,t,e){function r(t){var e=n(t),r=i(e,1);return r-t>t-e?e:r}function u(e){return t(e=n(new oc(e-1)),1),e}function i(n,e){return t(n=new oc(+n),e),n}function o(n,r,i){var o=u(n),a=[];if(i>1)for(;r>o;)e(o)%i||a.push(new Date(+o)),t(o,1);else for(;r>o;)a.push(new Date(+o)),t(o,1);return a}function a(n,t,e){try{oc=jt;var r=new jt;return r._=n,o(r,t,e)}finally{oc=Date}}n.floor=n,n.round=r,n.ceil=u,n.offset=i,n.range=o;var c=n.utc=Ft(n);return c.floor=c,c.round=Ft(r),c.ceil=Ft(u),c.offset=Ft(i),c.range=a,n}function Ft(n){return function(t,e){try{oc=jt;var r=new jt;return r._=t,n(r,e)._}finally{oc=Date}}}function Ot(n){function t(n){function t(t){for(var e,u,i,o=[],a=-1,c=0;++aa;){if(r>=s)return-1;if(u=t.charCodeAt(a++),37===u){if(o=t.charAt(a++),i=N[o in cc?t.charAt(a++):o],!i||(r=i(n,e,r))<0)return-1}else if(u!=e.charCodeAt(r++))return-1}return r}function r(n,t,e){b.lastIndex=0;var r=b.exec(t.substring(e));return r?(n.w=w.get(r[0].toLowerCase()),e+r[0].length):-1}function u(n,t,e){M.lastIndex=0;var r=M.exec(t.substring(e));return r?(n.w=_.get(r[0].toLowerCase()),e+r[0].length):-1}function i(n,t,e){E.lastIndex=0;var r=E.exec(t.substring(e));return r?(n.m=A.get(r[0].toLowerCase()),e+r[0].length):-1}function o(n,t,e){S.lastIndex=0;var r=S.exec(t.substring(e));return r?(n.m=k.get(r[0].toLowerCase()),e+r[0].length):-1}function a(n,t,r){return e(n,C.c.toString(),t,r)}function c(n,t,r){return e(n,C.x.toString(),t,r)}function s(n,t,r){return e(n,C.X.toString(),t,r)}function l(n,t,e){var r=x.get(t.substring(e,e+=2).toLowerCase());return null==r?-1:(n.p=r,e)}var f=n.dateTime,h=n.date,g=n.time,p=n.periods,v=n.days,d=n.shortDays,m=n.months,y=n.shortMonths;t.utc=function(n){function e(n){try{oc=jt;var t=new oc;return t._=n,r(t)}finally{oc=Date}}var r=t(n);return e.parse=function(n){try{oc=jt;var t=r.parse(n);return t&&t._}finally{oc=Date}},e.toString=r.toString,e},t.multi=t.utc.multi=ae;var x=Go.map(),M=Yt(v),_=Zt(v),b=Yt(d),w=Zt(d),S=Yt(m),k=Zt(m),E=Yt(y),A=Zt(y);p.forEach(function(n,t){x.set(n.toLowerCase(),t)});var C={a:function(n){return d[n.getDay()]},A:function(n){return v[n.getDay()]},b:function(n){return y[n.getMonth()]},B:function(n){return m[n.getMonth()]},c:t(f),d:function(n,t){return It(n.getDate(),t,2)},e:function(n,t){return It(n.getDate(),t,2)},H:function(n,t){return It(n.getHours(),t,2)},I:function(n,t){return It(n.getHours()%12||12,t,2)},j:function(n,t){return It(1+ic.dayOfYear(n),t,3)},L:function(n,t){return It(n.getMilliseconds(),t,3)},m:function(n,t){return It(n.getMonth()+1,t,2)},M:function(n,t){return It(n.getMinutes(),t,2)},p:function(n){return p[+(n.getHours()>=12)]},S:function(n,t){return It(n.getSeconds(),t,2)},U:function(n,t){return It(ic.sundayOfYear(n),t,2)},w:function(n){return n.getDay()},W:function(n,t){return It(ic.mondayOfYear(n),t,2)},x:t(h),X:t(g),y:function(n,t){return It(n.getFullYear()%100,t,2)},Y:function(n,t){return It(n.getFullYear()%1e4,t,4)},Z:ie,"%":function(){return"%"}},N={a:r,A:u,b:i,B:o,c:a,d:Qt,e:Qt,H:te,I:te,j:ne,L:ue,m:Kt,M:ee,p:l,S:re,U:$t,w:Vt,W:Xt,x:c,X:s,y:Jt,Y:Bt,Z:Wt,"%":oe};return t}function It(n,t,e){var r=0>n?"-":"",u=(r?-n:n)+"",i=u.length;return r+(e>i?new Array(e-i+1).join(t)+u:u)}function Yt(n){return new RegExp("^(?:"+n.map(Go.requote).join("|")+")","i")}function Zt(n){for(var t=new o,e=-1,r=n.length;++e68?1900:2e3)}function Kt(n,t,e){sc.lastIndex=0;var r=sc.exec(t.substring(e,e+2));return r?(n.m=r[0]-1,e+r[0].length):-1}function Qt(n,t,e){sc.lastIndex=0;var r=sc.exec(t.substring(e,e+2));return r?(n.d=+r[0],e+r[0].length):-1}function ne(n,t,e){sc.lastIndex=0;var r=sc.exec(t.substring(e,e+3));return r?(n.j=+r[0],e+r[0].length):-1}function te(n,t,e){sc.lastIndex=0;var r=sc.exec(t.substring(e,e+2));return r?(n.H=+r[0],e+r[0].length):-1}function ee(n,t,e){sc.lastIndex=0;var r=sc.exec(t.substring(e,e+2));return r?(n.M=+r[0],e+r[0].length):-1}function re(n,t,e){sc.lastIndex=0;var r=sc.exec(t.substring(e,e+2));return r?(n.S=+r[0],e+r[0].length):-1}function ue(n,t,e){sc.lastIndex=0;var r=sc.exec(t.substring(e,e+3));return r?(n.L=+r[0],e+r[0].length):-1}function ie(n){var t=n.getTimezoneOffset(),e=t>0?"-":"+",r=~~(fa(t)/60),u=fa(t)%60;return e+It(r,"0",2)+It(u,"0",2)}function oe(n,t,e){lc.lastIndex=0;var r=lc.exec(t.substring(e,e+1));return r?e+r[0].length:-1}function ae(n){for(var t=n.length,e=-1;++e=0?1:-1,a=o*e,c=Math.cos(t),s=Math.sin(t),l=i*s,f=u*c+l*Math.cos(a),h=l*o*Math.sin(a);dc.add(Math.atan2(h,f)),r=n,u=c,i=s}var t,e,r,u,i;mc.point=function(o,a){mc.point=n,r=(t=o)*za,u=Math.cos(a=(e=a)*za/2+Ca/4),i=Math.sin(a)},mc.lineEnd=function(){n(t,e)}}function pe(n){var t=n[0],e=n[1],r=Math.cos(e);return[r*Math.cos(t),r*Math.sin(t),Math.sin(e)]}function ve(n,t){return n[0]*t[0]+n[1]*t[1]+n[2]*t[2]}function de(n,t){return[n[1]*t[2]-n[2]*t[1],n[2]*t[0]-n[0]*t[2],n[0]*t[1]-n[1]*t[0]]}function me(n,t){n[0]+=t[0],n[1]+=t[1],n[2]+=t[2]}function ye(n,t){return[n[0]*t,n[1]*t,n[2]*t]}function xe(n){var t=Math.sqrt(n[0]*n[0]+n[1]*n[1]+n[2]*n[2]);n[0]/=t,n[1]/=t,n[2]/=t}function Me(n){return[Math.atan2(n[1],n[0]),G(n[2])]}function _e(n,t){return fa(n[0]-t[0])a;++a)u.point((e=n[a])[0],e[1]);return u.lineEnd(),void 0}var c=new Le(e,n,null,!0),s=new Le(e,null,c,!1);c.o=s,i.push(c),o.push(s),c=new Le(r,n,null,!1),s=new Le(r,null,c,!0),c.o=s,i.push(c),o.push(s)}}),o.sort(t),Ne(i),Ne(o),i.length){for(var a=0,c=e,s=o.length;s>a;++a)o[a].e=c=!c;for(var l,f,h=i[0];;){for(var g=h,p=!0;g.v;)if((g=g.n)===h)return;l=g.z,u.lineStart();do{if(g.v=g.o.v=!0,g.e){if(p)for(var a=0,s=l.length;s>a;++a)u.point((f=l[a])[0],f[1]);else r(g.x,g.n.x,1,u);g=g.n}else{if(p){l=g.p.z;for(var a=l.length-1;a>=0;--a)u.point((f=l[a])[0],f[1])}else r(g.x,g.p.x,-1,u);g=g.p}g=g.o,l=g.z,p=!p}while(!g.v);u.lineEnd()}}}function Ne(n){if(t=n.length){for(var t,e,r=0,u=n[0];++r0){for(_||(i.polygonStart(),_=!0),i.lineStart();++o1&&2&t&&e.push(e.pop().concat(e.shift())),g.push(e.filter(qe))}var g,p,v,d=t(i),m=u.invert(r[0],r[1]),y={point:o,lineStart:c,lineEnd:s,polygonStart:function(){y.point=l,y.lineStart=f,y.lineEnd=h,g=[],p=[]},polygonEnd:function(){y.point=o,y.lineStart=c,y.lineEnd=s,g=Go.merge(g);var n=De(m,p);g.length?(_||(i.polygonStart(),_=!0),Ce(g,Re,n,e,i)):n&&(_||(i.polygonStart(),_=!0),i.lineStart(),e(null,null,1,i),i.lineEnd()),_&&(i.polygonEnd(),_=!1),g=p=null},sphere:function(){i.polygonStart(),i.lineStart(),e(null,null,1,i),i.lineEnd(),i.polygonEnd()}},x=ze(),M=t(x),_=!1;return y}}function qe(n){return n.length>1}function ze(){var n,t=[];return{lineStart:function(){t.push(n=[])},point:function(t,e){n.push([t,e])},lineEnd:v,buffer:function(){var e=t;return t=[],n=null,e},rejoin:function(){t.length>1&&t.push(t.pop().concat(t.shift()))}}}function Re(n,t){return((n=n.x)[0]<0?n[1]-La-Ta:La-n[1])-((t=t.x)[0]<0?t[1]-La-Ta:La-t[1])}function De(n,t){var e=n[0],r=n[1],u=[Math.sin(e),-Math.cos(e),0],i=0,o=0;dc.reset();for(var a=0,c=t.length;c>a;++a){var s=t[a],l=s.length;if(l)for(var f=s[0],h=f[0],g=f[1]/2+Ca/4,p=Math.sin(g),v=Math.cos(g),d=1;;){d===l&&(d=0),n=s[d];var m=n[0],y=n[1]/2+Ca/4,x=Math.sin(y),M=Math.cos(y),_=m-h,b=_>=0?1:-1,w=b*_,S=w>Ca,k=p*x;if(dc.add(Math.atan2(k*b*Math.sin(w),v*M+k*Math.cos(w))),i+=S?_+b*Na:_,S^h>=e^m>=e){var E=de(pe(f),pe(n));xe(E);var A=de(u,E);xe(A);var C=(S^_>=0?-1:1)*G(A[2]);(r>C||r===C&&(E[0]||E[1]))&&(o+=S^_>=0?1:-1)}if(!d++)break;h=m,p=x,v=M,f=n}}return(-Ta>i||Ta>i&&0>dc)^1&o}function Pe(n){var t,e=0/0,r=0/0,u=0/0;return{lineStart:function(){n.lineStart(),t=1},point:function(i,o){var a=i>0?Ca:-Ca,c=fa(i-e);fa(c-Ca)0?La:-La),n.point(u,r),n.lineEnd(),n.lineStart(),n.point(a,r),n.point(i,r),t=0):u!==a&&c>=Ca&&(fa(e-u)Ta?Math.atan((Math.sin(t)*(i=Math.cos(r))*Math.sin(e)-Math.sin(r)*(u=Math.cos(t))*Math.sin(n))/(u*i*o)):(t+r)/2}function je(n,t,e,r){var u;if(null==n)u=e*La,r.point(-Ca,u),r.point(0,u),r.point(Ca,u),r.point(Ca,0),r.point(Ca,-u),r.point(0,-u),r.point(-Ca,-u),r.point(-Ca,0),r.point(-Ca,u);else if(fa(n[0]-t[0])>Ta){var i=n[0]i}function e(n){var e,i,c,s,l;return{lineStart:function(){s=c=!1,l=1},point:function(f,h){var g,p=[f,h],v=t(f,h),d=o?v?0:u(f,h):v?u(f+(0>f?Ca:-Ca),h):0;if(!e&&(s=c=v)&&n.lineStart(),v!==c&&(g=r(e,p),(_e(e,g)||_e(p,g))&&(p[0]+=Ta,p[1]+=Ta,v=t(p[0],p[1]))),v!==c)l=0,v?(n.lineStart(),g=r(p,e),n.point(g[0],g[1])):(g=r(e,p),n.point(g[0],g[1]),n.lineEnd()),e=g;else if(a&&e&&o^v){var m;d&i||!(m=r(p,e,!0))||(l=0,o?(n.lineStart(),n.point(m[0][0],m[0][1]),n.point(m[1][0],m[1][1]),n.lineEnd()):(n.point(m[1][0],m[1][1]),n.lineEnd(),n.lineStart(),n.point(m[0][0],m[0][1])))}!v||e&&_e(e,p)||n.point(p[0],p[1]),e=p,c=v,i=d},lineEnd:function(){c&&n.lineEnd(),e=null},clean:function(){return l|(s&&c)<<1}}}function r(n,t,e){var r=pe(n),u=pe(t),o=[1,0,0],a=de(r,u),c=ve(a,a),s=a[0],l=c-s*s;if(!l)return!e&&n;var f=i*c/l,h=-i*s/l,g=de(o,a),p=ye(o,f),v=ye(a,h);me(p,v);var d=g,m=ve(p,d),y=ve(d,d),x=m*m-y*(ve(p,p)-1);if(!(0>x)){var M=Math.sqrt(x),_=ye(d,(-m-M)/y);if(me(_,p),_=Me(_),!e)return _;var b,w=n[0],S=t[0],k=n[1],E=t[1];w>S&&(b=w,w=S,S=b);var A=S-w,C=fa(A-Ca)A;if(!C&&k>E&&(b=k,k=E,E=b),N?C?k+E>0^_[1]<(fa(_[0]-w)Ca^(w<=_[0]&&_[0]<=S)){var L=ye(d,(-m+M)/y);return me(L,p),[_,Me(L)]}}}function u(t,e){var r=o?n:Ca-n,u=0;return-r>t?u|=1:t>r&&(u|=2),-r>e?u|=4:e>r&&(u|=8),u}var i=Math.cos(n),o=i>0,a=fa(i)>Ta,c=gr(n,6*za);return Te(t,e,c,o?[0,-n]:[-Ca,n-Ca])}function Fe(n,t,e,r){return function(u){var i,o=u.a,a=u.b,c=o.x,s=o.y,l=a.x,f=a.y,h=0,g=1,p=l-c,v=f-s;if(i=n-c,p||!(i>0)){if(i/=p,0>p){if(h>i)return;g>i&&(g=i)}else if(p>0){if(i>g)return;i>h&&(h=i)}if(i=e-c,p||!(0>i)){if(i/=p,0>p){if(i>g)return;i>h&&(h=i)}else if(p>0){if(h>i)return;g>i&&(g=i)}if(i=t-s,v||!(i>0)){if(i/=v,0>v){if(h>i)return;g>i&&(g=i)}else if(v>0){if(i>g)return;i>h&&(h=i)}if(i=r-s,v||!(0>i)){if(i/=v,0>v){if(i>g)return;i>h&&(h=i)}else if(v>0){if(h>i)return;g>i&&(g=i)}return h>0&&(u.a={x:c+h*p,y:s+h*v}),1>g&&(u.b={x:c+g*p,y:s+g*v}),u}}}}}}function Oe(n,t,e,r){function u(r,u){return fa(r[0]-n)0?0:3:fa(r[0]-e)0?2:1:fa(r[1]-t)0?1:0:u>0?3:2}function i(n,t){return o(n.x,t.x)}function o(n,t){var e=u(n,1),r=u(t,1);return e!==r?e-r:0===e?t[1]-n[1]:1===e?n[0]-t[0]:2===e?n[1]-t[1]:t[0]-n[0]}return function(a){function c(n){for(var t=0,e=d.length,r=n[1],u=0;e>u;++u)for(var i,o=1,a=d[u],c=a.length,s=a[0];c>o;++o)i=a[o],s[1]<=r?i[1]>r&&J(s,i,n)>0&&++t:i[1]<=r&&J(s,i,n)<0&&--t,s=i;return 0!==t}function s(i,a,c,s){var l=0,f=0;if(null==i||(l=u(i,c))!==(f=u(a,c))||o(i,a)<0^c>0){do s.point(0===l||3===l?n:e,l>1?r:t);while((l=(l+c+4)%4)!==f)}else s.point(a[0],a[1])}function l(u,i){return u>=n&&e>=u&&i>=t&&r>=i}function f(n,t){l(n,t)&&a.point(n,t)}function h(){N.point=p,d&&d.push(m=[]),S=!0,w=!1,_=b=0/0}function g(){v&&(p(y,x),M&&w&&A.rejoin(),v.push(A.buffer())),N.point=f,w&&a.lineEnd()}function p(n,t){n=Math.max(-Tc,Math.min(Tc,n)),t=Math.max(-Tc,Math.min(Tc,t));var e=l(n,t);if(d&&m.push([n,t]),S)y=n,x=t,M=e,S=!1,e&&(a.lineStart(),a.point(n,t));else if(e&&w)a.point(n,t);else{var r={a:{x:_,y:b},b:{x:n,y:t}};C(r)?(w||(a.lineStart(),a.point(r.a.x,r.a.y)),a.point(r.b.x,r.b.y),e||a.lineEnd(),k=!1):e&&(a.lineStart(),a.point(n,t),k=!1)}_=n,b=t,w=e}var v,d,m,y,x,M,_,b,w,S,k,E=a,A=ze(),C=Fe(n,t,e,r),N={point:f,lineStart:h,lineEnd:g,polygonStart:function(){a=A,v=[],d=[],k=!0},polygonEnd:function(){a=E,v=Go.merge(v);var t=c([n,r]),e=k&&t,u=v.length;(e||u)&&(a.polygonStart(),e&&(a.lineStart(),s(null,null,1,a),a.lineEnd()),u&&Ce(v,i,t,s,a),a.polygonEnd()),v=d=m=null}};return N}}function Ie(n,t){function e(e,r){return e=n(e,r),t(e[0],e[1])}return n.invert&&t.invert&&(e.invert=function(e,r){return e=t.invert(e,r),e&&n.invert(e[0],e[1])}),e}function Ye(n){var t=0,e=Ca/3,r=ir(n),u=r(t,e);return u.parallels=function(n){return arguments.length?r(t=n[0]*Ca/180,e=n[1]*Ca/180):[180*(t/Ca),180*(e/Ca)]},u}function Ze(n,t){function e(n,t){var e=Math.sqrt(i-2*u*Math.sin(t))/u;return[e*Math.sin(n*=u),o-e*Math.cos(n)]}var r=Math.sin(n),u=(r+Math.sin(t))/2,i=1+r*(2*u-r),o=Math.sqrt(i)/u;return e.invert=function(n,t){var e=o-t;return[Math.atan2(n,e)/u,G((i-(n*n+e*e)*u*u)/(2*u))]},e}function Ve(){function n(n,t){zc+=u*n-r*t,r=n,u=t}var t,e,r,u;jc.point=function(i,o){jc.point=n,t=r=i,e=u=o},jc.lineEnd=function(){n(t,e)}}function $e(n,t){Rc>n&&(Rc=n),n>Pc&&(Pc=n),Dc>t&&(Dc=t),t>Uc&&(Uc=t)}function Xe(){function n(n,t){o.push("M",n,",",t,i)}function t(n,t){o.push("M",n,",",t),a.point=e}function e(n,t){o.push("L",n,",",t)}function r(){a.point=n}function u(){o.push("Z")}var i=Be(4.5),o=[],a={point:n,lineStart:function(){a.point=t},lineEnd:r,polygonStart:function(){a.lineEnd=u},polygonEnd:function(){a.lineEnd=r,a.point=n},pointRadius:function(n){return i=Be(n),a},result:function(){if(o.length){var n=o.join("");return o=[],n}}};return a}function Be(n){return"m0,"+n+"a"+n+","+n+" 0 1,1 0,"+-2*n+"a"+n+","+n+" 0 1,1 0,"+2*n+"z"}function Je(n,t){Mc+=n,_c+=t,++bc}function We(){function n(n,r){var u=n-t,i=r-e,o=Math.sqrt(u*u+i*i);wc+=o*(t+n)/2,Sc+=o*(e+r)/2,kc+=o,Je(t=n,e=r)}var t,e;Fc.point=function(r,u){Fc.point=n,Je(t=r,e=u)}}function Ge(){Fc.point=Je}function Ke(){function n(n,t){var e=n-r,i=t-u,o=Math.sqrt(e*e+i*i);wc+=o*(r+n)/2,Sc+=o*(u+t)/2,kc+=o,o=u*n-r*t,Ec+=o*(r+n),Ac+=o*(u+t),Cc+=3*o,Je(r=n,u=t)}var t,e,r,u;Fc.point=function(i,o){Fc.point=n,Je(t=r=i,e=u=o)},Fc.lineEnd=function(){n(t,e)}}function Qe(n){function t(t,e){n.moveTo(t,e),n.arc(t,e,o,0,Na)}function e(t,e){n.moveTo(t,e),a.point=r}function r(t,e){n.lineTo(t,e)}function u(){a.point=t}function i(){n.closePath()}var o=4.5,a={point:t,lineStart:function(){a.point=e},lineEnd:u,polygonStart:function(){a.lineEnd=i},polygonEnd:function(){a.lineEnd=u,a.point=t},pointRadius:function(n){return o=n,a},result:v};return a}function nr(n){function t(n){return(a?r:e)(n)}function e(t){return rr(t,function(e,r){e=n(e,r),t.point(e[0],e[1])})}function r(t){function e(e,r){e=n(e,r),t.point(e[0],e[1])}function r(){x=0/0,S.point=i,t.lineStart()}function i(e,r){var i=pe([e,r]),o=n(e,r);u(x,M,y,_,b,w,x=o[0],M=o[1],y=e,_=i[0],b=i[1],w=i[2],a,t),t.point(x,M)}function o(){S.point=e,t.lineEnd()}function c(){r(),S.point=s,S.lineEnd=l}function s(n,t){i(f=n,h=t),g=x,p=M,v=_,d=b,m=w,S.point=i}function l(){u(x,M,y,_,b,w,g,p,f,v,d,m,a,t),S.lineEnd=o,o()}var f,h,g,p,v,d,m,y,x,M,_,b,w,S={point:e,lineStart:r,lineEnd:o,polygonStart:function(){t.polygonStart(),S.lineStart=c},polygonEnd:function(){t.polygonEnd(),S.lineStart=r}};return S}function u(t,e,r,a,c,s,l,f,h,g,p,v,d,m){var y=l-t,x=f-e,M=y*y+x*x;if(M>4*i&&d--){var _=a+g,b=c+p,w=s+v,S=Math.sqrt(_*_+b*b+w*w),k=Math.asin(w/=S),E=fa(fa(w)-1)i||fa((y*L+x*T)/M-.5)>.3||o>a*g+c*p+s*v)&&(u(t,e,r,a,c,s,C,N,E,_/=S,b/=S,w,d,m),m.point(C,N),u(C,N,E,_,b,w,l,f,h,g,p,v,d,m))}}var i=.5,o=Math.cos(30*za),a=16;return t.precision=function(n){return arguments.length?(a=(i=n*n)>0&&16,t):Math.sqrt(i)},t}function tr(n){var t=nr(function(t,e){return n([t*Ra,e*Ra])});return function(n){return or(t(n))}}function er(n){this.stream=n}function rr(n,t){return{point:t,sphere:function(){n.sphere()},lineStart:function(){n.lineStart()},lineEnd:function(){n.lineEnd()},polygonStart:function(){n.polygonStart()},polygonEnd:function(){n.polygonEnd()}}}function ur(n){return ir(function(){return n})()}function ir(n){function t(n){return n=a(n[0]*za,n[1]*za),[n[0]*h+c,s-n[1]*h]}function e(n){return n=a.invert((n[0]-c)/h,(s-n[1])/h),n&&[n[0]*Ra,n[1]*Ra]}function r(){a=Ie(o=sr(m,y,x),i);var n=i(v,d);return c=g-n[0]*h,s=p+n[1]*h,u() +}function u(){return l&&(l.valid=!1,l=null),t}var i,o,a,c,s,l,f=nr(function(n,t){return n=i(n,t),[n[0]*h+c,s-n[1]*h]}),h=150,g=480,p=250,v=0,d=0,m=0,y=0,x=0,M=Lc,_=At,b=null,w=null;return t.stream=function(n){return l&&(l.valid=!1),l=or(M(o,f(_(n)))),l.valid=!0,l},t.clipAngle=function(n){return arguments.length?(M=null==n?(b=n,Lc):He((b=+n)*za),u()):b},t.clipExtent=function(n){return arguments.length?(w=n,_=n?Oe(n[0][0],n[0][1],n[1][0],n[1][1]):At,u()):w},t.scale=function(n){return arguments.length?(h=+n,r()):h},t.translate=function(n){return arguments.length?(g=+n[0],p=+n[1],r()):[g,p]},t.center=function(n){return arguments.length?(v=n[0]%360*za,d=n[1]%360*za,r()):[v*Ra,d*Ra]},t.rotate=function(n){return arguments.length?(m=n[0]%360*za,y=n[1]%360*za,x=n.length>2?n[2]%360*za:0,r()):[m*Ra,y*Ra,x*Ra]},Go.rebind(t,f,"precision"),function(){return i=n.apply(this,arguments),t.invert=i.invert&&e,r()}}function or(n){return rr(n,function(t,e){n.point(t*za,e*za)})}function ar(n,t){return[n,t]}function cr(n,t){return[n>Ca?n-Na:-Ca>n?n+Na:n,t]}function sr(n,t,e){return n?t||e?Ie(fr(n),hr(t,e)):fr(n):t||e?hr(t,e):cr}function lr(n){return function(t,e){return t+=n,[t>Ca?t-Na:-Ca>t?t+Na:t,e]}}function fr(n){var t=lr(n);return t.invert=lr(-n),t}function hr(n,t){function e(n,t){var e=Math.cos(t),a=Math.cos(n)*e,c=Math.sin(n)*e,s=Math.sin(t),l=s*r+a*u;return[Math.atan2(c*i-l*o,a*r-s*u),G(l*i+c*o)]}var r=Math.cos(n),u=Math.sin(n),i=Math.cos(t),o=Math.sin(t);return e.invert=function(n,t){var e=Math.cos(t),a=Math.cos(n)*e,c=Math.sin(n)*e,s=Math.sin(t),l=s*i-c*o;return[Math.atan2(c*i+s*o,a*r+l*u),G(l*r-a*u)]},e}function gr(n,t){var e=Math.cos(n),r=Math.sin(n);return function(u,i,o,a){var c=o*t;null!=u?(u=pr(e,u),i=pr(e,i),(o>0?i>u:u>i)&&(u+=o*Na)):(u=n+o*Na,i=n-.5*c);for(var s,l=u;o>0?l>i:i>l;l-=c)a.point((s=Me([e,-r*Math.cos(l),-r*Math.sin(l)]))[0],s[1])}}function pr(n,t){var e=pe(t);e[0]-=n,xe(e);var r=W(-e[1]);return((-e[2]<0?-r:r)+2*Math.PI-Ta)%(2*Math.PI)}function vr(n,t,e){var r=Go.range(n,t-Ta,e).concat(t);return function(n){return r.map(function(t){return[n,t]})}}function dr(n,t,e){var r=Go.range(n,t-Ta,e).concat(t);return function(n){return r.map(function(t){return[t,n]})}}function mr(n){return n.source}function yr(n){return n.target}function xr(n,t,e,r){var u=Math.cos(t),i=Math.sin(t),o=Math.cos(r),a=Math.sin(r),c=u*Math.cos(n),s=u*Math.sin(n),l=o*Math.cos(e),f=o*Math.sin(e),h=2*Math.asin(Math.sqrt(tt(r-t)+u*o*tt(e-n))),g=1/Math.sin(h),p=h?function(n){var t=Math.sin(n*=h)*g,e=Math.sin(h-n)*g,r=e*c+t*l,u=e*s+t*f,o=e*i+t*a;return[Math.atan2(u,r)*Ra,Math.atan2(o,Math.sqrt(r*r+u*u))*Ra]}:function(){return[n*Ra,t*Ra]};return p.distance=h,p}function Mr(){function n(n,u){var i=Math.sin(u*=za),o=Math.cos(u),a=fa((n*=za)-t),c=Math.cos(a);Oc+=Math.atan2(Math.sqrt((a=o*Math.sin(a))*a+(a=r*i-e*o*c)*a),e*i+r*o*c),t=n,e=i,r=o}var t,e,r;Ic.point=function(u,i){t=u*za,e=Math.sin(i*=za),r=Math.cos(i),Ic.point=n},Ic.lineEnd=function(){Ic.point=Ic.lineEnd=v}}function _r(n,t){function e(t,e){var r=Math.cos(t),u=Math.cos(e),i=n(r*u);return[i*u*Math.sin(t),i*Math.sin(e)]}return e.invert=function(n,e){var r=Math.sqrt(n*n+e*e),u=t(r),i=Math.sin(u),o=Math.cos(u);return[Math.atan2(n*i,r*o),Math.asin(r&&e*i/r)]},e}function br(n,t){function e(n,t){o>0?-La+Ta>t&&(t=-La+Ta):t>La-Ta&&(t=La-Ta);var e=o/Math.pow(u(t),i);return[e*Math.sin(i*n),o-e*Math.cos(i*n)]}var r=Math.cos(n),u=function(n){return Math.tan(Ca/4+n/2)},i=n===t?Math.sin(n):Math.log(r/Math.cos(t))/Math.log(u(t)/u(n)),o=r*Math.pow(u(n),i)/i;return i?(e.invert=function(n,t){var e=o-t,r=B(i)*Math.sqrt(n*n+e*e);return[Math.atan2(n,e)/i,2*Math.atan(Math.pow(o/r,1/i))-La]},e):Sr}function wr(n,t){function e(n,t){var e=i-t;return[e*Math.sin(u*n),i-e*Math.cos(u*n)]}var r=Math.cos(n),u=n===t?Math.sin(n):(r-Math.cos(t))/(t-n),i=r/u+n;return fa(u)u;u++){for(;r>1&&J(n[e[r-2]],n[e[r-1]],n[u])<=0;)--r;e[r++]=u}return e.slice(0,r)}function Lr(n,t){return n[0]-t[0]||n[1]-t[1]}function Tr(n,t,e){return(e[0]-t[0])*(n[1]-t[1])<(e[1]-t[1])*(n[0]-t[0])}function qr(n,t,e,r){var u=n[0],i=e[0],o=t[0]-u,a=r[0]-i,c=n[1],s=e[1],l=t[1]-c,f=r[1]-s,h=(a*(c-s)-f*(u-i))/(f*o-a*l);return[u+h*o,c+h*l]}function zr(n){var t=n[0],e=n[n.length-1];return!(t[0]-e[0]||t[1]-e[1])}function Rr(){tu(this),this.edge=this.site=this.circle=null}function Dr(n){var t=ns.pop()||new Rr;return t.site=n,t}function Pr(n){$r(n),Gc.remove(n),ns.push(n),tu(n)}function Ur(n){var t=n.circle,e=t.x,r=t.cy,u={x:e,y:r},i=n.P,o=n.N,a=[n];Pr(n);for(var c=i;c.circle&&fa(e-c.circle.x)l;++l)s=a[l],c=a[l-1],Kr(s.edge,c.site,s.site,u);c=a[0],s=a[f-1],s.edge=Wr(c.site,s.site,null,u),Vr(c),Vr(s)}function jr(n){for(var t,e,r,u,i=n.x,o=n.y,a=Gc._;a;)if(r=Hr(a,o)-i,r>Ta)a=a.L;else{if(u=i-Fr(a,o),!(u>Ta)){r>-Ta?(t=a.P,e=a):u>-Ta?(t=a,e=a.N):t=e=a;break}if(!a.R){t=a;break}a=a.R}var c=Dr(n);if(Gc.insert(t,c),t||e){if(t===e)return $r(t),e=Dr(t.site),Gc.insert(c,e),c.edge=e.edge=Wr(t.site,c.site),Vr(t),Vr(e),void 0;if(!e)return c.edge=Wr(t.site,c.site),void 0;$r(t),$r(e);var s=t.site,l=s.x,f=s.y,h=n.x-l,g=n.y-f,p=e.site,v=p.x-l,d=p.y-f,m=2*(h*d-g*v),y=h*h+g*g,x=v*v+d*d,M={x:(d*y-g*x)/m+l,y:(h*x-v*y)/m+f};Kr(e.edge,s,p,M),c.edge=Wr(s,n,null,M),e.edge=Wr(n,p,null,M),Vr(t),Vr(e)}}function Hr(n,t){var e=n.site,r=e.x,u=e.y,i=u-t;if(!i)return r;var o=n.P;if(!o)return-1/0;e=o.site;var a=e.x,c=e.y,s=c-t;if(!s)return a;var l=a-r,f=1/i-1/s,h=l/s;return f?(-h+Math.sqrt(h*h-2*f*(l*l/(-2*s)-c+s/2+u-i/2)))/f+r:(r+a)/2}function Fr(n,t){var e=n.N;if(e)return Hr(e,t);var r=n.site;return r.y===t?r.x:1/0}function Or(n){this.site=n,this.edges=[]}function Ir(n){for(var t,e,r,u,i,o,a,c,s,l,f=n[0][0],h=n[1][0],g=n[0][1],p=n[1][1],v=Wc,d=v.length;d--;)if(i=v[d],i&&i.prepare())for(a=i.edges,c=a.length,o=0;c>o;)l=a[o].end(),r=l.x,u=l.y,s=a[++o%c].start(),t=s.x,e=s.y,(fa(r-t)>Ta||fa(u-e)>Ta)&&(a.splice(o,0,new Qr(Gr(i.site,l,fa(r-f)Ta?{x:f,y:fa(t-f)Ta?{x:fa(e-p)Ta?{x:h,y:fa(t-h)Ta?{x:fa(e-g)=-qa)){var g=c*c+s*s,p=l*l+f*f,v=(f*g-s*p)/h,d=(c*p-l*g)/h,f=d+a,m=ts.pop()||new Zr;m.arc=n,m.site=u,m.x=v+o,m.y=f+Math.sqrt(v*v+d*d),m.cy=f,n.circle=m;for(var y=null,x=Qc._;x;)if(m.yd||d>=a)return;if(h>p){if(i){if(i.y>=s)return}else i={x:d,y:c};e={x:d,y:s}}else{if(i){if(i.yr||r>1)if(h>p){if(i){if(i.y>=s)return}else i={x:(c-u)/r,y:c};e={x:(s-u)/r,y:s}}else{if(i){if(i.yg){if(i){if(i.x>=a)return}else i={x:o,y:r*o+u};e={x:a,y:r*a+u}}else{if(i){if(i.xi&&(u=t.substring(i,u),a[o]?a[o]+=u:a[++o]=u),(e=e[0])===(r=r[0])?a[o]?a[o]+=r:a[++o]=r:(a[++o]=null,c.push({i:o,x:pu(e,r)})),i=us.lastIndex;return ir;++r)a[(e=c[r]).i]=e.x(n);return a.join("")})}function du(n,t){for(var e,r=Go.interpolators.length;--r>=0&&!(e=Go.interpolators[r](n,t)););return e}function mu(n,t){var e,r=[],u=[],i=n.length,o=t.length,a=Math.min(n.length,t.length);for(e=0;a>e;++e)r.push(du(n[e],t[e]));for(;i>e;++e)u[e]=n[e];for(;o>e;++e)u[e]=t[e];return function(n){for(e=0;a>e;++e)u[e]=r[e](n);return u}}function yu(n){return function(t){return 0>=t?0:t>=1?1:n(t)}}function xu(n){return function(t){return 1-n(1-t)}}function Mu(n){return function(t){return.5*(.5>t?n(2*t):2-n(2-2*t))}}function _u(n){return n*n}function bu(n){return n*n*n}function wu(n){if(0>=n)return 0;if(n>=1)return 1;var t=n*n,e=t*n;return 4*(.5>n?e:3*(n-t)+e-.75)}function Su(n){return function(t){return Math.pow(t,n)}}function ku(n){return 1-Math.cos(n*La)}function Eu(n){return Math.pow(2,10*(n-1))}function Au(n){return 1-Math.sqrt(1-n*n)}function Cu(n,t){var e;return arguments.length<2&&(t=.45),arguments.length?e=t/Na*Math.asin(1/n):(n=1,e=t/4),function(r){return 1+n*Math.pow(2,-10*r)*Math.sin((r-e)*Na/t)}}function Nu(n){return n||(n=1.70158),function(t){return t*t*((n+1)*t-n)}}function Lu(n){return 1/2.75>n?7.5625*n*n:2/2.75>n?7.5625*(n-=1.5/2.75)*n+.75:2.5/2.75>n?7.5625*(n-=2.25/2.75)*n+.9375:7.5625*(n-=2.625/2.75)*n+.984375}function Tu(n,t){n=Go.hcl(n),t=Go.hcl(t);var e=n.h,r=n.c,u=n.l,i=t.h-e,o=t.c-r,a=t.l-u;return isNaN(o)&&(o=0,r=isNaN(r)?t.c:r),isNaN(i)?(i=0,e=isNaN(e)?t.h:e):i>180?i-=360:-180>i&&(i+=360),function(n){return ct(e+i*n,r+o*n,u+a*n)+""}}function qu(n,t){n=Go.hsl(n),t=Go.hsl(t);var e=n.h,r=n.s,u=n.l,i=t.h-e,o=t.s-r,a=t.l-u;return isNaN(o)&&(o=0,r=isNaN(r)?t.s:r),isNaN(i)?(i=0,e=isNaN(e)?t.h:e):i>180?i-=360:-180>i&&(i+=360),function(n){return it(e+i*n,r+o*n,u+a*n)+""}}function zu(n,t){n=Go.lab(n),t=Go.lab(t);var e=n.l,r=n.a,u=n.b,i=t.l-e,o=t.a-r,a=t.b-u;return function(n){return ft(e+i*n,r+o*n,u+a*n)+""}}function Ru(n,t){return t-=n,function(e){return Math.round(n+t*e)}}function Du(n){var t=[n.a,n.b],e=[n.c,n.d],r=Uu(t),u=Pu(t,e),i=Uu(ju(e,t,-u))||0;t[0]*e[1]180?l+=360:l-s>180&&(s+=360),u.push({i:r.push(r.pop()+"rotate(",null,")")-2,x:pu(s,l)})):l&&r.push(r.pop()+"rotate("+l+")"),f!=h?u.push({i:r.push(r.pop()+"skewX(",null,")")-2,x:pu(f,h)}):h&&r.push(r.pop()+"skewX("+h+")"),g[0]!=p[0]||g[1]!=p[1]?(e=r.push(r.pop()+"scale(",null,",",null,")"),u.push({i:e-4,x:pu(g[0],p[0])},{i:e-2,x:pu(g[1],p[1])})):(1!=p[0]||1!=p[1])&&r.push(r.pop()+"scale("+p+")"),e=u.length,function(n){for(var t,i=-1;++ie;++e)(t=n[e][1])>u&&(r=e,u=t);return r}function ai(n){return n.reduce(ci,0)}function ci(n,t){return n+t[1]}function si(n,t){return li(n,Math.ceil(Math.log(t.length)/Math.LN2+1))}function li(n,t){for(var e=-1,r=+n[0],u=(n[1]-r)/t,i=[];++e<=t;)i[e]=u*e+r;return i}function fi(n){return[Go.min(n),Go.max(n)]}function hi(n,t){return n.parent==t.parent?1:2}function gi(n){var t=n.children;return t&&t.length?t[0]:n._tree.thread}function pi(n){var t,e=n.children;return e&&(t=e.length)?e[t-1]:n._tree.thread}function vi(n,t){var e=n.children;if(e&&(u=e.length))for(var r,u,i=-1;++i0&&(n=r);return n}function di(n,t){return n.x-t.x}function mi(n,t){return t.x-n.x}function yi(n,t){return n.depth-t.depth}function xi(n,t){function e(n,r){var u=n.children;if(u&&(o=u.length))for(var i,o,a=null,c=-1;++c=0;)t=u[i]._tree,t.prelim+=e,t.mod+=e,e+=t.shift+(r+=t.change)}function _i(n,t,e){n=n._tree,t=t._tree;var r=e/(t.number-n.number);n.change+=r,t.change-=r,t.shift+=e,t.prelim+=e,t.mod+=e}function bi(n,t,e){return n._tree.ancestor.parent==t.parent?n._tree.ancestor:e}function wi(n,t){return n.value-t.value}function Si(n,t){var e=n._pack_next;n._pack_next=t,t._pack_prev=n,t._pack_next=e,e._pack_prev=t}function ki(n,t){n._pack_next=t,t._pack_prev=n}function Ei(n,t){var e=t.x-n.x,r=t.y-n.y,u=n.r+t.r;return.999*u*u>e*e+r*r}function Ai(n){function t(n){l=Math.min(n.x-n.r,l),f=Math.max(n.x+n.r,f),h=Math.min(n.y-n.r,h),g=Math.max(n.y+n.r,g)}if((e=n.children)&&(s=e.length)){var e,r,u,i,o,a,c,s,l=1/0,f=-1/0,h=1/0,g=-1/0;if(e.forEach(Ci),r=e[0],r.x=-r.r,r.y=0,t(r),s>1&&(u=e[1],u.x=u.r,u.y=0,t(u),s>2))for(i=e[2],Ti(r,u,i),t(i),Si(r,i),r._pack_prev=i,Si(i,u),u=r._pack_next,o=3;s>o;o++){Ti(r,u,i=e[o]);var p=0,v=1,d=1;for(a=u._pack_next;a!==u;a=a._pack_next,v++)if(Ei(a,i)){p=1;break}if(1==p)for(c=r._pack_prev;c!==a._pack_prev&&!Ei(c,i);c=c._pack_prev,d++);p?(d>v||v==d&&u.ro;o++)i=e[o],i.x-=m,i.y-=y,x=Math.max(x,i.r+Math.sqrt(i.x*i.x+i.y*i.y));n.r=x,e.forEach(Ni)}}function Ci(n){n._pack_next=n._pack_prev=n}function Ni(n){delete n._pack_next,delete n._pack_prev}function Li(n,t,e,r){var u=n.children;if(n.x=t+=r*n.x,n.y=e+=r*n.y,n.r*=r,u)for(var i=-1,o=u.length;++iu&&(e+=u/2,u=0),0>i&&(r+=i/2,i=0),{x:e,y:r,dx:u,dy:i}}function ji(n){var t=n[0],e=n[n.length-1];return e>t?[t,e]:[e,t]}function Hi(n){return n.rangeExtent?n.rangeExtent():ji(n.range())}function Fi(n,t,e,r){var u=e(n[0],n[1]),i=r(t[0],t[1]);return function(n){return i(u(n))}}function Oi(n,t){var e,r=0,u=n.length-1,i=n[r],o=n[u];return i>o&&(e=r,r=u,u=e,e=i,i=o,o=e),n[r]=t.floor(i),n[u]=t.ceil(o),n}function Ii(n){return n?{floor:function(t){return Math.floor(t/n)*n},ceil:function(t){return Math.ceil(t/n)*n}}:vs}function Yi(n,t,e,r){var u=[],i=[],o=0,a=Math.min(n.length,t.length)-1;for(n[a]2?Yi:Fi,c=r?Ou:Fu;return o=u(n,t,c,e),a=u(t,n,c,du),i}function i(n){return o(n)}var o,a;return i.invert=function(n){return a(n)},i.domain=function(t){return arguments.length?(n=t.map(Number),u()):n},i.range=function(n){return arguments.length?(t=n,u()):t},i.rangeRound=function(n){return i.range(n).interpolate(Ru)},i.clamp=function(n){return arguments.length?(r=n,u()):r},i.interpolate=function(n){return arguments.length?(e=n,u()):e},i.ticks=function(t){return Bi(n,t)},i.tickFormat=function(t,e){return Ji(n,t,e)},i.nice=function(t){return $i(n,t),u()},i.copy=function(){return Zi(n,t,e,r)},u()}function Vi(n,t){return Go.rebind(n,t,"range","rangeRound","interpolate","clamp")}function $i(n,t){return Oi(n,Ii(Xi(n,t)[2]))}function Xi(n,t){null==t&&(t=10);var e=ji(n),r=e[1]-e[0],u=Math.pow(10,Math.floor(Math.log(r/t)/Math.LN10)),i=t/r*u;return.15>=i?u*=10:.35>=i?u*=5:.75>=i&&(u*=2),e[0]=Math.ceil(e[0]/u)*u,e[1]=Math.floor(e[1]/u)*u+.5*u,e[2]=u,e}function Bi(n,t){return Go.range.apply(Go,Xi(n,t))}function Ji(n,t,e){var r=Xi(n,t);if(e){var u=rc.exec(e);if(u.shift(),"s"===u[8]){var i=Go.formatPrefix(Math.max(fa(r[0]),fa(r[1])));return u[7]||(u[7]="."+Wi(i.scale(r[2]))),u[8]="f",e=Go.format(u.join("")),function(n){return e(i.scale(n))+i.symbol}}u[7]||(u[7]="."+Gi(u[8],r)),e=u.join("")}else e=",."+Wi(r[2])+"f";return Go.format(e)}function Wi(n){return-Math.floor(Math.log(n)/Math.LN10+.01)}function Gi(n,t){var e=Wi(t[2]);return n in ds?Math.abs(e-Wi(Math.max(fa(t[0]),fa(t[1]))))+ +("e"!==n):e-2*("%"===n)}function Ki(n,t,e,r){function u(n){return(e?Math.log(0>n?0:n):-Math.log(n>0?0:-n))/Math.log(t)}function i(n){return e?Math.pow(t,n):-Math.pow(t,-n)}function o(t){return n(u(t))}return o.invert=function(t){return i(n.invert(t))},o.domain=function(t){return arguments.length?(e=t[0]>=0,n.domain((r=t.map(Number)).map(u)),o):r},o.base=function(e){return arguments.length?(t=+e,n.domain(r.map(u)),o):t},o.nice=function(){var t=Oi(r.map(u),e?Math:ys);return n.domain(t),r=t.map(i),o},o.ticks=function(){var n=ji(r),o=[],a=n[0],c=n[1],s=Math.floor(u(a)),l=Math.ceil(u(c)),f=t%1?2:t;if(isFinite(l-s)){if(e){for(;l>s;s++)for(var h=1;f>h;h++)o.push(i(s)*h);o.push(i(s))}else for(o.push(i(s));s++0;h--)o.push(i(s)*h);for(s=0;o[s]c;l--);o=o.slice(s,l)}return o},o.tickFormat=function(n,t){if(!arguments.length)return ms;arguments.length<2?t=ms:"function"!=typeof t&&(t=Go.format(t));var r,a=Math.max(.1,n/o.ticks().length),c=e?(r=1e-12,Math.ceil):(r=-1e-12,Math.floor);return function(n){return n/i(c(u(n)+r))<=a?t(n):""}},o.copy=function(){return Ki(n.copy(),t,e,r)},Vi(o,n)}function Qi(n,t,e){function r(t){return n(u(t))}var u=no(t),i=no(1/t);return r.invert=function(t){return i(n.invert(t))},r.domain=function(t){return arguments.length?(n.domain((e=t.map(Number)).map(u)),r):e},r.ticks=function(n){return Bi(e,n)},r.tickFormat=function(n,t){return Ji(e,n,t)},r.nice=function(n){return r.domain($i(e,n))},r.exponent=function(o){return arguments.length?(u=no(t=o),i=no(1/t),n.domain(e.map(u)),r):t},r.copy=function(){return Qi(n.copy(),t,e)},Vi(r,n)}function no(n){return function(t){return 0>t?-Math.pow(-t,n):Math.pow(t,n)}}function to(n,t){function e(e){return i[((u.get(e)||("range"===t.t?u.set(e,n.push(e)):0/0))-1)%i.length]}function r(t,e){return Go.range(n.length).map(function(n){return t+e*n})}var u,i,a;return e.domain=function(r){if(!arguments.length)return n;n=[],u=new o;for(var i,a=-1,c=r.length;++an?[0/0,0/0]:[n>0?o[n-1]:e[0],nt?0/0:t/i+n,[t,t+1/i]},r.copy=function(){return ro(n,t,e)},u()}function uo(n,t){function e(e){return e>=e?t[Go.bisect(n,e)]:void 0}return e.domain=function(t){return arguments.length?(n=t,e):n},e.range=function(n){return arguments.length?(t=n,e):t},e.invertExtent=function(e){return e=t.indexOf(e),[n[e-1],n[e]]},e.copy=function(){return uo(n,t)},e}function io(n){function t(n){return+n}return t.invert=t,t.domain=t.range=function(e){return arguments.length?(n=e.map(t),t):n},t.ticks=function(t){return Bi(n,t)},t.tickFormat=function(t,e){return Ji(n,t,e)},t.copy=function(){return io(n)},t}function oo(n){return n.innerRadius}function ao(n){return n.outerRadius}function co(n){return n.startAngle}function so(n){return n.endAngle}function lo(n){function t(t){function o(){s.push("M",i(n(l),a))}for(var c,s=[],l=[],f=-1,h=t.length,g=Et(e),p=Et(r);++f1&&u.push("H",r[0]),u.join("")}function po(n){for(var t=0,e=n.length,r=n[0],u=[r[0],",",r[1]];++t1){a=t[1],i=n[c],c++,r+="C"+(u[0]+o[0])+","+(u[1]+o[1])+","+(i[0]-a[0])+","+(i[1]-a[1])+","+i[0]+","+i[1];for(var s=2;s9&&(u=3*t/Math.sqrt(u),o[a]=u*e,o[a+1]=u*r));for(a=-1;++a<=c;)u=(n[Math.min(c,a+1)][0]-n[Math.max(0,a-1)][0])/(6*(1+o[a]*o[a])),i.push([u||0,o[a]*u||0]);return i}function To(n){return n.length<3?fo(n):n[0]+Mo(n,Lo(n))}function qo(n){for(var t,e,r,u=-1,i=n.length;++ue?s():(u.active=e,i.event&&i.event.start.call(n,l,t),i.tween.forEach(function(e,r){(r=r.call(n,l,t))&&v.push(r)}),Go.timer(function(){return p.c=c(r||1)?Ae:c,1},0,a),void 0)}function c(r){if(u.active!==e)return s();for(var o=r/g,a=f(o),c=v.length;c>0;)v[--c].call(n,a);return o>=1?(i.event&&i.event.end.call(n,l,t),s()):void 0}function s(){return--u.count?delete u[e]:delete n.__transition__,1}var l=n.__data__,f=i.ease,h=i.delay,g=i.duration,p=nc,v=[];return p.t=h+a,r>=h?o(r-h):(p.c=o,void 0)},0,a)}}function Zo(n,t){n.attr("transform",function(n){return"translate("+t(n)+",0)"})}function Vo(n,t){n.attr("transform",function(n){return"translate(0,"+t(n)+")"})}function $o(n){return n.toISOString()}function Xo(n,t,e){function r(t){return n(t) +}function u(n,e){var r=n[1]-n[0],u=r/e,i=Go.bisect(Ys,u);return i==Ys.length?[t.year,Xi(n.map(function(n){return n/31536e6}),e)[2]]:i?t[u/Ys[i-1]1?{floor:function(t){for(;e(t=n.floor(t));)t=Bo(t-1);return t},ceil:function(t){for(;e(t=n.ceil(t));)t=Bo(+t+1);return t}}:n))},r.ticks=function(n,t){var e=ji(r.domain()),i=null==n?u(e,10):"number"==typeof n?u(e,n):!n.range&&[{range:n},t];return i&&(n=i[0],t=i[1]),n.range(e[0],Bo(+e[1]+1),1>t?1:t)},r.tickFormat=function(){return e},r.copy=function(){return Xo(n.copy(),t,e)},Vi(r,n)}function Bo(n){return new Date(n)}function Jo(n){return JSON.parse(n.responseText)}function Wo(n){var t=na.createRange();return t.selectNode(na.body),t.createContextualFragment(n.responseText)}var Go={version:"3.4.6"};Date.now||(Date.now=function(){return+new Date});var Ko=[].slice,Qo=function(n){return Ko.call(n)},na=document,ta=na.documentElement,ea=window;try{Qo(ta.childNodes)[0].nodeType}catch(ra){Qo=function(n){for(var t=n.length,e=new Array(t);t--;)e[t]=n[t];return e}}try{na.createElement("div").style.setProperty("opacity",0,"")}catch(ua){var ia=ea.Element.prototype,oa=ia.setAttribute,aa=ia.setAttributeNS,ca=ea.CSSStyleDeclaration.prototype,sa=ca.setProperty;ia.setAttribute=function(n,t){oa.call(this,n,t+"")},ia.setAttributeNS=function(n,t,e){aa.call(this,n,t,e+"")},ca.setProperty=function(n,t,e){sa.call(this,n,t+"",e)}}Go.ascending=n,Go.descending=function(n,t){return n>t?-1:t>n?1:t>=n?0:0/0},Go.min=function(n,t){var e,r,u=-1,i=n.length;if(1===arguments.length){for(;++u=e);)e=void 0;for(;++ur&&(e=r)}else{for(;++u=e);)e=void 0;for(;++ur&&(e=r)}return e},Go.max=function(n,t){var e,r,u=-1,i=n.length;if(1===arguments.length){for(;++u=e);)e=void 0;for(;++ue&&(e=r)}else{for(;++u=e);)e=void 0;for(;++ue&&(e=r)}return e},Go.extent=function(n,t){var e,r,u,i=-1,o=n.length;if(1===arguments.length){for(;++i=e);)e=u=void 0;for(;++ir&&(e=r),r>u&&(u=r))}else{for(;++i=e);)e=void 0;for(;++ir&&(e=r),r>u&&(u=r))}return[e,u]},Go.sum=function(n,t){var e,r=0,u=n.length,i=-1;if(1===arguments.length)for(;++i1&&(e=e.map(r)),e=e.filter(t),e.length?Go.quantile(e.sort(n),.5):void 0};var la=e(n);Go.bisectLeft=la.left,Go.bisect=Go.bisectRight=la.right,Go.bisector=function(t){return e(1===t.length?function(e,r){return n(t(e),r)}:t)},Go.shuffle=function(n){for(var t,e,r=n.length;r;)e=0|Math.random()*r--,t=n[r],n[r]=n[e],n[e]=t;return n},Go.permute=function(n,t){for(var e=t.length,r=new Array(e);e--;)r[e]=n[t[e]];return r},Go.pairs=function(n){for(var t,e=0,r=n.length-1,u=n[0],i=new Array(0>r?0:r);r>e;)i[e]=[t=u,u=n[++e]];return i},Go.zip=function(){if(!(u=arguments.length))return[];for(var n=-1,t=Go.min(arguments,r),e=new Array(t);++n=0;)for(r=n[u],t=r.length;--t>=0;)e[--o]=r[t];return e};var fa=Math.abs;Go.range=function(n,t,e){if(arguments.length<3&&(e=1,arguments.length<2&&(t=n,n=0)),1/0===(t-n)/e)throw new Error("infinite range");var r,i=[],o=u(fa(e)),a=-1;if(n*=o,t*=o,e*=o,0>e)for(;(r=n+e*++a)>t;)i.push(r/o);else for(;(r=n+e*++a)=i.length)return r?r.call(u,a):e?a.sort(e):a;for(var s,l,f,h,g=-1,p=a.length,v=i[c++],d=new o;++g=i.length)return n;var r=[],u=a[e++];return n.forEach(function(n,u){r.push({key:n,values:t(u,e)})}),u?r.sort(function(n,t){return u(n.key,t.key)}):r}var e,r,u={},i=[],a=[];return u.map=function(t,e){return n(e,t,0)},u.entries=function(e){return t(n(Go.map,e,0),0)},u.key=function(n){return i.push(n),u},u.sortKeys=function(n){return a[i.length-1]=n,u},u.sortValues=function(n){return e=n,u},u.rollup=function(n){return r=n,u},u},Go.set=function(n){var t=new h;if(n)for(var e=0,r=n.length;r>e;++e)t.add(n[e]);return t},i(h,{has:a,add:function(n){return this[ha+n]=!0,n},remove:function(n){return n=ha+n,n in this&&delete this[n]},values:s,size:l,empty:f,forEach:function(n){for(var t in this)t.charCodeAt(0)===ga&&n.call(this,t.substring(1))}}),Go.behavior={},Go.rebind=function(n,t){for(var e,r=1,u=arguments.length;++r=0&&(r=n.substring(e+1),n=n.substring(0,e)),n)return arguments.length<2?this[n].on(r):this[n].on(r,t);if(2===arguments.length){if(null==t)for(n in this)this.hasOwnProperty(n)&&this[n].on(r,null);return this}},Go.event=null,Go.requote=function(n){return n.replace(va,"\\$&")};var va=/[\\\^\$\*\+\?\|\[\]\(\)\.\{\}]/g,da={}.__proto__?function(n,t){n.__proto__=t}:function(n,t){for(var e in t)n[e]=t[e]},ma=function(n,t){return t.querySelector(n)},ya=function(n,t){return t.querySelectorAll(n)},xa=ta[p(ta,"matchesSelector")],Ma=function(n,t){return xa.call(n,t)};"function"==typeof Sizzle&&(ma=function(n,t){return Sizzle(n,t)[0]||null},ya=Sizzle,Ma=Sizzle.matchesSelector),Go.selection=function(){return Sa};var _a=Go.selection.prototype=[];_a.select=function(n){var t,e,r,u,i=[];n=b(n);for(var o=-1,a=this.length;++o=0&&(e=n.substring(0,t),n=n.substring(t+1)),ba.hasOwnProperty(e)?{space:ba[e],local:n}:n}},_a.attr=function(n,t){if(arguments.length<2){if("string"==typeof n){var e=this.node();return n=Go.ns.qualify(n),n.local?e.getAttributeNS(n.space,n.local):e.getAttribute(n)}for(t in n)this.each(S(t,n[t]));return this}return this.each(S(n,t))},_a.classed=function(n,t){if(arguments.length<2){if("string"==typeof n){var e=this.node(),r=(n=A(n)).length,u=-1;if(t=e.classList){for(;++ur){if("string"!=typeof n){2>r&&(t="");for(e in n)this.each(L(e,n[e],t));return this}if(2>r)return ea.getComputedStyle(this.node(),null).getPropertyValue(n);e=""}return this.each(L(n,t,e))},_a.property=function(n,t){if(arguments.length<2){if("string"==typeof n)return this.node()[n];for(t in n)this.each(T(t,n[t]));return this}return this.each(T(n,t))},_a.text=function(n){return arguments.length?this.each("function"==typeof n?function(){var t=n.apply(this,arguments);this.textContent=null==t?"":t}:null==n?function(){this.textContent=""}:function(){this.textContent=n}):this.node().textContent},_a.html=function(n){return arguments.length?this.each("function"==typeof n?function(){var t=n.apply(this,arguments);this.innerHTML=null==t?"":t}:null==n?function(){this.innerHTML=""}:function(){this.innerHTML=n}):this.node().innerHTML},_a.append=function(n){return n=q(n),this.select(function(){return this.appendChild(n.apply(this,arguments))})},_a.insert=function(n,t){return n=q(n),t=b(t),this.select(function(){return this.insertBefore(n.apply(this,arguments),t.apply(this,arguments)||null)})},_a.remove=function(){return this.each(function(){var n=this.parentNode;n&&n.removeChild(this)})},_a.data=function(n,t){function e(n,e){var r,u,i,a=n.length,f=e.length,h=Math.min(a,f),g=new Array(f),p=new Array(f),v=new Array(a);if(t){var d,m=new o,y=new o,x=[];for(r=-1;++rr;++r)p[r]=z(e[r]);for(;a>r;++r)v[r]=n[r]}p.update=g,p.parentNode=g.parentNode=v.parentNode=n.parentNode,c.push(p),s.push(g),l.push(v)}var r,u,i=-1,a=this.length;if(!arguments.length){for(n=new Array(a=(r=this[0]).length);++ii;i++){u.push(t=[]),t.parentNode=(e=this[i]).parentNode;for(var a=0,c=e.length;c>a;a++)(r=e[a])&&n.call(r,r.__data__,a,i)&&t.push(r)}return _(u)},_a.order=function(){for(var n=-1,t=this.length;++n=0;)(e=r[u])&&(i&&i!==e.nextSibling&&i.parentNode.insertBefore(e,i),i=e);return this},_a.sort=function(n){n=D.apply(this,arguments);for(var t=-1,e=this.length;++tn;n++)for(var e=this[n],r=0,u=e.length;u>r;r++){var i=e[r];if(i)return i}return null},_a.size=function(){var n=0;return this.each(function(){++n}),n};var wa=[];Go.selection.enter=U,Go.selection.enter.prototype=wa,wa.append=_a.append,wa.empty=_a.empty,wa.node=_a.node,wa.call=_a.call,wa.size=_a.size,wa.select=function(n){for(var t,e,r,u,i,o=[],a=-1,c=this.length;++ar){if("string"!=typeof n){2>r&&(t=!1);for(e in n)this.each(F(e,n[e],t));return this}if(2>r)return(r=this.node()["__on"+n])&&r._;e=!1}return this.each(F(n,t,e))};var ka=Go.map({mouseenter:"mouseover",mouseleave:"mouseout"});ka.forEach(function(n){"on"+n in na&&ka.remove(n)});var Ea="onselectstart"in na?null:p(ta.style,"userSelect"),Aa=0;Go.mouse=function(n){return Z(n,x())},Go.touches=function(n,t){return arguments.length<2&&(t=x().touches),t?Qo(t).map(function(t){var e=Z(n,t);return e.identifier=t.identifier,e}):[]},Go.behavior.drag=function(){function n(){this.on("mousedown.drag",u).on("touchstart.drag",i)}function t(n,t,u,i,o){return function(){function a(){var n,e,r=t(h,v);r&&(n=r[0]-x[0],e=r[1]-x[1],p|=n|e,x=r,g({type:"drag",x:r[0]+s[0],y:r[1]+s[1],dx:n,dy:e}))}function c(){t(h,v)&&(m.on(i+d,null).on(o+d,null),y(p&&Go.event.target===f),g({type:"dragend"}))}var s,l=this,f=Go.event.target,h=l.parentNode,g=e.of(l,arguments),p=0,v=n(),d=".drag"+(null==v?"":"-"+v),m=Go.select(u()).on(i+d,a).on(o+d,c),y=Y(),x=t(h,v);r?(s=r.apply(l,arguments),s=[s.x-x[0],s.y-x[1]]):s=[0,0],g({type:"dragstart"})}}var e=M(n,"drag","dragstart","dragend"),r=null,u=t(v,Go.mouse,X,"mousemove","mouseup"),i=t(V,Go.touch,$,"touchmove","touchend");return n.origin=function(t){return arguments.length?(r=t,n):r},Go.rebind(n,e,"on")};var Ca=Math.PI,Na=2*Ca,La=Ca/2,Ta=1e-6,qa=Ta*Ta,za=Ca/180,Ra=180/Ca,Da=Math.SQRT2,Pa=2,Ua=4;Go.interpolateZoom=function(n,t){function e(n){var t=n*y;if(m){var e=Q(v),o=i/(Pa*h)*(e*nt(Da*t+v)-K(v));return[r+o*s,u+o*l,i*e/Q(Da*t+v)]}return[r+n*s,u+n*l,i*Math.exp(Da*t)]}var r=n[0],u=n[1],i=n[2],o=t[0],a=t[1],c=t[2],s=o-r,l=a-u,f=s*s+l*l,h=Math.sqrt(f),g=(c*c-i*i+Ua*f)/(2*i*Pa*h),p=(c*c-i*i-Ua*f)/(2*c*Pa*h),v=Math.log(Math.sqrt(g*g+1)-g),d=Math.log(Math.sqrt(p*p+1)-p),m=d-v,y=(m||Math.log(c/i))/Da;return e.duration=1e3*y,e},Go.behavior.zoom=function(){function n(n){n.on(A,s).on(Fa+".zoom",f).on(C,h).on("dblclick.zoom",g).on(L,l)}function t(n){return[(n[0]-S.x)/S.k,(n[1]-S.y)/S.k]}function e(n){return[n[0]*S.k+S.x,n[1]*S.k+S.y]}function r(n){S.k=Math.max(E[0],Math.min(E[1],n))}function u(n,t){t=e(t),S.x+=n[0]-t[0],S.y+=n[1]-t[1]}function i(){_&&_.domain(x.range().map(function(n){return(n-S.x)/S.k}).map(x.invert)),w&&w.domain(b.range().map(function(n){return(n-S.y)/S.k}).map(b.invert))}function o(n){n({type:"zoomstart"})}function a(n){i(),n({type:"zoom",scale:S.k,translate:[S.x,S.y]})}function c(n){n({type:"zoomend"})}function s(){function n(){l=1,u(Go.mouse(r),g),a(s)}function e(){f.on(C,ea===r?h:null).on(N,null),p(l&&Go.event.target===i),c(s)}var r=this,i=Go.event.target,s=T.of(r,arguments),l=0,f=Go.select(ea).on(C,n).on(N,e),g=t(Go.mouse(r)),p=Y();H.call(r),o(s)}function l(){function n(){var n=Go.touches(g);return h=S.k,n.forEach(function(n){n.identifier in v&&(v[n.identifier]=t(n))}),n}function e(){for(var t=Go.event.changedTouches,e=0,i=t.length;i>e;++e)v[t[e].identifier]=null;var o=n(),c=Date.now();if(1===o.length){if(500>c-m){var s=o[0],l=v[s.identifier];r(2*S.k),u(s,l),y(),a(p)}m=c}else if(o.length>1){var s=o[0],f=o[1],h=s[0]-f[0],g=s[1]-f[1];d=h*h+g*g}}function i(){for(var n,t,e,i,o=Go.touches(g),c=0,s=o.length;s>c;++c,i=null)if(e=o[c],i=v[e.identifier]){if(t)break;n=e,t=i}if(i){var l=(l=e[0]-n[0])*l+(l=e[1]-n[1])*l,f=d&&Math.sqrt(l/d);n=[(n[0]+e[0])/2,(n[1]+e[1])/2],t=[(t[0]+i[0])/2,(t[1]+i[1])/2],r(f*h)}m=null,u(n,t),a(p)}function f(){if(Go.event.touches.length){for(var t=Go.event.changedTouches,e=0,r=t.length;r>e;++e)delete v[t[e].identifier];for(var u in v)return void n()}b.on(x,null),w.on(A,s).on(L,l),k(),c(p)}var h,g=this,p=T.of(g,arguments),v={},d=0,x=".zoom-"+Go.event.changedTouches[0].identifier,M="touchmove"+x,_="touchend"+x,b=Go.select(Go.event.target).on(M,i).on(_,f),w=Go.select(g).on(A,null).on(L,e),k=Y();H.call(g),e(),o(p)}function f(){var n=T.of(this,arguments);d?clearTimeout(d):(H.call(this),o(n)),d=setTimeout(function(){d=null,c(n)},50),y();var e=v||Go.mouse(this);p||(p=t(e)),r(Math.pow(2,.002*ja())*S.k),u(e,p),a(n)}function h(){p=null}function g(){var n=T.of(this,arguments),e=Go.mouse(this),i=t(e),s=Math.log(S.k)/Math.LN2;o(n),r(Math.pow(2,Go.event.shiftKey?Math.ceil(s)-1:Math.floor(s)+1)),u(e,i),a(n),c(n)}var p,v,d,m,x,_,b,w,S={x:0,y:0,k:1},k=[960,500],E=Ha,A="mousedown.zoom",C="mousemove.zoom",N="mouseup.zoom",L="touchstart.zoom",T=M(n,"zoomstart","zoom","zoomend");return n.event=function(n){n.each(function(){var n=T.of(this,arguments),t=S;Ls?Go.select(this).transition().each("start.zoom",function(){S=this.__chart__||{x:0,y:0,k:1},o(n)}).tween("zoom:zoom",function(){var e=k[0],r=k[1],u=e/2,i=r/2,o=Go.interpolateZoom([(u-S.x)/S.k,(i-S.y)/S.k,e/S.k],[(u-t.x)/t.k,(i-t.y)/t.k,e/t.k]);return function(t){var r=o(t),c=e/r[2];this.__chart__=S={x:u-r[0]*c,y:i-r[1]*c,k:c},a(n)}}).each("end.zoom",function(){c(n)}):(this.__chart__=S,o(n),a(n),c(n))})},n.translate=function(t){return arguments.length?(S={x:+t[0],y:+t[1],k:S.k},i(),n):[S.x,S.y]},n.scale=function(t){return arguments.length?(S={x:S.x,y:S.y,k:+t},i(),n):S.k},n.scaleExtent=function(t){return arguments.length?(E=null==t?Ha:[+t[0],+t[1]],n):E},n.center=function(t){return arguments.length?(v=t&&[+t[0],+t[1]],n):v},n.size=function(t){return arguments.length?(k=t&&[+t[0],+t[1]],n):k},n.x=function(t){return arguments.length?(_=t,x=t.copy(),S={x:0,y:0,k:1},n):_},n.y=function(t){return arguments.length?(w=t,b=t.copy(),S={x:0,y:0,k:1},n):w},Go.rebind(n,T,"on")};var ja,Ha=[0,1/0],Fa="onwheel"in na?(ja=function(){return-Go.event.deltaY*(Go.event.deltaMode?120:1)},"wheel"):"onmousewheel"in na?(ja=function(){return Go.event.wheelDelta},"mousewheel"):(ja=function(){return-Go.event.detail},"MozMousePixelScroll");et.prototype.toString=function(){return this.rgb()+""},Go.hsl=function(n,t,e){return 1===arguments.length?n instanceof ut?rt(n.h,n.s,n.l):_t(""+n,bt,rt):rt(+n,+t,+e)};var Oa=ut.prototype=new et;Oa.brighter=function(n){return n=Math.pow(.7,arguments.length?n:1),rt(this.h,this.s,this.l/n)},Oa.darker=function(n){return n=Math.pow(.7,arguments.length?n:1),rt(this.h,this.s,n*this.l)},Oa.rgb=function(){return it(this.h,this.s,this.l)},Go.hcl=function(n,t,e){return 1===arguments.length?n instanceof at?ot(n.h,n.c,n.l):n instanceof lt?ht(n.l,n.a,n.b):ht((n=wt((n=Go.rgb(n)).r,n.g,n.b)).l,n.a,n.b):ot(+n,+t,+e)};var Ia=at.prototype=new et;Ia.brighter=function(n){return ot(this.h,this.c,Math.min(100,this.l+Ya*(arguments.length?n:1)))},Ia.darker=function(n){return ot(this.h,this.c,Math.max(0,this.l-Ya*(arguments.length?n:1)))},Ia.rgb=function(){return ct(this.h,this.c,this.l).rgb()},Go.lab=function(n,t,e){return 1===arguments.length?n instanceof lt?st(n.l,n.a,n.b):n instanceof at?ct(n.l,n.c,n.h):wt((n=Go.rgb(n)).r,n.g,n.b):st(+n,+t,+e)};var Ya=18,Za=.95047,Va=1,$a=1.08883,Xa=lt.prototype=new et;Xa.brighter=function(n){return st(Math.min(100,this.l+Ya*(arguments.length?n:1)),this.a,this.b)},Xa.darker=function(n){return st(Math.max(0,this.l-Ya*(arguments.length?n:1)),this.a,this.b)},Xa.rgb=function(){return ft(this.l,this.a,this.b)},Go.rgb=function(n,t,e){return 1===arguments.length?n instanceof xt?yt(n.r,n.g,n.b):_t(""+n,yt,it):yt(~~n,~~t,~~e)};var Ba=xt.prototype=new et;Ba.brighter=function(n){n=Math.pow(.7,arguments.length?n:1);var t=this.r,e=this.g,r=this.b,u=30;return t||e||r?(t&&u>t&&(t=u),e&&u>e&&(e=u),r&&u>r&&(r=u),yt(Math.min(255,~~(t/n)),Math.min(255,~~(e/n)),Math.min(255,~~(r/n)))):yt(u,u,u)},Ba.darker=function(n){return n=Math.pow(.7,arguments.length?n:1),yt(~~(n*this.r),~~(n*this.g),~~(n*this.b))},Ba.hsl=function(){return bt(this.r,this.g,this.b)},Ba.toString=function(){return"#"+Mt(this.r)+Mt(this.g)+Mt(this.b)};var Ja=Go.map({aliceblue:15792383,antiquewhite:16444375,aqua:65535,aquamarine:8388564,azure:15794175,beige:16119260,bisque:16770244,black:0,blanchedalmond:16772045,blue:255,blueviolet:9055202,brown:10824234,burlywood:14596231,cadetblue:6266528,chartreuse:8388352,chocolate:13789470,coral:16744272,cornflowerblue:6591981,cornsilk:16775388,crimson:14423100,cyan:65535,darkblue:139,darkcyan:35723,darkgoldenrod:12092939,darkgray:11119017,darkgreen:25600,darkgrey:11119017,darkkhaki:12433259,darkmagenta:9109643,darkolivegreen:5597999,darkorange:16747520,darkorchid:10040012,darkred:9109504,darksalmon:15308410,darkseagreen:9419919,darkslateblue:4734347,darkslategray:3100495,darkslategrey:3100495,darkturquoise:52945,darkviolet:9699539,deeppink:16716947,deepskyblue:49151,dimgray:6908265,dimgrey:6908265,dodgerblue:2003199,firebrick:11674146,floralwhite:16775920,forestgreen:2263842,fuchsia:16711935,gainsboro:14474460,ghostwhite:16316671,gold:16766720,goldenrod:14329120,gray:8421504,green:32768,greenyellow:11403055,grey:8421504,honeydew:15794160,hotpink:16738740,indianred:13458524,indigo:4915330,ivory:16777200,khaki:15787660,lavender:15132410,lavenderblush:16773365,lawngreen:8190976,lemonchiffon:16775885,lightblue:11393254,lightcoral:15761536,lightcyan:14745599,lightgoldenrodyellow:16448210,lightgray:13882323,lightgreen:9498256,lightgrey:13882323,lightpink:16758465,lightsalmon:16752762,lightseagreen:2142890,lightskyblue:8900346,lightslategray:7833753,lightslategrey:7833753,lightsteelblue:11584734,lightyellow:16777184,lime:65280,limegreen:3329330,linen:16445670,magenta:16711935,maroon:8388608,mediumaquamarine:6737322,mediumblue:205,mediumorchid:12211667,mediumpurple:9662683,mediumseagreen:3978097,mediumslateblue:8087790,mediumspringgreen:64154,mediumturquoise:4772300,mediumvioletred:13047173,midnightblue:1644912,mintcream:16121850,mistyrose:16770273,moccasin:16770229,navajowhite:16768685,navy:128,oldlace:16643558,olive:8421376,olivedrab:7048739,orange:16753920,orangered:16729344,orchid:14315734,palegoldenrod:15657130,palegreen:10025880,paleturquoise:11529966,palevioletred:14381203,papayawhip:16773077,peachpuff:16767673,peru:13468991,pink:16761035,plum:14524637,powderblue:11591910,purple:8388736,red:16711680,rosybrown:12357519,royalblue:4286945,saddlebrown:9127187,salmon:16416882,sandybrown:16032864,seagreen:3050327,seashell:16774638,sienna:10506797,silver:12632256,skyblue:8900331,slateblue:6970061,slategray:7372944,slategrey:7372944,snow:16775930,springgreen:65407,steelblue:4620980,tan:13808780,teal:32896,thistle:14204888,tomato:16737095,turquoise:4251856,violet:15631086,wheat:16113331,white:16777215,whitesmoke:16119285,yellow:16776960,yellowgreen:10145074});Ja.forEach(function(n,t){Ja.set(n,dt(t))}),Go.functor=Et,Go.xhr=Ct(At),Go.dsv=function(n,t){function e(n,e,i){arguments.length<3&&(i=e,e=null);var o=Nt(n,t,null==e?r:u(e),i);return o.row=function(n){return arguments.length?o.response(null==(e=n)?r:u(n)):e},o}function r(n){return e.parse(n.responseText)}function u(n){return function(t){return e.parse(t.responseText,n)}}function i(t){return t.map(o).join(n)}function o(n){return a.test(n)?'"'+n.replace(/\"/g,'""')+'"':n}var a=new RegExp('["'+n+"\n]"),c=n.charCodeAt(0);return e.parse=function(n,t){var r;return e.parseRows(n,function(n,e){if(r)return r(n,e-1);var u=new Function("d","return {"+n.map(function(n,t){return JSON.stringify(n)+": d["+t+"]"}).join(",")+"}");r=t?function(n,e){return t(u(n),e)}:u})},e.parseRows=function(n,t){function e(){if(l>=s)return o;if(u)return u=!1,i;var t=l;if(34===n.charCodeAt(t)){for(var e=t;e++l;){var r=n.charCodeAt(l++),a=1;if(10===r)u=!0;else if(13===r)u=!0,10===n.charCodeAt(l)&&(++l,++a);else if(r!==c)continue;return n.substring(t,l-a)}return n.substring(t)}for(var r,u,i={},o={},a=[],s=n.length,l=0,f=0;(r=e())!==o;){for(var h=[];r!==i&&r!==o;)h.push(r),r=e();(!t||(h=t(h,f++)))&&a.push(h)}return a},e.format=function(t){if(Array.isArray(t[0]))return e.formatRows(t);var r=new h,u=[];return t.forEach(function(n){for(var t in n)r.has(t)||u.push(r.add(t))}),[u.map(o).join(n)].concat(t.map(function(t){return u.map(function(n){return o(t[n])}).join(n)})).join("\n")},e.formatRows=function(n){return n.map(i).join("\n")},e},Go.csv=Go.dsv(",","text/csv"),Go.tsv=Go.dsv(" ","text/tab-separated-values"),Go.touch=function(n,t,e){if(arguments.length<3&&(e=t,t=x().changedTouches),t)for(var r,u=0,i=t.length;i>u;++u)if((r=t[u]).identifier===e)return Z(n,r)};var Wa,Ga,Ka,Qa,nc,tc=ea[p(ea,"requestAnimationFrame")]||function(n){setTimeout(n,17)};Go.timer=function(n,t,e){var r=arguments.length;2>r&&(t=0),3>r&&(e=Date.now());var u=e+t,i={c:n,t:u,f:!1,n:null};Ga?Ga.n=i:Wa=i,Ga=i,Ka||(Qa=clearTimeout(Qa),Ka=1,tc(Tt))},Go.timer.flush=function(){qt(),zt()},Go.round=function(n,t){return t?Math.round(n*(t=Math.pow(10,t)))/t:Math.round(n)};var ec=["y","z","a","f","p","n","\xb5","m","","k","M","G","T","P","E","Z","Y"].map(Dt);Go.formatPrefix=function(n,t){var e=0;return n&&(0>n&&(n*=-1),t&&(n=Go.round(n,Rt(n,t))),e=1+Math.floor(1e-12+Math.log(n)/Math.LN10),e=Math.max(-24,Math.min(24,3*Math.floor((e-1)/3)))),ec[8+e/3]};var rc=/(?:([^{])?([<>=^]))?([+\- ])?([$#])?(0)?(\d+)?(,)?(\.-?\d+)?([a-z%])?/i,uc=Go.map({b:function(n){return n.toString(2)},c:function(n){return String.fromCharCode(n)},o:function(n){return n.toString(8)},x:function(n){return n.toString(16)},X:function(n){return n.toString(16).toUpperCase()},g:function(n,t){return n.toPrecision(t)},e:function(n,t){return n.toExponential(t)},f:function(n,t){return n.toFixed(t)},r:function(n,t){return(n=Go.round(n,Rt(n,t))).toFixed(Math.max(0,Math.min(20,Rt(n*(1+1e-15),t))))}}),ic=Go.time={},oc=Date;jt.prototype={getDate:function(){return this._.getUTCDate()},getDay:function(){return this._.getUTCDay()},getFullYear:function(){return this._.getUTCFullYear()},getHours:function(){return this._.getUTCHours()},getMilliseconds:function(){return this._.getUTCMilliseconds()},getMinutes:function(){return this._.getUTCMinutes()},getMonth:function(){return this._.getUTCMonth()},getSeconds:function(){return this._.getUTCSeconds()},getTime:function(){return this._.getTime()},getTimezoneOffset:function(){return 0},valueOf:function(){return this._.valueOf()},setDate:function(){ac.setUTCDate.apply(this._,arguments)},setDay:function(){ac.setUTCDay.apply(this._,arguments)},setFullYear:function(){ac.setUTCFullYear.apply(this._,arguments)},setHours:function(){ac.setUTCHours.apply(this._,arguments)},setMilliseconds:function(){ac.setUTCMilliseconds.apply(this._,arguments)},setMinutes:function(){ac.setUTCMinutes.apply(this._,arguments)},setMonth:function(){ac.setUTCMonth.apply(this._,arguments)},setSeconds:function(){ac.setUTCSeconds.apply(this._,arguments)},setTime:function(){ac.setTime.apply(this._,arguments)}};var ac=Date.prototype;ic.year=Ht(function(n){return n=ic.day(n),n.setMonth(0,1),n},function(n,t){n.setFullYear(n.getFullYear()+t)},function(n){return n.getFullYear()}),ic.years=ic.year.range,ic.years.utc=ic.year.utc.range,ic.day=Ht(function(n){var t=new oc(2e3,0);return t.setFullYear(n.getFullYear(),n.getMonth(),n.getDate()),t},function(n,t){n.setDate(n.getDate()+t)},function(n){return n.getDate()-1}),ic.days=ic.day.range,ic.days.utc=ic.day.utc.range,ic.dayOfYear=function(n){var t=ic.year(n);return Math.floor((n-t-6e4*(n.getTimezoneOffset()-t.getTimezoneOffset()))/864e5)},["sunday","monday","tuesday","wednesday","thursday","friday","saturday"].forEach(function(n,t){t=7-t;var e=ic[n]=Ht(function(n){return(n=ic.day(n)).setDate(n.getDate()-(n.getDay()+t)%7),n},function(n,t){n.setDate(n.getDate()+7*Math.floor(t))},function(n){var e=ic.year(n).getDay();return Math.floor((ic.dayOfYear(n)+(e+t)%7)/7)-(e!==t)});ic[n+"s"]=e.range,ic[n+"s"].utc=e.utc.range,ic[n+"OfYear"]=function(n){var e=ic.year(n).getDay();return Math.floor((ic.dayOfYear(n)+(e+t)%7)/7)}}),ic.week=ic.sunday,ic.weeks=ic.sunday.range,ic.weeks.utc=ic.sunday.utc.range,ic.weekOfYear=ic.sundayOfYear;var cc={"-":"",_:" ",0:"0"},sc=/^\s*\d+/,lc=/^%/;Go.locale=function(n){return{numberFormat:Pt(n),timeFormat:Ot(n)}};var fc=Go.locale({decimal:".",thousands:",",grouping:[3],currency:["$",""],dateTime:"%a %b %e %X %Y",date:"%m/%d/%Y",time:"%H:%M:%S",periods:["AM","PM"],days:["Sunday","Monday","Tuesday","Wednesday","Thursday","Friday","Saturday"],shortDays:["Sun","Mon","Tue","Wed","Thu","Fri","Sat"],months:["January","February","March","April","May","June","July","August","September","October","November","December"],shortMonths:["Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"]});Go.format=fc.numberFormat,Go.geo={},ce.prototype={s:0,t:0,add:function(n){se(n,this.t,hc),se(hc.s,this.s,this),this.s?this.t+=hc.t:this.s=hc.t},reset:function(){this.s=this.t=0},valueOf:function(){return this.s}};var hc=new ce;Go.geo.stream=function(n,t){n&&gc.hasOwnProperty(n.type)?gc[n.type](n,t):le(n,t)};var gc={Feature:function(n,t){le(n.geometry,t)},FeatureCollection:function(n,t){for(var e=n.features,r=-1,u=e.length;++rn?4*Ca+n:n,mc.lineStart=mc.lineEnd=mc.point=v}};Go.geo.bounds=function(){function n(n,t){x.push(M=[l=n,h=n]),f>t&&(f=t),t>g&&(g=t)}function t(t,e){var r=pe([t*za,e*za]);if(m){var u=de(m,r),i=[u[1],-u[0],0],o=de(i,u);xe(o),o=Me(o);var c=t-p,s=c>0?1:-1,v=o[0]*Ra*s,d=fa(c)>180;if(d^(v>s*p&&s*t>v)){var y=o[1]*Ra;y>g&&(g=y)}else if(v=(v+360)%360-180,d^(v>s*p&&s*t>v)){var y=-o[1]*Ra;f>y&&(f=y)}else f>e&&(f=e),e>g&&(g=e);d?p>t?a(l,t)>a(l,h)&&(h=t):a(t,h)>a(l,h)&&(l=t):h>=l?(l>t&&(l=t),t>h&&(h=t)):t>p?a(l,t)>a(l,h)&&(h=t):a(t,h)>a(l,h)&&(l=t)}else n(t,e);m=r,p=t}function e(){_.point=t}function r(){M[0]=l,M[1]=h,_.point=n,m=null}function u(n,e){if(m){var r=n-p;y+=fa(r)>180?r+(r>0?360:-360):r}else v=n,d=e;mc.point(n,e),t(n,e)}function i(){mc.lineStart()}function o(){u(v,d),mc.lineEnd(),fa(y)>Ta&&(l=-(h=180)),M[0]=l,M[1]=h,m=null}function a(n,t){return(t-=n)<0?t+360:t}function c(n,t){return n[0]-t[0]}function s(n,t){return t[0]<=t[1]?t[0]<=n&&n<=t[1]:ndc?(l=-(h=180),f=-(g=90)):y>Ta?g=90:-Ta>y&&(f=-90),M[0]=l,M[1]=h}};return function(n){g=h=-(l=f=1/0),x=[],Go.geo.stream(n,_);var t=x.length;if(t){x.sort(c);for(var e,r=1,u=x[0],i=[u];t>r;++r)e=x[r],s(e[0],u)||s(e[1],u)?(a(u[0],e[1])>a(u[0],u[1])&&(u[1]=e[1]),a(e[0],u[1])>a(u[0],u[1])&&(u[0]=e[0])):i.push(u=e); +for(var o,e,p=-1/0,t=i.length-1,r=0,u=i[t];t>=r;u=e,++r)e=i[r],(o=a(u[1],e[0]))>p&&(p=o,l=e[0],h=u[1])}return x=M=null,1/0===l||1/0===f?[[0/0,0/0],[0/0,0/0]]:[[l,f],[h,g]]}}(),Go.geo.centroid=function(n){yc=xc=Mc=_c=bc=wc=Sc=kc=Ec=Ac=Cc=0,Go.geo.stream(n,Nc);var t=Ec,e=Ac,r=Cc,u=t*t+e*e+r*r;return qa>u&&(t=wc,e=Sc,r=kc,Ta>xc&&(t=Mc,e=_c,r=bc),u=t*t+e*e+r*r,qa>u)?[0/0,0/0]:[Math.atan2(e,t)*Ra,G(r/Math.sqrt(u))*Ra]};var yc,xc,Mc,_c,bc,wc,Sc,kc,Ec,Ac,Cc,Nc={sphere:v,point:be,lineStart:Se,lineEnd:ke,polygonStart:function(){Nc.lineStart=Ee},polygonEnd:function(){Nc.lineStart=Se}},Lc=Te(Ae,Pe,je,[-Ca,-Ca/2]),Tc=1e9;Go.geo.clipExtent=function(){var n,t,e,r,u,i,o={stream:function(n){return u&&(u.valid=!1),u=i(n),u.valid=!0,u},extent:function(a){return arguments.length?(i=Oe(n=+a[0][0],t=+a[0][1],e=+a[1][0],r=+a[1][1]),u&&(u.valid=!1,u=null),o):[[n,t],[e,r]]}};return o.extent([[0,0],[960,500]])},(Go.geo.conicEqualArea=function(){return Ye(Ze)}).raw=Ze,Go.geo.albers=function(){return Go.geo.conicEqualArea().rotate([96,0]).center([-.6,38.7]).parallels([29.5,45.5]).scale(1070)},Go.geo.albersUsa=function(){function n(n){var i=n[0],o=n[1];return t=null,e(i,o),t||(r(i,o),t)||u(i,o),t}var t,e,r,u,i=Go.geo.albers(),o=Go.geo.conicEqualArea().rotate([154,0]).center([-2,58.5]).parallels([55,65]),a=Go.geo.conicEqualArea().rotate([157,0]).center([-3,19.9]).parallels([8,18]),c={point:function(n,e){t=[n,e]}};return n.invert=function(n){var t=i.scale(),e=i.translate(),r=(n[0]-e[0])/t,u=(n[1]-e[1])/t;return(u>=.12&&.234>u&&r>=-.425&&-.214>r?o:u>=.166&&.234>u&&r>=-.214&&-.115>r?a:i).invert(n)},n.stream=function(n){var t=i.stream(n),e=o.stream(n),r=a.stream(n);return{point:function(n,u){t.point(n,u),e.point(n,u),r.point(n,u)},sphere:function(){t.sphere(),e.sphere(),r.sphere()},lineStart:function(){t.lineStart(),e.lineStart(),r.lineStart()},lineEnd:function(){t.lineEnd(),e.lineEnd(),r.lineEnd()},polygonStart:function(){t.polygonStart(),e.polygonStart(),r.polygonStart()},polygonEnd:function(){t.polygonEnd(),e.polygonEnd(),r.polygonEnd()}}},n.precision=function(t){return arguments.length?(i.precision(t),o.precision(t),a.precision(t),n):i.precision()},n.scale=function(t){return arguments.length?(i.scale(t),o.scale(.35*t),a.scale(t),n.translate(i.translate())):i.scale()},n.translate=function(t){if(!arguments.length)return i.translate();var s=i.scale(),l=+t[0],f=+t[1];return e=i.translate(t).clipExtent([[l-.455*s,f-.238*s],[l+.455*s,f+.238*s]]).stream(c).point,r=o.translate([l-.307*s,f+.201*s]).clipExtent([[l-.425*s+Ta,f+.12*s+Ta],[l-.214*s-Ta,f+.234*s-Ta]]).stream(c).point,u=a.translate([l-.205*s,f+.212*s]).clipExtent([[l-.214*s+Ta,f+.166*s+Ta],[l-.115*s-Ta,f+.234*s-Ta]]).stream(c).point,n},n.scale(1070)};var qc,zc,Rc,Dc,Pc,Uc,jc={point:v,lineStart:v,lineEnd:v,polygonStart:function(){zc=0,jc.lineStart=Ve},polygonEnd:function(){jc.lineStart=jc.lineEnd=jc.point=v,qc+=fa(zc/2)}},Hc={point:$e,lineStart:v,lineEnd:v,polygonStart:v,polygonEnd:v},Fc={point:Je,lineStart:We,lineEnd:Ge,polygonStart:function(){Fc.lineStart=Ke},polygonEnd:function(){Fc.point=Je,Fc.lineStart=We,Fc.lineEnd=Ge}};Go.geo.path=function(){function n(n){return n&&("function"==typeof a&&i.pointRadius(+a.apply(this,arguments)),o&&o.valid||(o=u(i)),Go.geo.stream(n,o)),i.result()}function t(){return o=null,n}var e,r,u,i,o,a=4.5;return n.area=function(n){return qc=0,Go.geo.stream(n,u(jc)),qc},n.centroid=function(n){return Mc=_c=bc=wc=Sc=kc=Ec=Ac=Cc=0,Go.geo.stream(n,u(Fc)),Cc?[Ec/Cc,Ac/Cc]:kc?[wc/kc,Sc/kc]:bc?[Mc/bc,_c/bc]:[0/0,0/0]},n.bounds=function(n){return Pc=Uc=-(Rc=Dc=1/0),Go.geo.stream(n,u(Hc)),[[Rc,Dc],[Pc,Uc]]},n.projection=function(n){return arguments.length?(u=(e=n)?n.stream||tr(n):At,t()):e},n.context=function(n){return arguments.length?(i=null==(r=n)?new Xe:new Qe(n),"function"!=typeof a&&i.pointRadius(a),t()):r},n.pointRadius=function(t){return arguments.length?(a="function"==typeof t?t:(i.pointRadius(+t),+t),n):a},n.projection(Go.geo.albersUsa()).context(null)},Go.geo.transform=function(n){return{stream:function(t){var e=new er(t);for(var r in n)e[r]=n[r];return e}}},er.prototype={point:function(n,t){this.stream.point(n,t)},sphere:function(){this.stream.sphere()},lineStart:function(){this.stream.lineStart()},lineEnd:function(){this.stream.lineEnd()},polygonStart:function(){this.stream.polygonStart()},polygonEnd:function(){this.stream.polygonEnd()}},Go.geo.projection=ur,Go.geo.projectionMutator=ir,(Go.geo.equirectangular=function(){return ur(ar)}).raw=ar.invert=ar,Go.geo.rotation=function(n){function t(t){return t=n(t[0]*za,t[1]*za),t[0]*=Ra,t[1]*=Ra,t}return n=sr(n[0]%360*za,n[1]*za,n.length>2?n[2]*za:0),t.invert=function(t){return t=n.invert(t[0]*za,t[1]*za),t[0]*=Ra,t[1]*=Ra,t},t},cr.invert=ar,Go.geo.circle=function(){function n(){var n="function"==typeof r?r.apply(this,arguments):r,t=sr(-n[0]*za,-n[1]*za,0).invert,u=[];return e(null,null,1,{point:function(n,e){u.push(n=t(n,e)),n[0]*=Ra,n[1]*=Ra}}),{type:"Polygon",coordinates:[u]}}var t,e,r=[0,0],u=6;return n.origin=function(t){return arguments.length?(r=t,n):r},n.angle=function(r){return arguments.length?(e=gr((t=+r)*za,u*za),n):t},n.precision=function(r){return arguments.length?(e=gr(t*za,(u=+r)*za),n):u},n.angle(90)},Go.geo.distance=function(n,t){var e,r=(t[0]-n[0])*za,u=n[1]*za,i=t[1]*za,o=Math.sin(r),a=Math.cos(r),c=Math.sin(u),s=Math.cos(u),l=Math.sin(i),f=Math.cos(i);return Math.atan2(Math.sqrt((e=f*o)*e+(e=s*l-c*f*a)*e),c*l+s*f*a)},Go.geo.graticule=function(){function n(){return{type:"MultiLineString",coordinates:t()}}function t(){return Go.range(Math.ceil(i/d)*d,u,d).map(h).concat(Go.range(Math.ceil(s/m)*m,c,m).map(g)).concat(Go.range(Math.ceil(r/p)*p,e,p).filter(function(n){return fa(n%d)>Ta}).map(l)).concat(Go.range(Math.ceil(a/v)*v,o,v).filter(function(n){return fa(n%m)>Ta}).map(f))}var e,r,u,i,o,a,c,s,l,f,h,g,p=10,v=p,d=90,m=360,y=2.5;return n.lines=function(){return t().map(function(n){return{type:"LineString",coordinates:n}})},n.outline=function(){return{type:"Polygon",coordinates:[h(i).concat(g(c).slice(1),h(u).reverse().slice(1),g(s).reverse().slice(1))]}},n.extent=function(t){return arguments.length?n.majorExtent(t).minorExtent(t):n.minorExtent()},n.majorExtent=function(t){return arguments.length?(i=+t[0][0],u=+t[1][0],s=+t[0][1],c=+t[1][1],i>u&&(t=i,i=u,u=t),s>c&&(t=s,s=c,c=t),n.precision(y)):[[i,s],[u,c]]},n.minorExtent=function(t){return arguments.length?(r=+t[0][0],e=+t[1][0],a=+t[0][1],o=+t[1][1],r>e&&(t=r,r=e,e=t),a>o&&(t=a,a=o,o=t),n.precision(y)):[[r,a],[e,o]]},n.step=function(t){return arguments.length?n.majorStep(t).minorStep(t):n.minorStep()},n.majorStep=function(t){return arguments.length?(d=+t[0],m=+t[1],n):[d,m]},n.minorStep=function(t){return arguments.length?(p=+t[0],v=+t[1],n):[p,v]},n.precision=function(t){return arguments.length?(y=+t,l=vr(a,o,90),f=dr(r,e,y),h=vr(s,c,90),g=dr(i,u,y),n):y},n.majorExtent([[-180,-90+Ta],[180,90-Ta]]).minorExtent([[-180,-80-Ta],[180,80+Ta]])},Go.geo.greatArc=function(){function n(){return{type:"LineString",coordinates:[t||r.apply(this,arguments),e||u.apply(this,arguments)]}}var t,e,r=mr,u=yr;return n.distance=function(){return Go.geo.distance(t||r.apply(this,arguments),e||u.apply(this,arguments))},n.source=function(e){return arguments.length?(r=e,t="function"==typeof e?null:e,n):r},n.target=function(t){return arguments.length?(u=t,e="function"==typeof t?null:t,n):u},n.precision=function(){return arguments.length?n:0},n},Go.geo.interpolate=function(n,t){return xr(n[0]*za,n[1]*za,t[0]*za,t[1]*za)},Go.geo.length=function(n){return Oc=0,Go.geo.stream(n,Ic),Oc};var Oc,Ic={sphere:v,point:v,lineStart:Mr,lineEnd:v,polygonStart:v,polygonEnd:v},Yc=_r(function(n){return Math.sqrt(2/(1+n))},function(n){return 2*Math.asin(n/2)});(Go.geo.azimuthalEqualArea=function(){return ur(Yc)}).raw=Yc;var Zc=_r(function(n){var t=Math.acos(n);return t&&t/Math.sin(t)},At);(Go.geo.azimuthalEquidistant=function(){return ur(Zc)}).raw=Zc,(Go.geo.conicConformal=function(){return Ye(br)}).raw=br,(Go.geo.conicEquidistant=function(){return Ye(wr)}).raw=wr;var Vc=_r(function(n){return 1/n},Math.atan);(Go.geo.gnomonic=function(){return ur(Vc)}).raw=Vc,Sr.invert=function(n,t){return[n,2*Math.atan(Math.exp(t))-La]},(Go.geo.mercator=function(){return kr(Sr)}).raw=Sr;var $c=_r(function(){return 1},Math.asin);(Go.geo.orthographic=function(){return ur($c)}).raw=$c;var Xc=_r(function(n){return 1/(1+n)},function(n){return 2*Math.atan(n)});(Go.geo.stereographic=function(){return ur(Xc)}).raw=Xc,Er.invert=function(n,t){return[-t,2*Math.atan(Math.exp(n))-La]},(Go.geo.transverseMercator=function(){var n=kr(Er),t=n.center,e=n.rotate;return n.center=function(n){return n?t([-n[1],n[0]]):(n=t(),[-n[1],n[0]])},n.rotate=function(n){return n?e([n[0],n[1],n.length>2?n[2]+90:90]):(n=e(),[n[0],n[1],n[2]-90])},n.rotate([0,0])}).raw=Er,Go.geom={},Go.geom.hull=function(n){function t(n){if(n.length<3)return[];var t,u=Et(e),i=Et(r),o=n.length,a=[],c=[];for(t=0;o>t;t++)a.push([+u.call(this,n[t],t),+i.call(this,n[t],t),t]);for(a.sort(Lr),t=0;o>t;t++)c.push([a[t][0],-a[t][1]]);var s=Nr(a),l=Nr(c),f=l[0]===s[0],h=l[l.length-1]===s[s.length-1],g=[];for(t=s.length-1;t>=0;--t)g.push(n[a[s[t]][2]]);for(t=+f;t=r&&s.x<=i&&s.y>=u&&s.y<=o?[[r,o],[i,o],[i,u],[r,u]]:[];l.point=n[a]}),t}function e(n){return n.map(function(n,t){return{x:Math.round(i(n,t)/Ta)*Ta,y:Math.round(o(n,t)/Ta)*Ta,i:t}})}var r=Ar,u=Cr,i=r,o=u,a=es;return n?t(n):(t.links=function(n){return iu(e(n)).edges.filter(function(n){return n.l&&n.r}).map(function(t){return{source:n[t.l.i],target:n[t.r.i]}})},t.triangles=function(n){var t=[];return iu(e(n)).cells.forEach(function(e,r){for(var u,i,o=e.site,a=e.edges.sort(Yr),c=-1,s=a.length,l=a[s-1].edge,f=l.l===o?l.r:l.l;++c=s,h=r>=l,g=(h<<1)+f;n.leaf=!1,n=n.nodes[g]||(n.nodes[g]=lu()),f?u=s:a=s,h?o=l:c=l,i(n,t,e,r,u,o,a,c)}var l,f,h,g,p,v,d,m,y,x=Et(a),M=Et(c);if(null!=t)v=t,d=e,m=r,y=u;else if(m=y=-(v=d=1/0),f=[],h=[],p=n.length,o)for(g=0;p>g;++g)l=n[g],l.xm&&(m=l.x),l.y>y&&(y=l.y),f.push(l.x),h.push(l.y);else for(g=0;p>g;++g){var _=+x(l=n[g],g),b=+M(l,g);v>_&&(v=_),d>b&&(d=b),_>m&&(m=_),b>y&&(y=b),f.push(_),h.push(b)}var w=m-v,S=y-d;w>S?y=d+w:m=v+S;var k=lu();if(k.add=function(n){i(k,n,+x(n,++g),+M(n,g),v,d,m,y)},k.visit=function(n){fu(n,k,v,d,m,y)},g=-1,null==t){for(;++g=0?n.substring(0,t):n,r=t>=0?n.substring(t+1):"in";return e=os.get(e)||is,r=as.get(r)||At,yu(r(e.apply(null,Ko.call(arguments,1))))},Go.interpolateHcl=Tu,Go.interpolateHsl=qu,Go.interpolateLab=zu,Go.interpolateRound=Ru,Go.transform=function(n){var t=na.createElementNS(Go.ns.prefix.svg,"g");return(Go.transform=function(n){if(null!=n){t.setAttribute("transform",n);var e=t.transform.baseVal.consolidate()}return new Du(e?e.matrix:cs)})(n)},Du.prototype.toString=function(){return"translate("+this.translate+")rotate("+this.rotate+")skewX("+this.skew+")scale("+this.scale+")"};var cs={a:1,b:0,c:0,d:1,e:0,f:0};Go.interpolateTransform=Hu,Go.layout={},Go.layout.bundle=function(){return function(n){for(var t=[],e=-1,r=n.length;++ea*a/d){if(p>c){var s=t.charge/c;n.px-=i*s,n.py-=o*s}return!0}if(t.point&&c&&p>c){var s=t.pointCharge/c;n.px-=i*s,n.py-=o*s}}return!t.charge}}function t(n){n.px=Go.event.x,n.py=Go.event.y,a.resume()}var e,r,u,i,o,a={},c=Go.dispatch("start","tick","end"),s=[1,1],l=.9,f=ss,h=ls,g=-30,p=fs,v=.1,d=.64,m=[],y=[];return a.tick=function(){if((r*=.99)<.005)return c.end({type:"end",alpha:r=0}),!0;var t,e,a,f,h,p,d,x,M,_=m.length,b=y.length;for(e=0;b>e;++e)a=y[e],f=a.source,h=a.target,x=h.x-f.x,M=h.y-f.y,(p=x*x+M*M)&&(p=r*i[e]*((p=Math.sqrt(p))-u[e])/p,x*=p,M*=p,h.x-=x*(d=f.weight/(h.weight+f.weight)),h.y-=M*d,f.x+=x*(d=1-d),f.y+=M*d);if((d=r*v)&&(x=s[0]/2,M=s[1]/2,e=-1,d))for(;++e<_;)a=m[e],a.x+=(x-a.x)*d,a.y+=(M-a.y)*d;if(g)for(Ju(t=Go.geom.quadtree(m),r,o),e=-1;++e<_;)(a=m[e]).fixed||t.visit(n(a));for(e=-1;++e<_;)a=m[e],a.fixed?(a.x=a.px,a.y=a.py):(a.x-=(a.px-(a.px=a.x))*l,a.y-=(a.py-(a.py=a.y))*l);c.tick({type:"tick",alpha:r})},a.nodes=function(n){return arguments.length?(m=n,a):m},a.links=function(n){return arguments.length?(y=n,a):y},a.size=function(n){return arguments.length?(s=n,a):s},a.linkDistance=function(n){return arguments.length?(f="function"==typeof n?n:+n,a):f},a.distance=a.linkDistance,a.linkStrength=function(n){return arguments.length?(h="function"==typeof n?n:+n,a):h},a.friction=function(n){return arguments.length?(l=+n,a):l},a.charge=function(n){return arguments.length?(g="function"==typeof n?n:+n,a):g},a.chargeDistance=function(n){return arguments.length?(p=n*n,a):Math.sqrt(p)},a.gravity=function(n){return arguments.length?(v=+n,a):v},a.theta=function(n){return arguments.length?(d=n*n,a):Math.sqrt(d)},a.alpha=function(n){return arguments.length?(n=+n,r?r=n>0?n:0:n>0&&(c.start({type:"start",alpha:r=n}),Go.timer(a.tick)),a):r},a.start=function(){function n(n,r){if(!e){for(e=new Array(c),a=0;c>a;++a)e[a]=[];for(a=0;s>a;++a){var u=y[a];e[u.source.index].push(u.target),e[u.target.index].push(u.source)}}for(var i,o=e[t],a=-1,s=o.length;++at;++t)(r=m[t]).index=t,r.weight=0;for(t=0;l>t;++t)r=y[t],"number"==typeof r.source&&(r.source=m[r.source]),"number"==typeof r.target&&(r.target=m[r.target]),++r.source.weight,++r.target.weight;for(t=0;c>t;++t)r=m[t],isNaN(r.x)&&(r.x=n("x",p)),isNaN(r.y)&&(r.y=n("y",v)),isNaN(r.px)&&(r.px=r.x),isNaN(r.py)&&(r.py=r.y);if(u=[],"function"==typeof f)for(t=0;l>t;++t)u[t]=+f.call(this,y[t],t);else for(t=0;l>t;++t)u[t]=f;if(i=[],"function"==typeof h)for(t=0;l>t;++t)i[t]=+h.call(this,y[t],t);else for(t=0;l>t;++t)i[t]=h;if(o=[],"function"==typeof g)for(t=0;c>t;++t)o[t]=+g.call(this,m[t],t);else for(t=0;c>t;++t)o[t]=g;return a.resume()},a.resume=function(){return a.alpha(.1)},a.stop=function(){return a.alpha(0)},a.drag=function(){return e||(e=Go.behavior.drag().origin(At).on("dragstart.force",Vu).on("drag.force",t).on("dragend.force",$u)),arguments.length?(this.on("mouseover.force",Xu).on("mouseout.force",Bu).call(e),void 0):e},Go.rebind(a,c,"on")};var ss=20,ls=1,fs=1/0;Go.layout.hierarchy=function(){function n(t,o,a){var c=u.call(e,t,o);if(t.depth=o,a.push(t),c&&(s=c.length)){for(var s,l,f=-1,h=t.children=new Array(s),g=0,p=o+1;++fg;++g)for(u.call(n,s[0][g],p=v[g],l[0][g][1]),h=1;d>h;++h)u.call(n,s[h][g],p+=l[h-1][g][1],l[h][g][1]);return a}var t=At,e=ui,r=ii,u=ri,i=ti,o=ei;return n.values=function(e){return arguments.length?(t=e,n):t},n.order=function(t){return arguments.length?(e="function"==typeof t?t:gs.get(t)||ui,n):e},n.offset=function(t){return arguments.length?(r="function"==typeof t?t:ps.get(t)||ii,n):r},n.x=function(t){return arguments.length?(i=t,n):i},n.y=function(t){return arguments.length?(o=t,n):o},n.out=function(t){return arguments.length?(u=t,n):u},n};var gs=Go.map({"inside-out":function(n){var t,e,r=n.length,u=n.map(oi),i=n.map(ai),o=Go.range(r).sort(function(n,t){return u[n]-u[t]}),a=0,c=0,s=[],l=[];for(t=0;r>t;++t)e=o[t],c>a?(a+=i[e],s.push(e)):(c+=i[e],l.push(e));return l.reverse().concat(s)},reverse:function(n){return Go.range(n.length).reverse()},"default":ui}),ps=Go.map({silhouette:function(n){var t,e,r,u=n.length,i=n[0].length,o=[],a=0,c=[];for(e=0;i>e;++e){for(t=0,r=0;u>t;t++)r+=n[t][e][1];r>a&&(a=r),o.push(r)}for(e=0;i>e;++e)c[e]=(a-o[e])/2;return c},wiggle:function(n){var t,e,r,u,i,o,a,c,s,l=n.length,f=n[0],h=f.length,g=[];for(g[0]=c=s=0,e=1;h>e;++e){for(t=0,u=0;l>t;++t)u+=n[t][e][1];for(t=0,i=0,a=f[e][0]-f[e-1][0];l>t;++t){for(r=0,o=(n[t][e][1]-n[t][e-1][1])/(2*a);t>r;++r)o+=(n[r][e][1]-n[r][e-1][1])/a;i+=o*n[t][e][1]}g[e]=c-=u?i/u*a:0,s>c&&(s=c)}for(e=0;h>e;++e)g[e]-=s;return g},expand:function(n){var t,e,r,u=n.length,i=n[0].length,o=1/u,a=[];for(e=0;i>e;++e){for(t=0,r=0;u>t;t++)r+=n[t][e][1];if(r)for(t=0;u>t;t++)n[t][e][1]/=r;else for(t=0;u>t;t++)n[t][e][1]=o}for(e=0;i>e;++e)a[e]=0;return a},zero:ii});Go.layout.histogram=function(){function n(n,i){for(var o,a,c=[],s=n.map(e,this),l=r.call(this,s,i),f=u.call(this,l,s,i),i=-1,h=s.length,g=f.length-1,p=t?1:1/h;++i0)for(i=-1;++i=l[0]&&a<=l[1]&&(o=c[Go.bisect(f,a,1,g)-1],o.y+=p,o.push(n[i]));return c}var t=!0,e=Number,r=fi,u=si;return n.value=function(t){return arguments.length?(e=t,n):e},n.range=function(t){return arguments.length?(r=Et(t),n):r},n.bins=function(t){return arguments.length?(u="number"==typeof t?function(n){return li(n,t)}:Et(t),n):u},n.frequency=function(e){return arguments.length?(t=!!e,n):t},n},Go.layout.tree=function(){function n(n,i){function o(n,t){var r=n.children,u=n._tree;if(r&&(i=r.length)){for(var i,a,s,l=r[0],f=l,h=-1;++h0&&(_i(bi(a,n,r),n,u),s+=u,l+=u),f+=a._tree.mod,s+=i._tree.mod,h+=c._tree.mod,l+=o._tree.mod;a&&!pi(o)&&(o._tree.thread=a,o._tree.mod+=f-l),i&&!gi(c)&&(c._tree.thread=i,c._tree.mod+=s-h,r=n)}return r}var s=t.call(this,n,i),l=s[0];xi(l,function(n,t){n._tree={ancestor:n,prelim:0,mod:0,change:0,shift:0,number:t?t._tree.number+1:0}}),o(l),a(l,-l._tree.prelim);var f=vi(l,mi),h=vi(l,di),g=vi(l,yi),p=f.x-e(f,h)/2,v=h.x+e(h,f)/2,d=g.depth||1;return xi(l,u?function(n){n.x*=r[0],n.y=n.depth*r[1],delete n._tree}:function(n){n.x=(n.x-p)/(v-p)*r[0],n.y=n.depth/d*r[1],delete n._tree}),s}var t=Go.layout.hierarchy().sort(null).value(null),e=hi,r=[1,1],u=!1;return n.separation=function(t){return arguments.length?(e=t,n):e},n.size=function(t){return arguments.length?(u=null==(r=t),n):u?null:r},n.nodeSize=function(t){return arguments.length?(u=null!=(r=t),n):u?r:null},Wu(n,t)},Go.layout.pack=function(){function n(n,i){var o=e.call(this,n,i),a=o[0],c=u[0],s=u[1],l=null==t?Math.sqrt:"function"==typeof t?t:function(){return t};if(a.x=a.y=0,xi(a,function(n){n.r=+l(n.value)}),xi(a,Ai),r){var f=r*(t?1:Math.max(2*a.r/c,2*a.r/s))/2;xi(a,function(n){n.r+=f}),xi(a,Ai),xi(a,function(n){n.r-=f})}return Li(a,c/2,s/2,t?1:1/Math.max(2*a.r/c,2*a.r/s)),o}var t,e=Go.layout.hierarchy().sort(wi),r=0,u=[1,1];return n.size=function(t){return arguments.length?(u=t,n):u},n.radius=function(e){return arguments.length?(t=null==e||"function"==typeof e?e:+e,n):t},n.padding=function(t){return arguments.length?(r=+t,n):r},Wu(n,e)},Go.layout.cluster=function(){function n(n,i){var o,a=t.call(this,n,i),c=a[0],s=0;xi(c,function(n){var t=n.children;t&&t.length?(n.x=zi(t),n.y=qi(t)):(n.x=o?s+=e(n,o):0,n.y=0,o=n)});var l=Ri(c),f=Di(c),h=l.x-e(l,f)/2,g=f.x+e(f,l)/2;return xi(c,u?function(n){n.x=(n.x-c.x)*r[0],n.y=(c.y-n.y)*r[1]}:function(n){n.x=(n.x-h)/(g-h)*r[0],n.y=(1-(c.y?n.y/c.y:1))*r[1]}),a}var t=Go.layout.hierarchy().sort(null).value(null),e=hi,r=[1,1],u=!1;return n.separation=function(t){return arguments.length?(e=t,n):e},n.size=function(t){return arguments.length?(u=null==(r=t),n):u?null:r},n.nodeSize=function(t){return arguments.length?(u=null!=(r=t),n):u?r:null},Wu(n,t)},Go.layout.treemap=function(){function n(n,t){for(var e,r,u=-1,i=n.length;++ut?0:t),e.area=isNaN(r)||0>=r?0:r}function t(e){var i=e.children;if(i&&i.length){var o,a,c,s=f(e),l=[],h=i.slice(),p=1/0,v="slice"===g?s.dx:"dice"===g?s.dy:"slice-dice"===g?1&e.depth?s.dy:s.dx:Math.min(s.dx,s.dy);for(n(h,s.dx*s.dy/e.value),l.area=0;(c=h.length)>0;)l.push(o=h[c-1]),l.area+=o.area,"squarify"!==g||(a=r(l,v))<=p?(h.pop(),p=a):(l.area-=l.pop().area,u(l,v,s,!1),v=Math.min(s.dx,s.dy),l.length=l.area=0,p=1/0);l.length&&(u(l,v,s,!0),l.length=l.area=0),i.forEach(t)}}function e(t){var r=t.children;if(r&&r.length){var i,o=f(t),a=r.slice(),c=[];for(n(a,o.dx*o.dy/t.value),c.area=0;i=a.pop();)c.push(i),c.area+=i.area,null!=i.z&&(u(c,i.z?o.dx:o.dy,o,!a.length),c.length=c.area=0);r.forEach(e)}}function r(n,t){for(var e,r=n.area,u=0,i=1/0,o=-1,a=n.length;++oe&&(i=e),e>u&&(u=e));return r*=r,t*=t,r?Math.max(t*u*p/r,r/(t*i*p)):1/0}function u(n,t,e,r){var u,i=-1,o=n.length,a=e.x,s=e.y,l=t?c(n.area/t):0;if(t==e.dx){for((r||l>e.dy)&&(l=e.dy);++ie.dx)&&(l=e.dx);++ie&&(t=1),1>e&&(n=0),function(){var e,r,u;do e=2*Math.random()-1,r=2*Math.random()-1,u=e*e+r*r;while(!u||u>1);return n+t*e*Math.sqrt(-2*Math.log(u)/u)}},logNormal:function(){var n=Go.random.normal.apply(Go,arguments);return function(){return Math.exp(n())}},bates:function(n){var t=Go.random.irwinHall(n);return function(){return t()/n}},irwinHall:function(n){return function(){for(var t=0,e=0;n>e;e++)t+=Math.random();return t}}},Go.scale={};var vs={floor:At,ceil:At};Go.scale.linear=function(){return Zi([0,1],[0,1],du,!1)};var ds={s:1,g:1,p:1,r:1,e:1};Go.scale.log=function(){return Ki(Go.scale.linear().domain([0,1]),10,!0,[1,10])};var ms=Go.format(".0e"),ys={floor:function(n){return-Math.ceil(-n)},ceil:function(n){return-Math.floor(-n)}};Go.scale.pow=function(){return Qi(Go.scale.linear(),1,[0,1])},Go.scale.sqrt=function(){return Go.scale.pow().exponent(.5)},Go.scale.ordinal=function(){return to([],{t:"range",a:[[]]})},Go.scale.category10=function(){return Go.scale.ordinal().range(xs)},Go.scale.category20=function(){return Go.scale.ordinal().range(Ms)},Go.scale.category20b=function(){return Go.scale.ordinal().range(_s)},Go.scale.category20c=function(){return Go.scale.ordinal().range(bs)};var xs=[2062260,16744206,2924588,14034728,9725885,9197131,14907330,8355711,12369186,1556175].map(mt),Ms=[2062260,11454440,16744206,16759672,2924588,10018698,14034728,16750742,9725885,12955861,9197131,12885140,14907330,16234194,8355711,13092807,12369186,14408589,1556175,10410725].map(mt),_s=[3750777,5395619,7040719,10264286,6519097,9216594,11915115,13556636,9202993,12426809,15186514,15190932,8666169,11356490,14049643,15177372,8077683,10834324,13528509,14589654].map(mt),bs=[3244733,7057110,10406625,13032431,15095053,16616764,16625259,16634018,3253076,7652470,10607003,13101504,7695281,10394312,12369372,14342891,6513507,9868950,12434877,14277081].map(mt);Go.scale.quantile=function(){return eo([],[])},Go.scale.quantize=function(){return ro(0,1,[0,1])},Go.scale.threshold=function(){return uo([.5],[0,1])},Go.scale.identity=function(){return io([0,1])},Go.svg={},Go.svg.arc=function(){function n(){var n=t.apply(this,arguments),i=e.apply(this,arguments),o=r.apply(this,arguments)+ws,a=u.apply(this,arguments)+ws,c=(o>a&&(c=o,o=a,a=c),a-o),s=Ca>c?"0":"1",l=Math.cos(o),f=Math.sin(o),h=Math.cos(a),g=Math.sin(a); +return c>=Ss?n?"M0,"+i+"A"+i+","+i+" 0 1,1 0,"+-i+"A"+i+","+i+" 0 1,1 0,"+i+"M0,"+n+"A"+n+","+n+" 0 1,0 0,"+-n+"A"+n+","+n+" 0 1,0 0,"+n+"Z":"M0,"+i+"A"+i+","+i+" 0 1,1 0,"+-i+"A"+i+","+i+" 0 1,1 0,"+i+"Z":n?"M"+i*l+","+i*f+"A"+i+","+i+" 0 "+s+",1 "+i*h+","+i*g+"L"+n*h+","+n*g+"A"+n+","+n+" 0 "+s+",0 "+n*l+","+n*f+"Z":"M"+i*l+","+i*f+"A"+i+","+i+" 0 "+s+",1 "+i*h+","+i*g+"L0,0"+"Z"}var t=oo,e=ao,r=co,u=so;return n.innerRadius=function(e){return arguments.length?(t=Et(e),n):t},n.outerRadius=function(t){return arguments.length?(e=Et(t),n):e},n.startAngle=function(t){return arguments.length?(r=Et(t),n):r},n.endAngle=function(t){return arguments.length?(u=Et(t),n):u},n.centroid=function(){var n=(t.apply(this,arguments)+e.apply(this,arguments))/2,i=(r.apply(this,arguments)+u.apply(this,arguments))/2+ws;return[Math.cos(i)*n,Math.sin(i)*n]},n};var ws=-La,Ss=Na-Ta;Go.svg.line=function(){return lo(At)};var ks=Go.map({linear:fo,"linear-closed":ho,step:go,"step-before":po,"step-after":vo,basis:bo,"basis-open":wo,"basis-closed":So,bundle:ko,cardinal:xo,"cardinal-open":mo,"cardinal-closed":yo,monotone:To});ks.forEach(function(n,t){t.key=n,t.closed=/-closed$/.test(n)});var Es=[0,2/3,1/3,0],As=[0,1/3,2/3,0],Cs=[0,1/6,2/3,1/6];Go.svg.line.radial=function(){var n=lo(qo);return n.radius=n.x,delete n.x,n.angle=n.y,delete n.y,n},po.reverse=vo,vo.reverse=po,Go.svg.area=function(){return zo(At)},Go.svg.area.radial=function(){var n=zo(qo);return n.radius=n.x,delete n.x,n.innerRadius=n.x0,delete n.x0,n.outerRadius=n.x1,delete n.x1,n.angle=n.y,delete n.y,n.startAngle=n.y0,delete n.y0,n.endAngle=n.y1,delete n.y1,n},Go.svg.chord=function(){function n(n,a){var c=t(this,i,n,a),s=t(this,o,n,a);return"M"+c.p0+r(c.r,c.p1,c.a1-c.a0)+(e(c,s)?u(c.r,c.p1,c.r,c.p0):u(c.r,c.p1,s.r,s.p0)+r(s.r,s.p1,s.a1-s.a0)+u(s.r,s.p1,c.r,c.p0))+"Z"}function t(n,t,e,r){var u=t.call(n,e,r),i=a.call(n,u,r),o=c.call(n,u,r)+ws,l=s.call(n,u,r)+ws;return{r:i,a0:o,a1:l,p0:[i*Math.cos(o),i*Math.sin(o)],p1:[i*Math.cos(l),i*Math.sin(l)]}}function e(n,t){return n.a0==t.a0&&n.a1==t.a1}function r(n,t,e){return"A"+n+","+n+" 0 "+ +(e>Ca)+",1 "+t}function u(n,t,e,r){return"Q 0,0 "+r}var i=mr,o=yr,a=Ro,c=co,s=so;return n.radius=function(t){return arguments.length?(a=Et(t),n):a},n.source=function(t){return arguments.length?(i=Et(t),n):i},n.target=function(t){return arguments.length?(o=Et(t),n):o},n.startAngle=function(t){return arguments.length?(c=Et(t),n):c},n.endAngle=function(t){return arguments.length?(s=Et(t),n):s},n},Go.svg.diagonal=function(){function n(n,u){var i=t.call(this,n,u),o=e.call(this,n,u),a=(i.y+o.y)/2,c=[i,{x:i.x,y:a},{x:o.x,y:a},o];return c=c.map(r),"M"+c[0]+"C"+c[1]+" "+c[2]+" "+c[3]}var t=mr,e=yr,r=Do;return n.source=function(e){return arguments.length?(t=Et(e),n):t},n.target=function(t){return arguments.length?(e=Et(t),n):e},n.projection=function(t){return arguments.length?(r=t,n):r},n},Go.svg.diagonal.radial=function(){var n=Go.svg.diagonal(),t=Do,e=n.projection;return n.projection=function(n){return arguments.length?e(Po(t=n)):t},n},Go.svg.symbol=function(){function n(n,r){return(Ns.get(t.call(this,n,r))||Ho)(e.call(this,n,r))}var t=jo,e=Uo;return n.type=function(e){return arguments.length?(t=Et(e),n):t},n.size=function(t){return arguments.length?(e=Et(t),n):e},n};var Ns=Go.map({circle:Ho,cross:function(n){var t=Math.sqrt(n/5)/2;return"M"+-3*t+","+-t+"H"+-t+"V"+-3*t+"H"+t+"V"+-t+"H"+3*t+"V"+t+"H"+t+"V"+3*t+"H"+-t+"V"+t+"H"+-3*t+"Z"},diamond:function(n){var t=Math.sqrt(n/(2*zs)),e=t*zs;return"M0,"+-t+"L"+e+",0"+" 0,"+t+" "+-e+",0"+"Z"},square:function(n){var t=Math.sqrt(n)/2;return"M"+-t+","+-t+"L"+t+","+-t+" "+t+","+t+" "+-t+","+t+"Z"},"triangle-down":function(n){var t=Math.sqrt(n/qs),e=t*qs/2;return"M0,"+e+"L"+t+","+-e+" "+-t+","+-e+"Z"},"triangle-up":function(n){var t=Math.sqrt(n/qs),e=t*qs/2;return"M0,"+-e+"L"+t+","+e+" "+-t+","+e+"Z"}});Go.svg.symbolTypes=Ns.keys();var Ls,Ts,qs=Math.sqrt(3),zs=Math.tan(30*za),Rs=[],Ds=0;Rs.call=_a.call,Rs.empty=_a.empty,Rs.node=_a.node,Rs.size=_a.size,Go.transition=function(n){return arguments.length?Ls?n.transition():n:Sa.transition()},Go.transition.prototype=Rs,Rs.select=function(n){var t,e,r,u=this.id,i=[];n=b(n);for(var o=-1,a=this.length;++oi;i++){u.push(t=[]);for(var e=this[i],a=0,c=e.length;c>a;a++)(r=e[a])&&n.call(r,r.__data__,a,i)&&t.push(r)}return Fo(u,this.id)},Rs.tween=function(n,t){var e=this.id;return arguments.length<2?this.node().__transition__[e].tween.get(n):P(this,null==t?function(t){t.__transition__[e].tween.remove(n)}:function(r){r.__transition__[e].tween.set(n,t)})},Rs.attr=function(n,t){function e(){this.removeAttribute(a)}function r(){this.removeAttributeNS(a.space,a.local)}function u(n){return null==n?e:(n+="",function(){var t,e=this.getAttribute(a);return e!==n&&(t=o(e,n),function(n){this.setAttribute(a,t(n))})})}function i(n){return null==n?r:(n+="",function(){var t,e=this.getAttributeNS(a.space,a.local);return e!==n&&(t=o(e,n),function(n){this.setAttributeNS(a.space,a.local,t(n))})})}if(arguments.length<2){for(t in n)this.attr(t,n[t]);return this}var o="transform"==n?Hu:du,a=Go.ns.qualify(n);return Oo(this,"attr."+n,t,a.local?i:u)},Rs.attrTween=function(n,t){function e(n,e){var r=t.call(this,n,e,this.getAttribute(u));return r&&function(n){this.setAttribute(u,r(n))}}function r(n,e){var r=t.call(this,n,e,this.getAttributeNS(u.space,u.local));return r&&function(n){this.setAttributeNS(u.space,u.local,r(n))}}var u=Go.ns.qualify(n);return this.tween("attr."+n,u.local?r:e)},Rs.style=function(n,t,e){function r(){this.style.removeProperty(n)}function u(t){return null==t?r:(t+="",function(){var r,u=ea.getComputedStyle(this,null).getPropertyValue(n);return u!==t&&(r=du(u,t),function(t){this.style.setProperty(n,r(t),e)})})}var i=arguments.length;if(3>i){if("string"!=typeof n){2>i&&(t="");for(e in n)this.style(e,n[e],t);return this}e=""}return Oo(this,"style."+n,t,u)},Rs.styleTween=function(n,t,e){function r(r,u){var i=t.call(this,r,u,ea.getComputedStyle(this,null).getPropertyValue(n));return i&&function(t){this.style.setProperty(n,i(t),e)}}return arguments.length<3&&(e=""),this.tween("style."+n,r)},Rs.text=function(n){return Oo(this,"text",n,Io)},Rs.remove=function(){return this.each("end.transition",function(){var n;this.__transition__.count<2&&(n=this.parentNode)&&n.removeChild(this)})},Rs.ease=function(n){var t=this.id;return arguments.length<1?this.node().__transition__[t].ease:("function"!=typeof n&&(n=Go.ease.apply(Go,arguments)),P(this,function(e){e.__transition__[t].ease=n}))},Rs.delay=function(n){var t=this.id;return arguments.length<1?this.node().__transition__[t].delay:P(this,"function"==typeof n?function(e,r,u){e.__transition__[t].delay=+n.call(e,e.__data__,r,u)}:(n=+n,function(e){e.__transition__[t].delay=n}))},Rs.duration=function(n){var t=this.id;return arguments.length<1?this.node().__transition__[t].duration:P(this,"function"==typeof n?function(e,r,u){e.__transition__[t].duration=Math.max(1,n.call(e,e.__data__,r,u))}:(n=Math.max(1,n),function(e){e.__transition__[t].duration=n}))},Rs.each=function(n,t){var e=this.id;if(arguments.length<2){var r=Ts,u=Ls;Ls=e,P(this,function(t,r,u){Ts=t.__transition__[e],n.call(t,t.__data__,r,u)}),Ts=r,Ls=u}else P(this,function(r){var u=r.__transition__[e];(u.event||(u.event=Go.dispatch("start","end"))).on(n,t)});return this},Rs.transition=function(){for(var n,t,e,r,u=this.id,i=++Ds,o=[],a=0,c=this.length;c>a;a++){o.push(n=[]);for(var t=this[a],s=0,l=t.length;l>s;s++)(e=t[s])&&(r=Object.create(e.__transition__[u]),r.delay+=r.duration,Yo(e,s,i,r)),n.push(e)}return Fo(o,i)},Go.svg.axis=function(){function n(n){n.each(function(){var n,s=Go.select(this),l=this.__chart__||e,f=this.__chart__=e.copy(),h=null==c?f.ticks?f.ticks.apply(f,a):f.domain():c,g=null==t?f.tickFormat?f.tickFormat.apply(f,a):At:t,p=s.selectAll(".tick").data(h,f),v=p.enter().insert("g",".domain").attr("class","tick").style("opacity",Ta),d=Go.transition(p.exit()).style("opacity",Ta).remove(),m=Go.transition(p.order()).style("opacity",1),y=Hi(f),x=s.selectAll(".domain").data([0]),M=(x.enter().append("path").attr("class","domain"),Go.transition(x));v.append("line"),v.append("text");var _=v.select("line"),b=m.select("line"),w=p.select("text").text(g),S=v.select("text"),k=m.select("text");switch(r){case"bottom":n=Zo,_.attr("y2",u),S.attr("y",Math.max(u,0)+o),b.attr("x2",0).attr("y2",u),k.attr("x",0).attr("y",Math.max(u,0)+o),w.attr("dy",".71em").style("text-anchor","middle"),M.attr("d","M"+y[0]+","+i+"V0H"+y[1]+"V"+i);break;case"top":n=Zo,_.attr("y2",-u),S.attr("y",-(Math.max(u,0)+o)),b.attr("x2",0).attr("y2",-u),k.attr("x",0).attr("y",-(Math.max(u,0)+o)),w.attr("dy","0em").style("text-anchor","middle"),M.attr("d","M"+y[0]+","+-i+"V0H"+y[1]+"V"+-i);break;case"left":n=Vo,_.attr("x2",-u),S.attr("x",-(Math.max(u,0)+o)),b.attr("x2",-u).attr("y2",0),k.attr("x",-(Math.max(u,0)+o)).attr("y",0),w.attr("dy",".32em").style("text-anchor","end"),M.attr("d","M"+-i+","+y[0]+"H0V"+y[1]+"H"+-i);break;case"right":n=Vo,_.attr("x2",u),S.attr("x",Math.max(u,0)+o),b.attr("x2",u).attr("y2",0),k.attr("x",Math.max(u,0)+o).attr("y",0),w.attr("dy",".32em").style("text-anchor","start"),M.attr("d","M"+i+","+y[0]+"H0V"+y[1]+"H"+i)}if(f.rangeBand){var E=f,A=E.rangeBand()/2;l=f=function(n){return E(n)+A}}else l.rangeBand?l=f:d.call(n,f);v.call(n,l),m.call(n,f)})}var t,e=Go.scale.linear(),r=Ps,u=6,i=6,o=3,a=[10],c=null;return n.scale=function(t){return arguments.length?(e=t,n):e},n.orient=function(t){return arguments.length?(r=t in Us?t+"":Ps,n):r},n.ticks=function(){return arguments.length?(a=arguments,n):a},n.tickValues=function(t){return arguments.length?(c=t,n):c},n.tickFormat=function(e){return arguments.length?(t=e,n):t},n.tickSize=function(t){var e=arguments.length;return e?(u=+t,i=+arguments[e-1],n):u},n.innerTickSize=function(t){return arguments.length?(u=+t,n):u},n.outerTickSize=function(t){return arguments.length?(i=+t,n):i},n.tickPadding=function(t){return arguments.length?(o=+t,n):o},n.tickSubdivide=function(){return arguments.length&&n},n};var Ps="bottom",Us={top:1,right:1,bottom:1,left:1};Go.svg.brush=function(){function n(i){i.each(function(){var i=Go.select(this).style("pointer-events","all").style("-webkit-tap-highlight-color","rgba(0,0,0,0)").on("mousedown.brush",u).on("touchstart.brush",u),o=i.selectAll(".background").data([0]);o.enter().append("rect").attr("class","background").style("visibility","hidden").style("cursor","crosshair"),i.selectAll(".extent").data([0]).enter().append("rect").attr("class","extent").style("cursor","move");var a=i.selectAll(".resize").data(p,At);a.exit().remove(),a.enter().append("g").attr("class",function(n){return"resize "+n}).style("cursor",function(n){return js[n]}).append("rect").attr("x",function(n){return/[ew]$/.test(n)?-3:null}).attr("y",function(n){return/^[ns]/.test(n)?-3:null}).attr("width",6).attr("height",6).style("visibility","hidden"),a.style("display",n.empty()?"none":null);var l,f=Go.transition(i),h=Go.transition(o);c&&(l=Hi(c),h.attr("x",l[0]).attr("width",l[1]-l[0]),e(f)),s&&(l=Hi(s),h.attr("y",l[0]).attr("height",l[1]-l[0]),r(f)),t(f)})}function t(n){n.selectAll(".resize").attr("transform",function(n){return"translate("+l[+/e$/.test(n)]+","+f[+/^s/.test(n)]+")"})}function e(n){n.select(".extent").attr("x",l[0]),n.selectAll(".extent,.n>rect,.s>rect").attr("width",l[1]-l[0])}function r(n){n.select(".extent").attr("y",f[0]),n.selectAll(".extent,.e>rect,.w>rect").attr("height",f[1]-f[0])}function u(){function u(){32==Go.event.keyCode&&(C||(x=null,L[0]-=l[1],L[1]-=f[1],C=2),y())}function p(){32==Go.event.keyCode&&2==C&&(L[0]+=l[1],L[1]+=f[1],C=0,y())}function v(){var n=Go.mouse(_),u=!1;M&&(n[0]+=M[0],n[1]+=M[1]),C||(Go.event.altKey?(x||(x=[(l[0]+l[1])/2,(f[0]+f[1])/2]),L[0]=l[+(n[0]p?(u=r,r=p):u=p),v[0]!=r||v[1]!=u?(e?o=null:i=null,v[0]=r,v[1]=u,!0):void 0}function m(){v(),S.style("pointer-events","all").selectAll(".resize").style("display",n.empty()?"none":null),Go.select("body").style("cursor",null),T.on("mousemove.brush",null).on("mouseup.brush",null).on("touchmove.brush",null).on("touchend.brush",null).on("keydown.brush",null).on("keyup.brush",null),N(),w({type:"brushend"})}var x,M,_=this,b=Go.select(Go.event.target),w=a.of(_,arguments),S=Go.select(_),k=b.datum(),E=!/^(n|s)$/.test(k)&&c,A=!/^(e|w)$/.test(k)&&s,C=b.classed("extent"),N=Y(),L=Go.mouse(_),T=Go.select(ea).on("keydown.brush",u).on("keyup.brush",p);if(Go.event.changedTouches?T.on("touchmove.brush",v).on("touchend.brush",m):T.on("mousemove.brush",v).on("mouseup.brush",m),S.interrupt().selectAll("*").interrupt(),C)L[0]=l[0]-L[0],L[1]=f[0]-L[1];else if(k){var q=+/w$/.test(k),z=+/^n/.test(k);M=[l[1-q]-L[0],f[1-z]-L[1]],L[0]=l[q],L[1]=f[z]}else Go.event.altKey&&(x=L.slice());S.style("pointer-events","none").selectAll(".resize").style("display",null),Go.select("body").style("cursor",b.style("cursor")),w({type:"brushstart"}),v()}var i,o,a=M(n,"brushstart","brush","brushend"),c=null,s=null,l=[0,0],f=[0,0],h=!0,g=!0,p=Hs[0];return n.event=function(n){n.each(function(){var n=a.of(this,arguments),t={x:l,y:f,i:i,j:o},e=this.__chart__||t;this.__chart__=t,Ls?Go.select(this).transition().each("start.brush",function(){i=e.i,o=e.j,l=e.x,f=e.y,n({type:"brushstart"})}).tween("brush:brush",function(){var e=mu(l,t.x),r=mu(f,t.y);return i=o=null,function(u){l=t.x=e(u),f=t.y=r(u),n({type:"brush",mode:"resize"})}}).each("end.brush",function(){i=t.i,o=t.j,n({type:"brush",mode:"resize"}),n({type:"brushend"})}):(n({type:"brushstart"}),n({type:"brush",mode:"resize"}),n({type:"brushend"}))})},n.x=function(t){return arguments.length?(c=t,p=Hs[!c<<1|!s],n):c},n.y=function(t){return arguments.length?(s=t,p=Hs[!c<<1|!s],n):s},n.clamp=function(t){return arguments.length?(c&&s?(h=!!t[0],g=!!t[1]):c?h=!!t:s&&(g=!!t),n):c&&s?[h,g]:c?h:s?g:null},n.extent=function(t){var e,r,u,a,h;return arguments.length?(c&&(e=t[0],r=t[1],s&&(e=e[0],r=r[0]),i=[e,r],c.invert&&(e=c(e),r=c(r)),e>r&&(h=e,e=r,r=h),(e!=l[0]||r!=l[1])&&(l=[e,r])),s&&(u=t[0],a=t[1],c&&(u=u[1],a=a[1]),o=[u,a],s.invert&&(u=s(u),a=s(a)),u>a&&(h=u,u=a,a=h),(u!=f[0]||a!=f[1])&&(f=[u,a])),n):(c&&(i?(e=i[0],r=i[1]):(e=l[0],r=l[1],c.invert&&(e=c.invert(e),r=c.invert(r)),e>r&&(h=e,e=r,r=h))),s&&(o?(u=o[0],a=o[1]):(u=f[0],a=f[1],s.invert&&(u=s.invert(u),a=s.invert(a)),u>a&&(h=u,u=a,a=h))),c&&s?[[e,u],[r,a]]:c?[e,r]:s&&[u,a])},n.clear=function(){return n.empty()||(l=[0,0],f=[0,0],i=o=null),n},n.empty=function(){return!!c&&l[0]==l[1]||!!s&&f[0]==f[1]},Go.rebind(n,a,"on")};var js={n:"ns-resize",e:"ew-resize",s:"ns-resize",w:"ew-resize",nw:"nwse-resize",ne:"nesw-resize",se:"nwse-resize",sw:"nesw-resize"},Hs=[["n","e","s","w","nw","ne","se","sw"],["e","w"],["n","s"],[]],Fs=ic.format=fc.timeFormat,Os=Fs.utc,Is=Os("%Y-%m-%dT%H:%M:%S.%LZ");Fs.iso=Date.prototype.toISOString&&+new Date("2000-01-01T00:00:00.000Z")?$o:Is,$o.parse=function(n){var t=new Date(n);return isNaN(t)?null:t},$o.toString=Is.toString,ic.second=Ht(function(n){return new oc(1e3*Math.floor(n/1e3))},function(n,t){n.setTime(n.getTime()+1e3*Math.floor(t))},function(n){return n.getSeconds()}),ic.seconds=ic.second.range,ic.seconds.utc=ic.second.utc.range,ic.minute=Ht(function(n){return new oc(6e4*Math.floor(n/6e4))},function(n,t){n.setTime(n.getTime()+6e4*Math.floor(t))},function(n){return n.getMinutes()}),ic.minutes=ic.minute.range,ic.minutes.utc=ic.minute.utc.range,ic.hour=Ht(function(n){var t=n.getTimezoneOffset()/60;return new oc(36e5*(Math.floor(n/36e5-t)+t))},function(n,t){n.setTime(n.getTime()+36e5*Math.floor(t))},function(n){return n.getHours()}),ic.hours=ic.hour.range,ic.hours.utc=ic.hour.utc.range,ic.month=Ht(function(n){return n=ic.day(n),n.setDate(1),n},function(n,t){n.setMonth(n.getMonth()+t)},function(n){return n.getMonth()}),ic.months=ic.month.range,ic.months.utc=ic.month.utc.range;var Ys=[1e3,5e3,15e3,3e4,6e4,3e5,9e5,18e5,36e5,108e5,216e5,432e5,864e5,1728e5,6048e5,2592e6,7776e6,31536e6],Zs=[[ic.second,1],[ic.second,5],[ic.second,15],[ic.second,30],[ic.minute,1],[ic.minute,5],[ic.minute,15],[ic.minute,30],[ic.hour,1],[ic.hour,3],[ic.hour,6],[ic.hour,12],[ic.day,1],[ic.day,2],[ic.week,1],[ic.month,1],[ic.month,3],[ic.year,1]],Vs=Fs.multi([[".%L",function(n){return n.getMilliseconds()}],[":%S",function(n){return n.getSeconds()}],["%I:%M",function(n){return n.getMinutes()}],["%I %p",function(n){return n.getHours()}],["%a %d",function(n){return n.getDay()&&1!=n.getDate()}],["%b %d",function(n){return 1!=n.getDate()}],["%B",function(n){return n.getMonth()}],["%Y",Ae]]),$s={range:function(n,t,e){return Go.range(Math.ceil(n/e)*e,+t,e).map(Bo)},floor:At,ceil:At};Zs.year=ic.year,ic.scale=function(){return Xo(Go.scale.linear(),Zs,Vs)};var Xs=Zs.map(function(n){return[n[0].utc,n[1]]}),Bs=Os.multi([[".%L",function(n){return n.getUTCMilliseconds()}],[":%S",function(n){return n.getUTCSeconds()}],["%I:%M",function(n){return n.getUTCMinutes()}],["%I %p",function(n){return n.getUTCHours()}],["%a %d",function(n){return n.getUTCDay()&&1!=n.getUTCDate()}],["%b %d",function(n){return 1!=n.getUTCDate()}],["%B",function(n){return n.getUTCMonth()}],["%Y",Ae]]);Xs.year=ic.year.utc,ic.scale.utc=function(){return Xo(Go.scale.linear(),Xs,Bs)},Go.text=Ct(function(n){return n.responseText}),Go.json=function(n,t){return Nt(n,"application/json",Jo,t)},Go.html=function(n,t){return Nt(n,"text/html",Wo,t)},Go.xml=Ct(function(n){return n.responseXML}),"function"==typeof define&&define.amd?define(Go):"object"==typeof module&&module.exports?module.exports=Go:this.d3=Go}(); \ No newline at end of file diff --git a/vis_resources/d3utils.css b/vis_resources/d3utils.css new file mode 100644 index 0000000..5f50b4b --- /dev/null +++ b/vis_resources/d3utils.css @@ -0,0 +1,23 @@ +path { + stroke: steelblue; + stroke-width: 1; + fill: none; +} + +.axis { + shape-rendering: crispEdges; +} + +.x.axis line { + stroke: lightgrey; +} +.y.axis line { + stroke: lightgrey; +} +.x.axis path { + display: none; +} +.y.axis line, .y.axis path { + fill: none; + stroke: #000; +} diff --git a/vis_resources/d3utils.js b/vis_resources/d3utils.js new file mode 100644 index 0000000..1765c79 --- /dev/null +++ b/vis_resources/d3utils.js @@ -0,0 +1,48 @@ + +function plotToDiv(div, ydata, xdata, params) { + params = params || {}; + + var dw = div.offsetWidth; + var dh = div.offsetHeight; + + var m = [40, 0, 40, 100]; // margins + var w = dw - m[1] - m[3]; // width + var h = dh - m[0] - m[2]; // height + + var xmax = 'xmax' in params ? params.xmax : d3.max(xdata); + var ymax = 'ymax' in params ? params.ymax : d3.max(ydata); + + var x = d3.scale.linear().domain([0, xmax]).range([0, w]); + var y = d3.scale.linear().domain([0, ymax]).range([h, 0]); + + // create a line function that can convert data[] into x and y points + var line = d3.svg.line() + .x(function(d,i) { + return x(xdata[i]); + }) + .y(function(d,i) { + return y(ydata[i]); + }) + + // Add an SVG element with the desired dimensions and margin. + var graph = d3.select(div).append("svg:svg") + .attr("width", w + m[1] + m[3]) + .attr("height", h + m[0] + m[2]) + .append("svg:g") + .attr("transform", "translate(" + m[3] + "," + m[0] + ")"); + + var xAxis = d3.svg.axis().scale(x).tickSize(-h).tickSubdivide(true); + graph.append("svg:g") + .attr("class", "x axis") + .attr("transform", "translate(0," + h + ")") + .call(xAxis); + + var yAxisLeft = d3.svg.axis().scale(y).ticks(4).orient("left"); + graph.append("svg:g") + .attr("class", "y axis") + .attr("transform", "translate(-25,0)") + .call(yAxisLeft); + + graph.append("svg:path").attr("d", line(xdata)); + +} \ No newline at end of file diff --git a/vis_resources/jquery-1.8.3.min.js b/vis_resources/jquery-1.8.3.min.js new file mode 100644 index 0000000..3883779 --- /dev/null +++ b/vis_resources/jquery-1.8.3.min.js @@ -0,0 +1,2 @@ +/*! jQuery v1.8.3 jquery.com | jquery.org/license */ +(function(e,t){function _(e){var t=M[e]={};return v.each(e.split(y),function(e,n){t[n]=!0}),t}function H(e,n,r){if(r===t&&e.nodeType===1){var i="data-"+n.replace(P,"-$1").toLowerCase();r=e.getAttribute(i);if(typeof r=="string"){try{r=r==="true"?!0:r==="false"?!1:r==="null"?null:+r+""===r?+r:D.test(r)?v.parseJSON(r):r}catch(s){}v.data(e,n,r)}else r=t}return r}function B(e){var t;for(t in e){if(t==="data"&&v.isEmptyObject(e[t]))continue;if(t!=="toJSON")return!1}return!0}function et(){return!1}function tt(){return!0}function ut(e){return!e||!e.parentNode||e.parentNode.nodeType===11}function at(e,t){do e=e[t];while(e&&e.nodeType!==1);return e}function ft(e,t,n){t=t||0;if(v.isFunction(t))return v.grep(e,function(e,r){var i=!!t.call(e,r,e);return i===n});if(t.nodeType)return v.grep(e,function(e,r){return e===t===n});if(typeof t=="string"){var r=v.grep(e,function(e){return e.nodeType===1});if(it.test(t))return v.filter(t,r,!n);t=v.filter(t,r)}return v.grep(e,function(e,r){return v.inArray(e,t)>=0===n})}function lt(e){var t=ct.split("|"),n=e.createDocumentFragment();if(n.createElement)while(t.length)n.createElement(t.pop());return n}function Lt(e,t){return e.getElementsByTagName(t)[0]||e.appendChild(e.ownerDocument.createElement(t))}function At(e,t){if(t.nodeType!==1||!v.hasData(e))return;var n,r,i,s=v._data(e),o=v._data(t,s),u=s.events;if(u){delete o.handle,o.events={};for(n in u)for(r=0,i=u[n].length;r").appendTo(i.body),n=t.css("display");t.remove();if(n==="none"||n===""){Pt=i.body.appendChild(Pt||v.extend(i.createElement("iframe"),{frameBorder:0,width:0,height:0}));if(!Ht||!Pt.createElement)Ht=(Pt.contentWindow||Pt.contentDocument).document,Ht.write(""),Ht.close();t=Ht.body.appendChild(Ht.createElement(e)),n=Dt(t,"display"),i.body.removeChild(Pt)}return Wt[e]=n,n}function fn(e,t,n,r){var i;if(v.isArray(t))v.each(t,function(t,i){n||sn.test(e)?r(e,i):fn(e+"["+(typeof i=="object"?t:"")+"]",i,n,r)});else if(!n&&v.type(t)==="object")for(i in t)fn(e+"["+i+"]",t[i],n,r);else r(e,t)}function Cn(e){return function(t,n){typeof t!="string"&&(n=t,t="*");var r,i,s,o=t.toLowerCase().split(y),u=0,a=o.length;if(v.isFunction(n))for(;u)[^>]*$|#([\w\-]*)$)/,E=/^<(\w+)\s*\/?>(?:<\/\1>|)$/,S=/^[\],:{}\s]*$/,x=/(?:^|:|,)(?:\s*\[)+/g,T=/\\(?:["\\\/bfnrt]|u[\da-fA-F]{4})/g,N=/"[^"\\\r\n]*"|true|false|null|-?(?:\d\d*\.|)\d+(?:[eE][\-+]?\d+|)/g,C=/^-ms-/,k=/-([\da-z])/gi,L=function(e,t){return(t+"").toUpperCase()},A=function(){i.addEventListener?(i.removeEventListener("DOMContentLoaded",A,!1),v.ready()):i.readyState==="complete"&&(i.detachEvent("onreadystatechange",A),v.ready())},O={};v.fn=v.prototype={constructor:v,init:function(e,n,r){var s,o,u,a;if(!e)return this;if(e.nodeType)return this.context=this[0]=e,this.length=1,this;if(typeof e=="string"){e.charAt(0)==="<"&&e.charAt(e.length-1)===">"&&e.length>=3?s=[null,e,null]:s=w.exec(e);if(s&&(s[1]||!n)){if(s[1])return n=n instanceof v?n[0]:n,a=n&&n.nodeType?n.ownerDocument||n:i,e=v.parseHTML(s[1],a,!0),E.test(s[1])&&v.isPlainObject(n)&&this.attr.call(e,n,!0),v.merge(this,e);o=i.getElementById(s[2]);if(o&&o.parentNode){if(o.id!==s[2])return r.find(e);this.length=1,this[0]=o}return this.context=i,this.selector=e,this}return!n||n.jquery?(n||r).find(e):this.constructor(n).find(e)}return v.isFunction(e)?r.ready(e):(e.selector!==t&&(this.selector=e.selector,this.context=e.context),v.makeArray(e,this))},selector:"",jquery:"1.8.3",length:0,size:function(){return this.length},toArray:function(){return l.call(this)},get:function(e){return e==null?this.toArray():e<0?this[this.length+e]:this[e]},pushStack:function(e,t,n){var r=v.merge(this.constructor(),e);return r.prevObject=this,r.context=this.context,t==="find"?r.selector=this.selector+(this.selector?" ":"")+n:t&&(r.selector=this.selector+"."+t+"("+n+")"),r},each:function(e,t){return v.each(this,e,t)},ready:function(e){return v.ready.promise().done(e),this},eq:function(e){return e=+e,e===-1?this.slice(e):this.slice(e,e+1)},first:function(){return this.eq(0)},last:function(){return this.eq(-1)},slice:function(){return this.pushStack(l.apply(this,arguments),"slice",l.call(arguments).join(","))},map:function(e){return this.pushStack(v.map(this,function(t,n){return e.call(t,n,t)}))},end:function(){return this.prevObject||this.constructor(null)},push:f,sort:[].sort,splice:[].splice},v.fn.init.prototype=v.fn,v.extend=v.fn.extend=function(){var e,n,r,i,s,o,u=arguments[0]||{},a=1,f=arguments.length,l=!1;typeof u=="boolean"&&(l=u,u=arguments[1]||{},a=2),typeof u!="object"&&!v.isFunction(u)&&(u={}),f===a&&(u=this,--a);for(;a0)return;r.resolveWith(i,[v]),v.fn.trigger&&v(i).trigger("ready").off("ready")},isFunction:function(e){return v.type(e)==="function"},isArray:Array.isArray||function(e){return v.type(e)==="array"},isWindow:function(e){return e!=null&&e==e.window},isNumeric:function(e){return!isNaN(parseFloat(e))&&isFinite(e)},type:function(e){return e==null?String(e):O[h.call(e)]||"object"},isPlainObject:function(e){if(!e||v.type(e)!=="object"||e.nodeType||v.isWindow(e))return!1;try{if(e.constructor&&!p.call(e,"constructor")&&!p.call(e.constructor.prototype,"isPrototypeOf"))return!1}catch(n){return!1}var r;for(r in e);return r===t||p.call(e,r)},isEmptyObject:function(e){var t;for(t in e)return!1;return!0},error:function(e){throw new Error(e)},parseHTML:function(e,t,n){var r;return!e||typeof e!="string"?null:(typeof t=="boolean"&&(n=t,t=0),t=t||i,(r=E.exec(e))?[t.createElement(r[1])]:(r=v.buildFragment([e],t,n?null:[]),v.merge([],(r.cacheable?v.clone(r.fragment):r.fragment).childNodes)))},parseJSON:function(t){if(!t||typeof t!="string")return null;t=v.trim(t);if(e.JSON&&e.JSON.parse)return e.JSON.parse(t);if(S.test(t.replace(T,"@").replace(N,"]").replace(x,"")))return(new Function("return "+t))();v.error("Invalid JSON: "+t)},parseXML:function(n){var r,i;if(!n||typeof n!="string")return null;try{e.DOMParser?(i=new DOMParser,r=i.parseFromString(n,"text/xml")):(r=new ActiveXObject("Microsoft.XMLDOM"),r.async="false",r.loadXML(n))}catch(s){r=t}return(!r||!r.documentElement||r.getElementsByTagName("parsererror").length)&&v.error("Invalid XML: "+n),r},noop:function(){},globalEval:function(t){t&&g.test(t)&&(e.execScript||function(t){e.eval.call(e,t)})(t)},camelCase:function(e){return e.replace(C,"ms-").replace(k,L)},nodeName:function(e,t){return e.nodeName&&e.nodeName.toLowerCase()===t.toLowerCase()},each:function(e,n,r){var i,s=0,o=e.length,u=o===t||v.isFunction(e);if(r){if(u){for(i in e)if(n.apply(e[i],r)===!1)break}else for(;s0&&e[0]&&e[a-1]||a===0||v.isArray(e));if(f)for(;u-1)a.splice(n,1),i&&(n<=o&&o--,n<=u&&u--)}),this},has:function(e){return v.inArray(e,a)>-1},empty:function(){return a=[],this},disable:function(){return a=f=n=t,this},disabled:function(){return!a},lock:function(){return f=t,n||c.disable(),this},locked:function(){return!f},fireWith:function(e,t){return t=t||[],t=[e,t.slice?t.slice():t],a&&(!r||f)&&(i?f.push(t):l(t)),this},fire:function(){return c.fireWith(this,arguments),this},fired:function(){return!!r}};return c},v.extend({Deferred:function(e){var t=[["resolve","done",v.Callbacks("once memory"),"resolved"],["reject","fail",v.Callbacks("once memory"),"rejected"],["notify","progress",v.Callbacks("memory")]],n="pending",r={state:function(){return n},always:function(){return i.done(arguments).fail(arguments),this},then:function(){var e=arguments;return v.Deferred(function(n){v.each(t,function(t,r){var s=r[0],o=e[t];i[r[1]](v.isFunction(o)?function(){var e=o.apply(this,arguments);e&&v.isFunction(e.promise)?e.promise().done(n.resolve).fail(n.reject).progress(n.notify):n[s+"With"](this===i?n:this,[e])}:n[s])}),e=null}).promise()},promise:function(e){return e!=null?v.extend(e,r):r}},i={};return r.pipe=r.then,v.each(t,function(e,s){var o=s[2],u=s[3];r[s[1]]=o.add,u&&o.add(function(){n=u},t[e^1][2].disable,t[2][2].lock),i[s[0]]=o.fire,i[s[0]+"With"]=o.fireWith}),r.promise(i),e&&e.call(i,i),i},when:function(e){var t=0,n=l.call(arguments),r=n.length,i=r!==1||e&&v.isFunction(e.promise)?r:0,s=i===1?e:v.Deferred(),o=function(e,t,n){return function(r){t[e]=this,n[e]=arguments.length>1?l.call(arguments):r,n===u?s.notifyWith(t,n):--i||s.resolveWith(t,n)}},u,a,f;if(r>1){u=new Array(r),a=new Array(r),f=new Array(r);for(;t
a",n=p.getElementsByTagName("*"),r=p.getElementsByTagName("a")[0];if(!n||!r||!n.length)return{};s=i.createElement("select"),o=s.appendChild(i.createElement("option")),u=p.getElementsByTagName("input")[0],r.style.cssText="top:1px;float:left;opacity:.5",t={leadingWhitespace:p.firstChild.nodeType===3,tbody:!p.getElementsByTagName("tbody").length,htmlSerialize:!!p.getElementsByTagName("link").length,style:/top/.test(r.getAttribute("style")),hrefNormalized:r.getAttribute("href")==="/a",opacity:/^0.5/.test(r.style.opacity),cssFloat:!!r.style.cssFloat,checkOn:u.value==="on",optSelected:o.selected,getSetAttribute:p.className!=="t",enctype:!!i.createElement("form").enctype,html5Clone:i.createElement("nav").cloneNode(!0).outerHTML!=="<:nav>",boxModel:i.compatMode==="CSS1Compat",submitBubbles:!0,changeBubbles:!0,focusinBubbles:!1,deleteExpando:!0,noCloneEvent:!0,inlineBlockNeedsLayout:!1,shrinkWrapBlocks:!1,reliableMarginRight:!0,boxSizingReliable:!0,pixelPosition:!1},u.checked=!0,t.noCloneChecked=u.cloneNode(!0).checked,s.disabled=!0,t.optDisabled=!o.disabled;try{delete p.test}catch(d){t.deleteExpando=!1}!p.addEventListener&&p.attachEvent&&p.fireEvent&&(p.attachEvent("onclick",h=function(){t.noCloneEvent=!1}),p.cloneNode(!0).fireEvent("onclick"),p.detachEvent("onclick",h)),u=i.createElement("input"),u.value="t",u.setAttribute("type","radio"),t.radioValue=u.value==="t",u.setAttribute("checked","checked"),u.setAttribute("name","t"),p.appendChild(u),a=i.createDocumentFragment(),a.appendChild(p.lastChild),t.checkClone=a.cloneNode(!0).cloneNode(!0).lastChild.checked,t.appendChecked=u.checked,a.removeChild(u),a.appendChild(p);if(p.attachEvent)for(l in{submit:!0,change:!0,focusin:!0})f="on"+l,c=f in p,c||(p.setAttribute(f,"return;"),c=typeof p[f]=="function"),t[l+"Bubbles"]=c;return v(function(){var n,r,s,o,u="padding:0;margin:0;border:0;display:block;overflow:hidden;",a=i.getElementsByTagName("body")[0];if(!a)return;n=i.createElement("div"),n.style.cssText="visibility:hidden;border:0;width:0;height:0;position:static;top:0;margin-top:1px",a.insertBefore(n,a.firstChild),r=i.createElement("div"),n.appendChild(r),r.innerHTML="
t
",s=r.getElementsByTagName("td"),s[0].style.cssText="padding:0;margin:0;border:0;display:none",c=s[0].offsetHeight===0,s[0].style.display="",s[1].style.display="none",t.reliableHiddenOffsets=c&&s[0].offsetHeight===0,r.innerHTML="",r.style.cssText="box-sizing:border-box;-moz-box-sizing:border-box;-webkit-box-sizing:border-box;padding:1px;border:1px;display:block;width:4px;margin-top:1%;position:absolute;top:1%;",t.boxSizing=r.offsetWidth===4,t.doesNotIncludeMarginInBodyOffset=a.offsetTop!==1,e.getComputedStyle&&(t.pixelPosition=(e.getComputedStyle(r,null)||{}).top!=="1%",t.boxSizingReliable=(e.getComputedStyle(r,null)||{width:"4px"}).width==="4px",o=i.createElement("div"),o.style.cssText=r.style.cssText=u,o.style.marginRight=o.style.width="0",r.style.width="1px",r.appendChild(o),t.reliableMarginRight=!parseFloat((e.getComputedStyle(o,null)||{}).marginRight)),typeof r.style.zoom!="undefined"&&(r.innerHTML="",r.style.cssText=u+"width:1px;padding:1px;display:inline;zoom:1",t.inlineBlockNeedsLayout=r.offsetWidth===3,r.style.display="block",r.style.overflow="visible",r.innerHTML="
",r.firstChild.style.width="5px",t.shrinkWrapBlocks=r.offsetWidth!==3,n.style.zoom=1),a.removeChild(n),n=r=s=o=null}),a.removeChild(p),n=r=s=o=u=a=p=null,t}();var D=/(?:\{[\s\S]*\}|\[[\s\S]*\])$/,P=/([A-Z])/g;v.extend({cache:{},deletedIds:[],uuid:0,expando:"jQuery"+(v.fn.jquery+Math.random()).replace(/\D/g,""),noData:{embed:!0,object:"clsid:D27CDB6E-AE6D-11cf-96B8-444553540000",applet:!0},hasData:function(e){return e=e.nodeType?v.cache[e[v.expando]]:e[v.expando],!!e&&!B(e)},data:function(e,n,r,i){if(!v.acceptData(e))return;var s,o,u=v.expando,a=typeof n=="string",f=e.nodeType,l=f?v.cache:e,c=f?e[u]:e[u]&&u;if((!c||!l[c]||!i&&!l[c].data)&&a&&r===t)return;c||(f?e[u]=c=v.deletedIds.pop()||v.guid++:c=u),l[c]||(l[c]={},f||(l[c].toJSON=v.noop));if(typeof n=="object"||typeof n=="function")i?l[c]=v.extend(l[c],n):l[c].data=v.extend(l[c].data,n);return s=l[c],i||(s.data||(s.data={}),s=s.data),r!==t&&(s[v.camelCase(n)]=r),a?(o=s[n],o==null&&(o=s[v.camelCase(n)])):o=s,o},removeData:function(e,t,n){if(!v.acceptData(e))return;var r,i,s,o=e.nodeType,u=o?v.cache:e,a=o?e[v.expando]:v.expando;if(!u[a])return;if(t){r=n?u[a]:u[a].data;if(r){v.isArray(t)||(t in r?t=[t]:(t=v.camelCase(t),t in r?t=[t]:t=t.split(" ")));for(i=0,s=t.length;i1,null,!1))},removeData:function(e){return this.each(function(){v.removeData(this,e)})}}),v.extend({queue:function(e,t,n){var r;if(e)return t=(t||"fx")+"queue",r=v._data(e,t),n&&(!r||v.isArray(n)?r=v._data(e,t,v.makeArray(n)):r.push(n)),r||[]},dequeue:function(e,t){t=t||"fx";var n=v.queue(e,t),r=n.length,i=n.shift(),s=v._queueHooks(e,t),o=function(){v.dequeue(e,t)};i==="inprogress"&&(i=n.shift(),r--),i&&(t==="fx"&&n.unshift("inprogress"),delete s.stop,i.call(e,o,s)),!r&&s&&s.empty.fire()},_queueHooks:function(e,t){var n=t+"queueHooks";return v._data(e,n)||v._data(e,n,{empty:v.Callbacks("once memory").add(function(){v.removeData(e,t+"queue",!0),v.removeData(e,n,!0)})})}}),v.fn.extend({queue:function(e,n){var r=2;return typeof e!="string"&&(n=e,e="fx",r--),arguments.length1)},removeAttr:function(e){return this.each(function(){v.removeAttr(this,e)})},prop:function(e,t){return v.access(this,v.prop,e,t,arguments.length>1)},removeProp:function(e){return e=v.propFix[e]||e,this.each(function(){try{this[e]=t,delete this[e]}catch(n){}})},addClass:function(e){var t,n,r,i,s,o,u;if(v.isFunction(e))return this.each(function(t){v(this).addClass(e.call(this,t,this.className))});if(e&&typeof e=="string"){t=e.split(y);for(n=0,r=this.length;n=0)r=r.replace(" "+n[s]+" "," ");i.className=e?v.trim(r):""}}}return this},toggleClass:function(e,t){var n=typeof e,r=typeof t=="boolean";return v.isFunction(e)?this.each(function(n){v(this).toggleClass(e.call(this,n,this.className,t),t)}):this.each(function(){if(n==="string"){var i,s=0,o=v(this),u=t,a=e.split(y);while(i=a[s++])u=r?u:!o.hasClass(i),o[u?"addClass":"removeClass"](i)}else if(n==="undefined"||n==="boolean")this.className&&v._data(this,"__className__",this.className),this.className=this.className||e===!1?"":v._data(this,"__className__")||""})},hasClass:function(e){var t=" "+e+" ",n=0,r=this.length;for(;n=0)return!0;return!1},val:function(e){var n,r,i,s=this[0];if(!arguments.length){if(s)return n=v.valHooks[s.type]||v.valHooks[s.nodeName.toLowerCase()],n&&"get"in n&&(r=n.get(s,"value"))!==t?r:(r=s.value,typeof r=="string"?r.replace(R,""):r==null?"":r);return}return i=v.isFunction(e),this.each(function(r){var s,o=v(this);if(this.nodeType!==1)return;i?s=e.call(this,r,o.val()):s=e,s==null?s="":typeof s=="number"?s+="":v.isArray(s)&&(s=v.map(s,function(e){return e==null?"":e+""})),n=v.valHooks[this.type]||v.valHooks[this.nodeName.toLowerCase()];if(!n||!("set"in n)||n.set(this,s,"value")===t)this.value=s})}}),v.extend({valHooks:{option:{get:function(e){var t=e.attributes.value;return!t||t.specified?e.value:e.text}},select:{get:function(e){var t,n,r=e.options,i=e.selectedIndex,s=e.type==="select-one"||i<0,o=s?null:[],u=s?i+1:r.length,a=i<0?u:s?i:0;for(;a=0}),n.length||(e.selectedIndex=-1),n}}},attrFn:{},attr:function(e,n,r,i){var s,o,u,a=e.nodeType;if(!e||a===3||a===8||a===2)return;if(i&&v.isFunction(v.fn[n]))return v(e)[n](r);if(typeof e.getAttribute=="undefined")return v.prop(e,n,r);u=a!==1||!v.isXMLDoc(e),u&&(n=n.toLowerCase(),o=v.attrHooks[n]||(X.test(n)?F:j));if(r!==t){if(r===null){v.removeAttr(e,n);return}return o&&"set"in o&&u&&(s=o.set(e,r,n))!==t?s:(e.setAttribute(n,r+""),r)}return o&&"get"in o&&u&&(s=o.get(e,n))!==null?s:(s=e.getAttribute(n),s===null?t:s)},removeAttr:function(e,t){var n,r,i,s,o=0;if(t&&e.nodeType===1){r=t.split(y);for(;o=0}})});var $=/^(?:textarea|input|select)$/i,J=/^([^\.]*|)(?:\.(.+)|)$/,K=/(?:^|\s)hover(\.\S+|)\b/,Q=/^key/,G=/^(?:mouse|contextmenu)|click/,Y=/^(?:focusinfocus|focusoutblur)$/,Z=function(e){return v.event.special.hover?e:e.replace(K,"mouseenter$1 mouseleave$1")};v.event={add:function(e,n,r,i,s){var o,u,a,f,l,c,h,p,d,m,g;if(e.nodeType===3||e.nodeType===8||!n||!r||!(o=v._data(e)))return;r.handler&&(d=r,r=d.handler,s=d.selector),r.guid||(r.guid=v.guid++),a=o.events,a||(o.events=a={}),u=o.handle,u||(o.handle=u=function(e){return typeof v=="undefined"||!!e&&v.event.triggered===e.type?t:v.event.dispatch.apply(u.elem,arguments)},u.elem=e),n=v.trim(Z(n)).split(" ");for(f=0;f=0&&(y=y.slice(0,-1),a=!0),y.indexOf(".")>=0&&(b=y.split("."),y=b.shift(),b.sort());if((!s||v.event.customEvent[y])&&!v.event.global[y])return;n=typeof n=="object"?n[v.expando]?n:new v.Event(y,n):new v.Event(y),n.type=y,n.isTrigger=!0,n.exclusive=a,n.namespace=b.join("."),n.namespace_re=n.namespace?new RegExp("(^|\\.)"+b.join("\\.(?:.*\\.|)")+"(\\.|$)"):null,h=y.indexOf(":")<0?"on"+y:"";if(!s){u=v.cache;for(f in u)u[f].events&&u[f].events[y]&&v.event.trigger(n,r,u[f].handle.elem,!0);return}n.result=t,n.target||(n.target=s),r=r!=null?v.makeArray(r):[],r.unshift(n),p=v.event.special[y]||{};if(p.trigger&&p.trigger.apply(s,r)===!1)return;m=[[s,p.bindType||y]];if(!o&&!p.noBubble&&!v.isWindow(s)){g=p.delegateType||y,l=Y.test(g+y)?s:s.parentNode;for(c=s;l;l=l.parentNode)m.push([l,g]),c=l;c===(s.ownerDocument||i)&&m.push([c.defaultView||c.parentWindow||e,g])}for(f=0;f=0:v.find(h,this,null,[s]).length),u[h]&&f.push(c);f.length&&w.push({elem:s,matches:f})}d.length>m&&w.push({elem:this,matches:d.slice(m)});for(r=0;r0?this.on(t,null,e,n):this.trigger(t)},Q.test(t)&&(v.event.fixHooks[t]=v.event.keyHooks),G.test(t)&&(v.event.fixHooks[t]=v.event.mouseHooks)}),function(e,t){function nt(e,t,n,r){n=n||[],t=t||g;var i,s,a,f,l=t.nodeType;if(!e||typeof e!="string")return n;if(l!==1&&l!==9)return[];a=o(t);if(!a&&!r)if(i=R.exec(e))if(f=i[1]){if(l===9){s=t.getElementById(f);if(!s||!s.parentNode)return n;if(s.id===f)return n.push(s),n}else if(t.ownerDocument&&(s=t.ownerDocument.getElementById(f))&&u(t,s)&&s.id===f)return n.push(s),n}else{if(i[2])return S.apply(n,x.call(t.getElementsByTagName(e),0)),n;if((f=i[3])&&Z&&t.getElementsByClassName)return S.apply(n,x.call(t.getElementsByClassName(f),0)),n}return vt(e.replace(j,"$1"),t,n,r,a)}function rt(e){return function(t){var n=t.nodeName.toLowerCase();return n==="input"&&t.type===e}}function it(e){return function(t){var n=t.nodeName.toLowerCase();return(n==="input"||n==="button")&&t.type===e}}function st(e){return N(function(t){return t=+t,N(function(n,r){var i,s=e([],n.length,t),o=s.length;while(o--)n[i=s[o]]&&(n[i]=!(r[i]=n[i]))})})}function ot(e,t,n){if(e===t)return n;var r=e.nextSibling;while(r){if(r===t)return-1;r=r.nextSibling}return 1}function ut(e,t){var n,r,s,o,u,a,f,l=L[d][e+" "];if(l)return t?0:l.slice(0);u=e,a=[],f=i.preFilter;while(u){if(!n||(r=F.exec(u)))r&&(u=u.slice(r[0].length)||u),a.push(s=[]);n=!1;if(r=I.exec(u))s.push(n=new m(r.shift())),u=u.slice(n.length),n.type=r[0].replace(j," ");for(o in i.filter)(r=J[o].exec(u))&&(!f[o]||(r=f[o](r)))&&(s.push(n=new m(r.shift())),u=u.slice(n.length),n.type=o,n.matches=r);if(!n)break}return t?u.length:u?nt.error(e):L(e,a).slice(0)}function at(e,t,r){var i=t.dir,s=r&&t.dir==="parentNode",o=w++;return t.first?function(t,n,r){while(t=t[i])if(s||t.nodeType===1)return e(t,n,r)}:function(t,r,u){if(!u){var a,f=b+" "+o+" ",l=f+n;while(t=t[i])if(s||t.nodeType===1){if((a=t[d])===l)return t.sizset;if(typeof a=="string"&&a.indexOf(f)===0){if(t.sizset)return t}else{t[d]=l;if(e(t,r,u))return t.sizset=!0,t;t.sizset=!1}}}else while(t=t[i])if(s||t.nodeType===1)if(e(t,r,u))return t}}function ft(e){return e.length>1?function(t,n,r){var i=e.length;while(i--)if(!e[i](t,n,r))return!1;return!0}:e[0]}function lt(e,t,n,r,i){var s,o=[],u=0,a=e.length,f=t!=null;for(;u-1&&(s[f]=!(o[f]=c))}}else g=lt(g===o?g.splice(d,g.length):g),i?i(null,o,g,a):S.apply(o,g)})}function ht(e){var t,n,r,s=e.length,o=i.relative[e[0].type],u=o||i.relative[" "],a=o?1:0,f=at(function(e){return e===t},u,!0),l=at(function(e){return T.call(t,e)>-1},u,!0),h=[function(e,n,r){return!o&&(r||n!==c)||((t=n).nodeType?f(e,n,r):l(e,n,r))}];for(;a1&&ft(h),a>1&&e.slice(0,a-1).join("").replace(j,"$1"),n,a0,s=e.length>0,o=function(u,a,f,l,h){var p,d,v,m=[],y=0,w="0",x=u&&[],T=h!=null,N=c,C=u||s&&i.find.TAG("*",h&&a.parentNode||a),k=b+=N==null?1:Math.E;T&&(c=a!==g&&a,n=o.el);for(;(p=C[w])!=null;w++){if(s&&p){for(d=0;v=e[d];d++)if(v(p,a,f)){l.push(p);break}T&&(b=k,n=++o.el)}r&&((p=!v&&p)&&y--,u&&x.push(p))}y+=w;if(r&&w!==y){for(d=0;v=t[d];d++)v(x,m,a,f);if(u){if(y>0)while(w--)!x[w]&&!m[w]&&(m[w]=E.call(l));m=lt(m)}S.apply(l,m),T&&!u&&m.length>0&&y+t.length>1&&nt.uniqueSort(l)}return T&&(b=k,c=N),x};return o.el=0,r?N(o):o}function dt(e,t,n){var r=0,i=t.length;for(;r2&&(f=u[0]).type==="ID"&&t.nodeType===9&&!s&&i.relative[u[1].type]){t=i.find.ID(f.matches[0].replace($,""),t,s)[0];if(!t)return n;e=e.slice(u.shift().length)}for(o=J.POS.test(e)?-1:u.length-1;o>=0;o--){f=u[o];if(i.relative[l=f.type])break;if(c=i.find[l])if(r=c(f.matches[0].replace($,""),z.test(u[0].type)&&t.parentNode||t,s)){u.splice(o,1),e=r.length&&u.join("");if(!e)return S.apply(n,x.call(r,0)),n;break}}}return a(e,h)(r,t,s,n,z.test(e)),n}function mt(){}var n,r,i,s,o,u,a,f,l,c,h=!0,p="undefined",d=("sizcache"+Math.random()).replace(".",""),m=String,g=e.document,y=g.documentElement,b=0,w=0,E=[].pop,S=[].push,x=[].slice,T=[].indexOf||function(e){var t=0,n=this.length;for(;ti.cacheLength&&delete e[t.shift()],e[n+" "]=r},e)},k=C(),L=C(),A=C(),O="[\\x20\\t\\r\\n\\f]",M="(?:\\\\.|[-\\w]|[^\\x00-\\xa0])+",_=M.replace("w","w#"),D="([*^$|!~]?=)",P="\\["+O+"*("+M+")"+O+"*(?:"+D+O+"*(?:(['\"])((?:\\\\.|[^\\\\])*?)\\3|("+_+")|)|)"+O+"*\\]",H=":("+M+")(?:\\((?:(['\"])((?:\\\\.|[^\\\\])*?)\\2|([^()[\\]]*|(?:(?:"+P+")|[^:]|\\\\.)*|.*))\\)|)",B=":(even|odd|eq|gt|lt|nth|first|last)(?:\\("+O+"*((?:-\\d)?\\d*)"+O+"*\\)|)(?=[^-]|$)",j=new RegExp("^"+O+"+|((?:^|[^\\\\])(?:\\\\.)*)"+O+"+$","g"),F=new RegExp("^"+O+"*,"+O+"*"),I=new RegExp("^"+O+"*([\\x20\\t\\r\\n\\f>+~])"+O+"*"),q=new RegExp(H),R=/^(?:#([\w\-]+)|(\w+)|\.([\w\-]+))$/,U=/^:not/,z=/[\x20\t\r\n\f]*[+~]/,W=/:not\($/,X=/h\d/i,V=/input|select|textarea|button/i,$=/\\(?!\\)/g,J={ID:new RegExp("^#("+M+")"),CLASS:new RegExp("^\\.("+M+")"),NAME:new RegExp("^\\[name=['\"]?("+M+")['\"]?\\]"),TAG:new RegExp("^("+M.replace("w","w*")+")"),ATTR:new RegExp("^"+P),PSEUDO:new RegExp("^"+H),POS:new RegExp(B,"i"),CHILD:new RegExp("^:(only|nth|first|last)-child(?:\\("+O+"*(even|odd|(([+-]|)(\\d*)n|)"+O+"*(?:([+-]|)"+O+"*(\\d+)|))"+O+"*\\)|)","i"),needsContext:new RegExp("^"+O+"*[>+~]|"+B,"i")},K=function(e){var t=g.createElement("div");try{return e(t)}catch(n){return!1}finally{t=null}},Q=K(function(e){return e.appendChild(g.createComment("")),!e.getElementsByTagName("*").length}),G=K(function(e){return e.innerHTML="",e.firstChild&&typeof e.firstChild.getAttribute!==p&&e.firstChild.getAttribute("href")==="#"}),Y=K(function(e){e.innerHTML="";var t=typeof e.lastChild.getAttribute("multiple");return t!=="boolean"&&t!=="string"}),Z=K(function(e){return e.innerHTML="",!e.getElementsByClassName||!e.getElementsByClassName("e").length?!1:(e.lastChild.className="e",e.getElementsByClassName("e").length===2)}),et=K(function(e){e.id=d+0,e.innerHTML="
",y.insertBefore(e,y.firstChild);var t=g.getElementsByName&&g.getElementsByName(d).length===2+g.getElementsByName(d+0).length;return r=!g.getElementById(d),y.removeChild(e),t});try{x.call(y.childNodes,0)[0].nodeType}catch(tt){x=function(e){var t,n=[];for(;t=this[e];e++)n.push(t);return n}}nt.matches=function(e,t){return nt(e,null,null,t)},nt.matchesSelector=function(e,t){return nt(t,null,null,[e]).length>0},s=nt.getText=function(e){var t,n="",r=0,i=e.nodeType;if(i){if(i===1||i===9||i===11){if(typeof e.textContent=="string")return e.textContent;for(e=e.firstChild;e;e=e.nextSibling)n+=s(e)}else if(i===3||i===4)return e.nodeValue}else for(;t=e[r];r++)n+=s(t);return n},o=nt.isXML=function(e){var t=e&&(e.ownerDocument||e).documentElement;return t?t.nodeName!=="HTML":!1},u=nt.contains=y.contains?function(e,t){var n=e.nodeType===9?e.documentElement:e,r=t&&t.parentNode;return e===r||!!(r&&r.nodeType===1&&n.contains&&n.contains(r))}:y.compareDocumentPosition?function(e,t){return t&&!!(e.compareDocumentPosition(t)&16)}:function(e,t){while(t=t.parentNode)if(t===e)return!0;return!1},nt.attr=function(e,t){var n,r=o(e);return r||(t=t.toLowerCase()),(n=i.attrHandle[t])?n(e):r||Y?e.getAttribute(t):(n=e.getAttributeNode(t),n?typeof e[t]=="boolean"?e[t]?t:null:n.specified?n.value:null:null)},i=nt.selectors={cacheLength:50,createPseudo:N,match:J,attrHandle:G?{}:{href:function(e){return e.getAttribute("href",2)},type:function(e){return e.getAttribute("type")}},find:{ID:r?function(e,t,n){if(typeof t.getElementById!==p&&!n){var r=t.getElementById(e);return r&&r.parentNode?[r]:[]}}:function(e,n,r){if(typeof n.getElementById!==p&&!r){var i=n.getElementById(e);return i?i.id===e||typeof i.getAttributeNode!==p&&i.getAttributeNode("id").value===e?[i]:t:[]}},TAG:Q?function(e,t){if(typeof t.getElementsByTagName!==p)return t.getElementsByTagName(e)}:function(e,t){var n=t.getElementsByTagName(e);if(e==="*"){var r,i=[],s=0;for(;r=n[s];s++)r.nodeType===1&&i.push(r);return i}return n},NAME:et&&function(e,t){if(typeof t.getElementsByName!==p)return t.getElementsByName(name)},CLASS:Z&&function(e,t,n){if(typeof t.getElementsByClassName!==p&&!n)return t.getElementsByClassName(e)}},relative:{">":{dir:"parentNode",first:!0}," ":{dir:"parentNode"},"+":{dir:"previousSibling",first:!0},"~":{dir:"previousSibling"}},preFilter:{ATTR:function(e){return e[1]=e[1].replace($,""),e[3]=(e[4]||e[5]||"").replace($,""),e[2]==="~="&&(e[3]=" "+e[3]+" "),e.slice(0,4)},CHILD:function(e){return e[1]=e[1].toLowerCase(),e[1]==="nth"?(e[2]||nt.error(e[0]),e[3]=+(e[3]?e[4]+(e[5]||1):2*(e[2]==="even"||e[2]==="odd")),e[4]=+(e[6]+e[7]||e[2]==="odd")):e[2]&&nt.error(e[0]),e},PSEUDO:function(e){var t,n;if(J.CHILD.test(e[0]))return null;if(e[3])e[2]=e[3];else if(t=e[4])q.test(t)&&(n=ut(t,!0))&&(n=t.indexOf(")",t.length-n)-t.length)&&(t=t.slice(0,n),e[0]=e[0].slice(0,n)),e[2]=t;return e.slice(0,3)}},filter:{ID:r?function(e){return e=e.replace($,""),function(t){return t.getAttribute("id")===e}}:function(e){return e=e.replace($,""),function(t){var n=typeof t.getAttributeNode!==p&&t.getAttributeNode("id");return n&&n.value===e}},TAG:function(e){return e==="*"?function(){return!0}:(e=e.replace($,"").toLowerCase(),function(t){return t.nodeName&&t.nodeName.toLowerCase()===e})},CLASS:function(e){var t=k[d][e+" "];return t||(t=new RegExp("(^|"+O+")"+e+"("+O+"|$)"))&&k(e,function(e){return t.test(e.className||typeof e.getAttribute!==p&&e.getAttribute("class")||"")})},ATTR:function(e,t,n){return function(r,i){var s=nt.attr(r,e);return s==null?t==="!=":t?(s+="",t==="="?s===n:t==="!="?s!==n:t==="^="?n&&s.indexOf(n)===0:t==="*="?n&&s.indexOf(n)>-1:t==="$="?n&&s.substr(s.length-n.length)===n:t==="~="?(" "+s+" ").indexOf(n)>-1:t==="|="?s===n||s.substr(0,n.length+1)===n+"-":!1):!0}},CHILD:function(e,t,n,r){return e==="nth"?function(e){var t,i,s=e.parentNode;if(n===1&&r===0)return!0;if(s){i=0;for(t=s.firstChild;t;t=t.nextSibling)if(t.nodeType===1){i++;if(e===t)break}}return i-=r,i===n||i%n===0&&i/n>=0}:function(t){var n=t;switch(e){case"only":case"first":while(n=n.previousSibling)if(n.nodeType===1)return!1;if(e==="first")return!0;n=t;case"last":while(n=n.nextSibling)if(n.nodeType===1)return!1;return!0}}},PSEUDO:function(e,t){var n,r=i.pseudos[e]||i.setFilters[e.toLowerCase()]||nt.error("unsupported pseudo: "+e);return r[d]?r(t):r.length>1?(n=[e,e,"",t],i.setFilters.hasOwnProperty(e.toLowerCase())?N(function(e,n){var i,s=r(e,t),o=s.length;while(o--)i=T.call(e,s[o]),e[i]=!(n[i]=s[o])}):function(e){return r(e,0,n)}):r}},pseudos:{not:N(function(e){var t=[],n=[],r=a(e.replace(j,"$1"));return r[d]?N(function(e,t,n,i){var s,o=r(e,null,i,[]),u=e.length;while(u--)if(s=o[u])e[u]=!(t[u]=s)}):function(e,i,s){return t[0]=e,r(t,null,s,n),!n.pop()}}),has:N(function(e){return function(t){return nt(e,t).length>0}}),contains:N(function(e){return function(t){return(t.textContent||t.innerText||s(t)).indexOf(e)>-1}}),enabled:function(e){return e.disabled===!1},disabled:function(e){return e.disabled===!0},checked:function(e){var t=e.nodeName.toLowerCase();return t==="input"&&!!e.checked||t==="option"&&!!e.selected},selected:function(e){return e.parentNode&&e.parentNode.selectedIndex,e.selected===!0},parent:function(e){return!i.pseudos.empty(e)},empty:function(e){var t;e=e.firstChild;while(e){if(e.nodeName>"@"||(t=e.nodeType)===3||t===4)return!1;e=e.nextSibling}return!0},header:function(e){return X.test(e.nodeName)},text:function(e){var t,n;return e.nodeName.toLowerCase()==="input"&&(t=e.type)==="text"&&((n=e.getAttribute("type"))==null||n.toLowerCase()===t)},radio:rt("radio"),checkbox:rt("checkbox"),file:rt("file"),password:rt("password"),image:rt("image"),submit:it("submit"),reset:it("reset"),button:function(e){var t=e.nodeName.toLowerCase();return t==="input"&&e.type==="button"||t==="button"},input:function(e){return V.test(e.nodeName)},focus:function(e){var t=e.ownerDocument;return e===t.activeElement&&(!t.hasFocus||t.hasFocus())&&!!(e.type||e.href||~e.tabIndex)},active:function(e){return e===e.ownerDocument.activeElement},first:st(function(){return[0]}),last:st(function(e,t){return[t-1]}),eq:st(function(e,t,n){return[n<0?n+t:n]}),even:st(function(e,t){for(var n=0;n=0;)e.push(r);return e}),gt:st(function(e,t,n){for(var r=n<0?n+t:n;++r",e.querySelectorAll("[selected]").length||i.push("\\["+O+"*(?:checked|disabled|ismap|multiple|readonly|selected|value)"),e.querySelectorAll(":checked").length||i.push(":checked")}),K(function(e){e.innerHTML="

",e.querySelectorAll("[test^='']").length&&i.push("[*^$]="+O+"*(?:\"\"|'')"),e.innerHTML="",e.querySelectorAll(":enabled").length||i.push(":enabled",":disabled")}),i=new RegExp(i.join("|")),vt=function(e,r,s,o,u){if(!o&&!u&&!i.test(e)){var a,f,l=!0,c=d,h=r,p=r.nodeType===9&&e;if(r.nodeType===1&&r.nodeName.toLowerCase()!=="object"){a=ut(e),(l=r.getAttribute("id"))?c=l.replace(n,"\\$&"):r.setAttribute("id",c),c="[id='"+c+"'] ",f=a.length;while(f--)a[f]=c+a[f].join("");h=z.test(e)&&r.parentNode||r,p=a.join(",")}if(p)try{return S.apply(s,x.call(h.querySelectorAll(p),0)),s}catch(v){}finally{l||r.removeAttribute("id")}}return t(e,r,s,o,u)},u&&(K(function(t){e=u.call(t,"div");try{u.call(t,"[test!='']:sizzle"),s.push("!=",H)}catch(n){}}),s=new RegExp(s.join("|")),nt.matchesSelector=function(t,n){n=n.replace(r,"='$1']");if(!o(t)&&!s.test(n)&&!i.test(n))try{var a=u.call(t,n);if(a||e||t.document&&t.document.nodeType!==11)return a}catch(f){}return nt(n,null,null,[t]).length>0})}(),i.pseudos.nth=i.pseudos.eq,i.filters=mt.prototype=i.pseudos,i.setFilters=new mt,nt.attr=v.attr,v.find=nt,v.expr=nt.selectors,v.expr[":"]=v.expr.pseudos,v.unique=nt.uniqueSort,v.text=nt.getText,v.isXMLDoc=nt.isXML,v.contains=nt.contains}(e);var nt=/Until$/,rt=/^(?:parents|prev(?:Until|All))/,it=/^.[^:#\[\.,]*$/,st=v.expr.match.needsContext,ot={children:!0,contents:!0,next:!0,prev:!0};v.fn.extend({find:function(e){var t,n,r,i,s,o,u=this;if(typeof e!="string")return v(e).filter(function(){for(t=0,n=u.length;t0)for(i=r;i=0:v.filter(e,this).length>0:this.filter(e).length>0)},closest:function(e,t){var n,r=0,i=this.length,s=[],o=st.test(e)||typeof e!="string"?v(e,t||this.context):0;for(;r-1:v.find.matchesSelector(n,e)){s.push(n);break}n=n.parentNode}}return s=s.length>1?v.unique(s):s,this.pushStack(s,"closest",e)},index:function(e){return e?typeof e=="string"?v.inArray(this[0],v(e)):v.inArray(e.jquery?e[0]:e,this):this[0]&&this[0].parentNode?this.prevAll().length:-1},add:function(e,t){var n=typeof e=="string"?v(e,t):v.makeArray(e&&e.nodeType?[e]:e),r=v.merge(this.get(),n);return this.pushStack(ut(n[0])||ut(r[0])?r:v.unique(r))},addBack:function(e){return this.add(e==null?this.prevObject:this.prevObject.filter(e))}}),v.fn.andSelf=v.fn.addBack,v.each({parent:function(e){var t=e.parentNode;return t&&t.nodeType!==11?t:null},parents:function(e){return v.dir(e,"parentNode")},parentsUntil:function(e,t,n){return v.dir(e,"parentNode",n)},next:function(e){return at(e,"nextSibling")},prev:function(e){return at(e,"previousSibling")},nextAll:function(e){return v.dir(e,"nextSibling")},prevAll:function(e){return v.dir(e,"previousSibling")},nextUntil:function(e,t,n){return v.dir(e,"nextSibling",n)},prevUntil:function(e,t,n){return v.dir(e,"previousSibling",n)},siblings:function(e){return v.sibling((e.parentNode||{}).firstChild,e)},children:function(e){return v.sibling(e.firstChild)},contents:function(e){return v.nodeName(e,"iframe")?e.contentDocument||e.contentWindow.document:v.merge([],e.childNodes)}},function(e,t){v.fn[e]=function(n,r){var i=v.map(this,t,n);return nt.test(e)||(r=n),r&&typeof r=="string"&&(i=v.filter(r,i)),i=this.length>1&&!ot[e]?v.unique(i):i,this.length>1&&rt.test(e)&&(i=i.reverse()),this.pushStack(i,e,l.call(arguments).join(","))}}),v.extend({filter:function(e,t,n){return n&&(e=":not("+e+")"),t.length===1?v.find.matchesSelector(t[0],e)?[t[0]]:[]:v.find.matches(e,t)},dir:function(e,n,r){var i=[],s=e[n];while(s&&s.nodeType!==9&&(r===t||s.nodeType!==1||!v(s).is(r)))s.nodeType===1&&i.push(s),s=s[n];return i},sibling:function(e,t){var n=[];for(;e;e=e.nextSibling)e.nodeType===1&&e!==t&&n.push(e);return n}});var ct="abbr|article|aside|audio|bdi|canvas|data|datalist|details|figcaption|figure|footer|header|hgroup|mark|meter|nav|output|progress|section|summary|time|video",ht=/ jQuery\d+="(?:null|\d+)"/g,pt=/^\s+/,dt=/<(?!area|br|col|embed|hr|img|input|link|meta|param)(([\w:]+)[^>]*)\/>/gi,vt=/<([\w:]+)/,mt=/]","i"),Et=/^(?:checkbox|radio)$/,St=/checked\s*(?:[^=]|=\s*.checked.)/i,xt=/\/(java|ecma)script/i,Tt=/^\s*\s*$/g,Nt={option:[1,""],legend:[1,"
","
"],thead:[1,"","
"],tr:[2,"","
"],td:[3,"","
"],col:[2,"","
"],area:[1,"",""],_default:[0,"",""]},Ct=lt(i),kt=Ct.appendChild(i.createElement("div"));Nt.optgroup=Nt.option,Nt.tbody=Nt.tfoot=Nt.colgroup=Nt.caption=Nt.thead,Nt.th=Nt.td,v.support.htmlSerialize||(Nt._default=[1,"X
","
"]),v.fn.extend({text:function(e){return v.access(this,function(e){return e===t?v.text(this):this.empty().append((this[0]&&this[0].ownerDocument||i).createTextNode(e))},null,e,arguments.length)},wrapAll:function(e){if(v.isFunction(e))return this.each(function(t){v(this).wrapAll(e.call(this,t))});if(this[0]){var t=v(e,this[0].ownerDocument).eq(0).clone(!0);this[0].parentNode&&t.insertBefore(this[0]),t.map(function(){var e=this;while(e.firstChild&&e.firstChild.nodeType===1)e=e.firstChild;return e}).append(this)}return this},wrapInner:function(e){return v.isFunction(e)?this.each(function(t){v(this).wrapInner(e.call(this,t))}):this.each(function(){var t=v(this),n=t.contents();n.length?n.wrapAll(e):t.append(e)})},wrap:function(e){var t=v.isFunction(e);return this.each(function(n){v(this).wrapAll(t?e.call(this,n):e)})},unwrap:function(){return this.parent().each(function(){v.nodeName(this,"body")||v(this).replaceWith(this.childNodes)}).end()},append:function(){return this.domManip(arguments,!0,function(e){(this.nodeType===1||this.nodeType===11)&&this.appendChild(e)})},prepend:function(){return this.domManip(arguments,!0,function(e){(this.nodeType===1||this.nodeType===11)&&this.insertBefore(e,this.firstChild)})},before:function(){if(!ut(this[0]))return this.domManip(arguments,!1,function(e){this.parentNode.insertBefore(e,this)});if(arguments.length){var e=v.clean(arguments);return this.pushStack(v.merge(e,this),"before",this.selector)}},after:function(){if(!ut(this[0]))return this.domManip(arguments,!1,function(e){this.parentNode.insertBefore(e,this.nextSibling)});if(arguments.length){var e=v.clean(arguments);return this.pushStack(v.merge(this,e),"after",this.selector)}},remove:function(e,t){var n,r=0;for(;(n=this[r])!=null;r++)if(!e||v.filter(e,[n]).length)!t&&n.nodeType===1&&(v.cleanData(n.getElementsByTagName("*")),v.cleanData([n])),n.parentNode&&n.parentNode.removeChild(n);return this},empty:function(){var e,t=0;for(;(e=this[t])!=null;t++){e.nodeType===1&&v.cleanData(e.getElementsByTagName("*"));while(e.firstChild)e.removeChild(e.firstChild)}return this},clone:function(e,t){return e=e==null?!1:e,t=t==null?e:t,this.map(function(){return v.clone(this,e,t)})},html:function(e){return v.access(this,function(e){var n=this[0]||{},r=0,i=this.length;if(e===t)return n.nodeType===1?n.innerHTML.replace(ht,""):t;if(typeof e=="string"&&!yt.test(e)&&(v.support.htmlSerialize||!wt.test(e))&&(v.support.leadingWhitespace||!pt.test(e))&&!Nt[(vt.exec(e)||["",""])[1].toLowerCase()]){e=e.replace(dt,"<$1>");try{for(;r1&&typeof f=="string"&&St.test(f))return this.each(function(){v(this).domManip(e,n,r)});if(v.isFunction(f))return this.each(function(i){var s=v(this);e[0]=f.call(this,i,n?s.html():t),s.domManip(e,n,r)});if(this[0]){i=v.buildFragment(e,this,l),o=i.fragment,s=o.firstChild,o.childNodes.length===1&&(o=s);if(s){n=n&&v.nodeName(s,"tr");for(u=i.cacheable||c-1;a0?this.clone(!0):this).get(),v(o[i])[t](r),s=s.concat(r);return this.pushStack(s,e,o.selector)}}),v.extend({clone:function(e,t,n){var r,i,s,o;v.support.html5Clone||v.isXMLDoc(e)||!wt.test("<"+e.nodeName+">")?o=e.cloneNode(!0):(kt.innerHTML=e.outerHTML,kt.removeChild(o=kt.firstChild));if((!v.support.noCloneEvent||!v.support.noCloneChecked)&&(e.nodeType===1||e.nodeType===11)&&!v.isXMLDoc(e)){Ot(e,o),r=Mt(e),i=Mt(o);for(s=0;r[s];++s)i[s]&&Ot(r[s],i[s])}if(t){At(e,o);if(n){r=Mt(e),i=Mt(o);for(s=0;r[s];++s)At(r[s],i[s])}}return r=i=null,o},clean:function(e,t,n,r){var s,o,u,a,f,l,c,h,p,d,m,g,y=t===i&&Ct,b=[];if(!t||typeof t.createDocumentFragment=="undefined")t=i;for(s=0;(u=e[s])!=null;s++){typeof u=="number"&&(u+="");if(!u)continue;if(typeof u=="string")if(!gt.test(u))u=t.createTextNode(u);else{y=y||lt(t),c=t.createElement("div"),y.appendChild(c),u=u.replace(dt,"<$1>"),a=(vt.exec(u)||["",""])[1].toLowerCase(),f=Nt[a]||Nt._default,l=f[0],c.innerHTML=f[1]+u+f[2];while(l--)c=c.lastChild;if(!v.support.tbody){h=mt.test(u),p=a==="table"&&!h?c.firstChild&&c.firstChild.childNodes:f[1]===""&&!h?c.childNodes:[];for(o=p.length-1;o>=0;--o)v.nodeName(p[o],"tbody")&&!p[o].childNodes.length&&p[o].parentNode.removeChild(p[o])}!v.support.leadingWhitespace&&pt.test(u)&&c.insertBefore(t.createTextNode(pt.exec(u)[0]),c.firstChild),u=c.childNodes,c.parentNode.removeChild(c)}u.nodeType?b.push(u):v.merge(b,u)}c&&(u=c=y=null);if(!v.support.appendChecked)for(s=0;(u=b[s])!=null;s++)v.nodeName(u,"input")?_t(u):typeof u.getElementsByTagName!="undefined"&&v.grep(u.getElementsByTagName("input"),_t);if(n){m=function(e){if(!e.type||xt.test(e.type))return r?r.push(e.parentNode?e.parentNode.removeChild(e):e):n.appendChild(e)};for(s=0;(u=b[s])!=null;s++)if(!v.nodeName(u,"script")||!m(u))n.appendChild(u),typeof u.getElementsByTagName!="undefined"&&(g=v.grep(v.merge([],u.getElementsByTagName("script")),m),b.splice.apply(b,[s+1,0].concat(g)),s+=g.length)}return b},cleanData:function(e,t){var n,r,i,s,o=0,u=v.expando,a=v.cache,f=v.support.deleteExpando,l=v.event.special;for(;(i=e[o])!=null;o++)if(t||v.acceptData(i)){r=i[u],n=r&&a[r];if(n){if(n.events)for(s in n.events)l[s]?v.event.remove(i,s):v.removeEvent(i,s,n.handle);a[r]&&(delete a[r],f?delete i[u]:i.removeAttribute?i.removeAttribute(u):i[u]=null,v.deletedIds.push(r))}}}}),function(){var e,t;v.uaMatch=function(e){e=e.toLowerCase();var t=/(chrome)[ \/]([\w.]+)/.exec(e)||/(webkit)[ \/]([\w.]+)/.exec(e)||/(opera)(?:.*version|)[ \/]([\w.]+)/.exec(e)||/(msie) ([\w.]+)/.exec(e)||e.indexOf("compatible")<0&&/(mozilla)(?:.*? rv:([\w.]+)|)/.exec(e)||[];return{browser:t[1]||"",version:t[2]||"0"}},e=v.uaMatch(o.userAgent),t={},e.browser&&(t[e.browser]=!0,t.version=e.version),t.chrome?t.webkit=!0:t.webkit&&(t.safari=!0),v.browser=t,v.sub=function(){function e(t,n){return new e.fn.init(t,n)}v.extend(!0,e,this),e.superclass=this,e.fn=e.prototype=this(),e.fn.constructor=e,e.sub=this.sub,e.fn.init=function(r,i){return i&&i instanceof v&&!(i instanceof e)&&(i=e(i)),v.fn.init.call(this,r,i,t)},e.fn.init.prototype=e.fn;var t=e(i);return e}}();var Dt,Pt,Ht,Bt=/alpha\([^)]*\)/i,jt=/opacity=([^)]*)/,Ft=/^(top|right|bottom|left)$/,It=/^(none|table(?!-c[ea]).+)/,qt=/^margin/,Rt=new RegExp("^("+m+")(.*)$","i"),Ut=new RegExp("^("+m+")(?!px)[a-z%]+$","i"),zt=new RegExp("^([-+])=("+m+")","i"),Wt={BODY:"block"},Xt={position:"absolute",visibility:"hidden",display:"block"},Vt={letterSpacing:0,fontWeight:400},$t=["Top","Right","Bottom","Left"],Jt=["Webkit","O","Moz","ms"],Kt=v.fn.toggle;v.fn.extend({css:function(e,n){return v.access(this,function(e,n,r){return r!==t?v.style(e,n,r):v.css(e,n)},e,n,arguments.length>1)},show:function(){return Yt(this,!0)},hide:function(){return Yt(this)},toggle:function(e,t){var n=typeof e=="boolean";return v.isFunction(e)&&v.isFunction(t)?Kt.apply(this,arguments):this.each(function(){(n?e:Gt(this))?v(this).show():v(this).hide()})}}),v.extend({cssHooks:{opacity:{get:function(e,t){if(t){var n=Dt(e,"opacity");return n===""?"1":n}}}},cssNumber:{fillOpacity:!0,fontWeight:!0,lineHeight:!0,opacity:!0,orphans:!0,widows:!0,zIndex:!0,zoom:!0},cssProps:{"float":v.support.cssFloat?"cssFloat":"styleFloat"},style:function(e,n,r,i){if(!e||e.nodeType===3||e.nodeType===8||!e.style)return;var s,o,u,a=v.camelCase(n),f=e.style;n=v.cssProps[a]||(v.cssProps[a]=Qt(f,a)),u=v.cssHooks[n]||v.cssHooks[a];if(r===t)return u&&"get"in u&&(s=u.get(e,!1,i))!==t?s:f[n];o=typeof r,o==="string"&&(s=zt.exec(r))&&(r=(s[1]+1)*s[2]+parseFloat(v.css(e,n)),o="number");if(r==null||o==="number"&&isNaN(r))return;o==="number"&&!v.cssNumber[a]&&(r+="px");if(!u||!("set"in u)||(r=u.set(e,r,i))!==t)try{f[n]=r}catch(l){}},css:function(e,n,r,i){var s,o,u,a=v.camelCase(n);return n=v.cssProps[a]||(v.cssProps[a]=Qt(e.style,a)),u=v.cssHooks[n]||v.cssHooks[a],u&&"get"in u&&(s=u.get(e,!0,i)),s===t&&(s=Dt(e,n)),s==="normal"&&n in Vt&&(s=Vt[n]),r||i!==t?(o=parseFloat(s),r||v.isNumeric(o)?o||0:s):s},swap:function(e,t,n){var r,i,s={};for(i in t)s[i]=e.style[i],e.style[i]=t[i];r=n.call(e);for(i in t)e.style[i]=s[i];return r}}),e.getComputedStyle?Dt=function(t,n){var r,i,s,o,u=e.getComputedStyle(t,null),a=t.style;return u&&(r=u.getPropertyValue(n)||u[n],r===""&&!v.contains(t.ownerDocument,t)&&(r=v.style(t,n)),Ut.test(r)&&qt.test(n)&&(i=a.width,s=a.minWidth,o=a.maxWidth,a.minWidth=a.maxWidth=a.width=r,r=u.width,a.width=i,a.minWidth=s,a.maxWidth=o)),r}:i.documentElement.currentStyle&&(Dt=function(e,t){var n,r,i=e.currentStyle&&e.currentStyle[t],s=e.style;return i==null&&s&&s[t]&&(i=s[t]),Ut.test(i)&&!Ft.test(t)&&(n=s.left,r=e.runtimeStyle&&e.runtimeStyle.left,r&&(e.runtimeStyle.left=e.currentStyle.left),s.left=t==="fontSize"?"1em":i,i=s.pixelLeft+"px",s.left=n,r&&(e.runtimeStyle.left=r)),i===""?"auto":i}),v.each(["height","width"],function(e,t){v.cssHooks[t]={get:function(e,n,r){if(n)return e.offsetWidth===0&&It.test(Dt(e,"display"))?v.swap(e,Xt,function(){return tn(e,t,r)}):tn(e,t,r)},set:function(e,n,r){return Zt(e,n,r?en(e,t,r,v.support.boxSizing&&v.css(e,"boxSizing")==="border-box"):0)}}}),v.support.opacity||(v.cssHooks.opacity={get:function(e,t){return jt.test((t&&e.currentStyle?e.currentStyle.filter:e.style.filter)||"")?.01*parseFloat(RegExp.$1)+"":t?"1":""},set:function(e,t){var n=e.style,r=e.currentStyle,i=v.isNumeric(t)?"alpha(opacity="+t*100+")":"",s=r&&r.filter||n.filter||"";n.zoom=1;if(t>=1&&v.trim(s.replace(Bt,""))===""&&n.removeAttribute){n.removeAttribute("filter");if(r&&!r.filter)return}n.filter=Bt.test(s)?s.replace(Bt,i):s+" "+i}}),v(function(){v.support.reliableMarginRight||(v.cssHooks.marginRight={get:function(e,t){return v.swap(e,{display:"inline-block"},function(){if(t)return Dt(e,"marginRight")})}}),!v.support.pixelPosition&&v.fn.position&&v.each(["top","left"],function(e,t){v.cssHooks[t]={get:function(e,n){if(n){var r=Dt(e,t);return Ut.test(r)?v(e).position()[t]+"px":r}}}})}),v.expr&&v.expr.filters&&(v.expr.filters.hidden=function(e){return e.offsetWidth===0&&e.offsetHeight===0||!v.support.reliableHiddenOffsets&&(e.style&&e.style.display||Dt(e,"display"))==="none"},v.expr.filters.visible=function(e){return!v.expr.filters.hidden(e)}),v.each({margin:"",padding:"",border:"Width"},function(e,t){v.cssHooks[e+t]={expand:function(n){var r,i=typeof n=="string"?n.split(" "):[n],s={};for(r=0;r<4;r++)s[e+$t[r]+t]=i[r]||i[r-2]||i[0];return s}},qt.test(e)||(v.cssHooks[e+t].set=Zt)});var rn=/%20/g,sn=/\[\]$/,on=/\r?\n/g,un=/^(?:color|date|datetime|datetime-local|email|hidden|month|number|password|range|search|tel|text|time|url|week)$/i,an=/^(?:select|textarea)/i;v.fn.extend({serialize:function(){return v.param(this.serializeArray())},serializeArray:function(){return this.map(function(){return this.elements?v.makeArray(this.elements):this}).filter(function(){return this.name&&!this.disabled&&(this.checked||an.test(this.nodeName)||un.test(this.type))}).map(function(e,t){var n=v(this).val();return n==null?null:v.isArray(n)?v.map(n,function(e,n){return{name:t.name,value:e.replace(on,"\r\n")}}):{name:t.name,value:n.replace(on,"\r\n")}}).get()}}),v.param=function(e,n){var r,i=[],s=function(e,t){t=v.isFunction(t)?t():t==null?"":t,i[i.length]=encodeURIComponent(e)+"="+encodeURIComponent(t)};n===t&&(n=v.ajaxSettings&&v.ajaxSettings.traditional);if(v.isArray(e)||e.jquery&&!v.isPlainObject(e))v.each(e,function(){s(this.name,this.value)});else for(r in e)fn(r,e[r],n,s);return i.join("&").replace(rn,"+")};var ln,cn,hn=/#.*$/,pn=/^(.*?):[ \t]*([^\r\n]*)\r?$/mg,dn=/^(?:about|app|app\-storage|.+\-extension|file|res|widget):$/,vn=/^(?:GET|HEAD)$/,mn=/^\/\//,gn=/\?/,yn=/)<[^<]*)*<\/script>/gi,bn=/([?&])_=[^&]*/,wn=/^([\w\+\.\-]+:)(?:\/\/([^\/?#:]*)(?::(\d+)|)|)/,En=v.fn.load,Sn={},xn={},Tn=["*/"]+["*"];try{cn=s.href}catch(Nn){cn=i.createElement("a"),cn.href="",cn=cn.href}ln=wn.exec(cn.toLowerCase())||[],v.fn.load=function(e,n,r){if(typeof e!="string"&&En)return En.apply(this,arguments);if(!this.length)return this;var i,s,o,u=this,a=e.indexOf(" ");return a>=0&&(i=e.slice(a,e.length),e=e.slice(0,a)),v.isFunction(n)?(r=n,n=t):n&&typeof n=="object"&&(s="POST"),v.ajax({url:e,type:s,dataType:"html",data:n,complete:function(e,t){r&&u.each(r,o||[e.responseText,t,e])}}).done(function(e){o=arguments,u.html(i?v("
").append(e.replace(yn,"")).find(i):e)}),this},v.each("ajaxStart ajaxStop ajaxComplete ajaxError ajaxSuccess ajaxSend".split(" "),function(e,t){v.fn[t]=function(e){return this.on(t,e)}}),v.each(["get","post"],function(e,n){v[n]=function(e,r,i,s){return v.isFunction(r)&&(s=s||i,i=r,r=t),v.ajax({type:n,url:e,data:r,success:i,dataType:s})}}),v.extend({getScript:function(e,n){return v.get(e,t,n,"script")},getJSON:function(e,t,n){return v.get(e,t,n,"json")},ajaxSetup:function(e,t){return t?Ln(e,v.ajaxSettings):(t=e,e=v.ajaxSettings),Ln(e,t),e},ajaxSettings:{url:cn,isLocal:dn.test(ln[1]),global:!0,type:"GET",contentType:"application/x-www-form-urlencoded; charset=UTF-8",processData:!0,async:!0,accepts:{xml:"application/xml, text/xml",html:"text/html",text:"text/plain",json:"application/json, text/javascript","*":Tn},contents:{xml:/xml/,html:/html/,json:/json/},responseFields:{xml:"responseXML",text:"responseText"},converters:{"* text":e.String,"text html":!0,"text json":v.parseJSON,"text xml":v.parseXML},flatOptions:{context:!0,url:!0}},ajaxPrefilter:Cn(Sn),ajaxTransport:Cn(xn),ajax:function(e,n){function T(e,n,s,a){var l,y,b,w,S,T=n;if(E===2)return;E=2,u&&clearTimeout(u),o=t,i=a||"",x.readyState=e>0?4:0,s&&(w=An(c,x,s));if(e>=200&&e<300||e===304)c.ifModified&&(S=x.getResponseHeader("Last-Modified"),S&&(v.lastModified[r]=S),S=x.getResponseHeader("Etag"),S&&(v.etag[r]=S)),e===304?(T="notmodified",l=!0):(l=On(c,w),T=l.state,y=l.data,b=l.error,l=!b);else{b=T;if(!T||e)T="error",e<0&&(e=0)}x.status=e,x.statusText=(n||T)+"",l?d.resolveWith(h,[y,T,x]):d.rejectWith(h,[x,T,b]),x.statusCode(g),g=t,f&&p.trigger("ajax"+(l?"Success":"Error"),[x,c,l?y:b]),m.fireWith(h,[x,T]),f&&(p.trigger("ajaxComplete",[x,c]),--v.active||v.event.trigger("ajaxStop"))}typeof e=="object"&&(n=e,e=t),n=n||{};var r,i,s,o,u,a,f,l,c=v.ajaxSetup({},n),h=c.context||c,p=h!==c&&(h.nodeType||h instanceof v)?v(h):v.event,d=v.Deferred(),m=v.Callbacks("once memory"),g=c.statusCode||{},b={},w={},E=0,S="canceled",x={readyState:0,setRequestHeader:function(e,t){if(!E){var n=e.toLowerCase();e=w[n]=w[n]||e,b[e]=t}return this},getAllResponseHeaders:function(){return E===2?i:null},getResponseHeader:function(e){var n;if(E===2){if(!s){s={};while(n=pn.exec(i))s[n[1].toLowerCase()]=n[2]}n=s[e.toLowerCase()]}return n===t?null:n},overrideMimeType:function(e){return E||(c.mimeType=e),this},abort:function(e){return e=e||S,o&&o.abort(e),T(0,e),this}};d.promise(x),x.success=x.done,x.error=x.fail,x.complete=m.add,x.statusCode=function(e){if(e){var t;if(E<2)for(t in e)g[t]=[g[t],e[t]];else t=e[x.status],x.always(t)}return this},c.url=((e||c.url)+"").replace(hn,"").replace(mn,ln[1]+"//"),c.dataTypes=v.trim(c.dataType||"*").toLowerCase().split(y),c.crossDomain==null&&(a=wn.exec(c.url.toLowerCase()),c.crossDomain=!(!a||a[1]===ln[1]&&a[2]===ln[2]&&(a[3]||(a[1]==="http:"?80:443))==(ln[3]||(ln[1]==="http:"?80:443)))),c.data&&c.processData&&typeof c.data!="string"&&(c.data=v.param(c.data,c.traditional)),kn(Sn,c,n,x);if(E===2)return x;f=c.global,c.type=c.type.toUpperCase(),c.hasContent=!vn.test(c.type),f&&v.active++===0&&v.event.trigger("ajaxStart");if(!c.hasContent){c.data&&(c.url+=(gn.test(c.url)?"&":"?")+c.data,delete c.data),r=c.url;if(c.cache===!1){var N=v.now(),C=c.url.replace(bn,"$1_="+N);c.url=C+(C===c.url?(gn.test(c.url)?"&":"?")+"_="+N:"")}}(c.data&&c.hasContent&&c.contentType!==!1||n.contentType)&&x.setRequestHeader("Content-Type",c.contentType),c.ifModified&&(r=r||c.url,v.lastModified[r]&&x.setRequestHeader("If-Modified-Since",v.lastModified[r]),v.etag[r]&&x.setRequestHeader("If-None-Match",v.etag[r])),x.setRequestHeader("Accept",c.dataTypes[0]&&c.accepts[c.dataTypes[0]]?c.accepts[c.dataTypes[0]]+(c.dataTypes[0]!=="*"?", "+Tn+"; q=0.01":""):c.accepts["*"]);for(l in c.headers)x.setRequestHeader(l,c.headers[l]);if(!c.beforeSend||c.beforeSend.call(h,x,c)!==!1&&E!==2){S="abort";for(l in{success:1,error:1,complete:1})x[l](c[l]);o=kn(xn,c,n,x);if(!o)T(-1,"No Transport");else{x.readyState=1,f&&p.trigger("ajaxSend",[x,c]),c.async&&c.timeout>0&&(u=setTimeout(function(){x.abort("timeout")},c.timeout));try{E=1,o.send(b,T)}catch(k){if(!(E<2))throw k;T(-1,k)}}return x}return x.abort()},active:0,lastModified:{},etag:{}});var Mn=[],_n=/\?/,Dn=/(=)\?(?=&|$)|\?\?/,Pn=v.now();v.ajaxSetup({jsonp:"callback",jsonpCallback:function(){var e=Mn.pop()||v.expando+"_"+Pn++;return this[e]=!0,e}}),v.ajaxPrefilter("json jsonp",function(n,r,i){var s,o,u,a=n.data,f=n.url,l=n.jsonp!==!1,c=l&&Dn.test(f),h=l&&!c&&typeof a=="string"&&!(n.contentType||"").indexOf("application/x-www-form-urlencoded")&&Dn.test(a);if(n.dataTypes[0]==="jsonp"||c||h)return s=n.jsonpCallback=v.isFunction(n.jsonpCallback)?n.jsonpCallback():n.jsonpCallback,o=e[s],c?n.url=f.replace(Dn,"$1"+s):h?n.data=a.replace(Dn,"$1"+s):l&&(n.url+=(_n.test(f)?"&":"?")+n.jsonp+"="+s),n.converters["script json"]=function(){return u||v.error(s+" was not called"),u[0]},n.dataTypes[0]="json",e[s]=function(){u=arguments},i.always(function(){e[s]=o,n[s]&&(n.jsonpCallback=r.jsonpCallback,Mn.push(s)),u&&v.isFunction(o)&&o(u[0]),u=o=t}),"script"}),v.ajaxSetup({accepts:{script:"text/javascript, application/javascript, application/ecmascript, application/x-ecmascript"},contents:{script:/javascript|ecmascript/},converters:{"text script":function(e){return v.globalEval(e),e}}}),v.ajaxPrefilter("script",function(e){e.cache===t&&(e.cache=!1),e.crossDomain&&(e.type="GET",e.global=!1)}),v.ajaxTransport("script",function(e){if(e.crossDomain){var n,r=i.head||i.getElementsByTagName("head")[0]||i.documentElement;return{send:function(s,o){n=i.createElement("script"),n.async="async",e.scriptCharset&&(n.charset=e.scriptCharset),n.src=e.url,n.onload=n.onreadystatechange=function(e,i){if(i||!n.readyState||/loaded|complete/.test(n.readyState))n.onload=n.onreadystatechange=null,r&&n.parentNode&&r.removeChild(n),n=t,i||o(200,"success")},r.insertBefore(n,r.firstChild)},abort:function(){n&&n.onload(0,1)}}}});var Hn,Bn=e.ActiveXObject?function(){for(var e in Hn)Hn[e](0,1)}:!1,jn=0;v.ajaxSettings.xhr=e.ActiveXObject?function(){return!this.isLocal&&Fn()||In()}:Fn,function(e){v.extend(v.support,{ajax:!!e,cors:!!e&&"withCredentials"in e})}(v.ajaxSettings.xhr()),v.support.ajax&&v.ajaxTransport(function(n){if(!n.crossDomain||v.support.cors){var r;return{send:function(i,s){var o,u,a=n.xhr();n.username?a.open(n.type,n.url,n.async,n.username,n.password):a.open(n.type,n.url,n.async);if(n.xhrFields)for(u in n.xhrFields)a[u]=n.xhrFields[u];n.mimeType&&a.overrideMimeType&&a.overrideMimeType(n.mimeType),!n.crossDomain&&!i["X-Requested-With"]&&(i["X-Requested-With"]="XMLHttpRequest");try{for(u in i)a.setRequestHeader(u,i[u])}catch(f){}a.send(n.hasContent&&n.data||null),r=function(e,i){var u,f,l,c,h;try{if(r&&(i||a.readyState===4)){r=t,o&&(a.onreadystatechange=v.noop,Bn&&delete Hn[o]);if(i)a.readyState!==4&&a.abort();else{u=a.status,l=a.getAllResponseHeaders(),c={},h=a.responseXML,h&&h.documentElement&&(c.xml=h);try{c.text=a.responseText}catch(p){}try{f=a.statusText}catch(p){f=""}!u&&n.isLocal&&!n.crossDomain?u=c.text?200:404:u===1223&&(u=204)}}}catch(d){i||s(-1,d)}c&&s(u,f,c,l)},n.async?a.readyState===4?setTimeout(r,0):(o=++jn,Bn&&(Hn||(Hn={},v(e).unload(Bn)),Hn[o]=r),a.onreadystatechange=r):r()},abort:function(){r&&r(0,1)}}}});var qn,Rn,Un=/^(?:toggle|show|hide)$/,zn=new RegExp("^(?:([-+])=|)("+m+")([a-z%]*)$","i"),Wn=/queueHooks$/,Xn=[Gn],Vn={"*":[function(e,t){var n,r,i=this.createTween(e,t),s=zn.exec(t),o=i.cur(),u=+o||0,a=1,f=20;if(s){n=+s[2],r=s[3]||(v.cssNumber[e]?"":"px");if(r!=="px"&&u){u=v.css(i.elem,e,!0)||n||1;do a=a||".5",u/=a,v.style(i.elem,e,u+r);while(a!==(a=i.cur()/o)&&a!==1&&--f)}i.unit=r,i.start=u,i.end=s[1]?u+(s[1]+1)*n:n}return i}]};v.Animation=v.extend(Kn,{tweener:function(e,t){v.isFunction(e)?(t=e,e=["*"]):e=e.split(" ");var n,r=0,i=e.length;for(;r-1,f={},l={},c,h;a?(l=i.position(),c=l.top,h=l.left):(c=parseFloat(o)||0,h=parseFloat(u)||0),v.isFunction(t)&&(t=t.call(e,n,s)),t.top!=null&&(f.top=t.top-s.top+c),t.left!=null&&(f.left=t.left-s.left+h),"using"in t?t.using.call(e,f):i.css(f)}},v.fn.extend({position:function(){if(!this[0])return;var e=this[0],t=this.offsetParent(),n=this.offset(),r=er.test(t[0].nodeName)?{top:0,left:0}:t.offset();return n.top-=parseFloat(v.css(e,"marginTop"))||0,n.left-=parseFloat(v.css(e,"marginLeft"))||0,r.top+=parseFloat(v.css(t[0],"borderTopWidth"))||0,r.left+=parseFloat(v.css(t[0],"borderLeftWidth"))||0,{top:n.top-r.top,left:n.left-r.left}},offsetParent:function(){return this.map(function(){var e=this.offsetParent||i.body;while(e&&!er.test(e.nodeName)&&v.css(e,"position")==="static")e=e.offsetParent;return e||i.body})}}),v.each({scrollLeft:"pageXOffset",scrollTop:"pageYOffset"},function(e,n){var r=/Y/.test(n);v.fn[e]=function(i){return v.access(this,function(e,i,s){var o=tr(e);if(s===t)return o?n in o?o[n]:o.document.documentElement[i]:e[i];o?o.scrollTo(r?v(o).scrollLeft():s,r?s:v(o).scrollTop()):e[i]=s},e,i,arguments.length,null)}}),v.each({Height:"height",Width:"width"},function(e,n){v.each({padding:"inner"+e,content:n,"":"outer"+e},function(r,i){v.fn[i]=function(i,s){var o=arguments.length&&(r||typeof i!="boolean"),u=r||(i===!0||s===!0?"margin":"border");return v.access(this,function(n,r,i){var s;return v.isWindow(n)?n.document.documentElement["client"+e]:n.nodeType===9?(s=n.documentElement,Math.max(n.body["scroll"+e],s["scroll"+e],n.body["offset"+e],s["offset"+e],s["client"+e])):i===t?v.css(n,r,i,u):v.style(n,r,i,u)},n,o?i:t,o,null)}})}),e.jQuery=e.$=v,typeof define=="function"&&define.amd&&define.amd.jQuery&&define("jquery",[],function(){return v})})(window); \ No newline at end of file diff --git a/vis_resources/jsutils.js b/vis_resources/jsutils.js new file mode 100644 index 0000000..5e4fb65 --- /dev/null +++ b/vis_resources/jsutils.js @@ -0,0 +1,38 @@ +var randperm = function(n) { + var i = n, + j = 0, + temp; + var array = []; + for(var q=0;qu;u++)if(t.call(e,n[u],u,n)===r)return}else for(var a=j.keys(n),u=0,i=a.length;i>u;u++)if(t.call(e,n[a[u]],a[u],n)===r)return;return n};j.map=j.collect=function(n,t,r){var e=[];return null==n?e:p&&n.map===p?n.map(t,r):(A(n,function(n,u,i){e.push(t.call(r,n,u,i))}),e)};var O="Reduce of empty array with no initial value";j.reduce=j.foldl=j.inject=function(n,t,r,e){var u=arguments.length>2;if(null==n&&(n=[]),h&&n.reduce===h)return e&&(t=j.bind(t,e)),u?n.reduce(t,r):n.reduce(t);if(A(n,function(n,i,a){u?r=t.call(e,r,n,i,a):(r=n,u=!0)}),!u)throw new TypeError(O);return r},j.reduceRight=j.foldr=function(n,t,r,e){var u=arguments.length>2;if(null==n&&(n=[]),v&&n.reduceRight===v)return e&&(t=j.bind(t,e)),u?n.reduceRight(t,r):n.reduceRight(t);var i=n.length;if(i!==+i){var a=j.keys(n);i=a.length}if(A(n,function(o,c,l){c=a?a[--i]:--i,u?r=t.call(e,r,n[c],c,l):(r=n[c],u=!0)}),!u)throw new TypeError(O);return r},j.find=j.detect=function(n,t,r){var e;return k(n,function(n,u,i){return t.call(r,n,u,i)?(e=n,!0):void 0}),e},j.filter=j.select=function(n,t,r){var e=[];return null==n?e:g&&n.filter===g?n.filter(t,r):(A(n,function(n,u,i){t.call(r,n,u,i)&&e.push(n)}),e)},j.reject=function(n,t,r){return j.filter(n,function(n,e,u){return!t.call(r,n,e,u)},r)},j.every=j.all=function(n,t,e){t||(t=j.identity);var u=!0;return null==n?u:d&&n.every===d?n.every(t,e):(A(n,function(n,i,a){return(u=u&&t.call(e,n,i,a))?void 0:r}),!!u)};var k=j.some=j.any=function(n,t,e){t||(t=j.identity);var u=!1;return null==n?u:m&&n.some===m?n.some(t,e):(A(n,function(n,i,a){return u||(u=t.call(e,n,i,a))?r:void 0}),!!u)};j.contains=j.include=function(n,t){return null==n?!1:y&&n.indexOf===y?n.indexOf(t)!=-1:k(n,function(n){return n===t})},j.invoke=function(n,t){var r=o.call(arguments,2),e=j.isFunction(t);return j.map(n,function(n){return(e?t:n[t]).apply(n,r)})},j.pluck=function(n,t){return j.map(n,j.property(t))},j.where=function(n,t){return j.filter(n,j.matches(t))},j.findWhere=function(n,t){return j.find(n,j.matches(t))},j.max=function(n,t,r){if(!t&&j.isArray(n)&&n[0]===+n[0]&&n.length<65535)return Math.max.apply(Math,n);var e=-1/0,u=-1/0;return A(n,function(n,i,a){var o=t?t.call(r,n,i,a):n;o>u&&(e=n,u=o)}),e},j.min=function(n,t,r){if(!t&&j.isArray(n)&&n[0]===+n[0]&&n.length<65535)return Math.min.apply(Math,n);var e=1/0,u=1/0;return A(n,function(n,i,a){var o=t?t.call(r,n,i,a):n;u>o&&(e=n,u=o)}),e},j.shuffle=function(n){var t,r=0,e=[];return A(n,function(n){t=j.random(r++),e[r-1]=e[t],e[t]=n}),e},j.sample=function(n,t,r){return null==t||r?(n.length!==+n.length&&(n=j.values(n)),n[j.random(n.length-1)]):j.shuffle(n).slice(0,Math.max(0,t))};var E=function(n){return null==n?j.identity:j.isFunction(n)?n:j.property(n)};j.sortBy=function(n,t,r){return t=E(t),j.pluck(j.map(n,function(n,e,u){return{value:n,index:e,criteria:t.call(r,n,e,u)}}).sort(function(n,t){var r=n.criteria,e=t.criteria;if(r!==e){if(r>e||r===void 0)return 1;if(e>r||e===void 0)return-1}return n.index-t.index}),"value")};var F=function(n){return function(t,r,e){var u={};return r=E(r),A(t,function(i,a){var o=r.call(e,i,a,t);n(u,o,i)}),u}};j.groupBy=F(function(n,t,r){j.has(n,t)?n[t].push(r):n[t]=[r]}),j.indexBy=F(function(n,t,r){n[t]=r}),j.countBy=F(function(n,t){j.has(n,t)?n[t]++:n[t]=1}),j.sortedIndex=function(n,t,r,e){r=E(r);for(var u=r.call(e,t),i=0,a=n.length;a>i;){var o=i+a>>>1;r.call(e,n[o])t?[]:o.call(n,0,t)},j.initial=function(n,t,r){return o.call(n,0,n.length-(null==t||r?1:t))},j.last=function(n,t,r){return null==n?void 0:null==t||r?n[n.length-1]:o.call(n,Math.max(n.length-t,0))},j.rest=j.tail=j.drop=function(n,t,r){return o.call(n,null==t||r?1:t)},j.compact=function(n){return j.filter(n,j.identity)};var M=function(n,t,r){return t&&j.every(n,j.isArray)?c.apply(r,n):(A(n,function(n){j.isArray(n)||j.isArguments(n)?t?a.apply(r,n):M(n,t,r):r.push(n)}),r)};j.flatten=function(n,t){return M(n,t,[])},j.without=function(n){return j.difference(n,o.call(arguments,1))},j.partition=function(n,t){var r=[],e=[];return A(n,function(n){(t(n)?r:e).push(n)}),[r,e]},j.uniq=j.unique=function(n,t,r,e){j.isFunction(t)&&(e=r,r=t,t=!1);var u=r?j.map(n,r,e):n,i=[],a=[];return A(u,function(r,e){(t?e&&a[a.length-1]===r:j.contains(a,r))||(a.push(r),i.push(n[e]))}),i},j.union=function(){return j.uniq(j.flatten(arguments,!0))},j.intersection=function(n){var t=o.call(arguments,1);return j.filter(j.uniq(n),function(n){return j.every(t,function(t){return j.contains(t,n)})})},j.difference=function(n){var t=c.apply(e,o.call(arguments,1));return j.filter(n,function(n){return!j.contains(t,n)})},j.zip=function(){for(var n=j.max(j.pluck(arguments,"length").concat(0)),t=new Array(n),r=0;n>r;r++)t[r]=j.pluck(arguments,""+r);return t},j.object=function(n,t){if(null==n)return{};for(var r={},e=0,u=n.length;u>e;e++)t?r[n[e]]=t[e]:r[n[e][0]]=n[e][1];return r},j.indexOf=function(n,t,r){if(null==n)return-1;var e=0,u=n.length;if(r){if("number"!=typeof r)return e=j.sortedIndex(n,t),n[e]===t?e:-1;e=0>r?Math.max(0,u+r):r}if(y&&n.indexOf===y)return n.indexOf(t,r);for(;u>e;e++)if(n[e]===t)return e;return-1},j.lastIndexOf=function(n,t,r){if(null==n)return-1;var e=null!=r;if(b&&n.lastIndexOf===b)return e?n.lastIndexOf(t,r):n.lastIndexOf(t);for(var u=e?r:n.length;u--;)if(n[u]===t)return u;return-1},j.range=function(n,t,r){arguments.length<=1&&(t=n||0,n=0),r=arguments[2]||1;for(var e=Math.max(Math.ceil((t-n)/r),0),u=0,i=new Array(e);e>u;)i[u++]=n,n+=r;return i};var R=function(){};j.bind=function(n,t){var r,e;if(_&&n.bind===_)return _.apply(n,o.call(arguments,1));if(!j.isFunction(n))throw new TypeError;return r=o.call(arguments,2),e=function(){if(!(this instanceof e))return n.apply(t,r.concat(o.call(arguments)));R.prototype=n.prototype;var u=new R;R.prototype=null;var i=n.apply(u,r.concat(o.call(arguments)));return Object(i)===i?i:u}},j.partial=function(n){var t=o.call(arguments,1);return function(){for(var r=0,e=t.slice(),u=0,i=e.length;i>u;u++)e[u]===j&&(e[u]=arguments[r++]);for(;r=f?(clearTimeout(a),a=null,o=l,i=n.apply(e,u),e=u=null):a||r.trailing===!1||(a=setTimeout(c,f)),i}},j.debounce=function(n,t,r){var e,u,i,a,o,c=function(){var l=j.now()-a;t>l?e=setTimeout(c,t-l):(e=null,r||(o=n.apply(i,u),i=u=null))};return function(){i=this,u=arguments,a=j.now();var l=r&&!e;return e||(e=setTimeout(c,t)),l&&(o=n.apply(i,u),i=u=null),o}},j.once=function(n){var t,r=!1;return function(){return r?t:(r=!0,t=n.apply(this,arguments),n=null,t)}},j.wrap=function(n,t){return j.partial(t,n)},j.compose=function(){var n=arguments;return function(){for(var t=arguments,r=n.length-1;r>=0;r--)t=[n[r].apply(this,t)];return t[0]}},j.after=function(n,t){return function(){return--n<1?t.apply(this,arguments):void 0}},j.keys=function(n){if(!j.isObject(n))return[];if(w)return w(n);var t=[];for(var r in n)j.has(n,r)&&t.push(r);return t},j.values=function(n){for(var t=j.keys(n),r=t.length,e=new Array(r),u=0;r>u;u++)e[u]=n[t[u]];return e},j.pairs=function(n){for(var t=j.keys(n),r=t.length,e=new Array(r),u=0;r>u;u++)e[u]=[t[u],n[t[u]]];return e},j.invert=function(n){for(var t={},r=j.keys(n),e=0,u=r.length;u>e;e++)t[n[r[e]]]=r[e];return t},j.functions=j.methods=function(n){var t=[];for(var r in n)j.isFunction(n[r])&&t.push(r);return t.sort()},j.extend=function(n){return A(o.call(arguments,1),function(t){if(t)for(var r in t)n[r]=t[r]}),n},j.pick=function(n){var t={},r=c.apply(e,o.call(arguments,1));return A(r,function(r){r in n&&(t[r]=n[r])}),t},j.omit=function(n){var t={},r=c.apply(e,o.call(arguments,1));for(var u in n)j.contains(r,u)||(t[u]=n[u]);return t},j.defaults=function(n){return A(o.call(arguments,1),function(t){if(t)for(var r in t)n[r]===void 0&&(n[r]=t[r])}),n},j.clone=function(n){return j.isObject(n)?j.isArray(n)?n.slice():j.extend({},n):n},j.tap=function(n,t){return t(n),n};var S=function(n,t,r,e){if(n===t)return 0!==n||1/n==1/t;if(null==n||null==t)return n===t;n instanceof j&&(n=n._wrapped),t instanceof j&&(t=t._wrapped);var u=l.call(n);if(u!=l.call(t))return!1;switch(u){case"[object String]":return n==String(t);case"[object Number]":return n!=+n?t!=+t:0==n?1/n==1/t:n==+t;case"[object Date]":case"[object Boolean]":return+n==+t;case"[object RegExp]":return n.source==t.source&&n.global==t.global&&n.multiline==t.multiline&&n.ignoreCase==t.ignoreCase}if("object"!=typeof n||"object"!=typeof t)return!1;for(var i=r.length;i--;)if(r[i]==n)return e[i]==t;var a=n.constructor,o=t.constructor;if(a!==o&&!(j.isFunction(a)&&a instanceof a&&j.isFunction(o)&&o instanceof o)&&"constructor"in n&&"constructor"in t)return!1;r.push(n),e.push(t);var c=0,f=!0;if("[object Array]"==u){if(c=n.length,f=c==t.length)for(;c--&&(f=S(n[c],t[c],r,e)););}else{for(var s in n)if(j.has(n,s)&&(c++,!(f=j.has(t,s)&&S(n[s],t[s],r,e))))break;if(f){for(s in t)if(j.has(t,s)&&!c--)break;f=!c}}return r.pop(),e.pop(),f};j.isEqual=function(n,t){return S(n,t,[],[])},j.isEmpty=function(n){if(null==n)return!0;if(j.isArray(n)||j.isString(n))return 0===n.length;for(var t in n)if(j.has(n,t))return!1;return!0},j.isElement=function(n){return!(!n||1!==n.nodeType)},j.isArray=x||function(n){return"[object Array]"==l.call(n)},j.isObject=function(n){return n===Object(n)},A(["Arguments","Function","String","Number","Date","RegExp"],function(n){j["is"+n]=function(t){return l.call(t)=="[object "+n+"]"}}),j.isArguments(arguments)||(j.isArguments=function(n){return!(!n||!j.has(n,"callee"))}),"function"!=typeof/./&&(j.isFunction=function(n){return"function"==typeof n}),j.isFinite=function(n){return isFinite(n)&&!isNaN(parseFloat(n))},j.isNaN=function(n){return j.isNumber(n)&&n!=+n},j.isBoolean=function(n){return n===!0||n===!1||"[object Boolean]"==l.call(n)},j.isNull=function(n){return null===n},j.isUndefined=function(n){return n===void 0},j.has=function(n,t){return f.call(n,t)},j.noConflict=function(){return n._=t,this},j.identity=function(n){return n},j.constant=function(n){return function(){return n}},j.property=function(n){return function(t){return t[n]}},j.matches=function(n){return function(t){if(t===n)return!0;for(var r in n)if(n[r]!==t[r])return!1;return!0}},j.times=function(n,t,r){for(var e=Array(Math.max(0,n)),u=0;n>u;u++)e[u]=t.call(r,u);return e},j.random=function(n,t){return null==t&&(t=n,n=0),n+Math.floor(Math.random()*(t-n+1))},j.now=Date.now||function(){return(new Date).getTime()};var T={escape:{"&":"&","<":"<",">":">",'"':""","'":"'"}};T.unescape=j.invert(T.escape);var I={escape:new RegExp("["+j.keys(T.escape).join("")+"]","g"),unescape:new RegExp("("+j.keys(T.unescape).join("|")+")","g")};j.each(["escape","unescape"],function(n){j[n]=function(t){return null==t?"":(""+t).replace(I[n],function(t){return T[n][t]})}}),j.result=function(n,t){if(null==n)return void 0;var r=n[t];return j.isFunction(r)?r.call(n):r},j.mixin=function(n){A(j.functions(n),function(t){var r=j[t]=n[t];j.prototype[t]=function(){var n=[this._wrapped];return a.apply(n,arguments),z.call(this,r.apply(j,n))}})};var N=0;j.uniqueId=function(n){var t=++N+"";return n?n+t:t},j.templateSettings={evaluate:/<%([\s\S]+?)%>/g,interpolate:/<%=([\s\S]+?)%>/g,escape:/<%-([\s\S]+?)%>/g};var q=/(.)^/,B={"'":"'","\\":"\\","\r":"r","\n":"n"," ":"t","\u2028":"u2028","\u2029":"u2029"},D=/\\|'|\r|\n|\t|\u2028|\u2029/g;j.template=function(n,t,r){var e;r=j.defaults({},r,j.templateSettings);var u=new RegExp([(r.escape||q).source,(r.interpolate||q).source,(r.evaluate||q).source].join("|")+"|$","g"),i=0,a="__p+='";n.replace(u,function(t,r,e,u,o){return a+=n.slice(i,o).replace(D,function(n){return"\\"+B[n]}),r&&(a+="'+\n((__t=("+r+"))==null?'':_.escape(__t))+\n'"),e&&(a+="'+\n((__t=("+e+"))==null?'':__t)+\n'"),u&&(a+="';\n"+u+"\n__p+='"),i=o+t.length,t}),a+="';\n",r.variable||(a="with(obj||{}){\n"+a+"}\n"),a="var __t,__p='',__j=Array.prototype.join,"+"print=function(){__p+=__j.call(arguments,'');};\n"+a+"return __p;\n";try{e=new Function(r.variable||"obj","_",a)}catch(o){throw o.source=a,o}if(t)return e(t,j);var c=function(n){return e.call(this,n,j)};return c.source="function("+(r.variable||"obj")+"){\n"+a+"}",c},j.chain=function(n){return j(n).chain()};var z=function(n){return this._chain?j(n).chain():n};j.mixin(j),A(["pop","push","reverse","shift","sort","splice","unshift"],function(n){var t=e[n];j.prototype[n]=function(){var r=this._wrapped;return t.apply(r,arguments),"shift"!=n&&"splice"!=n||0!==r.length||delete r[0],z.call(this,r)}}),A(["concat","join","slice"],function(n){var t=e[n];j.prototype[n]=function(){return z.call(this,t.apply(this._wrapped,arguments))}}),j.extend(j.prototype,{chain:function(){return this._chain=!0,this},value:function(){return this._wrapped}}),"function"==typeof define&&define.amd&&define("underscore",[],function(){return j})}).call(this); +//# sourceMappingURL=underscore-min.map \ No newline at end of file diff --git a/vis_resources/underscore-min.map b/vis_resources/underscore-min.map new file mode 100644 index 0000000..73c951e --- /dev/null +++ b/vis_resources/underscore-min.map @@ -0,0 +1 @@ +{"version":3,"file":"underscore-min.js","sources":["underscore.js"],"names":["root","this","previousUnderscore","_","ArrayProto","Array","prototype","ObjProto","Object","FuncProto","Function","push","slice","concat","toString","hasOwnProperty","nativeIsArray","isArray","nativeKeys","keys","nativeBind","bind","obj","_wrapped","exports","module","VERSION","createCallback","func","context","argCount","value","call","other","index","collection","accumulator","apply","arguments","iteratee","identity","isFunction","isObject","matches","property","each","forEach","i","length","map","collect","currentKey","results","reduceError","reduce","foldl","inject","memo","TypeError","reduceRight","foldr","find","detect","predicate","result","some","list","filter","select","reject","negate","every","all","any","contains","include","target","values","indexOf","invoke","method","args","isFunc","pluck","key","where","attrs","findWhere","max","computed","Infinity","lastComputed","min","shuffle","rand","set","shuffled","random","sample","n","guard","Math","sortBy","criteria","sort","left","right","a","b","group","behavior","groupBy","has","indexBy","countBy","sortedIndex","array","low","high","mid","toArray","size","partition","pass","fail","first","head","take","initial","last","rest","tail","drop","compact","flatten","input","shallow","strict","output","isArguments","without","difference","uniq","unique","isSorted","isBoolean","seen","union","intersection","argsLength","item","j","zip","object","lastIndexOf","from","idx","range","start","stop","step","ceil","Ctor","bound","self","partial","boundArgs","position","bindAll","Error","memoize","hasher","cache","address","delay","wait","setTimeout","defer","throttle","options","timeout","previous","later","leading","now","remaining","clearTimeout","trailing","debounce","immediate","timestamp","callNow","wrap","wrapper","compose","after","times","before","once","pairs","invert","functions","methods","names","extend","source","prop","pick","omit","String","defaults","clone","tap","interceptor","eq","aStack","bStack","className","aCtor","constructor","bCtor","pop","isEqual","isEmpty","isString","isElement","nodeType","type","name","isFinite","isNaN","parseFloat","isNumber","isNull","isUndefined","noConflict","constant","noop","pair","accum","floor","Date","getTime","escapeMap","&","<",">","\"","'","`","unescapeMap","createEscaper","escaper","match","join","testRegexp","RegExp","replaceRegexp","string","test","replace","escape","unescape","idCounter","uniqueId","prefix","id","templateSettings","evaluate","interpolate","noMatch","escapes","\\","\r","\n","
","
","escapeChar","template","text","settings","oldSettings","matcher","offset","variable","render","e","data","argument","chain","instance","_chain","mixin","define","amd"],"mappings":";;;;CAKC,WAMC,GAAIA,GAAOC,KAGPC,EAAqBF,EAAKG,EAG1BC,EAAaC,MAAMC,UAAWC,EAAWC,OAAOF,UAAWG,EAAYC,SAASJ,UAIlFK,EAAmBP,EAAWO,KAC9BC,EAAmBR,EAAWQ,MAC9BC,EAAmBT,EAAWS,OAC9BC,EAAmBP,EAASO,SAC5BC,EAAmBR,EAASQ,eAK5BC,EAAqBX,MAAMY,QAC3BC,EAAqBV,OAAOW,KAC5BC,EAAqBX,EAAUY,KAG7BlB,EAAI,SAASmB,GACf,MAAIA,aAAenB,GAAUmB,EACvBrB,eAAgBE,QACtBF,KAAKsB,SAAWD,GADiB,GAAInB,GAAEmB,GAOlB,oBAAZE,UACa,mBAAXC,SAA0BA,OAAOD,UAC1CA,QAAUC,OAAOD,QAAUrB,GAE7BqB,QAAQrB,EAAIA,GAEZH,EAAKG,EAAIA,EAIXA,EAAEuB,QAAU,OAKZ,IAAIC,GAAiB,SAASC,EAAMC,EAASC,GAC3C,GAAID,QAAiB,GAAG,MAAOD,EAC/B,QAAoB,MAAZE,EAAmB,EAAIA,GAC7B,IAAK,GAAG,MAAO,UAASC,GACtB,MAAOH,GAAKI,KAAKH,EAASE,GAE5B,KAAK,GAAG,MAAO,UAASA,EAAOE,GAC7B,MAAOL,GAAKI,KAAKH,EAASE,EAAOE,GAEnC,KAAK,GAAG,MAAO,UAASF,EAAOG,EAAOC,GACpC,MAAOP,GAAKI,KAAKH,EAASE,EAAOG,EAAOC,GAE1C,KAAK,GAAG,MAAO,UAASC,EAAaL,EAAOG,EAAOC,GACjD,MAAOP,GAAKI,KAAKH,EAASO,EAAaL,EAAOG,EAAOC,IAGzD,MAAO,YACL,MAAOP,GAAKS,MAAMR,EAASS,YAO/BnC,GAAEoC,SAAW,SAASR,EAAOF,EAASC,GACpC,MAAa,OAATC,EAAsB5B,EAAEqC,SACxBrC,EAAEsC,WAAWV,GAAeJ,EAAeI,EAAOF,EAASC,GAC3D3B,EAAEuC,SAASX,GAAe5B,EAAEwC,QAAQZ,GACjC5B,EAAEyC,SAASb,IASpB5B,EAAE0C,KAAO1C,EAAE2C,QAAU,SAASxB,EAAKiB,EAAUV,GAC3C,GAAW,MAAPP,EAAa,MAAOA,EACxBiB,GAAWZ,EAAeY,EAAUV,EACpC,IAAIkB,GAAGC,EAAS1B,EAAI0B,MACpB,IAAIA,KAAYA,EACd,IAAKD,EAAI,EAAOC,EAAJD,EAAYA,IACtBR,EAASjB,EAAIyB,GAAIA,EAAGzB,OAEjB,CACL,GAAIH,GAAOhB,EAAEgB,KAAKG,EAClB,KAAKyB,EAAI,EAAGC,EAAS7B,EAAK6B,OAAYA,EAAJD,EAAYA,IAC5CR,EAASjB,EAAIH,EAAK4B,IAAK5B,EAAK4B,GAAIzB,GAGpC,MAAOA,IAITnB,EAAE8C,IAAM9C,EAAE+C,QAAU,SAAS5B,EAAKiB,EAAUV,GAC1C,GAAW,MAAPP,EAAa,QACjBiB,GAAWpC,EAAEoC,SAASA,EAAUV,EAKhC,KAAK,GADDsB,GAHAhC,EAAOG,EAAI0B,UAAY1B,EAAI0B,QAAU7C,EAAEgB,KAAKG,GAC5C0B,GAAU7B,GAAQG,GAAK0B,OACvBI,EAAU/C,MAAM2C,GAEXd,EAAQ,EAAWc,EAARd,EAAgBA,IAClCiB,EAAahC,EAAOA,EAAKe,GAASA,EAClCkB,EAAQlB,GAASK,EAASjB,EAAI6B,GAAaA,EAAY7B,EAEzD,OAAO8B,GAGT,IAAIC,GAAc,6CAIlBlD,GAAEmD,OAASnD,EAAEoD,MAAQpD,EAAEqD,OAAS,SAASlC,EAAKiB,EAAUkB,EAAM5B,GACjD,MAAPP,IAAaA,MACjBiB,EAAWZ,EAAeY,EAAUV,EAAS,EAC7C,IAEesB,GAFXhC,EAAOG,EAAI0B,UAAY1B,EAAI0B,QAAU7C,EAAEgB,KAAKG,GAC5C0B,GAAU7B,GAAQG,GAAK0B,OACvBd,EAAQ,CACZ,IAAII,UAAUU,OAAS,EAAG,CACxB,IAAKA,EAAQ,KAAM,IAAIU,WAAUL,EACjCI,GAAOnC,EAAIH,EAAOA,EAAKe,KAAWA,KAEpC,KAAec,EAARd,EAAgBA,IACrBiB,EAAahC,EAAOA,EAAKe,GAASA,EAClCuB,EAAOlB,EAASkB,EAAMnC,EAAI6B,GAAaA,EAAY7B,EAErD,OAAOmC,IAITtD,EAAEwD,YAAcxD,EAAEyD,MAAQ,SAAStC,EAAKiB,EAAUkB,EAAM5B,GAC3C,MAAPP,IAAaA,MACjBiB,EAAWZ,EAAeY,EAAUV,EAAS,EAC7C,IAEIsB,GAFAhC,EAAOG,EAAI0B,UAAa1B,EAAI0B,QAAU7C,EAAEgB,KAAKG,GAC7CY,GAASf,GAAQG,GAAK0B,MAE1B,IAAIV,UAAUU,OAAS,EAAG,CACxB,IAAKd,EAAO,KAAM,IAAIwB,WAAUL,EAChCI,GAAOnC,EAAIH,EAAOA,IAAOe,KAAWA,GAEtC,KAAOA,KACLiB,EAAahC,EAAOA,EAAKe,GAASA,EAClCuB,EAAOlB,EAASkB,EAAMnC,EAAI6B,GAAaA,EAAY7B,EAErD,OAAOmC,IAITtD,EAAE0D,KAAO1D,EAAE2D,OAAS,SAASxC,EAAKyC,EAAWlC,GAC3C,GAAImC,EAQJ,OAPAD,GAAY5D,EAAEoC,SAASwB,EAAWlC,GAClC1B,EAAE8D,KAAK3C,EAAK,SAASS,EAAOG,EAAOgC,GACjC,MAAIH,GAAUhC,EAAOG,EAAOgC,IAC1BF,EAASjC,GACF,GAFT,SAKKiC,GAKT7D,EAAEgE,OAAShE,EAAEiE,OAAS,SAAS9C,EAAKyC,EAAWlC,GAC7C,GAAIuB,KACJ,OAAW,OAAP9B,EAAoB8B,GACxBW,EAAY5D,EAAEoC,SAASwB,EAAWlC,GAClC1B,EAAE0C,KAAKvB,EAAK,SAASS,EAAOG,EAAOgC,GAC7BH,EAAUhC,EAAOG,EAAOgC,IAAOd,EAAQzC,KAAKoB,KAE3CqB,IAITjD,EAAEkE,OAAS,SAAS/C,EAAKyC,EAAWlC,GAClC,MAAO1B,GAAEgE,OAAO7C,EAAKnB,EAAEmE,OAAOnE,EAAEoC,SAASwB,IAAalC,IAKxD1B,EAAEoE,MAAQpE,EAAEqE,IAAM,SAASlD,EAAKyC,EAAWlC,GACzC,GAAW,MAAPP,EAAa,OAAO,CACxByC,GAAY5D,EAAEoC,SAASwB,EAAWlC,EAClC,IAEIK,GAAOiB,EAFPhC,EAAOG,EAAI0B,UAAY1B,EAAI0B,QAAU7C,EAAEgB,KAAKG,GAC5C0B,GAAU7B,GAAQG,GAAK0B,MAE3B,KAAKd,EAAQ,EAAWc,EAARd,EAAgBA,IAE9B,GADAiB,EAAahC,EAAOA,EAAKe,GAASA,GAC7B6B,EAAUzC,EAAI6B,GAAaA,EAAY7B,GAAM,OAAO,CAE3D,QAAO,GAKTnB,EAAE8D,KAAO9D,EAAEsE,IAAM,SAASnD,EAAKyC,EAAWlC,GACxC,GAAW,MAAPP,EAAa,OAAO,CACxByC,GAAY5D,EAAEoC,SAASwB,EAAWlC,EAClC,IAEIK,GAAOiB,EAFPhC,EAAOG,EAAI0B,UAAY1B,EAAI0B,QAAU7C,EAAEgB,KAAKG,GAC5C0B,GAAU7B,GAAQG,GAAK0B,MAE3B,KAAKd,EAAQ,EAAWc,EAARd,EAAgBA,IAE9B,GADAiB,EAAahC,EAAOA,EAAKe,GAASA,EAC9B6B,EAAUzC,EAAI6B,GAAaA,EAAY7B,GAAM,OAAO,CAE1D,QAAO,GAKTnB,EAAEuE,SAAWvE,EAAEwE,QAAU,SAASrD,EAAKsD,GACrC,MAAW,OAAPtD,GAAoB,GACpBA,EAAI0B,UAAY1B,EAAI0B,SAAQ1B,EAAMnB,EAAE0E,OAAOvD,IACxCnB,EAAE2E,QAAQxD,EAAKsD,IAAW,IAInCzE,EAAE4E,OAAS,SAASzD,EAAK0D,GACvB,GAAIC,GAAOrE,EAAMoB,KAAKM,UAAW,GAC7B4C,EAAS/E,EAAEsC,WAAWuC,EAC1B,OAAO7E,GAAE8C,IAAI3B,EAAK,SAASS,GACzB,OAAQmD,EAASF,EAASjD,EAAMiD,IAAS3C,MAAMN,EAAOkD,MAK1D9E,EAAEgF,MAAQ,SAAS7D,EAAK8D,GACtB,MAAOjF,GAAE8C,IAAI3B,EAAKnB,EAAEyC,SAASwC,KAK/BjF,EAAEkF,MAAQ,SAAS/D,EAAKgE,GACtB,MAAOnF,GAAEgE,OAAO7C,EAAKnB,EAAEwC,QAAQ2C,KAKjCnF,EAAEoF,UAAY,SAASjE,EAAKgE,GAC1B,MAAOnF,GAAE0D,KAAKvC,EAAKnB,EAAEwC,QAAQ2C,KAI/BnF,EAAEqF,IAAM,SAASlE,EAAKiB,EAAUV,GAC9B,GACIE,GAAO0D,EADPzB,GAAU0B,IAAUC,GAAgBD,GAExC,IAAgB,MAAZnD,GAA2B,MAAPjB,EAAa,CACnCA,EAAMA,EAAI0B,UAAY1B,EAAI0B,OAAS1B,EAAMnB,EAAE0E,OAAOvD,EAClD,KAAK,GAAIyB,GAAI,EAAGC,EAAS1B,EAAI0B,OAAYA,EAAJD,EAAYA,IAC/ChB,EAAQT,EAAIyB,GACRhB,EAAQiC,IACVA,EAASjC,OAIbQ,GAAWpC,EAAEoC,SAASA,EAAUV,GAChC1B,EAAE0C,KAAKvB,EAAK,SAASS,EAAOG,EAAOgC,GACjCuB,EAAWlD,EAASR,EAAOG,EAAOgC,IAC9BuB,EAAWE,GAAgBF,KAAcC,KAAY1B,KAAY0B,OACnE1B,EAASjC,EACT4D,EAAeF,IAIrB,OAAOzB,IAIT7D,EAAEyF,IAAM,SAAStE,EAAKiB,EAAUV,GAC9B,GACIE,GAAO0D,EADPzB,EAAS0B,IAAUC,EAAeD,GAEtC,IAAgB,MAAZnD,GAA2B,MAAPjB,EAAa,CACnCA,EAAMA,EAAI0B,UAAY1B,EAAI0B,OAAS1B,EAAMnB,EAAE0E,OAAOvD,EAClD,KAAK,GAAIyB,GAAI,EAAGC,EAAS1B,EAAI0B,OAAYA,EAAJD,EAAYA,IAC/ChB,EAAQT,EAAIyB,GACAiB,EAARjC,IACFiC,EAASjC,OAIbQ,GAAWpC,EAAEoC,SAASA,EAAUV,GAChC1B,EAAE0C,KAAKvB,EAAK,SAASS,EAAOG,EAAOgC,GACjCuB,EAAWlD,EAASR,EAAOG,EAAOgC,IACnByB,EAAXF,GAAwCC,MAAbD,GAAoCC,MAAX1B,KACtDA,EAASjC,EACT4D,EAAeF,IAIrB,OAAOzB,IAKT7D,EAAE0F,QAAU,SAASvE,GAInB,IAAK,GAAewE,GAHhBC,EAAMzE,GAAOA,EAAI0B,UAAY1B,EAAI0B,OAAS1B,EAAMnB,EAAE0E,OAAOvD,GACzD0B,EAAS+C,EAAI/C,OACbgD,EAAW3F,MAAM2C,GACZd,EAAQ,EAAiBc,EAARd,EAAgBA,IACxC4D,EAAO3F,EAAE8F,OAAO,EAAG/D,GACf4D,IAAS5D,IAAO8D,EAAS9D,GAAS8D,EAASF,IAC/CE,EAASF,GAAQC,EAAI7D,EAEvB,OAAO8D,IAMT7F,EAAE+F,OAAS,SAAS5E,EAAK6E,EAAGC,GAC1B,MAAS,OAALD,GAAaC,GACX9E,EAAI0B,UAAY1B,EAAI0B,SAAQ1B,EAAMnB,EAAE0E,OAAOvD,IACxCA,EAAInB,EAAE8F,OAAO3E,EAAI0B,OAAS,KAE5B7C,EAAE0F,QAAQvE,GAAKV,MAAM,EAAGyF,KAAKb,IAAI,EAAGW,KAI7ChG,EAAEmG,OAAS,SAAShF,EAAKiB,EAAUV,GAEjC,MADAU,GAAWpC,EAAEoC,SAASA,EAAUV,GACzB1B,EAAEgF,MAAMhF,EAAE8C,IAAI3B,EAAK,SAASS,EAAOG,EAAOgC,GAC/C,OACEnC,MAAOA,EACPG,MAAOA,EACPqE,SAAUhE,EAASR,EAAOG,EAAOgC,MAElCsC,KAAK,SAASC,EAAMC,GACrB,GAAIC,GAAIF,EAAKF,SACTK,EAAIF,EAAMH,QACd,IAAII,IAAMC,EAAG,CACX,GAAID,EAAIC,GAAKD,QAAW,GAAG,MAAO,EAClC,IAAQC,EAAJD,GAASC,QAAW,GAAG,OAAQ,EAErC,MAAOH,GAAKvE,MAAQwE,EAAMxE,QACxB,SAIN,IAAI2E,GAAQ,SAASC,GACnB,MAAO,UAASxF,EAAKiB,EAAUV,GAC7B,GAAImC,KAMJ,OALAzB,GAAWpC,EAAEoC,SAASA,EAAUV,GAChC1B,EAAE0C,KAAKvB,EAAK,SAASS,EAAOG,GAC1B,GAAIkD,GAAM7C,EAASR,EAAOG,EAAOZ,EACjCwF,GAAS9C,EAAQjC,EAAOqD,KAEnBpB,GAMX7D,GAAE4G,QAAUF,EAAM,SAAS7C,EAAQjC,EAAOqD,GACpCjF,EAAE6G,IAAIhD,EAAQoB,GAAMpB,EAAOoB,GAAKzE,KAAKoB,GAAaiC,EAAOoB,IAAQrD,KAKvE5B,EAAE8G,QAAUJ,EAAM,SAAS7C,EAAQjC,EAAOqD,GACxCpB,EAAOoB,GAAOrD,IAMhB5B,EAAE+G,QAAUL,EAAM,SAAS7C,EAAQjC,EAAOqD,GACpCjF,EAAE6G,IAAIhD,EAAQoB,GAAMpB,EAAOoB,KAAapB,EAAOoB,GAAO,IAK5DjF,EAAEgH,YAAc,SAASC,EAAO9F,EAAKiB,EAAUV,GAC7CU,EAAWpC,EAAEoC,SAASA,EAAUV,EAAS,EAGzC,KAFA,GAAIE,GAAQQ,EAASjB,GACjB+F,EAAM,EAAGC,EAAOF,EAAMpE,OACbsE,EAAND,GAAY,CACjB,GAAIE,GAAMF,EAAMC,IAAS,CACrB/E,GAAS6E,EAAMG,IAAQxF,EAAOsF,EAAME,EAAM,EAAQD,EAAOC,EAE/D,MAAOF,IAITlH,EAAEqH,QAAU,SAASlG,GACnB,MAAKA,GACDnB,EAAEc,QAAQK,GAAaV,EAAMoB,KAAKV,GAClCA,EAAI0B,UAAY1B,EAAI0B,OAAe7C,EAAE8C,IAAI3B,EAAKnB,EAAEqC,UAC7CrC,EAAE0E,OAAOvD,OAIlBnB,EAAEsH,KAAO,SAASnG,GAChB,MAAW,OAAPA,EAAoB,EACjBA,EAAI0B,UAAY1B,EAAI0B,OAAS1B,EAAI0B,OAAS7C,EAAEgB,KAAKG,GAAK0B,QAK/D7C,EAAEuH,UAAY,SAASpG,EAAKyC,EAAWlC,GACrCkC,EAAY5D,EAAEoC,SAASwB,EAAWlC,EAClC,IAAI8F,MAAWC,IAIf,OAHAzH,GAAE0C,KAAKvB,EAAK,SAASS,EAAOqD,EAAK9D,IAC9ByC,EAAUhC,EAAOqD,EAAK9D,GAAOqG,EAAOC,GAAMjH,KAAKoB,MAE1C4F,EAAMC,IAShBzH,EAAE0H,MAAQ1H,EAAE2H,KAAO3H,EAAE4H,KAAO,SAASX,EAAOjB,EAAGC,GAC7C,MAAa,OAATgB,MAA2B,GACtB,MAALjB,GAAaC,EAAcgB,EAAM,GAC7B,EAAJjB,KACGvF,EAAMoB,KAAKoF,EAAO,EAAGjB,IAO9BhG,EAAE6H,QAAU,SAASZ,EAAOjB,EAAGC,GAC7B,MAAOxF,GAAMoB,KAAKoF,EAAO,EAAGf,KAAKb,IAAI,EAAG4B,EAAMpE,QAAe,MAALmD,GAAaC,EAAQ,EAAID,MAKnFhG,EAAE8H,KAAO,SAASb,EAAOjB,EAAGC,GAC1B,MAAa,OAATgB,MAA2B,GACtB,MAALjB,GAAaC,EAAcgB,EAAMA,EAAMpE,OAAS,GAC7CpC,EAAMoB,KAAKoF,EAAOf,KAAKb,IAAI4B,EAAMpE,OAASmD,EAAG,KAOtDhG,EAAE+H,KAAO/H,EAAEgI,KAAOhI,EAAEiI,KAAO,SAAShB,EAAOjB,EAAGC,GAC5C,MAAOxF,GAAMoB,KAAKoF,EAAY,MAALjB,GAAaC,EAAQ,EAAID,IAIpDhG,EAAEkI,QAAU,SAASjB,GACnB,MAAOjH,GAAEgE,OAAOiD,EAAOjH,EAAEqC,UAI3B,IAAI8F,GAAU,SAASC,EAAOC,EAASC,EAAQC,GAC7C,GAAIF,GAAWrI,EAAEoE,MAAMgE,EAAOpI,EAAEc,SAC9B,MAAOJ,GAAOwB,MAAMqG,EAAQH,EAE9B,KAAK,GAAIxF,GAAI,EAAGC,EAASuF,EAAMvF,OAAYA,EAAJD,EAAYA,IAAK,CACtD,GAAIhB,GAAQwG,EAAMxF,EACb5C,GAAEc,QAAQc,IAAW5B,EAAEwI,YAAY5G,GAE7ByG,EACT7H,EAAK0B,MAAMqG,EAAQ3G,GAEnBuG,EAAQvG,EAAOyG,EAASC,EAAQC,GAJ3BD,GAAQC,EAAO/H,KAAKoB,GAO7B,MAAO2G,GAITvI,GAAEmI,QAAU,SAASlB,EAAOoB,GAC1B,MAAOF,GAAQlB,EAAOoB,GAAS,OAIjCrI,EAAEyI,QAAU,SAASxB,GACnB,MAAOjH,GAAE0I,WAAWzB,EAAOxG,EAAMoB,KAAKM,UAAW,KAMnDnC,EAAE2I,KAAO3I,EAAE4I,OAAS,SAAS3B,EAAO4B,EAAUzG,EAAUV,GACtD,GAAa,MAATuF,EAAe,QACdjH,GAAE8I,UAAUD,KACfnH,EAAUU,EACVA,EAAWyG,EACXA,GAAW,GAEG,MAAZzG,IAAkBA,EAAWpC,EAAEoC,SAASA,EAAUV,GAGtD,KAAK,GAFDmC,MACAkF,KACKnG,EAAI,EAAGC,EAASoE,EAAMpE,OAAYA,EAAJD,EAAYA,IAAK,CACtD,GAAIhB,GAAQqF,EAAMrE,EAClB,IAAIiG,EACGjG,GAAKmG,IAASnH,GAAOiC,EAAOrD,KAAKoB,GACtCmH,EAAOnH,MACF,IAAIQ,EAAU,CACnB,GAAIkD,GAAWlD,EAASR,EAAOgB,EAAGqE,EAC9BjH,GAAE2E,QAAQoE,EAAMzD,GAAY,IAC9ByD,EAAKvI,KAAK8E,GACVzB,EAAOrD,KAAKoB,QAEL5B,GAAE2E,QAAQd,EAAQjC,GAAS,GACpCiC,EAAOrD,KAAKoB,GAGhB,MAAOiC,IAKT7D,EAAEgJ,MAAQ,WACR,MAAOhJ,GAAE2I,KAAKR,EAAQhG,WAAW,GAAM,QAKzCnC,EAAEiJ,aAAe,SAAShC,GACxB,GAAa,MAATA,EAAe,QAGnB,KAAK,GAFDpD,MACAqF,EAAa/G,UAAUU,OAClBD,EAAI,EAAGC,EAASoE,EAAMpE,OAAYA,EAAJD,EAAYA,IAAK,CACtD,GAAIuG,GAAOlC,EAAMrE,EACjB,KAAI5C,EAAEuE,SAASV,EAAQsF,GAAvB,CACA,IAAK,GAAIC,GAAI,EAAOF,EAAJE,GACTpJ,EAAEuE,SAASpC,UAAUiH,GAAID,GADAC,KAG5BA,IAAMF,GAAYrF,EAAOrD,KAAK2I,IAEpC,MAAOtF,IAKT7D,EAAE0I,WAAa,SAASzB,GACtB,GAAIc,GAAOI,EAAQ1H,EAAMoB,KAAKM,UAAW,IAAI,GAAM,KACnD,OAAOnC,GAAEgE,OAAOiD,EAAO,SAASrF,GAC9B,OAAQ5B,EAAEuE,SAASwD,EAAMnG,MAM7B5B,EAAEqJ,IAAM,SAASpC,GACf,GAAa,MAATA,EAAe,QAGnB,KAAK,GAFDpE,GAAS7C,EAAEqF,IAAIlD,UAAW,UAAUU,OACpCI,EAAU/C,MAAM2C,GACXD,EAAI,EAAOC,EAAJD,EAAYA,IAC1BK,EAAQL,GAAK5C,EAAEgF,MAAM7C,UAAWS,EAElC,OAAOK,IAMTjD,EAAEsJ,OAAS,SAASvF,EAAMW,GACxB,GAAY,MAARX,EAAc,QAElB,KAAK,GADDF,MACKjB,EAAI,EAAGC,EAASkB,EAAKlB,OAAYA,EAAJD,EAAYA,IAC5C8B,EACFb,EAAOE,EAAKnB,IAAM8B,EAAO9B,GAEzBiB,EAAOE,EAAKnB,GAAG,IAAMmB,EAAKnB,GAAG,EAGjC,OAAOiB,IAOT7D,EAAE2E,QAAU,SAASsC,EAAOkC,EAAMN,GAChC,GAAa,MAAT5B,EAAe,OAAQ,CAC3B,IAAIrE,GAAI,EAAGC,EAASoE,EAAMpE,MAC1B,IAAIgG,EAAU,CACZ,GAAuB,gBAAZA,GAIT,MADAjG,GAAI5C,EAAEgH,YAAYC,EAAOkC,GAClBlC,EAAMrE,KAAOuG,EAAOvG,GAAK,CAHhCA,GAAe,EAAXiG,EAAe3C,KAAKb,IAAI,EAAGxC,EAASgG,GAAYA,EAMxD,KAAWhG,EAAJD,EAAYA,IAAK,GAAIqE,EAAMrE,KAAOuG,EAAM,MAAOvG,EACtD,QAAQ,GAGV5C,EAAEuJ,YAAc,SAAStC,EAAOkC,EAAMK,GACpC,GAAa,MAATvC,EAAe,OAAQ,CAC3B,IAAIwC,GAAMxC,EAAMpE,MAIhB,KAHmB,gBAAR2G,KACTC,EAAa,EAAPD,EAAWC,EAAMD,EAAO,EAAItD,KAAKT,IAAIgE,EAAKD,EAAO,MAEhDC,GAAO,GAAG,GAAIxC,EAAMwC,KAASN,EAAM,MAAOM,EACnD,QAAQ,GAMVzJ,EAAE0J,MAAQ,SAASC,EAAOC,EAAMC,GAC1B1H,UAAUU,QAAU,IACtB+G,EAAOD,GAAS,EAChBA,EAAQ,GAEVE,EAAOA,GAAQ,CAKf,KAAK,GAHDhH,GAASqD,KAAKb,IAAIa,KAAK4D,MAAMF,EAAOD,GAASE,GAAO,GACpDH,EAAQxJ,MAAM2C,GAET4G,EAAM,EAAS5G,EAAN4G,EAAcA,IAAOE,GAASE,EAC9CH,EAAMD,GAAOE,CAGf,OAAOD,GAOT,IAAIK,GAAO,YAKX/J,GAAEkB,KAAO,SAASO,EAAMC,GACtB,GAAIoD,GAAMkF,CACV,IAAI/I,GAAcQ,EAAKP,OAASD,EAAY,MAAOA,GAAWiB,MAAMT,EAAMhB,EAAMoB,KAAKM,UAAW,GAChG,KAAKnC,EAAEsC,WAAWb,GAAO,KAAM,IAAI8B,WAAU,oCAW7C,OAVAuB,GAAOrE,EAAMoB,KAAKM,UAAW,GAC7B6H,EAAQ,WACN,KAAMlK,eAAgBkK,IAAQ,MAAOvI,GAAKS,MAAMR,EAASoD,EAAKpE,OAAOD,EAAMoB,KAAKM,YAChF4H,GAAK5J,UAAYsB,EAAKtB,SACtB,IAAI8J,GAAO,GAAIF,EACfA,GAAK5J,UAAY,IACjB,IAAI0D,GAASpC,EAAKS,MAAM+H,EAAMnF,EAAKpE,OAAOD,EAAMoB,KAAKM,YACrD,OAAInC,GAAEuC,SAASsB,GAAgBA,EACxBoG,IAQXjK,EAAEkK,QAAU,SAASzI,GACnB,GAAI0I,GAAY1J,EAAMoB,KAAKM,UAAW,EACtC,OAAO,YAGL,IAAK,GAFDiI,GAAW,EACXtF,EAAOqF,EAAU1J,QACZmC,EAAI,EAAGC,EAASiC,EAAKjC,OAAYA,EAAJD,EAAYA,IAC5CkC,EAAKlC,KAAO5C,IAAG8E,EAAKlC,GAAKT,UAAUiI,KAEzC,MAAOA,EAAWjI,UAAUU,QAAQiC,EAAKtE,KAAK2B,UAAUiI,KACxD,OAAO3I,GAAKS,MAAMpC,KAAMgF,KAO5B9E,EAAEqK,QAAU,SAASlJ,GACnB,GAAIyB,GAA8BqC,EAA3BpC,EAASV,UAAUU,MAC1B,IAAc,GAAVA,EAAa,KAAM,IAAIyH,OAAM,wCACjC,KAAK1H,EAAI,EAAOC,EAAJD,EAAYA,IACtBqC,EAAM9C,UAAUS,GAChBzB,EAAI8D,GAAOjF,EAAEkB,KAAKC,EAAI8D,GAAM9D,EAE9B,OAAOA,IAITnB,EAAEuK,QAAU,SAAS9I,EAAM+I,GACzB,GAAID,GAAU,SAAStF,GACrB,GAAIwF,GAAQF,EAAQE,MAChBC,EAAUF,EAASA,EAAOtI,MAAMpC,KAAMqC,WAAa8C,CAEvD,OADKjF,GAAE6G,IAAI4D,EAAOC,KAAUD,EAAMC,GAAWjJ,EAAKS,MAAMpC,KAAMqC,YACvDsI,EAAMC,GAGf,OADAH,GAAQE,SACDF,GAKTvK,EAAE2K,MAAQ,SAASlJ,EAAMmJ,GACvB,GAAI9F,GAAOrE,EAAMoB,KAAKM,UAAW,EACjC,OAAO0I,YAAW,WAChB,MAAOpJ,GAAKS,MAAM,KAAM4C,IACvB8F,IAKL5K,EAAE8K,MAAQ,SAASrJ,GACjB,MAAOzB,GAAE2K,MAAMzI,MAAMlC,GAAIyB,EAAM,GAAGf,OAAOD,EAAMoB,KAAKM,UAAW,MAQjEnC,EAAE+K,SAAW,SAAStJ,EAAMmJ,EAAMI,GAChC,GAAItJ,GAASoD,EAAMjB,EACfoH,EAAU,KACVC,EAAW,CACVF,KAASA,KACd,IAAIG,GAAQ,WACVD,EAAWF,EAAQI,WAAY,EAAQ,EAAIpL,EAAEqL,MAC7CJ,EAAU,KACVpH,EAASpC,EAAKS,MAAMR,EAASoD,GACxBmG,IAASvJ,EAAUoD,EAAO,MAEjC,OAAO,YACL,GAAIuG,GAAMrL,EAAEqL,KACPH,IAAYF,EAAQI,WAAY,IAAOF,EAAWG,EACvD,IAAIC,GAAYV,GAAQS,EAAMH,EAY9B,OAXAxJ,GAAU5B,KACVgF,EAAO3C,UACU,GAAbmJ,GAAkBA,EAAYV,GAChCW,aAAaN,GACbA,EAAU,KACVC,EAAWG,EACXxH,EAASpC,EAAKS,MAAMR,EAASoD,GACxBmG,IAASvJ,EAAUoD,EAAO,OACrBmG,GAAWD,EAAQQ,YAAa,IAC1CP,EAAUJ,WAAWM,EAAOG,IAEvBzH,IAQX7D,EAAEyL,SAAW,SAAShK,EAAMmJ,EAAMc,GAChC,GAAIT,GAASnG,EAAMpD,EAASiK,EAAW9H,EAEnCsH,EAAQ,WACV,GAAIrD,GAAO9H,EAAEqL,MAAQM,CAEVf,GAAP9C,GAAeA,EAAO,EACxBmD,EAAUJ,WAAWM,EAAOP,EAAO9C,IAEnCmD,EAAU,KACLS,IACH7H,EAASpC,EAAKS,MAAMR,EAASoD,GACxBmG,IAASvJ,EAAUoD,EAAO,QAKrC,OAAO,YACLpD,EAAU5B,KACVgF,EAAO3C,UACPwJ,EAAY3L,EAAEqL,KACd,IAAIO,GAAUF,IAAcT,CAO5B,OANKA,KAASA,EAAUJ,WAAWM,EAAOP,IACtCgB,IACF/H,EAASpC,EAAKS,MAAMR,EAASoD,GAC7BpD,EAAUoD,EAAO,MAGZjB,IAOX7D,EAAE6L,KAAO,SAASpK,EAAMqK,GACtB,MAAO9L,GAAEkK,QAAQ4B,EAASrK,IAI5BzB,EAAEmE,OAAS,SAASP,GAClB,MAAO,YACL,OAAQA,EAAU1B,MAAMpC,KAAMqC,aAMlCnC,EAAE+L,QAAU,WACV,GAAIjH,GAAO3C,UACPwH,EAAQ7E,EAAKjC,OAAS,CAC1B,OAAO,YAGL,IAFA,GAAID,GAAI+G,EACJ9F,EAASiB,EAAK6E,GAAOzH,MAAMpC,KAAMqC,WAC9BS,KAAKiB,EAASiB,EAAKlC,GAAGf,KAAK/B,KAAM+D,EACxC,OAAOA,KAKX7D,EAAEgM,MAAQ,SAASC,EAAOxK,GACxB,MAAO,YACL,QAAMwK,EAAQ,EACLxK,EAAKS,MAAMpC,KAAMqC,WAD1B,SAOJnC,EAAEkM,OAAS,SAASD,EAAOxK,GACzB,GAAI6B,EACJ,OAAO,YAML,QALM2I,EAAQ,EACZ3I,EAAO7B,EAAKS,MAAMpC,KAAMqC,WAExBV,EAAO,KAEF6B,IAMXtD,EAAEmM,KAAOnM,EAAEkK,QAAQlK,EAAEkM,OAAQ,GAO7BlM,EAAEgB,KAAO,SAASG,GAChB,IAAKnB,EAAEuC,SAASpB,GAAM,QACtB,IAAIJ,EAAY,MAAOA,GAAWI,EAClC,IAAIH,KACJ,KAAK,GAAIiE,KAAO9D,GAASnB,EAAE6G,IAAI1F,EAAK8D,IAAMjE,EAAKR,KAAKyE,EACpD,OAAOjE,IAIThB,EAAE0E,OAAS,SAASvD,GAIlB,IAAK,GAHDH,GAAOhB,EAAEgB,KAAKG,GACd0B,EAAS7B,EAAK6B,OACd6B,EAASxE,MAAM2C,GACVD,EAAI,EAAOC,EAAJD,EAAYA,IAC1B8B,EAAO9B,GAAKzB,EAAIH,EAAK4B,GAEvB,OAAO8B,IAIT1E,EAAEoM,MAAQ,SAASjL,GAIjB,IAAK,GAHDH,GAAOhB,EAAEgB,KAAKG,GACd0B,EAAS7B,EAAK6B,OACduJ,EAAQlM,MAAM2C,GACTD,EAAI,EAAOC,EAAJD,EAAYA,IAC1BwJ,EAAMxJ,IAAM5B,EAAK4B,GAAIzB,EAAIH,EAAK4B,IAEhC,OAAOwJ,IAITpM,EAAEqM,OAAS,SAASlL,GAGlB,IAAK,GAFD0C,MACA7C,EAAOhB,EAAEgB,KAAKG,GACTyB,EAAI,EAAGC,EAAS7B,EAAK6B,OAAYA,EAAJD,EAAYA,IAChDiB,EAAO1C,EAAIH,EAAK4B,KAAO5B,EAAK4B,EAE9B,OAAOiB,IAKT7D,EAAEsM,UAAYtM,EAAEuM,QAAU,SAASpL,GACjC,GAAIqL,KACJ,KAAK,GAAIvH,KAAO9D,GACVnB,EAAEsC,WAAWnB,EAAI8D,KAAOuH,EAAMhM,KAAKyE,EAEzC,OAAOuH,GAAMnG,QAIfrG,EAAEyM,OAAS,SAAStL,GAClB,IAAKnB,EAAEuC,SAASpB,GAAM,MAAOA,EAE7B,KAAK,GADDuL,GAAQC,EACH/J,EAAI,EAAGC,EAASV,UAAUU,OAAYA,EAAJD,EAAYA,IAAK,CAC1D8J,EAASvK,UAAUS,EACnB,KAAK+J,IAAQD,GACP9L,EAAeiB,KAAK6K,EAAQC,KAC5BxL,EAAIwL,GAAQD,EAAOC,IAI3B,MAAOxL,IAITnB,EAAE4M,KAAO,SAASzL,EAAKiB,EAAUV,GAC/B,GAAiBuD,GAAbpB,IACJ,IAAW,MAAP1C,EAAa,MAAO0C,EACxB,IAAI7D,EAAEsC,WAAWF,GAAW,CAC1BA,EAAWZ,EAAeY,EAAUV,EACpC,KAAKuD,IAAO9D,GAAK,CACf,GAAIS,GAAQT,EAAI8D,EACZ7C,GAASR,EAAOqD,EAAK9D,KAAM0C,EAAOoB,GAAOrD,QAE1C,CACL,GAAIZ,GAAON,EAAOwB,SAAUzB,EAAMoB,KAAKM,UAAW,GAClDhB,GAAM,GAAId,QAAOc,EACjB,KAAK,GAAIyB,GAAI,EAAGC,EAAS7B,EAAK6B,OAAYA,EAAJD,EAAYA,IAChDqC,EAAMjE,EAAK4B,GACPqC,IAAO9D,KAAK0C,EAAOoB,GAAO9D,EAAI8D,IAGtC,MAAOpB,IAIT7D,EAAE6M,KAAO,SAAS1L,EAAKiB,EAAUV,GAC/B,GAAI1B,EAAEsC,WAAWF,GACfA,EAAWpC,EAAEmE,OAAO/B,OACf,CACL,GAAIpB,GAAOhB,EAAE8C,IAAIpC,EAAOwB,SAAUzB,EAAMoB,KAAKM,UAAW,IAAK2K,OAC7D1K,GAAW,SAASR,EAAOqD,GACzB,OAAQjF,EAAEuE,SAASvD,EAAMiE,IAG7B,MAAOjF,GAAE4M,KAAKzL,EAAKiB,EAAUV,IAI/B1B,EAAE+M,SAAW,SAAS5L,GACpB,IAAKnB,EAAEuC,SAASpB,GAAM,MAAOA,EAC7B,KAAK,GAAIyB,GAAI,EAAGC,EAASV,UAAUU,OAAYA,EAAJD,EAAYA,IAAK,CAC1D,GAAI8J,GAASvK,UAAUS,EACvB,KAAK,GAAI+J,KAAQD,GACXvL,EAAIwL,SAAe,KAAGxL,EAAIwL,GAAQD,EAAOC,IAGjD,MAAOxL,IAITnB,EAAEgN,MAAQ,SAAS7L,GACjB,MAAKnB,GAAEuC,SAASpB,GACTnB,EAAEc,QAAQK,GAAOA,EAAIV,QAAUT,EAAEyM,UAAWtL,GADtBA,GAO/BnB,EAAEiN,IAAM,SAAS9L,EAAK+L,GAEpB,MADAA,GAAY/L,GACLA,EAIT,IAAIgM,GAAK,SAAS3G,EAAGC,EAAG2G,EAAQC,GAG9B,GAAI7G,IAAMC,EAAG,MAAa,KAAND,GAAW,EAAIA,IAAM,EAAIC,CAE7C,IAAS,MAALD,GAAkB,MAALC,EAAW,MAAOD,KAAMC,CAErCD,aAAaxG,KAAGwG,EAAIA,EAAEpF,UACtBqF,YAAazG,KAAGyG,EAAIA,EAAErF,SAE1B,IAAIkM,GAAY3M,EAASkB,KAAK2E,EAC9B,IAAI8G,IAAc3M,EAASkB,KAAK4E,GAAI,OAAO,CAC3C,QAAQ6G,GAEN,IAAK,kBAEL,IAAK,kBAGH,MAAO,GAAK9G,GAAM,GAAKC,CACzB,KAAK,kBAGH,OAAKD,KAAOA,GAAWC,KAAOA,EAEhB,KAAND,EAAU,GAAKA,IAAM,EAAIC,GAAKD,KAAOC,CAC/C,KAAK,gBACL,IAAK,mBAIH,OAAQD,KAAOC,EAEnB,GAAgB,gBAALD,IAA6B,gBAALC,GAAe,OAAO,CAIzD,KADA,GAAI5D,GAASuK,EAAOvK,OACbA,KAGL,GAAIuK,EAAOvK,KAAY2D,EAAG,MAAO6G,GAAOxK,KAAY4D,CAItD,IAAI8G,GAAQ/G,EAAEgH,YAAaC,EAAQhH,EAAE+G,WACrC,IACED,IAAUE,GAEV,eAAiBjH,IAAK,eAAiBC,MACrCzG,EAAEsC,WAAWiL,IAAUA,YAAiBA,IACxCvN,EAAEsC,WAAWmL,IAAUA,YAAiBA,IAE1C,OAAO,CAGTL,GAAO5M,KAAKgG,GACZ6G,EAAO7M,KAAKiG,EACZ,IAAIa,GAAMzD,CAEV,IAAkB,mBAAdyJ,GAIF,GAFAhG,EAAOd,EAAE3D,OACTgB,EAASyD,IAASb,EAAE5D,OAGlB,KAAOyE,MACCzD,EAASsJ,EAAG3G,EAAEc,GAAOb,EAAEa,GAAO8F,EAAQC,WAG3C,CAEL,GAAsBpI,GAAlBjE,EAAOhB,EAAEgB,KAAKwF,EAIlB,IAHAc,EAAOtG,EAAK6B,OAEZgB,EAAS7D,EAAEgB,KAAKyF,GAAG5D,SAAWyE,EAE5B,KAAOA,MAELrC,EAAMjE,EAAKsG,GACLzD,EAAS7D,EAAE6G,IAAIJ,EAAGxB,IAAQkI,EAAG3G,EAAEvB,GAAMwB,EAAExB,GAAMmI,EAAQC,OAOjE,MAFAD,GAAOM,MACPL,EAAOK,MACA7J,EAIT7D,GAAE2N,QAAU,SAASnH,EAAGC,GACtB,MAAO0G,GAAG3G,EAAGC,UAKfzG,EAAE4N,QAAU,SAASzM,GACnB,GAAW,MAAPA,EAAa,OAAO,CACxB,IAAInB,EAAEc,QAAQK,IAAQnB,EAAE6N,SAAS1M,IAAQnB,EAAEwI,YAAYrH,GAAM,MAAsB,KAAfA,EAAI0B,MACxE,KAAK,GAAIoC,KAAO9D,GAAK,GAAInB,EAAE6G,IAAI1F,EAAK8D,GAAM,OAAO,CACjD,QAAO,GAITjF,EAAE8N,UAAY,SAAS3M,GACrB,SAAUA,GAAwB,IAAjBA,EAAI4M,WAKvB/N,EAAEc,QAAUD,GAAiB,SAASM,GACpC,MAA8B,mBAAvBR,EAASkB,KAAKV,IAIvBnB,EAAEuC,SAAW,SAASpB,GACpB,GAAI6M,SAAc7M,EAClB,OAAgB,aAAT6M,GAAgC,WAATA,KAAuB7M,GAIvDnB,EAAE0C,MAAM,YAAa,WAAY,SAAU,SAAU,OAAQ,UAAW,SAASuL,GAC/EjO,EAAE,KAAOiO,GAAQ,SAAS9M,GACxB,MAAOR,GAASkB,KAAKV,KAAS,WAAa8M,EAAO,OAMjDjO,EAAEwI,YAAYrG,aACjBnC,EAAEwI,YAAc,SAASrH,GACvB,MAAOnB,GAAE6G,IAAI1F,EAAK,YAKH,kBAAR,MACTnB,EAAEsC,WAAa,SAASnB,GACtB,MAAqB,kBAAPA,KAAqB,IAKvCnB,EAAEkO,SAAW,SAAS/M,GACpB,MAAO+M,UAAS/M,KAASgN,MAAMC,WAAWjN,KAI5CnB,EAAEmO,MAAQ,SAAShN,GACjB,MAAOnB,GAAEqO,SAASlN,IAAQA,KAASA,GAIrCnB,EAAE8I,UAAY,SAAS3H,GACrB,MAAOA,MAAQ,GAAQA,KAAQ,GAAgC,qBAAvBR,EAASkB,KAAKV,IAIxDnB,EAAEsO,OAAS,SAASnN,GAClB,MAAe,QAARA,GAITnB,EAAEuO,YAAc,SAASpN,GACvB,MAAOA,SAAa,IAKtBnB,EAAE6G,IAAM,SAAS1F,EAAK8D,GACpB,MAAc,OAAP9D,GAAeP,EAAeiB,KAAKV,EAAK8D,IAQjDjF,EAAEwO,WAAa,WAEb,MADA3O,GAAKG,EAAID,EACFD,MAITE,EAAEqC,SAAW,SAAST,GACpB,MAAOA,IAGT5B,EAAEyO,SAAW,SAAS7M,GACpB,MAAO,YACL,MAAOA,KAIX5B,EAAE0O,KAAO,aAET1O,EAAEyC,SAAW,SAASwC,GACpB,MAAO,UAAS9D,GACd,MAAOA,GAAI8D,KAKfjF,EAAEwC,QAAU,SAAS2C,GACnB,GAAIiH,GAAQpM,EAAEoM,MAAMjH,GAAQtC,EAASuJ,EAAMvJ,MAC3C,OAAO,UAAS1B,GACd,GAAW,MAAPA,EAAa,OAAQ0B,CACzB1B,GAAM,GAAId,QAAOc,EACjB,KAAK,GAAIyB,GAAI,EAAOC,EAAJD,EAAYA,IAAK,CAC/B,GAAI+L,GAAOvC,EAAMxJ,GAAIqC,EAAM0J,EAAK,EAChC,IAAIA,EAAK,KAAOxN,EAAI8D,MAAUA,IAAO9D,IAAM,OAAO,EAEpD,OAAO,IAKXnB,EAAEiM,MAAQ,SAASjG,EAAG5D,EAAUV,GAC9B,GAAIkN,GAAQ1O,MAAMgG,KAAKb,IAAI,EAAGW,GAC9B5D,GAAWZ,EAAeY,EAAUV,EAAS,EAC7C,KAAK,GAAIkB,GAAI,EAAOoD,EAAJpD,EAAOA,IAAKgM,EAAMhM,GAAKR,EAASQ,EAChD,OAAOgM,IAIT5O,EAAE8F,OAAS,SAASL,EAAKJ,GAKvB,MAJW,OAAPA,IACFA,EAAMI,EACNA,EAAM,GAEDA,EAAMS,KAAK2I,MAAM3I,KAAKJ,UAAYT,EAAMI,EAAM,KAIvDzF,EAAEqL,IAAMyD,KAAKzD,KAAO,WAClB,OAAO,GAAIyD,OAAOC,UAIpB,IAAIC,IACFC,IAAK,QACLC,IAAK,OACLC,IAAK,OACLC,IAAK,SACLC,IAAK,SACLC,IAAK,UAEHC,EAAcvP,EAAEqM,OAAO2C,GAGvBQ,EAAgB,SAAS1M,GAC3B,GAAI2M,GAAU,SAASC,GACrB,MAAO5M,GAAI4M,IAGThD,EAAS,MAAQ1M,EAAEgB,KAAK8B,GAAK6M,KAAK,KAAO,IACzCC,EAAaC,OAAOnD,GACpBoD,EAAgBD,OAAOnD,EAAQ,IACnC,OAAO,UAASqD,GAEd,MADAA,GAAmB,MAAVA,EAAiB,GAAK,GAAKA,EAC7BH,EAAWI,KAAKD,GAAUA,EAAOE,QAAQH,EAAeL,GAAWM,GAG9E/P,GAAEkQ,OAASV,EAAcR,GACzBhP,EAAEmQ,SAAWX,EAAcD,GAI3BvP,EAAE6D,OAAS,SAASyF,EAAQ7G,GAC1B,GAAc,MAAV6G,EAAgB,WAAY,EAChC,IAAI1H,GAAQ0H,EAAO7G,EACnB,OAAOzC,GAAEsC,WAAWV,GAAS0H,EAAO7G,KAAcb,EAKpD,IAAIwO,GAAY,CAChBpQ,GAAEqQ,SAAW,SAASC,GACpB,GAAIC,KAAOH,EAAY,EACvB,OAAOE,GAASA,EAASC,EAAKA,GAKhCvQ,EAAEwQ,kBACAC,SAAc,kBACdC,YAAc,mBACdR,OAAc,mBAMhB,IAAIS,GAAU,OAIVC,GACFvB,IAAU,IACVwB,KAAU,KACVC,KAAU,IACVC,KAAU,IACVC,SAAU,QACVC,SAAU,SAGRxB,EAAU,4BAEVyB,EAAa,SAASxB,GACxB,MAAO,KAAOkB,EAAQlB,GAOxB1P,GAAEmR,SAAW,SAASC,EAAMC,EAAUC,IAC/BD,GAAYC,IAAaD,EAAWC,GACzCD,EAAWrR,EAAE+M,YAAasE,EAAUrR,EAAEwQ,iBAGtC,IAAIe,GAAU1B,SACXwB,EAASnB,QAAUS,GAASjE,QAC5B2E,EAASX,aAAeC,GAASjE,QACjC2E,EAASZ,UAAYE,GAASjE,QAC/BiD,KAAK,KAAO,KAAM,KAGhB5N,EAAQ,EACR2K,EAAS,QACb0E,GAAKnB,QAAQsB,EAAS,SAAS7B,EAAOQ,EAAQQ,EAAaD,EAAUe,GAanE,MAZA9E,IAAU0E,EAAK3Q,MAAMsB,EAAOyP,GAAQvB,QAAQR,EAASyB,GACrDnP,EAAQyP,EAAS9B,EAAM7M,OAEnBqN,EACFxD,GAAU,cAAgBwD,EAAS,iCAC1BQ,EACThE,GAAU,cAAgBgE,EAAc,uBAC/BD,IACT/D,GAAU,OAAS+D,EAAW,YAIzBf,IAEThD,GAAU,OAGL2E,EAASI,WAAU/E,EAAS,mBAAqBA,EAAS,OAE/DA,EAAS,2CACP,oDACAA,EAAS,eAEX,KACE,GAAIgF,GAAS,GAAInR,UAAS8Q,EAASI,UAAY,MAAO,IAAK/E,GAC3D,MAAOiF,GAEP,KADAA,GAAEjF,OAASA,EACLiF,EAGR,GAAIR,GAAW,SAASS,GACtB,MAAOF,GAAO7P,KAAK/B,KAAM8R,EAAM5R,IAI7B6R,EAAWR,EAASI,UAAY,KAGpC,OAFAN,GAASzE,OAAS,YAAcmF,EAAW,OAASnF,EAAS,IAEtDyE,GAITnR,EAAE8R,MAAQ,SAAS3Q,GACjB,GAAI4Q,GAAW/R,EAAEmB,EAEjB,OADA4Q,GAASC,QAAS,EACXD,EAUT,IAAIlO,GAAS,SAAS1C,GACpB,MAAOrB,MAAKkS,OAAShS,EAAEmB,GAAK2Q,QAAU3Q,EAIxCnB,GAAEiS,MAAQ,SAAS9Q,GACjBnB,EAAE0C,KAAK1C,EAAEsM,UAAUnL,GAAM,SAAS8M,GAChC,GAAIxM,GAAOzB,EAAEiO,GAAQ9M,EAAI8M,EACzBjO,GAAEG,UAAU8N,GAAQ,WAClB,GAAInJ,IAAQhF,KAAKsB,SAEjB,OADAZ,GAAK0B,MAAM4C,EAAM3C,WACV0B,EAAOhC,KAAK/B,KAAM2B,EAAKS,MAAMlC,EAAG8E,QAM7C9E,EAAEiS,MAAMjS,GAGRA,EAAE0C,MAAM,MAAO,OAAQ,UAAW,QAAS,OAAQ,SAAU,WAAY,SAASuL,GAChF,GAAIpJ,GAAS5E,EAAWgO,EACxBjO,GAAEG,UAAU8N,GAAQ,WAClB,GAAI9M,GAAMrB,KAAKsB,QAGf,OAFAyD,GAAO3C,MAAMf,EAAKgB,WACJ,UAAT8L,GAA6B,WAATA,GAAqC,IAAf9M,EAAI0B,cAAqB1B,GAAI,GACrE0C,EAAOhC,KAAK/B,KAAMqB,MAK7BnB,EAAE0C,MAAM,SAAU,OAAQ,SAAU,SAASuL,GAC3C,GAAIpJ,GAAS5E,EAAWgO,EACxBjO,GAAEG,UAAU8N,GAAQ,WAClB,MAAOpK,GAAOhC,KAAK/B,KAAM+E,EAAO3C,MAAMpC,KAAKsB,SAAUe,eAKzDnC,EAAEG,UAAUyB,MAAQ,WAClB,MAAO9B,MAAKsB,UAUQ,kBAAX8Q,SAAyBA,OAAOC,KACzCD,OAAO,gBAAkB,WACvB,MAAOlS,OAGX6B,KAAK/B"} \ No newline at end of file diff --git a/visualize_result_struct.html b/visualize_result_struct.html new file mode 100644 index 0000000..a91ac0a --- /dev/null +++ b/visualize_result_struct.html @@ -0,0 +1,183 @@ + + + + + + Image Annotation Viewer + + + + + + + + + + + + + + +
+
+
+
+
+ +