Skip to content
This repository was archived by the owner on May 12, 2021. It is now read-only.

Commit 73fd214

Browse files
author
Chris Wewerka
committed
making async support backward compatible
1 parent 183c9a7 commit 73fd214

File tree

24 files changed

+159
-59
lines changed

24 files changed

+159
-59
lines changed

core/src/main/scala/org/apache/predictionio/controller/LAlgorithm.scala

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,38 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P]
7777
cartesian.flatMap { case (m, qArray) =>
7878
qArray.map {
7979
case (qx, q) =>
80-
(qx, Await.result(predict(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) )
80+
(qx,
81+
Await.result(predictAsync(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes) )
8182
}
8283
}
8384
}
8485

85-
def predictBase(localBaseModel: Any, q: Q)(implicit ec: ExecutionContext): Future[P] = {
86-
predict(localBaseModel.asInstanceOf[M], q)(ec)
87-
}
86+
override def predictBaseAsync(localBaseModel: Any, q: Q)(implicit ec: ExecutionContext)
87+
: Future[P] =
88+
predictAsync(localBaseModel.asInstanceOf[M], q)(ec)
89+
90+
@deprecated(message =
91+
"this method is just here for backward compatibility, predictBaseAsync() is called now",
92+
since = "0.14.0")
93+
override def predictBase(localBaseModel: Any, q: Q): P =
94+
predict(localBaseModel.asInstanceOf[M], q)
95+
96+
/** Implement this method to produce a Future of a prediction in a non blocking way
97+
* from a query and trained model.
98+
*
99+
* This method is implemented to just delegate to blocking predict() for
100+
* backward compatibility reasons.
101+
* Definitely overwrite it to implement your blocking prediction method, and leave
102+
* the old blocking predict() as it is (throwing an exception), it won't be called from
103+
* now on.
104+
*
105+
* @param model Trained model produced by [[train]].
106+
* @param query An input query.
107+
* @param ec ExecutionContext to use for async operations
108+
* @return A Future of a prediction.
109+
*/
110+
def predictAsync(model: M, query: Q)(implicit ec: ExecutionContext): Future[P] =
111+
Future.successful(blocking(predict(model, query)))
88112

89113
/** Implement this method to produce a prediction from a query and trained
90114
* model.
@@ -93,7 +117,9 @@ abstract class LAlgorithm[PD, M : ClassTag, Q, P]
93117
* @param q An input query.
94118
* @return A prediction.
95119
*/
96-
def predict(m: M, q: Q)(implicit ec: ExecutionContext): Future[P]
120+
@deprecated(message = "override non blocking predictAsync() instead", since = "0.14.0")
121+
def predict(m: M, q: Q): P =
122+
throw new NotImplementedError("predict() is deprecated, override predictAsync() instead")
97123

98124
/** :: DeveloperApi ::
99125
* Engine developers should not use this directly (read on to see how local

core/src/main/scala/org/apache/predictionio/controller/P2LAlgorithm.scala

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,34 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P]
7171
*/
7272
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] = {
7373
qs.mapValues { q =>
74-
Await.result(predict(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes)
74+
Await.result(predictAsync(m, q)(scala.concurrent.ExecutionContext.global), 60 minutes)
7575
}
7676
}
7777

78-
def predictBase(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] =
79-
predict(bm.asInstanceOf[M], q)(ec)
78+
override def predictBaseAsync(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] =
79+
predictAsync(bm.asInstanceOf[M], q)(ec)
80+
81+
@deprecated(message =
82+
"this method is just here for backward compatibility, predictBaseAsync() is called now",
83+
since = "0.14.0")
84+
override def predictBase(bm: Any, q: Q): P = predict(bm.asInstanceOf[M], q)
85+
86+
/** Implement this method to produce a Future of a prediction in a non blocking way
87+
* from a query and trained model.
88+
*
89+
* This method is implemented to just delegate to blocking predict() for
90+
* backward compatibility reasons.
91+
* Definitely overwrite it to implement your blocking prediction method, and leave
92+
* the old blocking predict() as it is (throwing an exception), it won't be called from
93+
* now on.
94+
*
95+
* @param model Trained model produced by [[train]].
96+
* @param query An input query.
97+
* @param ec ExecutionContext to use for async operations
98+
* @return A Future of a prediction.
99+
*/
100+
def predictAsync(model: M, query: Q)(implicit ec: ExecutionContext): Future[P] =
101+
Future.successful(blocking(predict(model, query)))
80102

81103
/** Implement this method to produce a prediction from a query and trained
82104
* model.
@@ -85,7 +107,9 @@ abstract class P2LAlgorithm[PD, M: ClassTag, Q: ClassTag, P]
85107
* @param query An input query.
86108
* @return A prediction.
87109
*/
88-
def predict(model: M, query: Q)(implicit ec: ExecutionContext): Future[P]
110+
@deprecated(message = "override non blocking predictAsync() instead", since = "0.14.0")
111+
def predict(model: M, query: Q): P =
112+
throw new NotImplementedError("predict() is deprecated, override predictAsync() instead")
89113

90114
/** :: DeveloperApi ::
91115
* Engine developers should not use this directly (read on to see how

core/src/main/scala/org/apache/predictionio/controller/PAlgorithm.scala

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.predictionio.workflow.PersistentModelManifest
2424
import org.apache.spark.SparkContext
2525
import org.apache.spark.rdd.RDD
2626

27-
import scala.concurrent.{ExecutionContext, Future}
27+
import scala.concurrent.{ExecutionContext, Future, blocking}
2828

2929
/** Base class of a parallel algorithm.
3030
*
@@ -74,9 +74,31 @@ abstract class PAlgorithm[PD, M, Q, P]
7474
def batchPredict(m: M, qs: RDD[(Long, Q)]): RDD[(Long, P)] =
7575
throw new NotImplementedError("batchPredict not implemented")
7676

77-
def predictBase(baseModel: Any, query: Q)(implicit ec: ExecutionContext): Future[P] = {
78-
predict(baseModel.asInstanceOf[M], query)(ec)
79-
}
77+
override def predictBaseAsync(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] =
78+
predictAsync(bm.asInstanceOf[M], q)(ec)
79+
80+
@deprecated(message =
81+
"this method is just here for backward compatibility, predictBaseAsync() is called now",
82+
since = "0.14.0")
83+
override def predictBase(baseModel: Any, query: Q): P =
84+
predict(baseModel.asInstanceOf[M], query)
85+
86+
/** Implement this method to produce a Future of a prediction in a non blocking way
87+
* from a query and trained model.
88+
*
89+
* This method is implemented to just delegate to blocking predict() for
90+
* backward compatibility reasons.
91+
* Definitely overwrite it to implement your blocking prediction method, and leave
92+
* the old blocking predict() as it is (throwing an exception), it won't be called from
93+
* now on.
94+
*
95+
* @param model Trained model produced by [[train]].
96+
* @param query An input query.
97+
* @param ec ExecutionContext to use for async operations
98+
* @return A Future of a prediction.
99+
*/
100+
def predictAsync(model: M, query: Q)(implicit ec: ExecutionContext): Future[P] =
101+
Future.successful(blocking(predict(model, query)))
80102

81103
/** Implement this method to produce a prediction from a query and trained
82104
* model.
@@ -85,7 +107,9 @@ abstract class PAlgorithm[PD, M, Q, P]
85107
* @param query An input query.
86108
* @return A prediction.
87109
*/
88-
def predict(model: M, query: Q)(implicit ec: ExecutionContext): Future[P]
110+
@deprecated(message = "override non blocking predictAsync() instead", since = "0.14.0")
111+
def predict(model: M, query: Q): P =
112+
throw new NotImplementedError("predict() is deprecated, override predictAsync() instead")
89113

90114
/** :: DeveloperApi ::
91115
* Engine developers should not use this directly (read on to see how parallel

core/src/main/scala/org/apache/predictionio/core/BaseAlgorithm.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import net.jodah.typetools.TypeResolver
2626
import org.apache.spark.SparkContext
2727
import org.apache.spark.rdd.RDD
2828

29-
import scala.concurrent.{ExecutionContext, Future}
29+
import scala.concurrent.{ExecutionContext, Future, blocking}
3030

3131
/** :: DeveloperApi ::
3232
* Base trait with default custom query serializer, exposed to engine developer
@@ -83,6 +83,19 @@ abstract class BaseAlgorithm[PD, M, Q, P]
8383
def batchPredictBase(sc: SparkContext, bm: Any, qs: RDD[(Long, Q)])
8484
: RDD[(Long, P)]
8585

86+
/** :: DeveloperApi ::
87+
* Engine developers should not use this directly. Called by serving to
88+
* perform a single prediction.
89+
*
90+
* @param bm Model
91+
* @param q Query
92+
* @param ec ExecutionContext to use for async operations
93+
* @return Future of a Predicted result
94+
*/
95+
@DeveloperApi
96+
def predictBaseAsync(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P] =
97+
Future.successful(blocking {predictBase(bm, q)})
98+
8699
/** :: DeveloperApi ::
87100
* Engine developers should not use this directly. Called by serving to
88101
* perform a single prediction.
@@ -92,7 +105,11 @@ abstract class BaseAlgorithm[PD, M, Q, P]
92105
* @return Predicted result
93106
*/
94107
@DeveloperApi
95-
def predictBase(bm: Any, q: Q)(implicit ec: ExecutionContext): Future[P]
108+
@deprecated(message = "override non blocking predictBaseAsync() instead", since = "0.14.0")
109+
def predictBase(bm: Any, q: Q): P =
110+
throw new NotImplementedError(
111+
"predictBase() is deprecated, override predictBaseAsync() instead"
112+
)
96113

97114
/** :: DeveloperApi ::
98115
* Engine developers should not use this directly. Prepare a model for

core/src/main/scala/org/apache/predictionio/workflow/BatchPredict.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ object BatchPredict extends Logging {
213213
// finally Serving.serve.
214214
val supplementedQuery = serving.supplementBase(query)
215215
val predictionsFuture = Future.sequence(algorithms.zip(models).map { case (a, m) =>
216-
a.predictBase(m, supplementedQuery)
216+
a.predictBaseAsync(m, supplementedQuery)
217217
})
218218
// Notice that it is by design to call Serving.serve with the
219219
// *original* query.

core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ class PredictionServer[Q, P](
506506
val supplementedQuery = serving.supplementBase(query)
507507

508508
val predictionsFuture = Future.sequence(algorithms.zip(models).map { case (a, m) =>
509-
a.predictBase(m, supplementedQuery)
509+
a.predictBaseAsync(m, supplementedQuery)
510510
})
511511
// Notice that it is by design to call Serving.serve with the
512512
// *original* query.

0 commit comments

Comments
 (0)