diff --git a/probing/data_former.py b/probing/data_former.py index 30a447f..6f7aae4 100644 --- a/probing/data_former.py +++ b/probing/data_former.py @@ -24,7 +24,7 @@ def __init__( self.shuffle = shuffle self.data_path = get_probe_task_path(probe_task, data_path) - self.samples, self.unique_labels = self.form_data(sep=sep) + self.samples, self.unique_labels, self.num_words = self.form_data(sep=sep) def __len__(self): return len(self.samples) @@ -48,8 +48,9 @@ def form_data( samples_dict = defaultdict(list) unique_labels = set() dataset = pd.read_csv(self.data_path, sep=sep, header=None, dtype=str) - for _, (stage, label, text) in dataset.iterrows(): - samples_dict[stage].append((text, label)) + for _, (stage, label, word_indices, text) in dataset.iterrows(): + num_words = len(word_indices.split(",")) + samples_dict[stage].append((text, label, word_indices)) unique_labels.add(label) if self.shuffle: @@ -58,7 +59,7 @@ def form_data( } else: samples_dict = {k: np.array(v) for k, v in samples_dict.items()} - return samples_dict, unique_labels + return samples_dict, unique_labels, num_words class EncodedVectorFormer(Dataset): diff --git a/probing/ud_filter/filtering_probing.py b/probing/ud_filter/filtering_probing.py index d114353..1d642da 100644 --- a/probing/ud_filter/filtering_probing.py +++ b/probing/ud_filter/filtering_probing.py @@ -41,7 +41,7 @@ def __init__(self, shuffle: bool = True): self.classes: Dict[ str, Tuple[Dict[str, Dict[str, Any]], Dict[Tuple[str, str], Dict[str, Any]]] ] = {} - self.probing_dict: Dict[str, List[str]] = {} + self.probing_dict: Dict[str, List[Tuple[str, List[int]]]] = {} self.parts_data: Dict[str, List[List[str]]] = {} def upload_files( @@ -74,11 +74,13 @@ def upload_files( self.language = extract_lang_from_udfile_path(self.paths[0], language=language) self.sentences = parse(conllu_data) - def _filter_conllu(self, class_label: str) -> Tuple[List[str], List[str]]: + def _filter_conllu( + self, class_label: str + ) -> Tuple[List[Tuple[str, List[int]]], List[Tuple[str, List[int]]]]: """Filters sentences by class's query and saves the result to the relevant fields""" - matching = [] - not_matching = [] + matching: List[Tuple[str, List[int]]] = [] + not_matching: List[Tuple[str, List[int]]] = [] node_pattern = self.classes[class_label][0] constraints = self.classes[class_label][1] @@ -91,10 +93,11 @@ def _filter_conllu(self, class_label: str) -> Tuple[List[str], List[str]]: for sentence in self.sentences: sf = SentenceFilter(sentence) tokenized_sentence = " ".join(wordpunct_tokenize(sentence.metadata["text"])) - if sf.filter_sentence(node_pattern, constraints): - matching.append(tokenized_sentence) + filter_result = sf.filter_sentence(node_pattern, constraints) + if filter_result is not None: + matching.append((tokenized_sentence, filter_result)) else: - not_matching.append(tokenized_sentence) + not_matching.append((tokenized_sentence, [])) return matching, not_matching def filter_and_convert( @@ -128,7 +131,7 @@ def filter_and_convert( matching, not_matching = self._filter_conllu(label) self.probing_dict[label] = matching if len(self.classes) == 1: - self.probing_dict["not_" + list(self.classes.keys())[0]] = not_matching + self.probing_dict["not_" + label] = not_matching self.probing_dict = delete_duplicates(self.probing_dict) self.parts_data = subsamples_split( diff --git a/probing/ud_filter/sentence_filter.py b/probing/ud_filter/sentence_filter.py index b43f253..0bbd87f 100644 --- a/probing/ud_filter/sentence_filter.py +++ b/probing/ud_filter/sentence_filter.py @@ -1,248 +1,248 @@ -import re -from collections import defaultdict -from itertools import product -from typing import Any, DefaultDict, Dict, List, Set, Tuple - -import networkx as nx -from conllu import models - -from probing.ud_filter.utils import check_query - - -class SentenceFilter: - """ - Checks if a sentence matches patterns - - Attributes: - sentence: sentence represented as TokenList - node_pattern: a dictionary with a node_label as a key and a dictionary of feature restrictions as a value. - sample_node_pattern = { node_label: { - field_or_category: regex_pattern, - 'exclude': [exclude categories]} } - constraints: a dictionary with a node_pair as a key and a dictionary with different constraints on node_pair - sample_constraints = { ('W1', 'W2'): { - 'deprels': regexp_pattern (W1 as head, W2 as dependent), - 'fconstraint': { - 'disjoint': [grammar_category], - 'intersec': [grammar_category]}, - 'lindist': (start, end) (relatively W1)} } - sent_deprels: a dictionary of all relations and pairs with these relations {relation: [(head, dependent)]} - nodes_tokens: a dictionary with all tokens that can be a node in the pattern {node: [token id]}, - if filter_sentence == True, saves only one instance in nodes_tokens - possible_token_pairs: a dictionary with all nodes pairs as a key and a list of possible token, - pairs as a value, if filter_sentence == True, saves only one instance in possible_token_pairs - """ - - def __init__(self, sentence: models.TokenList): - self.sentence = sentence - self.node_pattern: Dict[str, Dict[str, str]] = {} - self.constraints: Dict[Tuple[str, str], Dict[Any, Any]] = {} - self.sent_deprels: DefaultDict[str, List[Tuple[int, int]]] = defaultdict(list) - self.nodes_tokens: Dict[str, List[int]] = {} - self.possible_token_pairs: Dict[Tuple[str, str], Set[Tuple[int, int]]] = {} - - def token_match_node( - self, token: models.Token, node_pattern: Dict[str, str] - ) -> bool: - """Checks if a token matches the node_pattern""" - - for feat in node_pattern: - if feat in token.keys(): - if not re.fullmatch(node_pattern[feat], token[feat], re.I): - return False - elif token["feats"]: - if feat in token["feats"]: - if not re.fullmatch(node_pattern[feat], token["feats"][feat], re.I): - return False - elif feat == "exclude": - for ef in node_pattern[feat]: - if ef in token["feats"]: - return False - else: - return False - else: - return False - return True - - def all_deprels(self) -> DefaultDict[str, List[Tuple[int, int]]]: - """Returns a dictionary {relation: [(head, dependent)]} of all relations in the sentence""" - - deprels: DefaultDict[str, list] = defaultdict(list) - for token in self.sentence: - if isinstance(token["head"], int) and isinstance(token["id"], int): - deprels[token["deprel"]].append((token["head"] - 1, token["id"] - 1)) - return deprels - - def search_suitable_tokens(self, node: str) -> None: - """Selects from a token_list those tokens that match given node_pattern - and saves it in self.nodes_tokens[node]""" - - for token in self.sentence: - if self.token_match_node(token, self.node_pattern[node]) and isinstance( - token["id"], int - ): - self.nodes_tokens[node].append(token["id"] - 1) - - def find_all_nodes(self) -> bool: - """Checks if every node in the pattern has at least one matching token""" - - for node in self.node_pattern: - self.search_suitable_tokens(node) - if not self.nodes_tokens[node]: - return False - return True - - def pattern_relations(self, rel_pattern: str) -> List[str]: - """Returns all relation names in the sentence that match the given pattern""" - - rels = [] - for rel in self.sent_deprels: - if re.fullmatch( - rel_pattern, rel, re.I - ): # changed from re.serach to re.fullmatch - rels.append(rel) - return rels - - def pairs_with_rel( - self, node_pair: Tuple[str, str], rel_name: str - ) -> Set[Tuple[int, int]]: - """Returns those pairs of tokens that: - 1) are related by a rel_name - 2) are among possible_token_pairs for a node_pair""" - - if rel_name not in self.sent_deprels: - return set() - else: - return set(self.sent_deprels[rel_name]).intersection( - self.possible_token_pairs[node_pair] - ) - - def pairs_matching_relpattern( - self, node_pair: Tuple[str, str] - ) -> Set[Tuple[int, int]]: - """Returns a set of token pairs, whose relations match the pattern""" - - all_suitable_rels: Set[Tuple[int, int]] = set() - for rel in self.pattern_relations(self.constraints[node_pair]["deprels"]): - all_suitable_rels = all_suitable_rels | self.pairs_with_rel(node_pair, rel) - return all_suitable_rels - - def linear_distance(self, node_pair: Tuple[str, str]) -> Set[Tuple[int, int]]: - """Returns a set of token pairs with a given linear distance between tokens""" - - suitable_pairs: Set[Tuple[int, int]] = set() - lindist = self.constraints[node_pair]["lindist"] - for pair in self.possible_token_pairs[node_pair]: - dist = pair[1] - pair[0] - if lindist[0] <= dist <= lindist[1]: - suitable_pairs.add(pair) - return suitable_pairs - - def pair_match_fconstraint( - self, token_pair: Tuple[int, int], fconstraint: Dict[Any, Any] - ) -> bool: - """Checks if a token pair matches all the feature constraints""" - t1_feats = self.sentence[token_pair[0]]["feats"] - t2_feats = self.sentence[token_pair[1]]["feats"] - if t1_feats and t2_feats: - for ctype in fconstraint: - for f in fconstraint[ctype]: - if (f in t1_feats) and (f in t2_feats): - if ctype == "intersec": - if t1_feats[f] != t2_feats[f]: - return False - elif ctype == "disjoint": - if t1_feats[f] == t2_feats[f]: - return False - else: - raise ValueError("Wrong feature constraint type") - else: - return False - return True - else: - return False - - def feature_constraint(self, node_pair: Tuple[str, str]) -> Set[Tuple[int, int]]: - """Returns all pairs that match constraints on features""" - - suitable_pairs: Set[Tuple[int, int]] = set() - fconstraint = self.constraints[node_pair]["fconstraint"] - for pair in self.possible_token_pairs[node_pair]: - if self.pair_match_fconstraint(pair, fconstraint): - suitable_pairs.add(pair) - return suitable_pairs - - def find_isomorphism(self) -> bool: - """Checks if there is at least one graph with possible_token_pairs - that is isomorphic to a constraint pairs graph""" - - nodes_graph = nx.Graph() - nodes_graph.add_edges_from(list(self.possible_token_pairs.keys())) - possible_edges = list(product(*self.possible_token_pairs.values())) - for edges in possible_edges: - tokens_graph = nx.Graph() - tokens_graph.add_edges_from(edges) - if len(tokens_graph.nodes) != len(nodes_graph.nodes): - continue - if nx.is_isomorphic(tokens_graph, nodes_graph): - self.possible_token_pairs = { - k: {edges[i]} for i, k in enumerate(self.possible_token_pairs) - } - self.nodes_tokens = { - np[i]: [list(self.possible_token_pairs[np])[0][i]] - for np in self.possible_token_pairs - for i in range(2) - } - return True - return False - - def match_constraints(self) -> bool: - """Checks if there is at least one token pair that matches all constraints.""" - - for np in self.constraints: - self.possible_token_pairs[np] = set( - product(self.nodes_tokens[np[0]], self.nodes_tokens[np[1]]) - ) - for constraint in self.constraints[np]: - if constraint == "deprels": - self.possible_token_pairs[np] = self.pairs_matching_relpattern(np) - elif constraint == "lindist": - self.possible_token_pairs[np] = self.linear_distance(np) - elif constraint == "fconstraint": - self.possible_token_pairs[np] = self.feature_constraint(np) - # else: - # raise ValueError("Wrong constraint type") - # (not possible, this is controlled by ud_filter.utils.check_constraints) - if not self.possible_token_pairs[np]: - return False - else: - self.nodes_tokens[np[0]] = list( - set([p[0] for p in self.possible_token_pairs[np]]) - ) - self.nodes_tokens[np[1]] = list( - set([p[1] for p in self.possible_token_pairs[np]]) - ) - if not self.find_isomorphism(): - return False - return True - - def filter_sentence( - self, - node_pattern: Dict[str, Dict[str, str]], - constraints: Dict[Tuple[str, str], dict], - ) -> bool: - """Check if a sentence contains at least one instance of a node_pattern that matches - all the given and isomophism constraints""" - check_query(node_pattern, constraints) - self.node_pattern = node_pattern - self.constraints = constraints - self.nodes_tokens = {node: [] for node in self.node_pattern} - self.possible_token_pairs = {pair: set() for pair in self.constraints} - if not self.find_all_nodes(): - return False - else: - self.sent_deprels = self.all_deprels() - if self.match_constraints(): - return True - else: - return False +import re +from collections import defaultdict +from itertools import product +from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple + +import networkx as nx +from conllu import models + +from probing.ud_filter.utils import check_query + + +class SentenceFilter: + """ + Checks if a sentence matches patterns + + Attributes: + sentence: sentence represented as TokenList + node_pattern: a dictionary with a node_label as a key and a dictionary of feature restrictions as a value. + sample_node_pattern = { node_label: { + field_or_category: regex_pattern, + 'exclude': [exclude categories]} } + constraints: a dictionary with a node_pair as a key and a dictionary with different constraints on node_pair + sample_constraints = { ('W1', 'W2'): { + 'deprels': regexp_pattern (W1 as head, W2 as dependent), + 'fconstraint': { + 'disjoint': [grammar_category], + 'intersec': [grammar_category]}, + 'lindist': (start, end) (relatively W1)} } + sent_deprels: a dictionary of all relations and pairs with these relations {relation: [(head, dependent)]} + nodes_tokens: a dictionary with all tokens that can be a node in the pattern {node: [token id]}, + if filter_sentence == True, saves only one instance in nodes_tokens + possible_token_pairs: a dictionary with all nodes pairs as a key and a list of possible token, + pairs as a value, if filter_sentence == True, saves only one instance in possible_token_pairs + """ + + def __init__(self, sentence: models.TokenList): + self.sentence = sentence + self.node_pattern: Dict[str, Dict[str, str]] = {} + self.constraints: Dict[Tuple[str, str], Dict[Any, Any]] = {} + self.sent_deprels: DefaultDict[str, List[Tuple[int, int]]] = defaultdict(list) + self.nodes_tokens: Dict[str, List[int]] = {} + self.possible_token_pairs: Dict[Tuple[str, str], Set[Tuple[int, int]]] = {} + + def token_match_node( + self, token: models.Token, node_pattern: Dict[str, str] + ) -> bool: + """Checks if a token matches the node_pattern""" + + for feat in node_pattern: + if feat in token.keys(): + if not re.fullmatch(node_pattern[feat], token[feat], re.I): + return False + elif token["feats"]: + if feat in token["feats"]: + if not re.fullmatch(node_pattern[feat], token["feats"][feat], re.I): + return False + elif feat == "exclude": + for ef in node_pattern[feat]: + if ef in token["feats"]: + return False + else: + return False + else: + return False + return True + + def all_deprels(self) -> DefaultDict[str, List[Tuple[int, int]]]: + """Returns a dictionary {relation: [(head, dependent)]} of all relations in the sentence""" + + deprels: DefaultDict[str, list] = defaultdict(list) + for token in self.sentence: + if isinstance(token["head"], int) and isinstance(token["id"], int): + deprels[token["deprel"]].append((token["head"] - 1, token["id"] - 1)) + return deprels + + def search_suitable_tokens(self, node: str) -> None: + """Selects from a token_list those tokens that match given node_pattern + and saves it in self.nodes_tokens[node]""" + + for token in self.sentence: + if self.token_match_node(token, self.node_pattern[node]) and isinstance( + token["id"], int + ): + self.nodes_tokens[node].append(token["id"] - 1) + + def find_all_nodes(self) -> bool: + """Checks if every node in the pattern has at least one matching token""" + + for node in self.node_pattern: + self.search_suitable_tokens(node) + if not self.nodes_tokens[node]: + return False + return True + + def pattern_relations(self, rel_pattern: str) -> List[str]: + """Returns all relation names in the sentence that match the given pattern""" + + rels = [] + for rel in self.sent_deprels: + if re.fullmatch( + rel_pattern, rel, re.I + ): # changed from re.serach to re.fullmatch + rels.append(rel) + return rels + + def pairs_with_rel( + self, node_pair: Tuple[str, str], rel_name: str + ) -> Set[Tuple[int, int]]: + """Returns those pairs of tokens that: + 1) are related by a rel_name + 2) are among possible_token_pairs for a node_pair""" + + if rel_name not in self.sent_deprels: + return set() + else: + return set(self.sent_deprels[rel_name]).intersection( + self.possible_token_pairs[node_pair] + ) + + def pairs_matching_relpattern( + self, node_pair: Tuple[str, str] + ) -> Set[Tuple[int, int]]: + """Returns a set of token pairs, whose relations match the pattern""" + + all_suitable_rels: Set[Tuple[int, int]] = set() + for rel in self.pattern_relations(self.constraints[node_pair]["deprels"]): + all_suitable_rels = all_suitable_rels | self.pairs_with_rel(node_pair, rel) + return all_suitable_rels + + def linear_distance(self, node_pair: Tuple[str, str]) -> Set[Tuple[int, int]]: + """Returns a set of token pairs with a given linear distance between tokens""" + + suitable_pairs: Set[Tuple[int, int]] = set() + lindist = self.constraints[node_pair]["lindist"] + for pair in self.possible_token_pairs[node_pair]: + dist = pair[1] - pair[0] + if lindist[0] <= dist <= lindist[1]: + suitable_pairs.add(pair) + return suitable_pairs + + def pair_match_fconstraint( + self, token_pair: Tuple[int, int], fconstraint: Dict[Any, Any] + ) -> bool: + """Checks if a token pair matches all the feature constraints""" + t1_feats = self.sentence[token_pair[0]]["feats"] + t2_feats = self.sentence[token_pair[1]]["feats"] + if t1_feats and t2_feats: + for ctype in fconstraint: + for f in fconstraint[ctype]: + if (f in t1_feats) and (f in t2_feats): + if ctype == "intersec": + if t1_feats[f] != t2_feats[f]: + return False + elif ctype == "disjoint": + if t1_feats[f] == t2_feats[f]: + return False + else: + raise ValueError("Wrong feature constraint type") + else: + return False + return True + else: + return False + + def feature_constraint(self, node_pair: Tuple[str, str]) -> Set[Tuple[int, int]]: + """Returns all pairs that match constraints on features""" + + suitable_pairs: Set[Tuple[int, int]] = set() + fconstraint = self.constraints[node_pair]["fconstraint"] + for pair in self.possible_token_pairs[node_pair]: + if self.pair_match_fconstraint(pair, fconstraint): + suitable_pairs.add(pair) + return suitable_pairs + + def find_isomorphism(self) -> bool: + """Checks if there is at least one graph with possible_token_pairs + that is isomorphic to a constraint pairs graph""" + + nodes_graph = nx.Graph() + nodes_graph.add_edges_from(list(self.possible_token_pairs.keys())) + possible_edges = list(product(*self.possible_token_pairs.values())) + for edges in possible_edges: + tokens_graph = nx.Graph() + tokens_graph.add_edges_from(edges) + if len(tokens_graph.nodes) != len(nodes_graph.nodes): + continue + if nx.is_isomorphic(tokens_graph, nodes_graph): + self.possible_token_pairs = { + k: {edges[i]} for i, k in enumerate(self.possible_token_pairs) + } + self.nodes_tokens = { + np[i]: [list(self.possible_token_pairs[np])[0][i]] + for np in self.possible_token_pairs + for i in range(2) + } + return True + return False + + def match_constraints(self) -> bool: + """Checks if there is at least one token pair that matches all constraints.""" + + for np in self.constraints: + self.possible_token_pairs[np] = set( + product(self.nodes_tokens[np[0]], self.nodes_tokens[np[1]]) + ) + for constraint in self.constraints[np]: + if constraint == "deprels": + self.possible_token_pairs[np] = self.pairs_matching_relpattern(np) + elif constraint == "lindist": + self.possible_token_pairs[np] = self.linear_distance(np) + elif constraint == "fconstraint": + self.possible_token_pairs[np] = self.feature_constraint(np) + # else: + # raise ValueError("Wrong constraint type") + # (not possible, this is controlled by ud_filter.utils.check_constraints) + if not self.possible_token_pairs[np]: + return False + else: + self.nodes_tokens[np[0]] = list( + set([p[0] for p in self.possible_token_pairs[np]]) + ) + self.nodes_tokens[np[1]] = list( + set([p[1] for p in self.possible_token_pairs[np]]) + ) + if not self.find_isomorphism(): + return False + return True + + def filter_sentence( + self, + node_pattern: Dict[str, Dict[str, str]], + constraints: Dict[Tuple[str, str], dict], + ) -> Optional[List[int]]: + """Check if a sentence contains at least one instance of a node_pattern that matches + all the given and isomophism constraints""" + check_query(node_pattern, constraints) + self.node_pattern = node_pattern + self.constraints = constraints + self.nodes_tokens = {node: [] for node in self.node_pattern} + self.possible_token_pairs = {pair: set() for pair in self.constraints} + if not self.find_all_nodes(): + return None + else: + self.sent_deprels = self.all_deprels() + if self.match_constraints(): + return [val[0] for val in self.nodes_tokens.values()] + else: + return None diff --git a/probing/ud_filter/utils.py b/probing/ud_filter/utils.py index c8d6361..8586f16 100644 --- a/probing/ud_filter/utils.py +++ b/probing/ud_filter/utils.py @@ -1,287 +1,295 @@ -import csv -import os -from collections import Counter -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -from sklearn.model_selection import train_test_split - - -def filter_labels_after_split(labels: List[str]) -> List[str]: - """Skipping those classes which have only 1 sentence???""" - - labels_repeat_dict = Counter(labels) - n_repeat = 1 # threshold to overcome further splitting problem - return [label for label, count in labels_repeat_dict.items() if count > n_repeat] - - -def subsamples_split( - probing_dict: Dict[str, List[str]], - partition: List[float], - random_seed: int, - shuffle: bool = True, - split: List[str] = ["tr", "va", "te"], -) -> Dict[str, List[List[str]]]: - """ - Splits data into three sets: train, validation, and test - in the given relation - Args: - probing_dict: {class_label: [sentences]} - partition: a relation that sentences should be split in. ex: [0.8, 0.1, 0.1] - random_seed: random seed for splitting - shuffle: if sentences should be randomly shuffled - split: parts that data should be split to - Returns: - parts: {part: [[sentences], [labels]]} - """ - num_classes = len(probing_dict.keys()) - probing_data = [] - for class_name, sentences in probing_dict.items(): - if len(sentences) > num_classes: - for s in sentences: - probing_data.append((s, class_name)) - else: - print( - f"Class {class_name} has less sentences ({len(sentences)}) " - f"than the number of classes ({num_classes}), so it is excluded." - ) - if not probing_data: - raise Exception("All classes have less sentences than the number of classes") - parts = {} - data, labels = map(np.array, zip(*probing_data)) - X_train, X_test, y_train, y_test = train_test_split( - data, - labels, - stratify=labels, - train_size=partition[0], - shuffle=shuffle, - random_state=random_seed, - ) - - if len(partition) == 2: - parts = {split[0]: [X_train, y_train], split[1]: [X_test, y_test]} - else: - filtered_labels = filter_labels_after_split(y_test) - if len(filtered_labels) >= 2: - X_train = X_train[np.isin(y_train, filtered_labels)] - y_train = y_train[np.isin(y_train, filtered_labels)] - X_test = X_test[np.isin(y_test, filtered_labels)] - y_test = y_test[np.isin(y_test, filtered_labels)] - - val_size = partition[1] / (1 - partition[0]) - if y_test.size != 0: - X_val, X_test, y_val, y_test = train_test_split( - X_test, - y_test, - stratify=y_test, - train_size=val_size, - shuffle=shuffle, - random_state=random_seed, - ) - parts = { - split[0]: [X_train, y_train], - split[1]: [X_test, y_test], - split[2]: [X_val, y_val], - } - else: - raise Exception( - f"There is not enough sentences for {partition} partition." - ) # TODO - return parts - - -def read(path: os.PathLike) -> str: - """Reads CoNLL-U file""" - with open(path, encoding="utf-8") as f: - conllu_file = f.read() - return conllu_file - - -def writer( - partition_sets: Dict[str, List[List[str]]], - task_name: str, - language: str, - save_path_dir: os.PathLike, -) -> Path: - """ - Writes to a csv file - Args: - - partition_sets: {part: [[sentences], [labels]]} - task_name: name for the probing task (will be used in result file name) - language: language title - save_path_dir: path to the directory where to save - - """ - result_path = Path(Path(save_path_dir).resolve(), f"{language}_{task_name}.csv") - with open(result_path, "w", encoding="utf-8") as newf: - my_writer = csv.writer(newf, delimiter="\t", lineterminator="\n") - for part in partition_sets: - for sentence, value in zip(*partition_sets[part]): - my_writer.writerow([part, value, sentence]) - return result_path - - -def extract_lang_from_udfile_path( - ud_file_path: os.PathLike, language: Optional[str] -) -> str: - """Extracts language from conllu file name""" - - if not language: - return Path(ud_file_path).stem.split("-")[0] - return language - - -def determine_ud_savepath( - path_from_files: os.PathLike, save_path_dir: Optional[os.PathLike] -) -> Path: - """Creates a path to save the result file (the same directory where conllu paths are stored""" - - final_path = None - if not save_path_dir: - final_path = path_from_files - else: - final_path = save_path_dir - os.makedirs(final_path, exist_ok=True) - return Path(final_path) - - -def delete_duplicates(probing_dict: Dict[str, List[str]]) -> Dict[str, List[str]]: - """Deletes sentences with more than one different classes of node_pattern found""" - - all_sent = [s for cl_sent in probing_dict.values() for s in cl_sent] - duplicates = [item for item, count in Counter(all_sent).items() if count > 1] - new_probing_dict = {} - for cl in probing_dict: - new_probing_dict[cl] = [s for s in probing_dict[cl] if s not in duplicates] - return new_probing_dict - - -def check_query( - node_pattern: Dict[str, Dict[str, str]], - constraints: Dict[Tuple[str, str], Dict[Any, Any]], -) -> bool: - """Checks that a query fits the syntax""" - - check_node_pattern(node_pattern) - check_constraints(constraints) - constr_nodes = set([n for p in constraints for n in p]) - nodes = set(node_pattern.keys()) - if not constr_nodes <= nodes: - raise ValueError( - f"Not all nodes from the constraints are defined in the node_pattern" - ) - return True - - -def check_node_pattern(node_pattern: Dict[str, Dict[str, str]]) -> bool: - """Checks that node_pattern uses only UD categories and given in a right format""" - - NODES_FIELDS = {"form", "lemma", "upos", "xpos", "exclude"} - AVAILABLE_CATEGORIES = { - "PronType", - "Gender", - "VerbForm", - "NumType", - "Animacy", - "Mood", - "Poss", - "NounClass", - "Tense", - "Reflex", - "Number", - "Aspect", - "Foreign", - "Case", - "Voice", - "Abbr", - "Definite", - "Evident", - "Typo", - "Degree", - "Polarity", - "Person", - "Polite", - "Clusivity", - } - - for n in node_pattern: - npattern_fields = set(node_pattern[n].keys()) - if not npattern_fields <= (NODES_FIELDS | AVAILABLE_CATEGORIES): - raise KeyError( - f"Node_pattern can only include keys from this set: {NODES_FIELDS} or from the list of available " - f"grammar categories from here: https://universaldependencies.org/u/feat/index.html" - ) - - exclude_cat = node_pattern[n].get("exclude") - if exclude_cat: - if not isinstance(exclude_cat, list): - raise TypeError("Exclude features should be given in a list") - if not set(exclude_cat) <= AVAILABLE_CATEGORIES: - raise ValueError( - f"Wrong category name: {set(exclude_cat) - AVAILABLE_CATEGORIES}. Please use the same " - f"names as in the UD: https://universaldependencies.org/u/feat/index.html" - ) - return True - - -def check_constraints(constraints: Dict[Tuple[str, str], Dict[Any, Any]]) -> bool: - """Checks that constrains use only UD categories""" - - AVAILABLE_CATEGORIES = { - "PronType", - "Gender", - "VerbForm", - "NumType", - "Animacy", - "Mood", - "Poss", - "NounClass", - "Tense", - "Reflex", - "Number", - "Aspect", - "Foreign", - "Case", - "Voice", - "Abbr", - "Definite", - "Evident", - "Typo", - "Degree", - "Polarity", - "Person", - "Polite", - "Clusivity", - } - CONSTRAINT_FIELDS = {"deprels", "fconstraint", "lindist"} - FCONSTRAINT_FIELDS = {"disjoint", "intersec"} - - for np in constraints: - constr_types = set(constraints[np].keys()) - if not constr_types <= CONSTRAINT_FIELDS: - raise KeyError( - f"Wrong constraint type: {constr_types - CONSTRAINT_FIELDS}. Only {CONSTRAINT_FIELDS} can be used as " - f"keys" - ) - - fconstr = constraints[np].get("fconstraint") - if fconstr: - fconst_types = set(fconstr.keys()) - if not fconst_types <= FCONSTRAINT_FIELDS: - raise KeyError( - f"Wrong feature constraint type {fconst_types - FCONSTRAINT_FIELDS}. It can be only: {FCONSTRAINT_FIELDS}" - ) - - for fctype in fconstr: - if not isinstance(fconstr[fctype], list): - raise TypeError( - f"{fctype} features should be a list of grammar categories not a {type(fconstr[fctype])}" - ) - if not set(fconstr[fctype]) <= AVAILABLE_CATEGORIES: - raise ValueError( - f"Wrong grammar category names: {set(fconstr[fctype]) - AVAILABLE_CATEGORIES}. Please use the " - f"same names as in the UD: https://universaldependencies.org/u/feat/index.html" - ) - return True +import csv +import os +from collections import Counter +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from sklearn.model_selection import train_test_split + + +def filter_labels_after_split(labels: List[str]) -> List[str]: + """Skipping those classes which have only 1 sentence???""" + + labels_repeat_dict = Counter(labels) + n_repeat = 1 # threshold to overcome further splitting problem + return [label for label, count in labels_repeat_dict.items() if count > n_repeat] + + +def subsamples_split( + probing_dict: Dict[str, List[Tuple[str, List[int]]]], + partition: List[float], + random_seed: int, + shuffle: bool = True, + split: List[str] = ["tr", "va", "te"], +) -> Dict[str, List[List[str]]]: + """ + Splits data into three sets: train, validation, and test + in the given relation + Args: + probing_dict: {class_label: [sentences]} + partition: a relation that sentences should be split in. ex: [0.8, 0.1, 0.1] + random_seed: random seed for splitting + shuffle: if sentences should be randomly shuffled + split: parts that data should be split to + Returns: + parts: {part: [[sentences], [labels]]} + """ + num_classes = len(probing_dict.keys()) + probing_data = [] + for class_name, sentences in probing_dict.items(): + if len(sentences) > num_classes: + for s in sentences: + probing_data.append((s, class_name)) + else: + print( + f"Class {class_name} has less sentences ({len(sentences)}) " + f"than the number of classes ({num_classes}), so it is excluded." + ) + if not probing_data: + raise Exception("All classes have less sentences than the number of classes") + parts = {} + data, labels = map(list, zip(*probing_data)) + X_train, X_test, y_train, y_test = train_test_split( + data, + labels, + stratify=labels, + train_size=partition[0], + shuffle=shuffle, + random_state=random_seed, + ) + if len(partition) == 2: + parts = {split[0]: [X_train, y_train], split[1]: [X_test, y_test]} + else: + filtered_labels = filter_labels_after_split(y_test) + if len(filtered_labels) >= 2: + train_mask = np.isin(y_train, filtered_labels) + X_train = [X_train[i] for i in range(len(train_mask)) if train_mask[i]] + y_train = [y_train[i] for i in range(len(train_mask)) if train_mask[i]] + test_mask = np.isin(y_test, filtered_labels) + X_test = [X_test[i] for i in range(len(test_mask)) if test_mask[i]] + y_test = [y_test[i] for i in range(len(test_mask)) if test_mask[i]] + + val_size = partition[1] / (1 - partition[0]) + if len(y_test) != 0: + X_val, X_test, y_val, y_test = train_test_split( + X_test, + y_test, + stratify=y_test, + train_size=val_size, + shuffle=shuffle, + random_state=random_seed, + ) + parts = { + split[0]: [X_train, y_train], + split[1]: [X_test, y_test], + split[2]: [X_val, y_val], + } + else: + raise Exception( + f"There is not enough sentences for {partition} partition." + ) # TODO + return parts + + +def read(path: os.PathLike) -> str: + """Reads CoNLL-U file""" + with open(path, encoding="utf-8") as f: + conllu_file = f.read() + return conllu_file + + +def writer( + partition_sets: Dict[str, List[List[str]]], + task_name: str, + language: str, + save_path_dir: os.PathLike, +) -> Path: + """ + Writes to a csv file + Args: + + partition_sets: {part: [[sentences], [labels]]} + task_name: name for the probing task (will be used in result file name) + language: language title + save_path_dir: path to the directory where to save + + """ + result_path = Path(Path(save_path_dir).resolve(), f"{language}_{task_name}.csv") + with open(result_path, "w", encoding="utf-8") as newf: + my_writer = csv.writer(newf, delimiter="\t", lineterminator="\n") + for part in partition_sets: + for sentence_and_ids, value in zip(*partition_sets[part]): + sentence, ids = sentence_and_ids + my_writer.writerow( + [part, value, ",".join([str(x) for x in ids]), sentence] + ) + return result_path + + +def extract_lang_from_udfile_path( + ud_file_path: os.PathLike, language: Optional[str] +) -> str: + """Extracts language from conllu file name""" + + if not language: + return Path(ud_file_path).stem.split("-")[0] + return language + + +def determine_ud_savepath( + path_from_files: os.PathLike, save_path_dir: Optional[os.PathLike] +) -> Path: + """Creates a path to save the result file (the same directory where conllu paths are stored""" + + final_path = None + if not save_path_dir: + final_path = path_from_files + else: + final_path = save_path_dir + os.makedirs(final_path, exist_ok=True) + return Path(final_path) + + +def delete_duplicates( + probing_dict: Dict[str, List[Tuple[str, List[int]]]] +) -> Dict[str, List[Tuple[str, List[int]]]]: + """Deletes sentences with more than one different classes of node_pattern found""" + + all_sent = [sent for cl_sent in probing_dict.values() for sent, inds in cl_sent] + duplicates = {item for item, count in Counter(all_sent).items() if count > 1} + new_probing_dict = {} + for cl in probing_dict: + new_probing_dict[cl] = [ + (sent, ind) for sent, ind in probing_dict[cl] if sent not in duplicates + ] + return new_probing_dict + + +def check_query( + node_pattern: Dict[str, Dict[str, str]], + constraints: Dict[Tuple[str, str], Dict[Any, Any]], +) -> bool: + """Checks that a query fits the syntax""" + + check_node_pattern(node_pattern) + check_constraints(constraints) + constr_nodes = set([n for p in constraints for n in p]) + nodes = set(node_pattern.keys()) + if not constr_nodes <= nodes: + raise ValueError( + f"Not all nodes from the constraints are defined in the node_pattern" + ) + return True + + +def check_node_pattern(node_pattern: Dict[str, Dict[str, str]]) -> bool: + """Checks that node_pattern uses only UD categories and given in a right format""" + + NODES_FIELDS = {"form", "lemma", "upos", "xpos", "exclude"} + AVAILABLE_CATEGORIES = { + "PronType", + "Gender", + "VerbForm", + "NumType", + "Animacy", + "Mood", + "Poss", + "NounClass", + "Tense", + "Reflex", + "Number", + "Aspect", + "Foreign", + "Case", + "Voice", + "Abbr", + "Definite", + "Evident", + "Typo", + "Degree", + "Polarity", + "Person", + "Polite", + "Clusivity", + } + + for n in node_pattern: + npattern_fields = set(node_pattern[n].keys()) + if not npattern_fields <= (NODES_FIELDS | AVAILABLE_CATEGORIES): + raise KeyError( + f"Node_pattern can only include keys from this set: {NODES_FIELDS} or from the list of available " + f"grammar categories from here: https://universaldependencies.org/u/feat/index.html" + ) + + exclude_cat = node_pattern[n].get("exclude") + if exclude_cat: + if not isinstance(exclude_cat, list): + raise TypeError("Exclude features should be given in a list") + if not set(exclude_cat) <= AVAILABLE_CATEGORIES: + raise ValueError( + f"Wrong category name: {set(exclude_cat) - AVAILABLE_CATEGORIES}. Please use the same " + f"names as in the UD: https://universaldependencies.org/u/feat/index.html" + ) + return True + + +def check_constraints(constraints: Dict[Tuple[str, str], Dict[Any, Any]]) -> bool: + """Checks that constrains use only UD categories""" + + AVAILABLE_CATEGORIES = { + "PronType", + "Gender", + "VerbForm", + "NumType", + "Animacy", + "Mood", + "Poss", + "NounClass", + "Tense", + "Reflex", + "Number", + "Aspect", + "Foreign", + "Case", + "Voice", + "Abbr", + "Definite", + "Evident", + "Typo", + "Degree", + "Polarity", + "Person", + "Polite", + "Clusivity", + } + CONSTRAINT_FIELDS = {"deprels", "fconstraint", "lindist"} + FCONSTRAINT_FIELDS = {"disjoint", "intersec"} + + for np in constraints: + constr_types = set(constraints[np].keys()) + if not constr_types <= CONSTRAINT_FIELDS: + raise KeyError( + f"Wrong constraint type: {constr_types - CONSTRAINT_FIELDS}. Only {CONSTRAINT_FIELDS} can be used as " + f"keys" + ) + + fconstr = constraints[np].get("fconstraint") + if fconstr: + fconst_types = set(fconstr.keys()) + if not fconst_types <= FCONSTRAINT_FIELDS: + raise KeyError( + f"Wrong feature constraint type {fconst_types - FCONSTRAINT_FIELDS}. It can be only: {FCONSTRAINT_FIELDS}" + ) + + for fctype in fconstr: + if not isinstance(fconstr[fctype], list): + raise TypeError( + f"{fctype} features should be a list of grammar categories not a {type(fconstr[fctype])}" + ) + if not set(fconstr[fctype]) <= AVAILABLE_CATEGORIES: + raise ValueError( + f"Wrong grammar category names: {set(fconstr[fctype]) - AVAILABLE_CATEGORIES}. Please use the " + f"same names as in the UD: https://universaldependencies.org/u/feat/index.html" + ) + return True diff --git a/probing/ud_parser/ud_parser.py b/probing/ud_parser/ud_parser.py index 0260b15..6bf7bcb 100644 --- a/probing/ud_parser/ud_parser.py +++ b/probing/ud_parser/ud_parser.py @@ -65,8 +65,9 @@ def writer( with open(result_path, "w", encoding="utf-8") as newf: my_writer = csv.writer(newf, delimiter="\t", lineterminator="\n") for part in partition_sets: - for sentence, value in zip(*partition_sets[part]): - my_writer.writerow([part, value, sentence]) + for sentence_and_id, value in zip(*partition_sets[part]): + sentence, id = sentence_and_id + my_writer.writerow([part, value, sentence, id]) return result_path def find_category_token( @@ -134,7 +135,8 @@ def classify( ) ): value = category_token["feats"][category] - probing_data[value].append(s_text) + token_id = category_token["id"] - 1 + probing_data[value].append((s_text, token_id)) elif self.sorting == "by_pos_and_deprel": pos, deprel = subcategory.split("_") if ( @@ -142,7 +144,8 @@ def classify( and category_token["deprel"] == deprel ): value = category_token["feats"][category] - probing_data[value].append(s_text) + token_id = category_token["id"] - 1 + probing_data[value].append((s_text, token_id)) return probing_data def filter_labels_after_split(self, labels: List[Any]) -> List[Any]: