16
16
package za .co .absa .hyperdrive .trigger .persistance
17
17
18
18
import org .springframework .stereotype
19
+ import slick .ast .BaseTypedType
20
+ import slick .jdbc .SetParameter .{SetBoolean , SetInt , SetLong , SetString , SetTimestamp , SetUnit }
21
+ import slick .jdbc .{GetResult , PositionedParameters , SQLActionBuilder , SetParameter }
19
22
import za .co .absa .hyperdrive .trigger .models .dagRuns .DagRun
20
- import za .co .absa .hyperdrive .trigger .models .search .{TableSearchRequest , TableSearchResponse }
23
+ import za .co .absa .hyperdrive .trigger .models .search .{BooleanFilterAttributes , ContainsFilterAttributes , DateTimeRangeFilterAttributes , EqualsMultipleFilterAttributes , IntRangeFilterAttributes , LongFilterAttributes , SortAttributes , TableSearchRequest , TableSearchResponse }
21
24
25
+ import java .sql .Timestamp
22
26
import javax .inject .Inject
23
27
import scala .concurrent .{ExecutionContext , Future }
24
28
@@ -28,7 +32,211 @@ trait DagRunRepository extends Repository {
28
32
29
33
@ stereotype.Repository
30
34
class DagRunRepositoryImpl @ Inject ()(val dbProvider : DatabaseProvider ) extends DagRunRepository {
35
+ import api ._
36
+ private val fieldMapping = Map (
37
+ " workflowId" -> " workflow.id" ,
38
+ " workflowName" -> " workflow.name" ,
39
+ " projectName" -> " workflow.project" ,
40
+ " started" -> " dag_instance.started" ,
41
+ " finished" -> " dag_instance.finished" ,
42
+ " status" -> " dag_instance.status" ,
43
+ " triggeredBy" -> " dag_instance.triggered_by" ,
44
+ " id" -> " dag_instance.id"
45
+ )
46
+
47
+ private def orderByMapping (index : Int ) = if (index == - 1 ) " DESC" else " ASC"
48
+
31
49
override def searchDagRuns (searchRequest : TableSearchRequest )(implicit ec : ExecutionContext ): Future [TableSearchResponse [DagRun ]] = {
32
- db.run(dagRunTable.search(searchRequest))
50
+
51
+ val dagIdsQueryMain =
52
+ sql """
53
+ FROM dag_instance
54
+ JOIN workflow on dag_instance.workflow_id = workflow.id
55
+ WHERE 1=1
56
+ """
57
+ val queryFilters = generateQueryFilters(searchRequest)
58
+ val dagIdsQueryFilters = queryFilters._1
59
+ val setParameters = queryFilters._2
60
+
61
+ val queryOrderBy = searchRequest.sort match {
62
+ case Some (SortAttributes (by, order)) =>
63
+ sql """
64
+ ORDER BY # ${fieldMapping(by)} # ${orderByMapping(order)}, dag_instance.id DESC
65
+ """
66
+ case None =>
67
+ sql """
68
+ ORDER BY dag_instance.id DESC
69
+ """
70
+ }
71
+ val dagIdsQueryLimitOffset =
72
+ sql """
73
+ LIMIT # ${searchRequest.size} OFFSET # ${searchRequest.from}
74
+ """
75
+ val dagIdsQueryOpeningPart =
76
+ sql """
77
+ WITH dag_ids AS (
78
+ SELECT dag_instance.id
79
+ """
80
+ val dagIdsQueryClosingPart =
81
+ sql """
82
+ )
83
+ """
84
+ val countQuerySelect =
85
+ sql """
86
+ SELECT COUNT(1)
87
+ """
88
+ val dagRunQueryMain =
89
+ sql """
90
+ select workflow.id,
91
+ workflow.name,
92
+ workflow.project,
93
+ COALESCE(jobInstanceCount.count, 0) AS "job_count",
94
+ dag_instance.started,
95
+ dag_instance.finished,
96
+ dag_instance.status,
97
+ dag_instance.triggered_by,
98
+ dag_instance.id
99
+ from dag_instance
100
+ left join (
101
+ select job_instance.dag_instance_id, count(1) as "count"
102
+ from job_instance
103
+ join dag_ids on dag_ids.id = job_instance.dag_instance_id
104
+ group by dag_instance_id
105
+ ) as jobInstanceCount
106
+ on jobInstanceCount.dag_instance_id = dag_instance.id
107
+ left join workflow
108
+ on workflow.id = dag_instance.workflow_id
109
+ join dag_ids on dag_ids.id = dag_instance.id
110
+ """
111
+
112
+
113
+ val countQuery = SQLActionBuilder (
114
+ concatQueryParts(countQuerySelect, dagIdsQueryMain, dagIdsQueryFilters),
115
+ setParameters)
116
+ val countQueryAction = countQuery.as[Int ].head
117
+
118
+ implicit val getDagRunResult : GetResult [DagRun ] = GetResult (r => DagRun (
119
+ r.nextLong(), r.nextString(), r.nextString(), r.nextInt(),
120
+ r.nextTimestamp().toLocalDateTime, r.nextTimestampOption().map(_.toLocalDateTime),
121
+ r.nextString(), r.nextString(), r.nextLong()))
122
+ val dagRunsQuery = SQLActionBuilder (
123
+ concatQueryParts(dagIdsQueryOpeningPart, dagIdsQueryMain, dagIdsQueryFilters, queryOrderBy,
124
+ dagIdsQueryLimitOffset, dagIdsQueryClosingPart, dagRunQueryMain, queryOrderBy), setParameters)
125
+ val dagRunsQueryAction = dagRunsQuery.as[DagRun ]
126
+
127
+ db.run(
128
+ for {
129
+ l <- countQueryAction
130
+ r <- dagRunsQueryAction
131
+ } yield {
132
+ TableSearchResponse [DagRun ](items = r, total = l)
133
+ }
134
+ )
135
+ }
136
+
137
+ private def concatQueryParts (sqlActionBuilders : SQLActionBuilder * ): Seq [Any ] = {
138
+ sqlActionBuilders.map(_.queryParts).reduceOption(_ ++ _).getOrElse(Seq ())
139
+ }
140
+
141
+ private def generateQueryFilters (request : TableSearchRequest ) = {
142
+ val filters =
143
+ applyContainsFilter(request.getContainsFilterAttributes) ++
144
+ applyIntRangeFilter(request.getIntRangeFilterAttributes) ++
145
+ applyDateTimeRangeFilter(request.getDateTimeRangeFilterAttributes) ++
146
+ applyEqualsMultipleFilter(request.getEqualsMultipleFilterAttributes) ++
147
+ applyLongFilter(request.getLongFilterAttributes) ++
148
+ applyBooleanFilter(request.getBooleanFilterAttributes)
149
+ val queryPart = concatQueryParts(filters.map(_._1):_* )
150
+ val setParameterFns = filters.flatMap(_._2)
151
+ val setParameter = SetParameter ((_ : Unit , pp : PositionedParameters ) => {
152
+ setParameterFns.foreach(fn => fn(pp))
153
+ })
154
+ (SQLActionBuilder (queryPart, SetUnit ), setParameter)
155
+ }
156
+
157
+ private def applyContainsFilter (attributes : Seq [ContainsFilterAttributes ]) = {
158
+ attributes
159
+ .filter(attribute => fieldMapping.contains(attribute.field))
160
+ .map(attribute => (
161
+ sql """ AND # ${fieldMapping(attribute.field)} LIKE '%' || ? || '%' """ ,
162
+ Seq ((pp : PositionedParameters ) => SetString (attribute.value, pp))
163
+ ))
164
+ }
165
+
166
+ private def applyIntRangeFilter (attributes : Seq [IntRangeFilterAttributes ]) = {
167
+ val setIntParameter = (v : Int , pp: PositionedParameters ) => SetInt (v, pp)
168
+ attributes
169
+ .filter(attribute => fieldMapping.contains(attribute.field))
170
+ .map(attribute => applyRangeFilter(attribute.field, attribute.start, attribute.end, setIntParameter))
171
+ }
172
+
173
+ private def applyDateTimeRangeFilter (attributes : Seq [DateTimeRangeFilterAttributes ]) = {
174
+ val setTimestampParameter = (v : Timestamp , pp: PositionedParameters ) => SetTimestamp (v, pp)
175
+ attributes
176
+ .filter(attribute => fieldMapping.contains(attribute.field))
177
+ .map(attribute => applyRangeFilter(attribute.field, attribute.start.map(Timestamp .valueOf), attribute.end.map(Timestamp .valueOf), setTimestampParameter))
178
+ }
179
+
180
+ private def applyEqualsMultipleFilter (attributes : Seq [EqualsMultipleFilterAttributes ]) = {
181
+ attributes
182
+ .filter(attribute => fieldMapping.contains(attribute.field))
183
+ .map { attribute =>
184
+ val placeholders = List .fill(attribute.values.size)(" ?" ).mkString(" (" , " ," , " )" )
185
+ (
186
+ sql """ AND # ${fieldMapping(attribute.field)} IN # ${placeholders}""" ,
187
+ attribute.values.map(value => (pp : PositionedParameters ) => SetString (value, pp))
188
+ )
189
+ }
190
+ }
191
+
192
+ private def applyLongFilter (attributes : Seq [LongFilterAttributes ]) = {
193
+ attributes
194
+ .filter(attribute => fieldMapping.contains(attribute.field))
195
+ .map(attribute => (
196
+ sql """ AND # ${fieldMapping(attribute.field)} = ? """ ,
197
+ Seq ((pp : PositionedParameters ) => SetLong (attribute.value, pp))
198
+ ))
199
+ }
200
+
201
+ private def applyBooleanFilter (attributes : Seq [BooleanFilterAttributes ]) = {
202
+ attributes
203
+ .filter(attribute => fieldMapping.contains(attribute.field))
204
+ .filter(attribute => attribute.value.isTrue != attribute.value.isFalse)
205
+ .map(attribute => (
206
+ sql """ AND # ${fieldMapping(attribute.field)} = ? """ ,
207
+ Seq ((pp : PositionedParameters ) => SetBoolean (attribute.value.isTrue, pp))
208
+ ))
209
+ }
210
+
211
+ private def applyRangeFilter [B : BaseTypedType ](field : String , start : Option [B ], end : Option [B ], setParameterFn : (B , PositionedParameters ) => Unit ) = {
212
+ if (start.isDefined && end.isDefined) {
213
+ (
214
+ sql """ AND # ${fieldMapping(field)} >= ?
215
+ AND # ${fieldMapping(field)} <= ? """ ,
216
+ Seq (
217
+ (pp : PositionedParameters ) => setParameterFn(start.get, pp),
218
+ (pp : PositionedParameters ) => setParameterFn(end.get, pp)
219
+ )
220
+ )
221
+ } else if (start.isDefined) {
222
+ (
223
+ sql """ AND # ${fieldMapping(field)} >= ? """ ,
224
+ Seq (
225
+ (pp : PositionedParameters ) => setParameterFn(start.get, pp)
226
+ )
227
+ )
228
+ } else if (end.isDefined) {
229
+ (
230
+ sql """ AND # ${fieldMapping(field)} <= ? """ ,
231
+ Seq (
232
+ (pp : PositionedParameters ) => setParameterFn(end.get, pp)
233
+ )
234
+ )
235
+ } else {
236
+ (
237
+ sql """ AND 1=1 """ ,
238
+ Seq ()
239
+ )
240
+ }
33
241
}
34
242
}
0 commit comments