Skip to content

Commit

Permalink
add test_save_load_dragonnet
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandrmgservices committed Dec 5, 2023
1 parent dc74352 commit f1b3282
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/test_dragon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from causalml.inference.tf import DragonNet
from causalml.dataset.regression import *
import shutil


def test_save_load_dragonnet():
y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000)

dragon = DragonNet(neurons_per_layer=200, targeted_reg=True, verbose=False)
dragon_ite = dragon.fit_predict(X, w, y, return_components=False)
dragon_ate = dragon_ite.mean()
dragon.save("smaug")

smaug = DragonNet()
smaug.load("smaug")
shutil.rmtree("smaug")

assert smaug.predict_tau(X).mean() == dragon_ate

0 comments on commit f1b3282

Please sign in to comment.