diff --git a/causallearn/search/ConstraintBased/PC.py b/causallearn/search/ConstraintBased/PC.py index e3dde17..cb428e9 100644 --- a/causallearn/search/ConstraintBased/PC.py +++ b/causallearn/search/ConstraintBased/PC.py @@ -30,6 +30,7 @@ def pc( verbose: bool = False, show_progress: bool = True, node_names: List[str] | None = None, + max_k: int = None, **kwargs ): if data.shape[0] < data.shape[1]: @@ -41,11 +42,11 @@ def pc( return mvpc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, correction_name=correction_name, stable=stable, uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose, - show_progress=show_progress, **kwargs) + show_progress=show_progress, max_k=max_k, **kwargs) else: return pc_alg(data=data, node_names=node_names, alpha=alpha, indep_test=indep_test, stable=stable, uc_rule=uc_rule, uc_priority=uc_priority, background_knowledge=background_knowledge, verbose=verbose, - show_progress=show_progress, **kwargs) + show_progress=show_progress, max_k=max_k, **kwargs) def pc_alg( @@ -59,6 +60,7 @@ def pc_alg( background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, show_progress: bool = True, + max_k=None, **kwargs ) -> CausalGraph: """ @@ -103,7 +105,7 @@ def pc_alg( indep_test = CIT(data, indep_test, **kwargs) cg_1 = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable, background_knowledge=background_knowledge, verbose=verbose, - show_progress=show_progress, node_names=node_names) + show_progress=show_progress, node_names=node_names, max_k=max_k) if background_knowledge is not None: orient_by_background_knowledge(cg_1, background_knowledge) @@ -142,7 +144,7 @@ def mvpc_alg( data: ndarray, node_names: List[str] | None, alpha: float, - indep_test: str, + indep_test: Any, correction_name: str, stable: bool, uc_rule: int, @@ -150,6 +152,7 @@ def mvpc_alg( background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, show_progress: bool = True, + max_k: int | None = None, **kwargs, ) -> CausalGraph: """ @@ -197,14 +200,14 @@ def mvpc_alg( start = time.time() indep_test = CIT(data, indep_test, **kwargs) ## Step 1: detect the direct causes of missingness indicators - prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable) + prt_m = get_parent_missingness_pairs(data, alpha, indep_test, stable, max_k=max_k) # print('Finish detecting the parents of missingness indicators. ') ## Step 2: ## a) Run PC algorithm with the 1st step skeleton; cg_pre = SkeletonDiscovery.skeleton_discovery(data, alpha, indep_test, stable, background_knowledge=background_knowledge, - verbose=verbose, show_progress=show_progress, node_names=node_names) + verbose=verbose, show_progress=show_progress, node_names=node_names, max_k=max_k) if background_knowledge is not None: orient_by_background_knowledge(cg_pre, background_knowledge) @@ -251,7 +254,7 @@ def mvpc_alg( ####################################################################################################################### ## *********** Functions for Step 1 *********** -def get_parent_missingness_pairs(data: ndarray, alpha: float, indep_test, stable: bool = True) -> Dict[str, list]: +def get_parent_missingness_pairs(data: ndarray, alpha: float, indep_test, stable: bool = True, max_k: int | None = None) -> Dict[str, list]: """ Detect the parents of missingness indicators If a missingness indicator has no parent, it will not be included in the result @@ -272,7 +275,7 @@ def get_parent_missingness_pairs(data: ndarray, alpha: float, indep_test, stable ## Get the index of parents of missingness indicators # If the missingness indicator has no parent, then it will not be collected in prt_m for missingness_i in missingness_index: - parent_of_missingness_i = detect_parent(missingness_i, data, alpha, indep_test, stable) + parent_of_missingness_i = detect_parent(missingness_i, data, alpha, indep_test, stable, max_k=max_k) if not isempty(parent_of_missingness_i): parent_missingness_pairs['prt'].append(parent_of_missingness_i) parent_missingness_pairs['m'].append(missingness_i) @@ -299,7 +302,7 @@ def get_missingness_index(data: ndarray) -> List[int]: return missingness_index -def detect_parent(r: int, data_: ndarray, alpha: float, indep_test, stable: bool = True) -> ndarray: +def detect_parent(r: int, data_: ndarray, alpha: float, indep_test: Any, stable: bool = True, max_k: int | None = None) -> ndarray: """Detect the parents of a missingness indicator :param r: the missingness indicator :param data_: data set (numpy ndarray) @@ -334,7 +337,8 @@ def detect_parent(r: int, data_: ndarray, alpha: float, indep_test, stable: bool no_of_var = data.shape[1] cg = CausalGraph(no_of_var) - cg.set_ind_test(CIT(data, indep_test.method)) + cg.set_ind_test(indep_test) + node_ids = range(no_of_var) pair_of_variables = list(permutations(node_ids, 2)) @@ -342,7 +346,10 @@ def detect_parent(r: int, data_: ndarray, alpha: float, indep_test, stable: bool depth = -1 while cg.max_degree() - 1 > depth: depth += 1 + if max_k is not None and depth > max_k: + break edge_removal = [] + for (x, y) in pair_of_variables: ## *********** Adaptation 2 *********** @@ -495,3 +502,4 @@ def matrix_diff(cg1: CausalGraph, cg2: CausalGraph) -> (float, List[Tuple[int, i diff_ls.append((i, j)) count += 1 return count / 2, diff_ls + diff --git a/causallearn/utils/PCUtils/SkeletonDiscovery.py b/causallearn/utils/PCUtils/SkeletonDiscovery.py index 428d2ce..7c8908b 100644 --- a/causallearn/utils/PCUtils/SkeletonDiscovery.py +++ b/causallearn/utils/PCUtils/SkeletonDiscovery.py @@ -21,7 +21,7 @@ def skeleton_discovery( background_knowledge: BackgroundKnowledge | None = None, verbose: bool = False, show_progress: bool = True, - node_names: List[str] | None = None, + node_names: List[str] | None = None, max_k=None, ) -> CausalGraph: """ Perform skeleton discovery @@ -63,6 +63,8 @@ def skeleton_discovery( pbar = tqdm(total=no_of_var) if show_progress else None while cg.max_degree() - 1 > depth: depth += 1 + if max_k is not None and depth > max_k: + break edge_removal = [] if show_progress: pbar.reset()