Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_dataset matches R #150

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions marginaleffects/hypotheses_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def joint_hypotheses(obj, joint_index=None, joint_test="f", hypothesis=0):
joint_index = [
i for i in range(len(var_names)) if var_names[i] in joint_index
]
assert min(joint_index) >= 0 and max(joint_index) <= len(
var_names
), "`joint_index` contain invalid indices"
assert min(joint_index) >= 0 and max(joint_index) <= len(var_names), (
"`joint_index` contain invalid indices"
)

V_hat = obj.get_vcov()

Expand Down
16 changes: 8 additions & 8 deletions marginaleffects/hypothesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def lincom_revreference(x, by):
lincom[0] = 1
lab = get_hypothesis_row_labels(x, by)
if len(lab) == 0 or len(set(lab)) != len(lab):
lab = [f"Row 1 - Row {i+1}" for i in range(len(lincom))]
lab = [f"Row 1 - Row {i + 1}" for i in range(len(lincom))]
else:
lab = [f"{lab[0]} - {la}" for la in lab]
lincom = pl.DataFrame(lincom, schema=lab)
Expand All @@ -138,7 +138,7 @@ def lincom_reference(x, by):
lincom[0, :] = -1
lab = get_hypothesis_row_labels(x, by)
if len(lab) == 0 or len(set(lab)) != len(lab):
lab = [f"Row {i+1} - Row 1" for i in range(len(lincom))]
lab = [f"Row {i + 1} - Row 1" for i in range(len(lincom))]
else:
lab = [f"{la} - {lab[0]}" for la in lab]
if lincom.shape[1] == 1:
Expand All @@ -153,9 +153,9 @@ def lincom_revsequential(x, by):
lincom = np.zeros((len(x), len(x) - 1))
lab = get_hypothesis_row_labels(x, by)
if len(lab) == 0 or len(set(lab)) != len(lab):
lab = [f"Row {i+1} - Row {i+2}" for i in range(lincom.shape[1])]
lab = [f"Row {i + 1} - Row {i + 2}" for i in range(lincom.shape[1])]
else:
lab = [f"{lab[i]} - {lab[i+1]}" for i in range(lincom.shape[1])]
lab = [f"{lab[i]} - {lab[i + 1]}" for i in range(lincom.shape[1])]
for i in range(lincom.shape[1]):
lincom[i : i + 2, i] = [1, -1]
if lincom.shape[1] == 1:
Expand All @@ -169,9 +169,9 @@ def lincom_sequential(x, by):
lincom = np.zeros((len(x), len(x) - 1))
lab = get_hypothesis_row_labels(x, by)
if len(lab) == 0 or len(set(lab)) != len(lab):
lab = [f"Row {i+2} - Row {i+1}" for i in range(lincom.shape[1])]
lab = [f"Row {i + 2} - Row {i + 1}" for i in range(lincom.shape[1])]
else:
lab = [f"{lab[i+1]} - {lab[i]}" for i in range(lincom.shape[1])]
lab = [f"{lab[i + 1]} - {lab[i]}" for i in range(lincom.shape[1])]
for i in range(lincom.shape[1]):
lincom[i : i + 2, i] = [-1, 1]
if lincom.shape[1] == 1:
Expand All @@ -194,7 +194,7 @@ def lincom_revpairwise(x, by):
tmp[j] = 1
mat.append(tmp)
if flag:
lab_col.append(f"Row {j+1} - Row {i+1}")
lab_col.append(f"Row {j + 1} - Row {i + 1}")
else:
lab_col.append(f"{lab_row[j]} - {lab_row[i]}")
if len(mat) == 1:
Expand All @@ -217,7 +217,7 @@ def lincom_pairwise(x, by):
tmp[i] = 1
mat.append(tmp)
if flag:
lab_col.append(f"Row {i+1} - Row {j+1}")
lab_col.append(f"Row {i + 1} - Row {j + 1}")
else:
lab_col.append(f"{lab_row[i]} - {lab_row[j]}")
if len(mat) == 1:
Expand Down
24 changes: 12 additions & 12 deletions marginaleffects/plot_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ def dt_on_condition(model, condition):
first_key = "" # special case when the first element is numeric

if isinstance(condition_new, list):
assert all(
ele in modeldata.columns for ele in condition_new
), "All elements of condition must be columns of the model."
assert all(ele in modeldata.columns for ele in condition_new), (
"All elements of condition must be columns of the model."
)
first_key = condition_new[0]
to_datagrid = {key: None for key in condition_new}

elif isinstance(condition_new, dict):
assert all(
key in modeldata.columns for key in condition_new.keys()
), "All keys of condition must be columns of the model."
assert all(key in modeldata.columns for key in condition_new.keys()), (
"All keys of condition must be columns of the model."
)
first_key = next(iter(condition_new))
to_datagrid = (
condition_new # third pointer to the same object? looks like a BUG
Expand All @@ -38,9 +38,9 @@ def dt_on_condition(model, condition):
if isinstance(condition_new, dict) and "newdata" in to_datagrid.keys():
condition_new.pop("newdata", None)

assert (
1 <= len(condition_new) <= 4
), f"Lenght of condition must be inclusively between 1 and 4. Got : {len(condition_new)}."
assert 1 <= len(condition_new) <= 4, (
f"Lenght of condition must be inclusively between 1 and 4. Got : {len(condition_new)}."
)

for key, value in to_datagrid.items():
variable_type = model.variables_type[key]
Expand All @@ -58,9 +58,9 @@ def dt_on_condition(model, condition):
if to_datagrid[key]
else modeldata[key].unique().sort().to_list()
)
assert (
len(to_datagrid[key]) <= 10
), f"Character type variables of more than 10 unique values are not supported. {key} variable has {len(to_datagrid[key])} unique values."
assert len(to_datagrid[key]) <= 10, (
f"Character type variables of more than 10 unique values are not supported. {key} variable has {len(to_datagrid[key])} unique values."
)

to_datagrid["newdata"] = modeldata
dt = datagrid(**to_datagrid)
Expand Down
24 changes: 12 additions & 12 deletions marginaleffects/plot_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@ def plot_comparisons(

model = sanitize_model(model)

assert not (
not by and newdata is not None
), "The `newdata` argument requires a `by` argument."
assert not (not by and newdata is not None), (
"The `newdata` argument requires a `by` argument."
)

assert (condition is None and by) or (
condition is not None and not by
), "One of the `condition` and `by` arguments must be supplied, but not both."
assert (condition is None and by) or (condition is not None and not by), (
"One of the `condition` and `by` arguments must be supplied, but not both."
)

assert not (
wts is not None and not by
), "The `wts` argument requires a `by` argument."
assert not (wts is not None and not by), (
"The `wts` argument requires a `by` argument."
)

# before dt_on_condition, which modifies in-place
condition_input = copy.deepcopy(condition)
Expand Down Expand Up @@ -134,9 +134,9 @@ def plot_comparisons(
# not sure why these get appended
var_list = [x for x in var_list if x not in ["newdata", "model"]]

assert (
len(var_list) < 4
), "The `condition` and `by` arguments can have a max length of 3."
assert len(var_list) < 4, (
"The `condition` and `by` arguments can have a max length of 3."
)

if "contrast" in dt.columns and dt["contrast"].unique().len() > 1:
var_list = var_list + ["contrast"]
Expand Down
18 changes: 9 additions & 9 deletions marginaleffects/plot_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ def plot_predictions(

model = sanitize_model(model)

assert not (
not by and newdata is not None
), "The `newdata` argument requires a `by` argument."
assert not (not by and newdata is not None), (
"The `newdata` argument requires a `by` argument."
)

assert not (
wts is not None and not by
), "The `wts` argument requires a `by` argument."
assert not (wts is not None and not by), (
"The `wts` argument requires a `by` argument."
)

assert not (
condition is None and not by
), "One of the `condition` and `by` arguments must be supplied, but not both."
assert not (condition is None and not by), (
"One of the `condition` and `by` arguments must be supplied, but not both."
)

# before dt_on_condition, which modifies in-place
condition_input = copy.deepcopy(condition)
Expand Down
42 changes: 21 additions & 21 deletions marginaleffects/plot_slopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,29 +75,29 @@ def plot_slopes(

assert variables, "The `variables` argument must be supplied."

assert not (
not by and newdata is not None
), "The `newdata` argument requires a `by` argument."
assert not (not by and newdata is not None), (
"The `newdata` argument requires a `by` argument."
)

assert (condition is None and by) or (
condition is not None and not by
), "One of the `condition` and `by` arguments must be supplied, but not both."
assert (condition is None and by) or (condition is not None and not by), (
"One of the `condition` and `by` arguments must be supplied, but not both."
)

assert not (
wts is not None and not by
), "The `wts` argument requires a `by` argument."
assert not (wts is not None and not by), (
"The `wts` argument requires a `by` argument."
)

assert not (
not by and newdata is not None
), "The `newdata` argument requires a `by` argument."
assert not (not by and newdata is not None), (
"The `newdata` argument requires a `by` argument."
)

assert not (
wts is not None and not by
), "The `wts` argument requires a `by` argument."
assert not (wts is not None and not by), (
"The `wts` argument requires a `by` argument."
)

assert not (
condition is None and not by
), "One of the `condition` and `by` arguments must be supplied, but not both."
assert not (condition is None and not by), (
"One of the `condition` and `by` arguments must be supplied, but not both."
)

# before dt_on_condition, which modifies in-place
condition_input = copy.deepcopy(condition)
Expand Down Expand Up @@ -139,9 +139,9 @@ def plot_slopes(
# not sure why these get appended
var_list = [x for x in var_list if x not in ["newdata", "model"]]

assert (
len(var_list) < 5
), "The `condition` and `by` arguments can have a max length of 4."
assert len(var_list) < 5, (
"The `condition` and `by` arguments can have a max length of 4."
)

if "contrast" in dt.columns and dt["contrast"].unique().len() > 1:
var_list = var_list + ["contrast"]
Expand Down
1 change: 1 addition & 0 deletions marginaleffects/sanitize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def sanitize_model(model):

try:
from linearmodels.panel.results import PanelResults

if isinstance(model, PanelResults):
return ModelLinearmodels(model)
except ImportError:
Expand Down
18 changes: 9 additions & 9 deletions marginaleffects/sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def sanitize_newdata(model, newdata, wts, by=[]):
"contrast",
"statistic",
}
assert not (
set(out.columns) & reserved_names
), f"Input data contain reserved column name(s) : {set(out.columns).intersection(reserved_names)}"
assert not (set(out.columns) & reserved_names), (
f"Input data contain reserved column name(s) : {set(out.columns).intersection(reserved_names)}"
)

datagrid_explicit = None
if isinstance(out, pl.DataFrame) and hasattr(out, "datagrid_explicit"):
Expand Down Expand Up @@ -151,9 +151,9 @@ def sanitize_comparison(comparison, by, wts=None):
"expdydx": "exp(dY/dX)",
}

assert (
out in lab.keys()
), f"`comparison` must be one of: {', '.join(list(lab.keys()))}."
assert out in lab.keys(), (
f"`comparison` must be one of: {', '.join(list(lab.keys()))}."
)

return (out, lab[out])

Expand Down Expand Up @@ -297,9 +297,9 @@ def clean(k):

elif callable(value):
tmp = value(newdata[variable])
assert (
tmp.shape[1] == 2
), f"The function passed to `variables` must return a DataFrame with two columns. Got {tmp.shape[1]}."
assert tmp.shape[1] == 2, (
f"The function passed to `variables` must return a DataFrame with two columns. Got {tmp.shape[1]}."
)
lo = tmp[:, 0]
hi = tmp[:, 1]
lab = "custom"
Expand Down
56 changes: 44 additions & 12 deletions marginaleffects/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def ingest(df: ArrowStreamExportable):

try:
import pandas as pd

if isinstance(df, pd.DataFrame):
df = df.reset_index()
except ImportError:
Expand Down Expand Up @@ -170,16 +171,26 @@ def wrapper(*args, **kwargs):
return wrapper


def get_dataset(dataset: str, docs: bool = False):
def get_dataset(
dataset: str = "ArgentinaCPI",
package: str = "AER",
docs: bool = False,
search: str = None,
):
"""
Download and read a dataset as a Polars DataFrame or return documentation link.
Download and read a dataset as a Polars DataFrame from the `marginaleffects` or from the list at https://vincentarelbundock.github.io/Rdatasets/.
Returns documentation link if `docs` is True.

Parameters
----------
dataset : str
The dataset to download. Must be one of "affairs", "airbnb", "immigration", "military", "thornton".
The dataset to download. One of "affairs", "airbnb", "immigration", "military", "thornton" or Rdatasets
package : str, optional
The package to download the dataset from. Default is "marginaleffects".
docs : bool, optional
If True, return the documentation URL instead of the dataset. Default is False.
search: str, optional
The string is a regular expresion. Download the dataset index from Rdatasets; search the "Package", "Item", and "Title" columns; and return the matching rows.

Returns
-------
Expand All @@ -192,6 +203,20 @@ def get_dataset(dataset: str, docs: bool = False):
ValueError
If the dataset is not among the specified choices.
"""
if search:
try:
index = pl.read_csv(
"https://vincentarelbundock.github.io/Rdatasets/datasets.csv"
)
index = index.filter(
index["Package"].str.contains(search)
| index["Item"].str.contains(search)
| index["Title"].str.contains(search)
)
return index.select(["Package", "Item", "Title", "Rows", "Cols", "CSV"])
except BaseException as e:
raise ValueError(f"Error searching dataset: {e}")

datasets = {
"affairs": "https://marginaleffects.com/data/affairs",
"airbnb": "https://marginaleffects.com/data/airbnb",
Expand All @@ -200,15 +225,22 @@ def get_dataset(dataset: str, docs: bool = False):
"thornton": "https://marginaleffects.com/data/thornton",
}

if dataset not in datasets:
raise ValueError(
f"Invalid dataset choice. Expected one of {list(datasets.keys())}."
)
try:
if dataset in datasets:
base_url = datasets[dataset]
df = pl.read_parquet(f"{base_url}.parquet")
doc_url = (
"https://github.com/vincentarelbundock/marginaleffects/issues/1368"
)
else:
csv_url = f"https://vincentarelbundock.github.io/Rdatasets/csv/{package}/{dataset}.csv"
doc_url = f"https://vincentarelbundock.github.io/Rdatasets/doc/{package}/{dataset}.html"
df = pl.read_csv(csv_url)

base_url = datasets[dataset]
if docs:
return doc_url

if docs:
return f"{base_url}.html"
return df

df = pl.read_parquet(f"{base_url}.parquet")
return df
except BaseException as e:
raise ValueError(f"Error reading dataset: {e}")
Loading
Loading