Skip to content

Commit 6861515

Browse files
authored
feat: expose model api to client (#6)
* feat: expose model api to client * feat: expose model api to client * feat: expose the 3 needed APIs
1 parent 5451736 commit 6861515

File tree

5 files changed

+171
-23
lines changed

5 files changed

+171
-23
lines changed

build.gradle.kts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies {
3232
implementation("com.squareup.okio:okio:3.4.0")
3333

3434
testImplementation("org.junit.jupiter:junit-jupiter:5.10.0")
35+
testImplementation("org.mockito:mockito-core:5.14.2")
3536
testImplementation("com.squareup.okhttp3:mockwebserver:4.12.0")
3637
}
3738

@@ -45,6 +46,10 @@ tasks.withType<org.jetbrains.kotlin.gradle.tasks.KotlinCompile> {
4546
}
4647
}
4748

49+
tasks.test {
50+
useJUnitPlatform()
51+
}
52+
4853
sourceSets {
4954
main {
5055
java {

src/main/kotlin/client/ModelApiClient.kt renamed to src/main/kotlin/client/JaqpotApiClient.kt

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,25 @@
11
package org.jaqpot.client
22

3-
import client.BaseApiClient
43
import org.jaqpot.exception.JaqpotSDKException
4+
import org.jaqpot.http.JaqpotHttpClient
55
import org.openapitools.client.api.DatasetApi
6+
import org.openapitools.client.api.FeatureApi
67
import org.openapitools.client.api.ModelApi
78
import org.openapitools.client.model.Dataset
89
import org.openapitools.client.model.DatasetType
9-
import retrofit2.Call
1010

1111

12-
class ModelApiClient(
13-
apiKey: String,
14-
apiSecret: String
15-
) : BaseApiClient(apiKey, apiSecret) {
12+
class JaqpotApiClient(private val apiKey: String, private val apiSecret: String) {
1613

17-
private val modelApi: ModelApi = retrofit.create(ModelApi::class.java)
18-
private val datasetApi: DatasetApi = retrofit.create(DatasetApi::class.java)
14+
private val jaqpotHttpClient: JaqpotHttpClient = JaqpotHttpClient(apiKey, apiSecret)
15+
val modelApi: ModelApi = jaqpotHttpClient.retrofit.create(ModelApi::class.java)
16+
val datasetApi: DatasetApi = jaqpotHttpClient.retrofit.create(DatasetApi::class.java)
17+
val featureApi: FeatureApi = jaqpotHttpClient.retrofit.create(FeatureApi::class.java)
1918

2019
companion object {
2120
const val DATASET_CHECK_INTERVAL: Long = 2000
2221
}
2322

24-
fun predictAsync(modelId: Long, input: List<Any>): Call<Void>? {
25-
return modelApi.predictWithModel(
26-
modelId,
27-
Dataset.Builder().type(DatasetType.PREDICTION)
28-
.entryType(Dataset.EntryTypeEnum.ARRAY)
29-
.input(input)
30-
.build()
31-
)
32-
}
33-
3423
fun predictSync(modelId: Long, input: List<Any>): Dataset {
3524
val response = modelApi.predictWithModel(
3625
modelId,
@@ -41,7 +30,7 @@ class ModelApiClient(
4130
).execute()
4231

4332
if (!response.isSuccessful) {
44-
val message = response.errorBody()?.string();
33+
val message = response.errorBody()?.string()
4534
if (response.code() == 401) {
4635
throw JaqpotSDKException("Prediction failed: Unauthenticated \n$message", response.errorBody())
4736
} else if (response.code() == 403) {

src/main/kotlin/client/BaseApiClient.kt renamed to src/main/kotlin/http/JaqpotHttpClient.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package client
1+
package org.jaqpot.http
22

33
import auth.AuthorizationInterceptor
44
import com.google.gson.Gson
@@ -11,7 +11,7 @@ import retrofit2.converter.gson.GsonConverterFactory
1111
import java.time.OffsetDateTime
1212

1313

14-
open class BaseApiClient protected constructor(
14+
class JaqpotHttpClient(
1515
apiKey: String,
1616
apiSecret: String
1717
) {
@@ -21,11 +21,11 @@ open class BaseApiClient protected constructor(
2121
.addInterceptor(authorizationInterceptor)
2222
.build()
2323

24-
protected val gson: Gson = GsonBuilder()
24+
private val gson: Gson = GsonBuilder()
2525
.registerTypeAdapter(OffsetDateTime::class.java, JSON.OffsetDateTimeTypeAdapter())
2626
.create()
2727

28-
protected val retrofit: Retrofit = Retrofit.Builder()
28+
val retrofit: Retrofit = Retrofit.Builder()
2929
.baseUrl(SDKConfig.host)
3030
.client(httpClient)
3131
.addConverterFactory(GsonConverterFactory.create(gson))

src/test/kotlin/TestUtil.kt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import org.jaqpot.client.JaqpotApiClient
2+
import org.junit.platform.commons.util.ReflectionUtils
3+
import java.lang.reflect.Field
4+
5+
class TestUtil {
6+
companion object {
7+
fun mockPrivateProperty(instance: Any, propertyName: String, mockedValue: Any) {
8+
val field: Field = ReflectionUtils
9+
.findFields(
10+
JaqpotApiClient::class.java, { f: Field -> f.name == propertyName },
11+
ReflectionUtils.HierarchyTraversalMode.TOP_DOWN
12+
)[0]
13+
14+
field.setAccessible(true)
15+
field.set(instance, mockedValue)
16+
}
17+
}
18+
19+
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
package client
2+
3+
import TestUtil.Companion.mockPrivateProperty
4+
import okhttp3.Headers.Companion.toHeaders
5+
import okhttp3.ResponseBody
6+
import org.jaqpot.client.JaqpotApiClient
7+
import org.jaqpot.exception.JaqpotSDKException
8+
import org.junit.jupiter.api.Assertions.assertEquals
9+
import org.junit.jupiter.api.Assertions.assertTrue
10+
import org.junit.jupiter.api.BeforeEach
11+
import org.junit.jupiter.api.Test
12+
import org.junit.jupiter.api.assertThrows
13+
import org.mockito.Mockito.*
14+
import org.openapitools.client.api.DatasetApi
15+
import org.openapitools.client.api.ModelApi
16+
import org.openapitools.client.model.Dataset
17+
import retrofit2.Call
18+
import retrofit2.Response
19+
20+
21+
class JaqpotApiClientTest {
22+
23+
private lateinit var modelApi: ModelApi
24+
25+
private lateinit var datasetApi: DatasetApi
26+
27+
private lateinit var jaqpotApiClient: JaqpotApiClient
28+
29+
@BeforeEach
30+
fun setUp() {
31+
modelApi = mock(ModelApi::class.java)
32+
datasetApi = mock(DatasetApi::class.java)
33+
jaqpotApiClient = JaqpotApiClient("apiKey", "apiSecret")
34+
mockPrivateProperty(jaqpotApiClient, "modelApi", modelApi)
35+
mockPrivateProperty(jaqpotApiClient, "datasetApi", datasetApi)
36+
}
37+
38+
39+
@Test
40+
fun `predictSync should return dataset on success`() {
41+
val dataset = Dataset().apply {
42+
id = 1L
43+
status = Dataset.StatusEnum.SUCCESS
44+
}
45+
46+
val mockCall = mock(Call::class.java) as Call<Void>
47+
val mockResponse = mock(Response::class.java) as Response<Void>
48+
val mockDatasetCall = mock(Call::class.java) as Call<Dataset>
49+
val mockDatasetResponse = mock(Response::class.java) as Response<Dataset>
50+
51+
`when`(modelApi.predictWithModel(anyLong(), any(Dataset::class.java))).thenReturn(mockCall)
52+
`when`(mockCall.execute()).thenReturn(mockResponse)
53+
`when`(mockResponse.isSuccessful).thenReturn(true)
54+
`when`(mockResponse.headers()).thenReturn(mapOf("Location" to "/datasets/1").toHeaders())
55+
`when`(datasetApi.getDatasetById(anyLong())).thenReturn(mockDatasetCall)
56+
`when`(mockDatasetCall.execute()).thenReturn(mockDatasetResponse)
57+
`when`(mockDatasetResponse.body()).thenReturn(dataset)
58+
59+
val result = jaqpotApiClient.predictSync(1L, listOf())
60+
61+
assertEquals(dataset, result)
62+
}
63+
64+
@Test
65+
fun `predictSync should throw exception on failure`() {
66+
val mockCall = mock(Call::class.java) as Call<Void>
67+
val mockResponse = mock(Response::class.java) as Response<Void>
68+
69+
`when`(modelApi.predictWithModel(anyLong(), any(Dataset::class.java))).thenReturn(mockCall)
70+
`when`(mockCall.execute()).thenReturn(mockResponse)
71+
`when`(mockResponse.isSuccessful).thenReturn(false)
72+
`when`(mockResponse.errorBody()).thenReturn(mock(ResponseBody::class.java))
73+
74+
val exception = assertThrows<JaqpotSDKException> {
75+
jaqpotApiClient.predictSync(1L, listOf())
76+
}
77+
78+
assertTrue(exception.message!!.contains("Prediction failed"))
79+
}
80+
81+
@Test
82+
fun `predictSync should throw exception on dataset failure`() {
83+
val dataset = Dataset().apply {
84+
id = 1L
85+
status = Dataset.StatusEnum.FAILURE
86+
failureReason = "Some failure reason"
87+
}
88+
89+
val mockCall = mock(Call::class.java) as Call<Void>
90+
val mockResponse = mock(Response::class.java) as Response<Void>
91+
val mockDatasetCall = mock(Call::class.java) as Call<Dataset>
92+
val mockDatasetResponse = mock(Response::class.java) as Response<Dataset>
93+
94+
`when`(modelApi.predictWithModel(anyLong(), any(Dataset::class.java))).thenReturn(mockCall)
95+
`when`(mockCall.execute()).thenReturn(mockResponse)
96+
`when`(mockResponse.isSuccessful).thenReturn(true)
97+
`when`(mockResponse.headers()).thenReturn(mapOf("Location" to "/datasets/1").toHeaders())
98+
`when`(datasetApi.getDatasetById(anyLong())).thenReturn(mockDatasetCall)
99+
`when`(mockDatasetCall.execute()).thenReturn(mockDatasetResponse)
100+
`when`(mockDatasetResponse.body()).thenReturn(dataset)
101+
102+
val exception = assertThrows<JaqpotSDKException> {
103+
jaqpotApiClient.predictSync(1L, listOf())
104+
}
105+
106+
assertTrue(exception.message!!.contains("Prediction failed: Some failure reason"))
107+
}
108+
109+
@Test
110+
fun `predictSync should throw exception on maximum retries reached`() {
111+
val dataset = Dataset().apply {
112+
id = 1L
113+
status = Dataset.StatusEnum.EXECUTING
114+
}
115+
116+
val mockCall = mock(Call::class.java) as Call<Void>
117+
val mockResponse = mock(Response::class.java) as Response<Void>
118+
val mockDatasetCall = mock(Call::class.java) as Call<Dataset>
119+
val mockDatasetResponse = mock(Response::class.java) as Response<Dataset>
120+
121+
`when`(modelApi.predictWithModel(anyLong(), any(Dataset::class.java))).thenReturn(mockCall)
122+
`when`(mockCall.execute()).thenReturn(mockResponse)
123+
`when`(mockResponse.isSuccessful).thenReturn(true)
124+
`when`(mockResponse.headers()).thenReturn(mapOf("Location" to "/datasets/1").toHeaders())
125+
`when`(datasetApi.getDatasetById(anyLong())).thenReturn(mockDatasetCall)
126+
`when`(mockDatasetCall.execute()).thenReturn(mockDatasetResponse)
127+
`when`(mockDatasetResponse.body()).thenReturn(dataset)
128+
129+
val exception = assertThrows<JaqpotSDKException> {
130+
jaqpotApiClient.predictSync(1L, listOf())
131+
}
132+
133+
assertTrue(exception.message!!.contains("Maximum amount of retries reached"))
134+
}
135+
}

0 commit comments

Comments
 (0)