Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
*/
package org.apache.spark.ml.regression

import org.apache.spark.frame.FrameRdd
import org.apache.spark.mllib.linalg.DenseVector
import org.scalatest.Matchers
import org.trustedanalytics.atk.domain.schema.{DataTypes, Column, FrameSchema}
import org.trustedanalytics.atk.testutils.MatcherUtils._
import org.trustedanalytics.atk.testutils.TestingSparkContextFlatSpec
import breeze.linalg.{ DenseVector => BDV, * }
Expand Down Expand Up @@ -90,4 +92,15 @@ class CoxTest extends TestingSparkContextFlatSpec with Matchers {
gradient.toArray should equalWithTolerance(estimatedGradient.toArray)
informationMatrix shouldBe estimatedInformationMatrix +- 1e-6
}

"predict" should "compute predicted hazard ratio" in {
val features = new DenseVector(Array(27.9))
val estimatedPredictedOutput = 1.00261043677

val coxModel = new CoxModel("coxModelId",beta=new DenseVector(Array(-0.03351902788328871)), meanVector = new DenseVector(Array(27.977777777777778)))
val predictedOutput = coxModel.predict(features, coxModel.meanVector)

predictedOutput shouldBe estimatedPredictedOutput +- 1e-6
}

}