Skip to content

Commit 6eccb08

Browse files
committed
Refactor getUserGroups (no functional changes)
1 parent e3dcb30 commit 6eccb08

File tree

1 file changed

+47
-53
lines changed
  • libs/wire-subsystems/src/Wire/UserGroupStore

1 file changed

+47
-53
lines changed

libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs

Lines changed: 47 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,44 @@ getUserGroupsWithMembers ::
231231
Sem r UserGroupPageWithMembers
232232
getUserGroupsWithMembers _req = pure $ UserGroupPage [] 0 -- this is the stub
233233

234+
groupMatchIdName :: UserGroupPageRequest -> [QueryFragment]
235+
groupMatchIdName req =
236+
clause1 "team_id" "=" req.team :
237+
[like "name" name | name <- toList $ req.searchString]
238+
239+
groupPaginationWhereClause :: UserGroupPageRequest -> [QueryFragment]
240+
groupPaginationWhereClause req =
241+
[clause (sortOrderOperator req.sortOrder) ps | ps <- toList $ paginationClause req.paginationState]
242+
243+
groupPaginationOrderBy :: UserGroupPageRequest -> [QueryFragment]
244+
groupPaginationOrderBy req =
245+
[ orderBy [ (sortColumn req.paginationState, req.sortOrder),
246+
("id", req.sortOrder)
247+
],
248+
limit (pageSizeToInt32 req.pageSize)
249+
]
250+
where
251+
sortColumn :: PaginationState a -> Text
252+
sortColumn = \case
253+
PaginationSortByName _ -> "name"
254+
PaginationSortByCreatedAt _ -> "created_at"
255+
256+
getCountSession :: UserGroupPageRequest -> Tx.Transaction Int
257+
getCountSession req = Tx.statement () $ refineResult parseCount $ buildStatement query decoder
258+
where
259+
query = literal "select count(*) from user_group" <> where_ (groupMatchIdName req)
260+
decoder = HD.singleRow (HD.column (HD.nonNullable HD.int8))
261+
262+
decodeUuidVector :: HD.Row (Vector UUID)
263+
decodeUuidVector = HD.column $ HD.nonNullable $
264+
HD.array $ HD.dimension V.replicateM $ HD.element $ HD.nonNullable HD.uuid
265+
266+
parseManagedBy :: Int32 -> Either Text ManagedBy
267+
parseManagedBy = \case
268+
0 -> pure ManagedByWire
269+
1 -> pure ManagedByScim
270+
bad -> Left $ "Could not parse managedBy value: " <> T.pack (show bad)
271+
234272
getUserGroups ::
235273
forall r.
236274
( UserGroupStorePostgresEffectConstraints r,
@@ -241,7 +279,7 @@ getUserGroups ::
241279
getUserGroups req@(UserGroupPageRequest {..}) = do
242280
loc <- inputQualifyLocal ()
243281
runTransaction TxSessions.ReadCommitted TxSessions.Read $
244-
UserGroupPage <$> getUserGroupsSession loc <*> getCountSession
282+
UserGroupPage <$> getUserGroupsSession loc <*> getCountSession req
245283
where
246284
getUserGroupsSession :: Local () -> Tx.Transaction [UserGroupMeta]
247285
getUserGroupsSession loc =
@@ -252,36 +290,11 @@ getUserGroups req@(UserGroupPageRequest {..}) = do
252290
[ literal "select",
253291
literal selectors,
254292
literal "from user_group as ug",
255-
where_
256-
( [clause1 "team_id" "=" req.team]
257-
<> [ clause (sortOrderOperator sortOrder) c
258-
| c <- toList (paginationClause paginationState)
259-
]
260-
<> toList (like "name" <$> searchString)
261-
)
262-
]
263-
<> [ orderBy
264-
[ (sortColumn, sortOrder),
265-
("id", sortOrder)
266-
],
267-
limit (pageSizeToInt32 req.pageSize)
268-
]
293+
where_ (groupMatchIdName req <> groupPaginationWhereClause req)
294+
] <> groupPaginationOrderBy req
269295
)
270296
decodeRow
271297

272-
getCountSession :: Tx.Transaction Int
273-
getCountSession =
274-
Tx.statement () $
275-
refineResult parseCount $
276-
buildStatement
277-
( literal "select count(*) from user_group"
278-
<> where_
279-
( [clause1 "team_id" "=" req.team]
280-
<> toList (like "name" <$> searchString)
281-
)
282-
)
283-
(HD.singleRow (HD.column (HD.nonNullable HD.int8)))
284-
285298
decodeRow :: HD.Result [(UUID, Text, Int32, UTCTime, Maybe Int32, Int32, Maybe (Vector UUID))]
286299
decodeRow =
287300
HD.rowList
@@ -293,47 +306,28 @@ getUserGroups req@(UserGroupPageRequest {..}) = do
293306
<*> (if req.includeMemberCount then Just <$> HD.column (HD.nonNullable HD.int4) else pure Nothing)
294307
<*> HD.column (HD.nonNullable HD.int4)
295308
<*> ( if req.includeChannels
296-
then
297-
Just
298-
<$> HD.column
299-
( HD.nonNullable
300-
( HD.array
301-
( HD.dimension
302-
V.replicateM
303-
(HD.element (HD.nonNullable HD.uuid))
304-
)
305-
)
306-
)
309+
then Just <$> decodeUuidVector
307310
else pure Nothing
308311
)
309312
)
310313

311314
parseRow :: Local a -> (UUID, Text, Int32, UTCTime, Maybe Int32, Int32, Maybe (Vector UUID)) -> Either Text UserGroupMeta
312315
parseRow loc (Id -> id_, namePre, managedByPre, toUTCTimeMillis -> createdAt, membersCountRaw, channelsCountRaw, maybeChannels) = do
313-
managedBy <- case managedByPre of
314-
0 -> pure ManagedByWire
315-
1 -> pure ManagedByScim
316-
bad -> Left $ "Could not parse managedBy value: " <> T.pack (show bad)
316+
managedBy <- parseManagedBy managedByPre
317317
name <- userGroupNameFromText namePre
318318
let members = Const ()
319319
membersCount = fromIntegral <$> membersCountRaw
320320
channelsCount = Just (fromIntegral channelsCountRaw)
321321
channels = fmap (fmap (tUntagged . qualifyAs loc . Id)) maybeChannels
322322
pure $ UserGroup_ {..}
323323

324-
sortColumn :: Text
325-
sortColumn = case paginationState of
326-
PaginationSortByName _ -> "name"
327-
PaginationSortByCreatedAt _ -> "created_at"
328-
329324
selectors :: Text
330325
selectors =
331326
T.intercalate ", " $
332-
filter (not . T.null) $
333-
["id", "name", "managed_by", "created_at"]
334-
<> ["(select count(*) from user_group_member as ugm where ugm.user_group_id = ug.id) as members" | includeMemberCount]
335-
<> ["(select count(*) from user_group_channel as ugc where ugc.user_group_id = ug.id) as channels"]
336-
<> ["coalesce((select array_agg(ugc.conv_id) from user_group_channel as ugc where ugc.user_group_id = ug.id), array[]::uuid[]) as channel_ids" | includeChannels]
327+
["id", "name", "managed_by", "created_at"]
328+
<> ["(select count(*) from user_group_member as ugm where ugm.user_group_id = ug.id) as members" | includeMemberCount]
329+
<> ["(select count(*) from user_group_channel as ugc where ugc.user_group_id = ug.id) as channels"]
330+
<> ["coalesce((select array_agg(ugc.conv_id) from user_group_channel as ugc where ugc.user_group_id = ug.id), array[]::uuid[]) as channel_ids" | includeChannels]
337331

338332
createUserGroup ::
339333
forall r.

0 commit comments

Comments
 (0)