From 87b06b3990ad5852e9fdb6923330aa6a8c0305a5 Mon Sep 17 00:00:00 2001 From: Anahita Bhiwandiwalla Date: Mon, 13 Jun 2016 14:50:18 -0700 Subject: [PATCH] Added a test --- .../survivalanalysis/CoxPhPredictArgs.scala | 2 -- .../org/apache/spark/ml/regression/CoxTest.scala | 15 +++++++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/engine-plugins/model-plugins/src/main/scala/org/trustedanalytics/atk/engine/model/plugins/survivalanalysis/CoxPhPredictArgs.scala b/engine-plugins/model-plugins/src/main/scala/org/trustedanalytics/atk/engine/model/plugins/survivalanalysis/CoxPhPredictArgs.scala index 9fb33f3851..c50e74feb4 100644 --- a/engine-plugins/model-plugins/src/main/scala/org/trustedanalytics/atk/engine/model/plugins/survivalanalysis/CoxPhPredictArgs.scala +++ b/engine-plugins/model-plugins/src/main/scala/org/trustedanalytics/atk/engine/model/plugins/survivalanalysis/CoxPhPredictArgs.scala @@ -8,13 +8,11 @@ *       http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ - package org.trustedanalytics.atk.engine.model.plugins.survivalanalysis import org.trustedanalytics.atk.domain.frame.FrameReference diff --git a/engine-plugins/model-plugins/src/test/scala/org/apache/spark/ml/regression/CoxTest.scala b/engine-plugins/model-plugins/src/test/scala/org/apache/spark/ml/regression/CoxTest.scala index 9aac4f3d74..ded6fe1bd4 100644 --- a/engine-plugins/model-plugins/src/test/scala/org/apache/spark/ml/regression/CoxTest.scala +++ b/engine-plugins/model-plugins/src/test/scala/org/apache/spark/ml/regression/CoxTest.scala @@ -9,16 +9,16 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.ml.regression +import org.apache.spark.frame.FrameRdd import org.apache.spark.mllib.linalg.DenseVector import org.scalatest.Matchers +import org.trustedanalytics.atk.domain.schema.{DataTypes, Column, FrameSchema} import org.trustedanalytics.atk.testutils.MatcherUtils._ import org.trustedanalytics.atk.testutils.TestingSparkContextFlatSpec import breeze.linalg.{ DenseVector => BDV, * } @@ -92,4 +92,15 @@ class CoxTest extends TestingSparkContextFlatSpec with Matchers { gradient.toArray should equalWithTolerance(estimatedGradient.toArray) informationMatrix shouldBe estimatedInformationMatrix +- 1e-6 } + + "predict" should "compute predicted hazard ratio" in { + val features = new DenseVector(Array(27.9)) + val estimatedPredictedOutput = 1.00261043677 + + val coxModel = new CoxModel("coxModelId",beta=new DenseVector(Array(-0.03351902788328871)), meanVector = new DenseVector(Array(27.977777777777778))) + val predictedOutput = coxModel.predict(features, coxModel.meanVector) + + predictedOutput shouldBe estimatedPredictedOutput +- 1e-6 + } + }