diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..22948f9 Binary files /dev/null and b/.DS_Store differ diff --git a/codebleu/dataflow_match.py b/codebleu/dataflow_match.py index dfe1ebe..e42910f 100644 --- a/codebleu/dataflow_match.py +++ b/codebleu/dataflow_match.py @@ -13,6 +13,7 @@ DFG_python, DFG_ruby, DFG_rust, + DFG_fortran, index_to_code_token, remove_comments_and_docstrings, tree_to_token_index, @@ -30,6 +31,7 @@ "c": DFG_csharp, # XLCoST uses C# parser for C "cpp": DFG_csharp, # XLCoST uses C# parser for C++ "rust": DFG_rust, + "fortran": DFG_fortran, } diff --git a/codebleu/keywords/fortran.txt b/codebleu/keywords/fortran.txt new file mode 100644 index 0000000..c402892 --- /dev/null +++ b/codebleu/keywords/fortran.txt @@ -0,0 +1,78 @@ +allocatable +allocate +assignment +backspace +call +case +character +close +complex +contains +cycle +deallocate +default +do +else +else if +elsewhere +end do +end function +end if +end interface +end module +end program +end select +end subroutine +end type +end where +endfile +exit +format +function +if +implicit +in +inout +integer +intent +interface +intrinsic +inquire +kind +len +logical +module +namelist +nullify +only +open +operator +optional +out +print +pointer +private +program +public +read +real +recursive +result +return +rewind +select case +stop +subroutine +target +then +type +use +where +while +write + + + + + + diff --git a/codebleu/parser/DFG.py b/codebleu/parser/DFG.py index 74e4b0b..d335c01 100644 --- a/codebleu/parser/DFG.py +++ b/codebleu/parser/DFG.py @@ -1385,3 +1385,162 @@ def DFG_rust(root_node, index_to_code, states): DFG += temp return sorted(DFG, key=lambda x: x[1]), states + +def DFG_fortran(root_node, index_to_code, states): + assignment = ["assignment_stmt"] + def_statement = ["declaration_stmt"] + increment_statement = ["arithmetic_if_stmt"] + if_statement = ["if_stmt", "else"] + for_statement = ["do_loop"] + while_statement = ["while_loop"] + do_first_statement = [] + + states = states.copy() + + if ( + len(root_node.children) == 0 or root_node.type in ["string_literal", "character_literal"] + ) and root_node.type != "comment": + idx, code = index_to_code[(root_node.start_point, root_node.end_point)] + if root_node.type == code: + return [], states + elif code in states: + return [(code, idx, "comesFrom", [code], states[code].copy())], states + else: + if root_node.type == "identifier": + states[code] = [idx] + return [(code, idx, "comesFrom", [], [])], states + + elif root_node.type in def_statement: + name = root_node.child_by_field_name("name") + value = root_node.child_by_field_name("value") + DFG = [] + if value is None: + indexs = tree_to_variable_index(name, index_to_code) + for index in indexs: + idx, code = index_to_code[index] + DFG.append((code, idx, "comesFrom", [], [])) + states[code] = [idx] + return sorted(DFG, key=lambda x: x[1]), states + else: + name_indexs = tree_to_variable_index(name, index_to_code) + value_indexs = tree_to_variable_index(value, index_to_code) + temp, states = DFG_fortran(value, index_to_code, states) + DFG += temp + for index1 in name_indexs: + idx1, code1 = index_to_code[index1] + for index2 in value_indexs: + idx2, code2 = index_to_code[index2] + DFG.append((code1, idx1, "comesFrom", [code2], [idx2])) + states[code1] = [idx1] + return sorted(DFG, key=lambda x: x[1]), states + + elif root_node.type in assignment: + left_nodes = root_node.child_by_field_name("left") + right_nodes = root_node.child_by_field_name("right") + DFG = [] + temp, states = DFG_fortran(right_nodes, index_to_code, states) + DFG += temp + name_indexs = tree_to_variable_index(left_nodes, index_to_code) + value_indexs = tree_to_variable_index(right_nodes, index_to_code) + for index1 in name_indexs: + idx1, code1 = index_to_code[index1] + for index2 in value_indexs: + idx2, code2 = index_to_code[index2] + DFG.append((code1, idx1, "computedFrom", [code2], [idx2])) + states[code1] = [idx1] + return sorted(DFG, key=lambda x: x[1]), states + + elif root_node.type in increment_statement: + DFG = [] + indexs = tree_to_variable_index(root_node, index_to_code) + for index1 in indexs: + idx1, code1 = index_to_code[index1] + for index2 in indexs: + idx2, code2 = index_to_code[index2] + DFG.append((code1, idx1, "computedFrom", [code2], [idx2])) + states[code1] = [idx1] + return sorted(DFG, key=lambda x: x[1]), states + + elif root_node.type in if_statement: + DFG = [] + current_states = states.copy() + others_states = [] + flag = False + tag = False + if "else" in root_node.type: + tag = True + for child in root_node.children: + if "else" in child.type: + tag = True + if child.type not in if_statement and flag is False: + temp, current_states = DFG_fortran(child, index_to_code, current_states) + DFG += temp + else: + flag = True + temp, new_states = DFG_fortran(child, index_to_code, states) + DFG += temp + others_states.append(new_states) + others_states.append(current_states) + if tag is False: + others_states.append(states) + new_states = {} + for dic in others_states: + for key in dic: + if key not in new_states: + new_states[key] = dic[key].copy() + else: + new_states[key] += dic[key] + for key in new_states: + new_states[key] = sorted(list(set(new_states[key]))) + return sorted(DFG, key=lambda x: x[1]), new_states + + elif root_node.type in for_statement: + DFG = [] + for child in root_node.children: + temp, states = DFG_fortran(child, index_to_code, states) + DFG += temp + flag = False + for child in root_node.children: + if flag: + temp, states = DFG_fortran(child, index_to_code, states) + DFG += temp + elif child.type == "local_variable_declaration": + flag = True + dic = {} + for x in DFG: + if (x[0], x[1], x[2]) not in dic: + dic[(x[0], x[1], x[2])] = [x[3], x[4]] + else: + dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) + dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) + DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] + return sorted(DFG, key=lambda x: x[1]), states + + elif root_node.type in while_statement: + DFG = [] + for i in range(2): + for child in root_node.children: + temp, states = DFG_fortran(child, index_to_code, states) + DFG += temp + dic = {} + for x in DFG: + if (x[0], x[1], x[2]) not in dic: + dic[(x[0], x[1], x[2])] = [x[3], x[4]] + else: + dic[(x[0], x[1], x[2])][0] = list(set(dic[(x[0], x[1], x[2])][0] + x[3])) + dic[(x[0], x[1], x[2])][1] = sorted(list(set(dic[(x[0], x[1], x[2])][1] + x[4]))) + DFG = [(x[0], x[1], x[2], y[0], y[1]) for x, y in sorted(dic.items(), key=lambda t: t[0][1])] + return sorted(DFG, key=lambda x: x[1]), states + + else: + DFG = [] + for child in root_node.children: + if child.type in do_first_statement: + temp, states = DFG_fortran(child, index_to_code, states) + DFG += temp + for child in root_node.children: + if child.type not in do_first_statement: + temp, states = DFG_fortran(child, index_to_code, states) + DFG += temp + + return sorted(DFG, key=lambda x: x[1]), states diff --git a/codebleu/parser/__init__.py b/codebleu/parser/__init__.py index 2b4f1a5..4fe514c 100644 --- a/codebleu/parser/__init__.py +++ b/codebleu/parser/__init__.py @@ -10,6 +10,7 @@ DFG_python, DFG_ruby, DFG_rust, + DFG_fortran, ) from .utils import ( index_to_code_token, diff --git a/codebleu/utils.py b/codebleu/utils.py index 468df81..5dc84a1 100644 --- a/codebleu/utils.py +++ b/codebleu/utils.py @@ -20,6 +20,7 @@ "go", "ruby", "rust", + "fortran" ] # keywords available @@ -173,6 +174,10 @@ def get_tree_sitter_language(lang: str) -> Language: import tree_sitter_rust return Language(tree_sitter_rust.language()) + elif lang == "fortran": + import tree_sitter_fortran + + return Language(tree_sitter_fortran.language()) else: assert False, "Not reachable" except ImportError: diff --git a/tests/test_codebleu.py b/tests/test_codebleu.py index 65936a9..b52615a 100644 --- a/tests/test_codebleu.py +++ b/tests/test_codebleu.py @@ -42,6 +42,7 @@ def test_exact_match_works_for_all_langs(lang: str) -> None: ("go", ["func foo ( x ) { return x }"], ["func bar ( y ) {\n return y\n}"]), ("ruby", ["def foo ( x ) :\n return x"], ["def bar ( y ) :\n return y"]), ("rust", ["fn foo ( x ) -> i32 { x }"], ["fn bar ( y ) -> i32 { y }"]), + ("fortran", ["function foo ( x ) result ( x )\n end function foo"], ["function bar ( y ) result ( y )\n end function bar"]), ], ) def test_simple_cases_work_for_all_langs(lang: str, predictions: List[Any], references: List[Any]) -> None: diff --git a/use.py b/use.py new file mode 100644 index 0000000..ffbff4f --- /dev/null +++ b/use.py @@ -0,0 +1,31 @@ +# from codebleu import calc_codebleu + +# prediction = "public function add(a,b) { return (a+b) }" +# reference = "public function sum(a,b) { return (a+b) }" + +# result = calc_codebleu([reference], [prediction], lang="java", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None) +# print(result) + +# from codebleu import calc_codebleu + +# prediction = "def add ( a , b ) :\n return a + b" +# reference = "def sum ( first , second ) :\n return second + first" + +# result = calc_codebleu([reference], [prediction], lang="python", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None) +# print(result) + +from codebleu import calc_codebleu + +prediction = """function foo(x) + real :: foo + foo = x +end function foo +""" +reference = """function foo(x) + real :: foo + foo = x +end function foo +""" + +result = calc_codebleu([reference], [prediction], lang="fortran", weights=(0.25, 0.25, 0.25, 0.25), tokenizer=None) +print(result)