diff --git a/engine-plugins/model-plugins/src/test/scala/org/apache/spark/ml/regression/CoxTest.scala b/engine-plugins/model-plugins/src/test/scala/org/apache/spark/ml/regression/CoxTest.scala index 763516612b..ded6fe1bd4 100644 --- a/engine-plugins/model-plugins/src/test/scala/org/apache/spark/ml/regression/CoxTest.scala +++ b/engine-plugins/model-plugins/src/test/scala/org/apache/spark/ml/regression/CoxTest.scala @@ -15,8 +15,10 @@ */ package org.apache.spark.ml.regression +import org.apache.spark.frame.FrameRdd import org.apache.spark.mllib.linalg.DenseVector import org.scalatest.Matchers +import org.trustedanalytics.atk.domain.schema.{DataTypes, Column, FrameSchema} import org.trustedanalytics.atk.testutils.MatcherUtils._ import org.trustedanalytics.atk.testutils.TestingSparkContextFlatSpec import breeze.linalg.{ DenseVector => BDV, * } @@ -90,4 +92,15 @@ class CoxTest extends TestingSparkContextFlatSpec with Matchers { gradient.toArray should equalWithTolerance(estimatedGradient.toArray) informationMatrix shouldBe estimatedInformationMatrix +- 1e-6 } + + "predict" should "compute predicted hazard ratio" in { + val features = new DenseVector(Array(27.9)) + val estimatedPredictedOutput = 1.00261043677 + + val coxModel = new CoxModel("coxModelId",beta=new DenseVector(Array(-0.03351902788328871)), meanVector = new DenseVector(Array(27.977777777777778))) + val predictedOutput = coxModel.predict(features, coxModel.meanVector) + + predictedOutput shouldBe estimatedPredictedOutput +- 1e-6 + } + }