Skip to content

Commit

Permalink
✅ added tests for new modules
Browse files Browse the repository at this point in the history
  • Loading branch information
julkaar9 committed Apr 8, 2023
1 parent 71019bc commit d00f502
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 22 deletions.
32 changes: 17 additions & 15 deletions tests/test_bar.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from pynimate.bar import Barplot
# Legacy tests for barplot, will be removed in 2.0.0
import pytest

from pynimate.bar import Barplot


def test_barplot_set_bar_color_list(sample_bar_data1):
bar = Barplot(sample_bar_data1, "%Y-%m-%d", "3MS")
def test_barplot_set_bar_color_list(sample_data1):
bar = Barplot(sample_data1, "%Y-%m-%d", "3MS")
bar_colors_list = [
"#2a9d8f",
"#e9c46a",
Expand All @@ -22,8 +24,8 @@ def test_barplot_set_bar_color_list(sample_bar_data1):
assert bar.datafier.bar_colors == bar_colors


def test_barplot_set_bar_color_dict(sample_bar_data1):
bar = Barplot(sample_bar_data1, "%Y-%m-%d", "3MS")
def test_barplot_set_bar_color_dict(sample_data1):
bar = Barplot(sample_data1, "%Y-%m-%d", "3MS")
bar_colors = {
"Afghanistan": "#2a9d8f",
"Angola": "#e9c46a",
Expand All @@ -35,9 +37,9 @@ def test_barplot_set_bar_color_dict(sample_bar_data1):
assert bar.datafier.bar_colors == bar_colors


def test_barplot_set_bar_color_error_length(sample_bar_data1):
def test_barplot_set_bar_color_error_length(sample_data1):
with pytest.raises(AssertionError):
bar = Barplot(sample_bar_data1, "%Y-%m-%d", "3MS")
bar = Barplot(sample_data1, "%Y-%m-%d", "3MS")
bar_colors = [
"#2a9d8f",
"#e9c46a",
Expand All @@ -48,9 +50,9 @@ def test_barplot_set_bar_color_error_length(sample_bar_data1):
assert bar.datafier.bar_colors == bar_colors


def test_barplot_set_bar_color_error_col_mismatch(sample_bar_data1):
def test_barplot_set_bar_color_error_col_mismatch(sample_data1):
with pytest.raises(AssertionError):
bar = Barplot(sample_bar_data1, "%Y-%m-%d", "3MS")
bar = Barplot(sample_data1, "%Y-%m-%d", "3MS")
bar_colors = {
"India": "#2a9d8f",
"Angola": "#e9c46a",
Expand All @@ -62,20 +64,20 @@ def test_barplot_set_bar_color_error_col_mismatch(sample_bar_data1):
assert bar.datafier.bar_colors == bar_colors


def test_barplot_set_text_error_empty_text(sample_bar_data1):
def test_barplot_set_text_error_empty_text(sample_data1):
with pytest.raises(AssertionError):
bar = Barplot(sample_bar_data1, "%Y-%m-%d", "3MS")
bar = Barplot(sample_data1, "%Y-%m-%d", "3MS")
bar.set_text("text1")


def test_barplot_set_text_priority(sample_bar_data1):
bar = Barplot(sample_bar_data1, "%Y-%m-%d", "3MS")
def test_barplot_set_text_priority(sample_data1):
bar = Barplot(sample_data1, "%Y-%m-%d", "3MS")
bar.set_text("text1", text="Test", callback=lambda *args: "Test")
assert "s" not in bar.text_collection["text1"][1]


def test_barplot_remove_text(sample_bar_data1):
bar = Barplot(sample_bar_data1, "%Y-%m-%d", "3MS")
def test_barplot_remove_text(sample_data1):
bar = Barplot(sample_data1, "%Y-%m-%d", "3MS")
bar.set_text("text1", text="Test1")
bar.set_text("text2", text="Test2")
bar.set_text("text3", text="Test3")
Expand Down
6 changes: 6 additions & 0 deletions tests/test_barhplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pynimate.barhplot import Barhplot


def test_barhplot_xylim(sample_data1_bardfr):
barhplot = Barhplot(sample_data1_bardfr)
assert barhplot.xlim == [None, 10] and barhplot.ylim == [0.5, 5.6]
152 changes: 152 additions & 0 deletions tests/test_baseplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import pytest

from pynimate.baseplot import Baseplot


def test_baseplot_generate_column_colors(sample_data1_basedfr):
base_plot = Baseplot(sample_data1_basedfr)
column_colors = {
"Afghanistan": (0.267968, 0.223549, 0.512008),
"Angola": (0.190631, 0.407061, 0.556089),
"Albania": (0.127568, 0.566949, 0.550556),
"USA": (0.20803, 0.718701, 0.472873),
"Argentina": (0.565498, 0.84243, 0.262877),
}
assert base_plot.generate_column_colors() == column_colors


def test_baseplot_set_column_colors_str(sample_data1_basedfr):
base_plot = Baseplot(sample_data1_basedfr)
base_plot.set_column_colors("#FF2C55")
column_colors = {
"Afghanistan": "#FF2C55",
"Angola": "#FF2C55",
"Albania": "#FF2C55",
"USA": "#FF2C55",
"Argentina": "#FF2C55",
}
assert base_plot.column_colors == column_colors


def test_baseplot_set_column_colors_list(sample_data1_basedfr):
base_plot = Baseplot(sample_data1_basedfr)
base_plot.set_column_colors(
[
"#C41E3D",
"#7D1128",
"#FF2C55",
"#3C0919",
"#E2294F",
]
)
column_colors = {
"Afghanistan": "#C41E3D",
"Angola": "#7D1128",
"Albania": "#FF2C55",
"USA": "#3C0919",
"Argentina": "#E2294F",
}
assert base_plot.column_colors == column_colors


def test_baseplot_set_column_colors_dict(sample_data1_basedfr):
base_plot = Baseplot(sample_data1_basedfr)
base_plot.set_column_colors(
{
"Afghanistan": "#C41E3D",
"Angola": "#7D1128",
"Albania": "#FF2C55",
"USA": "#3C0919",
"Argentina": "#E2294F",
}
)
column_colors = {
"Afghanistan": "#C41E3D",
"Angola": "#7D1128",
"Albania": "#FF2C55",
"USA": "#3C0919",
"Argentina": "#E2294F",
}
assert base_plot.column_colors == column_colors


def test_baseplot_set_column_color_err_length(sample_data1_basedfr):
with pytest.raises(AssertionError):
base_plot = Baseplot(sample_data1_basedfr)
column_colors = [
"#2a9d8f",
"#e9c46a",
"#e76f51",
"#a7c957",
]
base_plot.set_column_colors(column_colors)


def test_baseplot_set_column_color_err_col_mismatch(sample_data1_basedfr):
with pytest.raises(ValueError):
bar = Baseplot(sample_data1_basedfr)
bar_colors = {
"India": "#2a9d8f",
"Angola": "#e9c46a",
"Albania": "#e76f51",
"USA": "#a7c957",
"Argentina": "#e5989b",
}
bar.set_column_colors(bar_colors)


def test_baseplot_set_column_color_err_type(sample_data1_basedfr):
with pytest.raises(TypeError):
bar = Baseplot(sample_data1_basedfr)
bar_colors = set(
[
"#2a9d8f",
"#e9c46a",
"#e76f51",
"#a7c957",
"#e5989b",
]
)

bar.set_column_colors(bar_colors)


def test_baseplot_xylim(sample_data1_basedfr):
base_plot = Baseplot(sample_data1_basedfr)
xmin, xmax = base_plot.xlim
ymin, ymax = base_plot.ylim
assert xmin is None and ymin is None
assert xmax.strftime("%Y-%m-%d") == "1962-01-01" and ymax == 5


def test_baseplot_text_structure(sample_data1_basedfr):
base_plot = Baseplot(sample_data1_basedfr)
base_plot.set_title("Title")
base_plot.set_xlabel("Xlabel")
base_plot.set_time()
for text in base_plot.text_collection.values():
assert isinstance(text, tuple)
assert text[0] is None or callable(text[0])
assert isinstance(text[1], dict)


def test_baseplot_set_text_error_empty_text(sample_data1_basedfr):
with pytest.raises(AssertionError):
bar = Baseplot(sample_data1_basedfr)
bar.set_text("text1")


def test_baseplot_set_text_priority(sample_data1_basedfr):
bar = Baseplot(sample_data1_basedfr)
bar.set_text("text1", text="Test", callback=lambda *args: "Test")
assert "s" not in bar.text_collection["text1"][1]


def test_baseplot_remove_text(sample_data1_basedfr):
bar = Baseplot(sample_data1_basedfr)
bar.set_text("text1", text="Test1")
bar.set_text("text2", text="Test2")
bar.set_text("text3", text="Test3")

bar.remove_text(["text1", "text2"])
assert list(bar.text_collection.keys()) == ["text3"]
16 changes: 9 additions & 7 deletions tests/test_datafier.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from pynimate.datafier import Datafier
# Legacy tests for datafier, will be removed in 2.0.0
import pandas as pd

from pynimate.datafier import Datafier


def test_datafier_init(sample_bar_data1):
dfr = Datafier(sample_bar_data1, "%Y-%m-%d", "3MS", 0.1)
def test_datafier_init(sample_data1):
dfr = Datafier(sample_data1, "%Y-%m-%d", "3MS", 0.1)
assert dfr.n_bars == 5


def test_datafier_interpolate_even(sample_bar_data2):
dfr = Datafier(sample_bar_data2, "%Y", "3MS")
def test_datafier_interpolate_even(sample_data2):
dfr = Datafier(sample_data2, "%Y", "3MS")
interpolated_data = pd.DataFrame(
{
"time": pd.to_datetime(
Expand Down Expand Up @@ -94,8 +96,8 @@ def test_datafier_get_bar_colors(map_data):
assert dfr.get_bar_colors() == bar_colors


def test_datafier_get_prepared_data(sample_bar_data1):
dfr = Datafier(sample_bar_data1, "%Y-%m-%d", "3MS")
def test_datafier_get_prepared_data(sample_data1):
dfr = Datafier(sample_data1, "%Y-%m-%d", "3MS")
dfr.df_ranks.index.name = "time"
df_ranks = pd.DataFrame(
{
Expand Down

0 comments on commit d00f502

Please sign in to comment.