diff --git a/siuba/tests/test_verb_mutate.py b/siuba/tests/test_verb_mutate.py index decb2ddf..6a4359f6 100644 --- a/siuba/tests/test_verb_mutate.py +++ b/siuba/tests/test_verb_mutate.py @@ -119,5 +119,35 @@ def test_mutate_overwrites_prev(backend): ) +def test_mutate_after_summarize_on_non_derived_column(backend): + dfs = backend.load_df(data_frame(x = range(1, 5), g = [1,2,2,2])) + query = group_by(_.g) >> summarize(avg = _.x.min()) >> mutate(avg_g = _.g.mean()) + assert_equal_query( + dfs, + query, + data_frame(g = [1,2], avg = [1,2], avg_g = 1.5) + ) + + +def test_mutate_after_summarize_on_derived_column(backend): + dfs = backend.load_df(data_frame(x = range(1, 5), g = [1,2,2,2])) + + query = group_by(_.g) >> summarize(avg = _.x.min()) >> mutate(avg_avg = _.avg.mean()) + assert_equal_query( + dfs, + query, + data_frame(g = [1,2], avg = [1,2], avg_avg = 1.5) + ) + + +def test_mutate_after_summarize_limits_column_access(backend): + dfs = backend.load_df(data_frame(x = range(1, 5), g = [1,2,2,2])) + query = group_by(_.g) >> summarize(avg = _.x.min()) >> mutate(x2 = _.x + 1) + + with pytest.raises(AttributeError) as exc_info: + query(dfs) + + + assert "x" in exc_info.value.args[0]