diff --git a/analyses/common/tables.py b/analyses/common/tables.py index 1b77fb0..1b329c6 100644 --- a/analyses/common/tables.py +++ b/analyses/common/tables.py @@ -14,7 +14,10 @@ "highlight_cols", "highlight_rows", "save_table", + "generate_column_rules", + "generate_partition_rules", ] + def get_styler(df: Union[pd.DataFrame, pd.Series], decimals: Optional[int]=2, thousands: Optional[str]=',') -> pandas.io.formats.style.Styler: '''Gets a Styler object for formatting a table. @@ -45,6 +48,73 @@ def highlight_rows(styler: pandas.io.formats.style.Styler) -> pandas.io.formats. Tuple[RuleLineIndex, Union[CmidruleSpec, List[CmidruleSpec]]]] ConcreteRule = Tuple[RuleLineIndex, str] +def generate_column_rules(df: pd.DataFrame, skip_index: bool=True, level: int=0, left_trim: TrimSpec=True, right_trim: TrimSpec=True) -> List[RuleSpecifier]: + """Generate post-header rule for DF, including cut rules for the column groups at LEVEL. + + If skip_index is False, simply return a specification for a regular \midrule after the column header(s). + Otherwise, generate an offset midrule or cmidrules for groups. + When grouping is performed, obey LEFT_TRIM and RIGHT_TRIM between \cmidrule s + """ + index_cols = 1 + if isinstance(df.index, pd.MultiIndex): + index_cols = df.index.nlevels + + if not isinstance(df.columns, pd.MultiIndex): + if skip_index: + return [(1, (1 + index_cols, df.columns.size + index_cols, False, False))] + return [1] + else: + if not skip_index: + return [df.columns.nlevels] + + cmidrules = [] + values = df.columns.get_level_values(level).array + cur = values[0] + cur_start = 1 + val = 1 + + for label in values[1:]: + if label != cur: + cur = label + cmidrules.append((cur_start + index_cols, val + index_cols, False if cur_start==1 else left_trim, right_trim)) + cur_start = val + 1 + val += 1 + + cmidrules.append((cur_start + index_cols, len(values) + index_cols, left_trim, False)) + + return [(df.columns.nlevels, cmidrules)] + +def generate_partition_rules(df: pd.DataFrame, skip_index: bool=False, level: int=0) -> List[RuleSpecifier]: + """Generate post-row-group rules for DF for row-groups at LEVEL. + + If SKIP_INDEX is true, generate a cmidrule which does not include the columns of the index. + """ + assert isinstance(df.index, pd.MultiIndex), "Index must be a MultiIndex" + + num_cols = df.columns.size + row_offset = 1 + if isinstance(df.columns, pd.MultiIndex): + row_offset = df.columns.nlevels + + index_offset = df.index.nlevels + + values = df.index.get_level_values(level).array + cur_row = 1 + cur_label = values[0] + + rules = [] + + for label in values[1:]: + if cur_label != label: + cur_label = label + if skip_index: + rules.append(cur_row + row_offset) + else: + rules.append((cur_row + row_offset, (index_offset + 1, num_cols + index_offset, False, False))) + cur_row += 1 + + return rules + def _trim_spec(trim_left: TrimSpec, trim_right: TrimSpec) -> str: if trim_left or trim_right: trim_spec = '('