diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index cb29bb8549a..c6c6a908c02 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -53,6 +53,8 @@ jobs: run: | pip install -r requirements.txt pip install pytype + # required for test and utl directory typecheck + pip install hyperopt matplotlib seaborn - name: Install wheel shell: bash run: | @@ -64,6 +66,8 @@ jobs: shell: bash run: | python -m pytype ./python/vowpalwabbit/ --verbosity=2 + python -m pytype ./test/ --verbosity=2 + python -m pytype ./utl/ --verbosity=2 python-formatting: name: python.formatting runs-on: ubuntu-latest @@ -76,7 +80,7 @@ jobs: - run: pip install black - shell: bash run: | - python -m black --check python/vowpalwabbit || (echo -e "---\nTo fix, run:\n\tpython -m black python/vowpalwabbit"; exit 1) + python -m black --check . --exclude ext_libs/ || (echo -e "---\nTo fix, run:\n\tpython -m black . --exclude ext_libs"; exit 1) cpp-formatting: name: c++.formatting runs-on: ubuntu-20.04 diff --git a/big_tests/testCode/ocr2vw.py b/big_tests/testCode/ocr2vw.py index 63fa3ecb426..3f423a92024 100755 --- a/big_tests/testCode/ocr2vw.py +++ b/big_tests/testCode/ocr2vw.py @@ -1,63 +1,71 @@ #!/usr/bin/env python # convert letter.data to letter.vw -def read_letter_names (fn): + +def read_letter_names(fn): ret = list() with open(fn) as ins: for line in ins: ret.append(line.rstrip()) - print "Read %d names from %s" % (len(ret),fn) + print("Read %d names from %s" % (len(ret), fn)) return ret -def find_pixel_start (names): + +def find_pixel_start(names): for i in range(len(names)): if names[i].startswith("p_"): return i - raise ValueError("No pixel data",names) + raise ValueError("No pixel data", names) + -def data2vw (ifn, train, test, names): +def data2vw(ifn, train, test, names): lineno = 0 trainN = 0 testN = 0 if ifn.endswith(".gz"): import gzip + iopener = gzip.open else: iopener = open id_pos = names.index("id") letter_pos = names.index("letter") pixel_start = find_pixel_start(names) - with iopener(ifn) as ins, open(train,"wb") as trainS, open(test,"wb") as testS: + with iopener(ifn) as ins, open(train, "wb") as trainS, open(test, "wb") as testS: for line in ins: lineno += 1 - vals = line.rstrip().split('\t') + vals = line.rstrip().split("\t") if len(vals) != len(names): - raise ValueError("Bad field count", - len(vals),len(names),vals,names) + raise ValueError("Bad field count", len(vals), len(names), vals, names) char = vals[letter_pos] if len(char) != 1: - raise ValueError("Bad letter",char) + raise ValueError("Bad letter", char) if lineno % 10 == 0: testN += 1 outs = testS else: trainN += 1 outs = trainS - outs.write("%d 1 %s-%s|Pixel" % (ord(char)-ord('a')+1,char,vals[id_pos])) - for i in range(pixel_start,len(names)): - if vals[i] != '0': - outs.write(' %s:%s' % (names[i],vals[i])) - outs.write('\n') - print "Read %d lines from %s; wrote %d lines into %s and %d lines into %s" % ( - lineno,ifn,trainN,train,testN,test) + outs.write( + "%d 1 %s-%s|Pixel" % (ord(char) - ord("a") + 1, char, vals[id_pos]) + ) + for i in range(pixel_start, len(names)): + if vals[i] != "0": + outs.write(" %s:%s" % (names[i], vals[i])) + outs.write("\n") + print( + "Read %d lines from %s; wrote %d lines into %s and %d lines into %s" + % (lineno, ifn, trainN, train, testN, test) + ) -if __name__ == '__main__': +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Convert letters.data to VW format') - parser.add_argument('input',help='path to letter.data[.gz]') - parser.add_argument('names',help='path to letter.names') - parser.add_argument('train',help='VW train file location (90%)') - parser.add_argument('test',help='VW test file location (10%)') + + parser = argparse.ArgumentParser(description="Convert letters.data to VW format") + parser.add_argument("input", help="path to letter.data[.gz]") + parser.add_argument("names", help="path to letter.names") + parser.add_argument("train", help="VW train file location (90%)") + parser.add_argument("test", help="VW test file location (10%)") args = parser.parse_args() - data2vw(args.input,args.train,args.test,read_letter_names(args.names)) + data2vw(args.input, args.train, args.test, read_letter_names(args.names)) diff --git a/demo/advertising/naive_baseline.py b/demo/advertising/naive_baseline.py index f98a7b37b02..6a0a758fb08 100755 --- a/demo/advertising/naive_baseline.py +++ b/demo/advertising/naive_baseline.py @@ -6,87 +6,96 @@ from os import devnull # The learning algorithm is vowpal wabbit, available at https://github.com/JohnLangford/vowpal_wabbit/wiki -vw_train_cmd = '../../vowpalwabbit/vw -c -f model --bfgs --passes 30 -b 22 --loss_function logistic --l2 14 --termination 0.00001 --holdout_off' -vw_test_cmd = '../../vowpalwabbit/vw -t -i model -p /dev/stdout' +vw_train_cmd = "../../vowpalwabbit/vw -c -f model --bfgs --passes 30 -b 22 --loss_function logistic --l2 14 --termination 0.00001 --holdout_off" +vw_test_cmd = "../../vowpalwabbit/vw -t -i model -p /dev/stdout" + def get_features(line): feat = line[2:] # Bucketizing the integer features on a logarithmic scale for i in range(8): - if feat[i]: + if feat[i]: v = int(feat[i]) - if v>0: - feat[i] = str(int(log(v+0.5)/log(1.5))) - return ' '.join(['%d_%s' % (i,v) for i,v in enumerate(feat) if v]) + if v > 0: + feat[i] = str(int(log(v + 0.5) / log(1.5))) + return " ".join(["%d_%s" % (i, v) for i, v in enumerate(feat) if v]) + def train_test_oneday(day): - ts_beginning_test = 86400*(day-1) + ts_beginning_test = 86400 * (day - 1) - with open('data.txt') as f: + with open("data.txt") as f: line = f.readline() # Beginning of the training set: 3 weeks before the test period - while int(line.split()[0]) < ts_beginning_test - 86400*21: + while int(line.split()[0]) < ts_beginning_test - 86400 * 21: line = f.readline() - call('rm -f .cache', shell=True) + call("rm -f .cache", shell=True) vw = Popen(vw_train_cmd, shell=True, stdin=PIPE) - print '---------- Training on days %d to %d ----------------' % (day-21, day-1) - print + print( + "---------- Training on days %d to %d ----------------" + % (day - 21, day - 1) + ) + print() while int(line.split()[0]) < ts_beginning_test: - line = line[:-1].split('\t') + line = line[:-1].split("\t") label = -1 if line[1]: conv_ts = int(line[1]) - if conv_ts < ts_beginning_test: - label = 1 # Positive label iff conversion and the conversion occured before the test period + if conv_ts < ts_beginning_test: + label = 1 # Positive label iff conversion and the conversion occured before the test period - out = '%d | %s' % (label, get_features(line)) - print >>vw.stdin, out + out = "%d | %s" % (label, get_features(line)) + print >> vw.stdin, out line = f.readline() vw.stdin.close() vw.wait() - print - print '---------- Testing on day %d ----------------' % (day-21) - - vw = Popen(vw_test_cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=open(devnull, 'w')) + print() + print("---------- Testing on day %d ----------------" % (day - 21)) + + vw = Popen( + vw_test_cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=open(devnull, "w") + ) ll = 0 n = 0 # Test is one day long while int(line.split()[0]) < ts_beginning_test + 86400: - line = line[:-1].split('\t') + line = line[:-1].split("\t") - print >>vw.stdin, '| '+get_features(line) + print >> vw.stdin, "| " + get_features(line) dotproduct = float(vw.stdout.readline()) - + # Test log likelihood - if line[1]: # Positive example - ll += log(1+exp(-dotproduct)) - else: # Negative sample - ll += log(1+exp(dotproduct)) + if line[1]: # Positive example + ll += log(1 + exp(-dotproduct)) + else: # Negative sample + ll += log(1 + exp(dotproduct)) n += 1 line = f.readline() return (ll, n) + def main(): ll = 0 n = 0 # Iterating over the 7 test days - for day in range(54,61): + for day in range(54, 61): ll_day, n_day = train_test_oneday(day) ll += ll_day n += n_day - print ll_day, n_day - print - print 'Average test log likelihood: %f' % (ll/n) - + print(ll_day, n_day) + print() + print("Average test log likelihood: %f" % (ll / n)) + + if __name__ == "__main__": main() diff --git a/demo/dependencyparsing/evaluate.py b/demo/dependencyparsing/evaluate.py index e74fe2b72f3..cebb4c3572f 100755 --- a/demo/dependencyparsing/evaluate.py +++ b/demo/dependencyparsing/evaluate.py @@ -5,22 +5,26 @@ import sys from collections import defaultdict + def pc(num, den): - return (num / float(den+1e-100)) * 100 + return (num / float(den + 1e-100)) * 100 + def fmt_acc(label, n, l_corr, u_corr, total_errs): l_pc = pc(l_corr, n) u_pc = pc(u_corr, n) err_pc = pc(n - l_corr, total_errs) - return '%s\t%d\t%.3f\t%.3f\t%.3f' % (label, n, l_pc, u_pc, err_pc) + return "%s\t%d\t%.3f\t%.3f\t%.3f" % (label, n, l_pc, u_pc, err_pc) def gen_toks(loc): - sent_strs = open(str(loc)).read().strip().split('\n\n') + sent_strs = open(str(loc)).read().strip().split("\n\n") token = None i = 0 for sent_str in sent_strs: - tokens = [Token(i, tok_str.split()) for i, tok_str in enumerate(sent_str.split('\n'))] + tokens = [ + Token(i, tok_str.split()) for i, tok_str in enumerate(sent_str.split("\n")) + ] for token in tokens: yield sent_str, token @@ -37,24 +41,24 @@ def __init__(self, id_, attrs): new_attrs.append(attrs[-3]) attrs = new_attrs self.label = attrs[-1] - if self.label.lower() == 'root': - self.label = 'ROOT' + if self.label.lower() == "root": + self.label = "ROOT" try: head = int(attrs[-2]) except: try: - self.label = 'P' + self.label = "P" head = int(attrs[-1]) except: - print attrs + print(attrs) raise attrs.pop() attrs.pop() self.head = head self.pos = attrs.pop() self.word = attrs.pop() - self.dir = 'R' if head >= 0 and head < self.id else 'L' - + self.dir = "R" if head >= 0 and head < self.id else "L" + def mymain(test_loc, gold_loc, eval_punct=False): if not os.path.exists(test_loc): @@ -67,7 +71,7 @@ def mymain(test_loc, gold_loc, eval_punct=False): l_nc = 0 for (sst, t), (ss, g) in zip(gen_toks(test_loc), gen_toks(gold_loc)): if not eval_punct and g.word in ",.-;:'\"!?`{}()[]": - continue + continue prev_g = g prev_t = t u_c = g.head == t.head @@ -79,7 +83,7 @@ def mymain(test_loc, gold_loc, eval_punct=False): u_by_label[g.dir][g.label] += u_c l_by_label[g.dir][g.label] += l_c n_l_err = N - l_nc - for D in ['L', 'R']: + for D in ["L", "R"]: n_other = 0 l_other = 0 u_other = 0 @@ -93,12 +97,13 @@ def mymain(test_loc, gold_loc, eval_punct=False): else: l_corr = l_by_label[D][label] u_corr = u_by_label[D][label] - yield 'U: %.3f' % pc(u_nc, N) - yield 'L: %.3f' % pc(l_nc, N) + yield "U: %.3f" % pc(u_nc, N) + yield "L: %.3f" % pc(l_nc, N) + -if __name__ == '__main__': - if(sys.argv < 3): - print 'Usage: parsed_pred_file gold_test_conll_file' - sys.exit(0) - for line in mymain(sys.argv[1], sys.argv[2], eval_punct=False): - print line +if __name__ == "__main__": + if sys.argv < 3: + print("Usage: parsed_pred_file gold_test_conll_file") + sys.exit(0) + for line in mymain(sys.argv[1], sys.argv[2], eval_punct=False): + print(line) diff --git a/demo/dependencyparsing/parse_data.py b/demo/dependencyparsing/parse_data.py index 2ba75c091e8..8bdf7ef102a 100644 --- a/demo/dependencyparsing/parse_data.py +++ b/demo/dependencyparsing/parse_data.py @@ -1,29 +1,34 @@ from sys import argv hash = {} + + def readtags(): - for line in open('tags').readlines(): - hash[line.split()[0]] = int(line.strip().split()[1]) + for line in open("tags").readlines(): + hash[line.split()[0]] = int(line.strip().split()[1]) -if __name__ == '__main__': - c = 1 - readtags() - if len(argv) != 3: - print 'parseDepData.py input output' - data = open(argv[1]).readlines() - writer = open(argv[2],'w') - for line in data: - if line == '\n': - writer.write('\n') - continue - splits = line.strip().lower().split() - strw = "|w %s"%splits[1].replace(":","COL"); - strp = "|p %s"%splits[4].replace(":","COL"); - tag = splits[8] - if tag not in hash: - hash[tag] = c - c+=1 - #writer.write('%s 1.0 %s:%s%s %s\n'%((int(splits[7])+1) + (hash[tag]<<8), int(splits[7]),tag,strw, strp)) - writer.write('%s %s %s:%s%s %s\n' % (int(splits[7]), hash[tag], int(splits[7]), tag, strw, strp)) - writer.close() +if __name__ == "__main__": + c = 1 + readtags() + if len(argv) != 3: + print("parseDepData.py input output") + data = open(argv[1]).readlines() + writer = open(argv[2], "w") + for line in data: + if line == "\n": + writer.write("\n") + continue + splits = line.strip().lower().split() + strw = "|w %s" % splits[1].replace(":", "COL") + strp = "|p %s" % splits[4].replace(":", "COL") + tag = splits[8] + if tag not in hash: + hash[tag] = c + c += 1 + # writer.write('%s 1.0 %s:%s%s %s\n'%((int(splits[7])+1) + (hash[tag]<<8), int(splits[7]),tag,strw, strp)) + writer.write( + "%s %s %s:%s%s %s\n" + % (int(splits[7]), hash[tag], int(splits[7]), tag, strw, strp) + ) + writer.close() diff --git a/demo/dependencyparsing/parse_test_result.py b/demo/dependencyparsing/parse_test_result.py index 2d4bdf4c119..dafd4a70a6e 100644 --- a/demo/dependencyparsing/parse_test_result.py +++ b/demo/dependencyparsing/parse_test_result.py @@ -1,25 +1,25 @@ from sys import argv from sys import exit + dict = {} -if len(argv) <4: - print "Usage: test_conll_file annotation_file_from_vw tag_id_mapping" - exit(1) +if len(argv) < 4: + print("Usage: test_conll_file annotation_file_from_vw tag_id_mapping") + exit(1) for data in open(argv[3]).readlines(): - dict[data.strip().split()[1]] = data.strip().split()[0] -annotation = open(argv[2]).readlines() -#for item in list(annotation): -# if item == ' w\n': -# annotation.remove(item) + dict[data.strip().split()[1]] = data.strip().split()[0] +annotation = open(argv[2]).readlines() +# for item in list(annotation): +# if item == ' w\n': +# annotation.remove(item) for idx, line in enumerate(open(argv[1]).readlines()): - item = line.split() - # conll07 - if len(item) ==10: - item[-4] = annotation[idx].strip().split(":")[0] - item[-3] = dict[annotation[idx].strip().split(":")[1]] - # wsj corpus - elif len(item) >0: -# print idx - item[-2] = annotation[idx].strip().split(":")[0] - item[-1] = dict[annotation[idx].strip().split(":")[1]] - print "\t".join(item) - + item = line.split() + # conll07 + if len(item) == 10: + item[-4] = annotation[idx].strip().split(":")[0] + item[-3] = dict[annotation[idx].strip().split(":")[1]] + # wsj corpus + elif len(item) > 0: + # print idx + item[-2] = annotation[idx].strip().split(":")[0] + item[-1] = dict[annotation[idx].strip().split(":")[1]] + print("\t".join(item)) diff --git a/demo/ocr/ocr2vw.py b/demo/ocr/ocr2vw.py index ce70ad75b6a..bf598f42ced 100755 --- a/demo/ocr/ocr2vw.py +++ b/demo/ocr/ocr2vw.py @@ -1,62 +1,70 @@ # convert letter.data to letter.vw -def read_letter_names (fn): + +def read_letter_names(fn): ret = list() with open(fn) as ins: for line in ins: ret.append(line.rstrip()) - print "Read %d names from %s" % (len(ret),fn) + print("Read %d names from %s" % (len(ret), fn)) return ret -def find_pixel_start (names): + +def find_pixel_start(names): for i in range(len(names)): if names[i].startswith("p_"): return i - raise ValueError("No pixel data",names) + raise ValueError("No pixel data", names) + -def data2vw (ifn, train, test, names): +def data2vw(ifn, train, test, names): lineno = 0 trainN = 0 testN = 0 if ifn.endswith(".gz"): import gzip + iopener = gzip.open else: iopener = open id_pos = names.index("id") letter_pos = names.index("letter") pixel_start = find_pixel_start(names) - with iopener(ifn) as ins, open(train,"wb") as trainS, open(test,"wb") as testS: + with iopener(ifn) as ins, open(train, "wb") as trainS, open(test, "wb") as testS: for line in ins: lineno += 1 - vals = line.rstrip().split('\t') + vals = line.rstrip().split("\t") if len(vals) != len(names): - raise ValueError("Bad field count", - len(vals),len(names),vals,names) + raise ValueError("Bad field count", len(vals), len(names), vals, names) char = vals[letter_pos] if len(char) != 1: - raise ValueError("Bad letter",char) + raise ValueError("Bad letter", char) if lineno % 10 == 0: testN += 1 outs = testS else: trainN += 1 outs = trainS - outs.write("%d 1 %s-%s|Pixel" % (ord(char)-ord('a')+1,char,vals[id_pos])) - for i in range(pixel_start,len(names)): - if vals[i] != '0': - outs.write(' %s:%s' % (names[i],vals[i])) - outs.write('\n') - print "Read %d lines from %s; wrote %d lines into %s and %d lines into %s" % ( - lineno,ifn,trainN,train,testN,test) + outs.write( + "%d 1 %s-%s|Pixel" % (ord(char) - ord("a") + 1, char, vals[id_pos]) + ) + for i in range(pixel_start, len(names)): + if vals[i] != "0": + outs.write(" %s:%s" % (names[i], vals[i])) + outs.write("\n") + print( + "Read %d lines from %s; wrote %d lines into %s and %d lines into %s" + % (lineno, ifn, trainN, train, testN, test) + ) -if __name__ == '__main__': +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Convert letters.data to VW format') - parser.add_argument('input',help='path to letter.data[.gz]') - parser.add_argument('names',help='path to letter.names') - parser.add_argument('train',help='VW train file location (90%)') - parser.add_argument('test',help='VW test file location (10%)') + + parser = argparse.ArgumentParser(description="Convert letters.data to VW format") + parser.add_argument("input", help="path to letter.data[.gz]") + parser.add_argument("names", help="path to letter.names") + parser.add_argument("train", help="VW train file location (90%)") + parser.add_argument("test", help="VW test file location (10%)") args = parser.parse_args() - data2vw(args.input,args.train,args.test,read_letter_names(args.names)) + data2vw(args.input, args.train, args.test, read_letter_names(args.names)) diff --git a/demo/recall_tree/wikipara/WikiExtractor.py b/demo/recall_tree/wikipara/WikiExtractor.py index 88edca21f2b..23019b17459 100755 --- a/demo/recall_tree/wikipara/WikiExtractor.py +++ b/demo/recall_tree/wikipara/WikiExtractor.py @@ -2123,7 +2123,7 @@ def clean(extractor, text): text = text.replace(match.group(), "%s_%d" % (placeholder, index)) index += 1 - text = text.replace("<<", u"«").replace(">>", u"»") + text = text.replace("<<", "«").replace(">>", "»") ############################################# @@ -2131,8 +2131,8 @@ def clean(extractor, text): text = text.replace("\t", " ") text = spaces.sub(" ", text) text = dots.sub("...", text) - text = re.sub(u" (,:\.\)\]»)", r"\1", text) - text = re.sub(u"(\[\(«) ", r"\1", text) + text = re.sub(" (,:\.\)\]»)", r"\1", text) + text = re.sub("(\[\(«) ", r"\1", text) text = re.sub(r"\n\W+?\n", "\n", text, flags=re.U) # lines with only punctuations text = text.replace(",,", ",").replace(",.", ".") @@ -2747,7 +2747,7 @@ def main(): try: power = "kmg".find(args.bytes[-1].lower()) + 1 - file_size = int(args.bytes[:-1]) * 1024 ** power + file_size = int(args.bytes[:-1]) * 1024**power if file_size < minFileSize: raise ValueError() except ValueError: diff --git a/python/docs/source/conf.py b/python/docs/source/conf.py index 2a6cd01b8d3..0d42dbeea1c 100644 --- a/python/docs/source/conf.py +++ b/python/docs/source/conf.py @@ -127,7 +127,7 @@ # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { "https://docs.python.org/3/": None, - "http://pandas.pydata.org/pandas-docs/dev": None + "http://pandas.pydata.org/pandas-docs/dev": None, } html_favicon = "favicon.png" diff --git a/python/tests/DistributionallyRobustUnitTestData.py b/python/tests/DistributionallyRobustUnitTestData.py index d8d963e02db..19dbfc39b17 100644 --- a/python/tests/DistributionallyRobustUnitTestData.py +++ b/python/tests/DistributionallyRobustUnitTestData.py @@ -29,7 +29,7 @@ def intervalimpl( uncgstar = 1 + 1 / n else: unca = (uncwfake + sumw) / (1 + n) - uncb = (uncwfake ** 2 + sumwsq) / (1 + n) + uncb = (uncwfake**2 + sumwsq) / (1 + n) uncgstar = (n + 1) * (unca - 1) ** 2 / (uncb - unca * unca) Delta = chi2.isf(q=alpha, df=1) phi = (-uncgstar - Delta) / (2 * (n + 1)) @@ -41,7 +41,7 @@ def intervalimpl( if wfake == inf: x = sign * (r + (sumwr - sumw * r) / n) y = (r * sumw - sumwr) ** 2 / (n * (1 + n)) - ( - r ** 2 * sumwsq - 2 * r * sumwsqr + sumwsqrsq + r**2 * sumwsq - 2 * r * sumwsqr + sumwsqrsq ) / (1 + n) z = phi + 1 / (2 * n) if isclose(y * z, 0, abs_tol=1e-9): @@ -79,16 +79,16 @@ def intervalimpl( barwsqr = sign * (wfake * wfake * r + sumwsqr) / (1 + n) barwsqrsq = (wfake * wfake * r * r + sumwsqrsq) / (1 + n) - if barwsq > barw ** 2: + if barwsq > barw**2: x = barwr + ( (1 - barw) * (barwsqr - barw * barwr) - / (barwsq - barw ** 2) + / (barwsq - barw**2) ) - y = (barwsqr - barw * barwr) ** 2 / (barwsq - barw ** 2) - ( - barwsqrsq - barwr ** 2 + y = (barwsqr - barw * barwr) ** 2 / (barwsq - barw**2) - ( + barwsqrsq - barwr**2 ) - z = phi + (1 / 2) * (1 - barw) ** 2 / (barwsq - barw ** 2) + z = phi + (1 / 2) * (1 - barw) ** 2 / (barwsq - barw**2) if isclose(y * z, 0, abs_tol=1e-9): y = 0 @@ -150,13 +150,13 @@ def update(self, c, w, r): ), "w = {} < {} < {}".format(self.wmin, w, self.wmax) assert r >= 0 and r <= 1, "r = {}".format(r) - decay = self.tau ** c + decay = self.tau**c self.n = decay * self.n + c self.sumw = decay * self.sumw + c * w - self.sumwsq = decay * self.sumwsq + c * w ** 2 + self.sumwsq = decay * self.sumwsq + c * w**2 self.sumwr = decay * self.sumwr + c * w * r - self.sumwsqr = decay * self.sumwsqr + c * (w ** 2) * r - self.sumwsqrsq = decay * self.sumwsqrsq + c * (w ** 2) * (r ** 2) + self.sumwsqr = decay * self.sumwsqr + c * (w**2) * r + self.sumwsqrsq = decay * self.sumwsqrsq + c * (w**2) * (r**2) self.duals = None self.mleduals = None diff --git a/python/tests/crminustwo.py b/python/tests/crminustwo.py index 68d590ba735..ddc28be5680 100644 --- a/python/tests/crminustwo.py +++ b/python/tests/crminustwo.py @@ -32,6 +32,7 @@ def estimate(datagen, wmin, wmax, rmin=0, rmax=1, raiseonerr=False, censored=Fal censored, ) + @staticmethod def estimateimpl( n, sumw, @@ -62,7 +63,7 @@ def estimateimpl( gstar = 1 + 1 / n else: a = (wfake + sumw) / (1 + n) - b = (wfake ** 2 + sumwsq) / (1 + n) + b = (wfake**2 + sumwsq) / (1 + n) assert a * a < b gammastar = (b - a) / (a * a - b) betastar = (1 - a) / (a * a - b) @@ -147,8 +148,8 @@ def estimatediff( sumu += c * u sumw += c * w sumuw += c * u * w - sumusq += c * u ** 2 - sumwsq += c * w ** 2 + sumusq += c * u**2 + sumwsq += c * w**2 sumuMwr += c * (u - w) * r sumuuMwr += c * u * (u - w) * r sumwuMwr += c * w * (u - w) * r @@ -159,10 +160,10 @@ def estimatediff( wfake = wmax if sumw < n else wmin ubar = (sumu + ufake) / (n + 1) - usqbar = (sumusq + ufake ** 2) / (n + 1) + usqbar = (sumusq + ufake**2) / (n + 1) uwbar = (sumuw + ufake * wfake) / (n + 1) wbar = (sumw + wfake) / (n + 1) - wsqbar = (sumwsq + wfake ** 2) / (n + 1) + wsqbar = (sumwsq + wfake**2) / (n + 1) A = np.array( [[-1, -ubar, -wbar], [-ubar, -usqbar, -uwbar], [-wbar, -uwbar, -wsqbar]], @@ -176,8 +177,8 @@ def estimatediff( deltavhat = (-beta * sumuMwr - gamma * sumuuMwr - tau * sumwuMwr) / n missing = ( -beta * (ufake - wfake) - - gamma * (ufake ** 2 - ufake * wfake) - - tau * (ufake * wfake - wfake ** 2) + - gamma * (ufake**2 - ufake * wfake) + - tau * (ufake * wfake - wfake**2) ) / (n + 1) deltavmin = deltavhat + min(rmin * missing, rmax * missing) @@ -218,10 +219,10 @@ def interval(datagen, wmin, wmax, alpha=0.05, rmin=0, rmax=1, raiseonerr=False): for c, w, r in datagen(): n += c sumw += c * w - sumwsq += c * w ** 2 + sumwsq += c * w**2 sumwr += c * w * r - sumwsqr += c * w ** 2 * r - sumwsqrsq += c * w ** 2 * r ** 2 + sumwsqr += c * w**2 * r + sumwsqrsq += c * w**2 * r**2 assert n > 0 return CrMinusTwo.intervalimpl( @@ -266,7 +267,7 @@ def intervalimpl( uncgstar = 1 + 1 / n else: unca = (uncwfake + sumw) / (1 + n) - uncb = (uncwfake ** 2 + sumwsq) / (1 + n) + uncb = (uncwfake**2 + sumwsq) / (1 + n) uncgstar = (n + 1) * (unca - 1) ** 2 / (uncb - unca * unca) Delta = f.isf(q=alpha, dfn=1, dfd=n) phi = (-uncgstar - Delta) / (2 * (n + 1)) @@ -278,7 +279,7 @@ def intervalimpl( if wfake == inf: x = sign * (r + (sumwr - sumw * r) / n) y = (r * sumw - sumwr) ** 2 / (n * (1 + n)) - ( - r ** 2 * sumwsq - 2 * r * sumwsqr + sumwsqrsq + r**2 * sumwsq - 2 * r * sumwsqr + sumwsqrsq ) / (1 + n) z = phi + 1 / (2 * n) if isclose(y * z, 0, abs_tol=1e-9): @@ -314,14 +315,14 @@ def intervalimpl( barwsqr = sign * (wfake * wfake * r + sumwsqr) / (1 + n) barwsqrsq = (wfake * wfake * r * r + sumwsqrsq) / (1 + n) - if barwsq > barw ** 2: + if barwsq > barw**2: x = barwr + ( - (1 - barw) * (barwsqr - barw * barwr) / (barwsq - barw ** 2) + (1 - barw) * (barwsqr - barw * barwr) / (barwsq - barw**2) ) - y = (barwsqr - barw * barwr) ** 2 / (barwsq - barw ** 2) - ( - barwsqrsq - barwr ** 2 + y = (barwsqr - barw * barwr) ** 2 / (barwsq - barw**2) - ( + barwsqrsq - barwr**2 ) - z = phi + (1 / 2) * (1 - barw) ** 2 / (barwsq - barw ** 2) + z = phi + (1 / 2) * (1 - barw) ** 2 / (barwsq - barw**2) if isclose(y * z, 0, abs_tol=1e-9): y = 0 @@ -387,12 +388,12 @@ def intervaldiff( sumu += c * u sumw += c * w sumuw += c * u * w - sumusq += c * u ** 2 - sumwsq += c * w ** 2 + sumusq += c * u**2 + sumwsq += c * w**2 sumuMwr += c * (u - w) * r sumuuMwr += c * u * (u - w) * r sumwuMwr += c * w * (u - w) * r - sumuMwsqrsq += c * (u - w) ** 2 * r ** 2 + sumuMwsqrsq += c * (u - w) ** 2 * r**2 assert n > 0 @@ -404,13 +405,13 @@ def intervaldiff( baru = (sumu + ufake) / (n + 1) barw = (sumw + wfake) / (n + 1) - barusq = (sumusq + ufake ** 2) / (n + 1) - barwsq = (sumwsq + wfake ** 2) / (n + 1) + barusq = (sumusq + ufake**2) / (n + 1) + barwsq = (sumwsq + wfake**2) / (n + 1) baruw = (sumuw + ufake * wfake) / (n + 1) baruMwr = sign * (sumuMwr + (ufake - wfake) * rex) / (n + 1) baruuMwr = sign * (sumuuMwr + ufake * (ufake - wfake) * rex) / (n + 1) barwuMwr = sign * (sumwuMwr + wfake * (ufake - wfake) * rex) / (n + 1) - baruMwsqrsq = (sumuMwsqrsq + (ufake - wfake) ** 2 * rex ** 2) / (n + 1) + baruMwsqrsq = (sumuMwsqrsq + (ufake - wfake) ** 2 * rex**2) / (n + 1) C = np.array( [ diff --git a/python/tests/test_cats.py b/python/tests/test_cats.py index d4c13b54378..2b90a1cca28 100644 --- a/python/tests/test_cats.py +++ b/python/tests/test_cats.py @@ -1,5 +1,6 @@ import vowpalwabbit + def test_cats(): min_value = 10 max_value = 20 diff --git a/python/tests/test_pyvw.py b/python/tests/test_pyvw.py index 1e5cb3c34d6..15db320adaa 100644 --- a/python/tests/test_pyvw.py +++ b/python/tests/test_pyvw.py @@ -40,7 +40,7 @@ def test_get_tag(self): assert ex.get_tag() == "baz" def test_num_weights(self): - assert self.model.num_weights() == 2 ** BIT_SIZE + assert self.model.num_weights() == 2**BIT_SIZE def test_get_weight(self): assert self.model.get_weight(0, 0) == 0 @@ -168,7 +168,9 @@ def test_CBContinuousLabel(): model = Workspace( cats=4, min_value=185, max_value=23959, bandwidth=3000, quiet=True ) - cb_contl = vowpalwabbit.CBContinuousLabel.from_example(model.example("ca 1:10:0.5 |")) + cb_contl = vowpalwabbit.CBContinuousLabel.from_example( + model.example("ca 1:10:0.5 |") + ) assert cb_contl.costs[0].action == 1 assert cb_contl.costs[0].pdf_value == 0.5 assert cb_contl.costs[0].cost == 10.0 @@ -189,7 +191,9 @@ def test_CostSensitiveLabel(): def test_MulticlassProbabilitiesLabel(): n = 4 - model = vowpalwabbit.Workspace(loss_function="logistic", oaa=n, probabilities=True, quiet=True) + model = vowpalwabbit.Workspace( + loss_function="logistic", oaa=n, probabilities=True, quiet=True + ) ex = model.example("1 | a b c d", 2) model.learn(ex) mpl = vowpalwabbit.MulticlassProbabilitiesLabel.from_example(ex) @@ -209,7 +213,9 @@ def test_ccb_label(): ccb_slot_label = vowpalwabbit.CCBLabel.from_example( model.example("ccb slot 0:0.8:1.0 0 | slot_0") ) - ccb_slot_pred_label = vowpalwabbit.CCBLabel.from_example(model.example("ccb slot |")) + ccb_slot_pred_label = vowpalwabbit.CCBLabel.from_example( + model.example("ccb slot |") + ) assert ccb_shared_label.type == vowpalwabbit.CCBLabelType.SHARED assert len(ccb_shared_label.explicit_included_actions) == 0 assert ccb_shared_label.outcome is None @@ -263,6 +269,7 @@ def test_slates_label(): assert str(slates_slot_label) == "slates slot 1:0.8,0:0.1,2:0.1" del model + def test_multilabel_label(): model = Workspace(multilabel_oaa=5, quiet=True) multil = vowpalwabbit.MultilabelLabel.from_example(model.example("1,2,3 |")) @@ -272,6 +279,7 @@ def test_multilabel_label(): assert multil.labels[2] == 3 assert str(multil) == "1,2,3" + def test_regressor_args(): # load and parse external data file data_file = os.path.join( @@ -298,6 +306,7 @@ def test_regressor_args(): os.remove("{}.cache".format(data_file)) os.remove("tmp.model") + def test_command_line_with_space_and_escape_kwargs(): # load and parse external data file test_file_dir = Path(__file__).resolve().parent @@ -311,12 +320,20 @@ def test_command_line_with_space_and_escape_kwargs(): assert model_file.is_file() model_file.unlink() + def test_command_line_using_arg_list(): # load and parse external data file test_file_dir = Path(__file__).resolve().parent data_file = test_file_dir / "resources" / "train file.dat" - args = ["--oaa", "3", "--data", str(data_file), "--final_regressor", "test model2.vw"] + args = [ + "--oaa", + "3", + "--data", + str(data_file), + "--final_regressor", + "test model2.vw", + ] model = Workspace(arg_list=args) assert model.predict("| feature1:2.5") == 1 del model @@ -325,11 +342,13 @@ def test_command_line_using_arg_list(): assert model_file.is_file() model_file.unlink() + def test_command_line_with_double_space_in_str(): # Test regression for double space in string breaking splitting model = Workspace(arg_list="--oaa 3 -q :: ") del model + def test_keys_with_list_of_values(): # No exception in creating and executing model with a key/list pair model = Workspace(quiet=True, q=["fa", "fb"]) diff --git a/python/tests/test_vwconfig.py b/python/tests/test_vwconfig.py index 6a38c89bdfe..49814328fc9 100644 --- a/python/tests/test_vwconfig.py +++ b/python/tests/test_vwconfig.py @@ -1,6 +1,7 @@ from vowpalwabbit import pyvw import vowpalwabbit + def helper_options_to_list_strings(config): cmd_str_list = [] diff --git a/setup.py b/setup.py index 0552fb6dfe8..c05b4ace3f6 100644 --- a/setup.py +++ b/setup.py @@ -88,13 +88,19 @@ def build_cmake(self, ext): # See bug: https://bugs.python.org/issue39825 if system == "Windows" and sys.version_info.minor < 8: from distutils import sysconfig as distutils_sysconfig - required_shared_lib_suffix = distutils_sysconfig.get_config_var('EXT_SUFFIX') + + required_shared_lib_suffix = distutils_sysconfig.get_config_var( + "EXT_SUFFIX" + ) else: import sysconfig + required_shared_lib_suffix = sysconfig.get_config_var("EXT_SUFFIX") if required_shared_lib_suffix is not None: - cmake_args += ["-DVW_PYTHON_SHARED_LIB_SUFFIX={}".format(required_shared_lib_suffix)] + cmake_args += [ + "-DVW_PYTHON_SHARED_LIB_SUFFIX={}".format(required_shared_lib_suffix) + ] if self.distribution.enable_boost_cmake is None: # Add this flag as default since testing indicates its safe. diff --git a/test/cluster_test.py b/test/cluster_test.py index 0f0bac33af3..32a6d1fdddd 100644 --- a/test/cluster_test.py +++ b/test/cluster_test.py @@ -76,8 +76,10 @@ print("VW succeeded") if return_code != 0: print("VW failed:") - print("STDOUT: \n" + proc.stdout.read().decode("utf-8")) - print("STDERR: \n" + proc.stderr.read().decode("utf-8")) + if proc.stdout: + print("STDOUT: \n" + proc.stdout.read().decode("utf-8")) + if proc.stderr: + print("STDERR: \n" + proc.stderr.read().decode("utf-8")) spanning_tree_proc.kill() sys.exit(1) diff --git a/test/run_tests.py b/test/run_tests.py index 8fc29ba383c..88449efeb0b 100644 --- a/test/run_tests.py +++ b/test/run_tests.py @@ -878,7 +878,8 @@ def get_test(test_number: int, tests: List[TestData]) -> Optional[TestData]: return test return None -def interpret_test_arg(arg: str, *, num_tests:int) -> List[int]: + +def interpret_test_arg(arg: str, *, num_tests: int) -> List[int]: single_number_pattern = re.compile(r"^\d+$") range_pattern = re.compile(r"^(\d+)?\.\.(\d+)?$") if single_number_pattern.match(arg): @@ -891,7 +892,9 @@ def interpret_test_arg(arg: str, *, num_tests:int) -> List[int]: raise ValueError(f"Invalid range: {arg}") return list(range(start, end + 1)) else: - raise ValueError(f"Invalid test argument '{arg}'. Must either be a single integer 'id' or a range in the form 'start..end'") + raise ValueError( + f"Invalid test argument '{arg}'. Must either be a single integer 'id' or a range in the form 'start..end'" + ) def main(): @@ -1096,23 +1099,27 @@ def main(): # Flatten nested lists for arg.test argument and process # Ideally we would have used action="extend", but that was added in 3.8 + interpreted_test_arg: Optional[List[int]] = None if args.test is not None: - res = [] + interpreted_test_arg = [] for arg in args.test: for value in arg: - res.extend(interpret_test_arg(value, num_tests=len(tests))) - args.test = res + interpreted_test_arg.extend( + interpret_test_arg(value, num_tests=len(tests)) + ) print() # Filter the test list if the requested tests were explicitly specified tests_to_run_explicitly = None - if args.test is not None: - tests_to_run_explicitly = calculate_test_to_run_explicitly(args.test, tests) + if interpreted_test_arg is not None: + tests_to_run_explicitly = calculate_test_to_run_explicitly( + interpreted_test_arg, tests + ) print(f"Running tests: {list(tests_to_run_explicitly)}") - if len(args.test) != len(tests_to_run_explicitly): + if len(interpreted_test_arg) != len(tests_to_run_explicitly): print( - f"Note: due to test dependencies, more than just tests {args.test} must be run" + f"Note: due to test dependencies, more than just tests {interpreted_test_arg} must be run" ) tests = list(filter(lambda x: x.id in tests_to_run_explicitly, tests)) diff --git a/test/same-model-test.py b/test/same-model-test.py index d64b55c87e5..36ead355eeb 100644 --- a/test/same-model-test.py +++ b/test/same-model-test.py @@ -71,8 +71,10 @@ print("VW succeeded") if return_code != 0: print("VW failed:") - print("STDOUT: \n" + proc.stdout.read().decode("utf-8")) - print("STDERR: \n" + proc.stderr.read().decode("utf-8")) + if proc.stdout: + print("STDOUT: \n" + proc.stdout.read().decode("utf-8")) + if proc.stderr: + print("STDERR: \n" + proc.stderr.read().decode("utf-8")) spanning_tree_proc.kill() sys.exit(1) diff --git a/utl/active_interactor.py b/utl/active_interactor.py index 07f82bf6782..7463f2de248 100755 --- a/utl/active_interactor.py +++ b/utl/active_interactor.py @@ -30,12 +30,13 @@ def _get_getch_impl_unix() -> Optional[Callable[[], str]]: def _getch(): fd = sys.stdin.fileno() old_settings = termios.tcgetattr(fd) + ch = None try: tty.setraw(fd) ch = sys.stdin.read(1) finally: termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - if ord(ch) == 3: + if ch is not None and ord(ch) == 3: raise KeyboardInterrupt return ch