1+ package org.thoughtcrime.securesms.database
2+
3+ import android.content.Context
4+ import androidx.sqlite.db.transaction
5+ import dagger.hilt.android.qualifiers.ApplicationContext
6+ import kotlinx.serialization.ExperimentalSerializationApi
7+ import kotlinx.serialization.SerialName
8+ import kotlinx.serialization.Serializable
9+ import kotlinx.serialization.json.Json
10+ import kotlinx.serialization.json.JsonClassDiscriminator
11+ import org.session.libsession.utilities.serializable.InstantAsMillisSerializer
12+ import org.thoughtcrime.securesms.database.helpers.SQLCipherOpenHelper
13+ import org.thoughtcrime.securesms.util.asSequence
14+ import java.time.Instant
15+ import javax.inject.Inject
16+ import javax.inject.Provider
17+ import javax.inject.Singleton
18+
19+ @Singleton
20+ class PushRegistrationDatabase @Inject constructor(
21+ @ApplicationContext context : Context ,
22+ helper : Provider <SQLCipherOpenHelper >,
23+ private val json : Json ,
24+ ) : Database(context, helper) {
25+
26+
27+ @Serializable
28+ data class Registration (val accountId : String , val input : Input )
29+
30+ /* *
31+ * Ensure that the provided registrations exist in the database. If input changes for an existing
32+ * registration, reset its state to NONE. Any registrations not in the provided list will be
33+ * marked as pending unregistration.
34+ *
35+ * @return The number of database rows that were changed.
36+ */
37+ fun ensureRegistrations (registrations : Collection <Registration >): Int {
38+ val registrationsAsText = json.encodeToString(registrations)
39+
40+ // It's important to specify the base RegistrationState so that the discriminator is correct
41+ val pendingRegisterAsText = json.encodeToString<RegistrationState >(RegistrationState .PendingRegister )
42+ val pendingUnregisterAsText = json.encodeToString<RegistrationState >(RegistrationState .PendingUnregister )
43+
44+ return writableDatabase.transaction {
45+ var numChanges = 0
46+
47+ if (registrations.isNotEmpty()) {
48+ // Insert the provided registrations with PendingRegister state
49+ // If they already exist with a PendingUnregister state, flip them back to PendingRegister,
50+ // otherwise keep their existing state.
51+ compileStatement(
52+ """
53+ INSERT INTO push_registration_state (account_id, input, state)
54+ SELECT
55+ value->>'$.accountId',
56+ value->>'$.input',
57+ :pending_register_state
58+ FROM json_each(:registrations)
59+ WHERE TRUE
60+ ON CONFLICT DO UPDATE
61+ SET state = :pending_register_state
62+ WHERE state_type = '$TYPE_PENDING_UNREGISTER '
63+ """
64+ ).use { stmt ->
65+ stmt.bindString(1 , pendingRegisterAsText)
66+ stmt.bindString(2 , registrationsAsText)
67+ numChanges + = stmt.executeUpdateDelete()
68+ }
69+ }
70+
71+ // Mark all other registrations that are registered or error as PendingUnregister to be cleaned up
72+ compileStatement("""
73+ UPDATE push_registration_state
74+ SET state = ?
75+ WHERE (account_id, input) NOT IN (SELECT value->>'$.accountId', value->>'$.input' FROM json_each(?))
76+ AND state_type IN ('$TYPE_REGISTERED ', '$TYPE_ERROR ')
77+ """ ).use { stmt ->
78+ stmt.bindString(1 , pendingUnregisterAsText)
79+ stmt.bindString(2 , registrationsAsText)
80+ numChanges + = stmt.executeUpdateDelete()
81+ }
82+
83+ // Delete no longer desired registrations that didn't start, or ended up in a permanent error
84+ // Note: the changes here do not count towards numChanges since they won't affect
85+ // the scheduling
86+ compileStatement("""
87+ DELETE FROM push_registration_state
88+ WHERE state_type IN ('$TYPE_PERMANENT_ERROR ', '$TYPE_PENDING_REGISTER ')
89+ AND (account_id, input) NOT IN (SELECT value->>'$.accountId', value->>'$.input' FROM json_each(?))
90+ """ ).use { stmt ->
91+ stmt.bindString(1 , registrationsAsText)
92+ stmt.execute()
93+ }
94+
95+ numChanges
96+ }
97+ }
98+
99+ fun updateRegistrations (registrationWithStates : Collection <RegistrationWithState >) {
100+ writableDatabase.compileStatement(
101+ """
102+ UPDATE push_registration_state
103+ SET state = ?
104+ WHERE account_id = ? AND input = ?
105+ """
106+ ).use { stmt ->
107+ for (r in registrationWithStates) {
108+ stmt.clearBindings()
109+ stmt.bindString(1 , json.encodeToString(r.state))
110+ stmt.bindString(2 , r.accountId)
111+ stmt.bindString(3 , json.encodeToString(r.input))
112+ stmt.execute()
113+ }
114+ }
115+ }
116+
117+ data class PendingRegistrationWork (
118+ val register : List <RegistrationWithState >,
119+ val unregister : List <RegistrationWithState >,
120+ )
121+
122+ fun getPendingRegistrationWork (now : Instant = Instant .now(), limit : Int ): PendingRegistrationWork {
123+ // This query needs to consider two type of data:
124+ // - Registrations that need to be registered (due REGISTER, due ERROR or NONE)
125+ // - Registrations that need to be unregistered (PENDING_UNREGISTER)
126+ // The query does not directly map to these two groups, so we partition the results in code.
127+ return readableDatabase.rawQuery(
128+ """
129+ SELECT account_id, input, state, CAST(state->>'$.due' AS INTEGER) AS due_time
130+ FROM push_registration_state
131+ WHERE state_type IN ('$TYPE_ERROR ', '$TYPE_REGISTERED ')
132+ AND CAST(state->>'$.due' AS INTEGER) <= ?
133+
134+ UNION ALL
135+
136+ SELECT account_id, input, state, 0 AS due_time
137+ FROM push_registration_state
138+ WHERE state_type IN ('$TYPE_PENDING_REGISTER ', '$TYPE_PENDING_UNREGISTER ')
139+
140+ ORDER BY due_time ASC
141+ LIMIT ?
142+ """ , now.toEpochMilli(), limit
143+ ).use { cursor ->
144+ val (unregister, register) = cursor.asSequence()
145+ .map {
146+ RegistrationWithState (
147+ accountId = cursor.getString(0 ),
148+ input = json.decodeFromString(cursor.getString(1 )),
149+ state = json.decodeFromString(cursor.getString(2 )),
150+ )
151+ }
152+ .partition { it.state is RegistrationState .PendingUnregister }
153+
154+ PendingRegistrationWork (register, unregister)
155+ }
156+ }
157+
158+ fun removeRegistrations (registrations : Collection <Registration >) {
159+ if (registrations.isEmpty()) return
160+
161+ writableDatabase.rawExecSQL(
162+ """
163+ DELETE FROM push_registration_state
164+ WHERE (account_id, input) IN (SELECT value->>'$.accountId', value->>'$.input' FROM json_each(?))
165+ """ , json.encodeToString(registrations)
166+ )
167+ }
168+
169+ /* *
170+ * Get the next due time among all registrations. Null if there are no pending registrations.
171+ *
172+ * Note that the due time can be in the past or now, meaning there are som registrations
173+ * that must be processed immediately.
174+ */
175+ fun getNextProcessTime (now : Instant = Instant .now()): Instant ? {
176+ // The NONE state means we should process immediately, so we'll look them up first
177+ readableDatabase.rawQuery(
178+ """
179+ SELECT 1 FROM push_registration_state
180+ WHERE state_type IN ('$TYPE_PENDING_REGISTER ', '$TYPE_PENDING_UNREGISTER ')
181+ """
182+ ).use { cursor ->
183+ if (cursor.moveToNext()) {
184+ return now
185+ }
186+ }
187+
188+ // Otherwise, find the minimum due time among ERROR and REGISTERED states
189+ readableDatabase.rawQuery(
190+ """
191+ SELECT MIN(CAST(state->>'$.due' AS INTEGER))
192+ FROM push_registration_state
193+ WHERE state_type IN ('$TYPE_ERROR ', '$TYPE_REGISTERED ')
194+ """ ,
195+ ).use { cursor ->
196+ if (cursor.moveToFirst()) {
197+ val dueMillis = cursor.getLong(0 )
198+ if (! cursor.isNull(0 )) {
199+ return Instant .ofEpochMilli(dueMillis)
200+ }
201+ }
202+ }
203+
204+ return null
205+ }
206+
207+ @Serializable
208+ data class RegistrationWithState (
209+ val accountId : String ,
210+ val input : Input ,
211+ val state : RegistrationState
212+ )
213+
214+ /* *
215+ * The registration state that is saved in the db.
216+ *
217+ * Please note that any changes to this class must consider the backward compatibility
218+ * to the existing data in the database.
219+ */
220+ @OptIn(ExperimentalSerializationApi ::class )
221+ @Serializable
222+ @JsonClassDiscriminator(STATE_TYPE_DISCRIMINATOR )
223+ sealed interface RegistrationState {
224+ @Serializable
225+ @SerialName(TYPE_PENDING_REGISTER )
226+ data object PendingRegister : RegistrationState
227+
228+ @Serializable
229+ @SerialName(TYPE_REGISTERED )
230+ data class Registered (
231+ @Serializable(with = InstantAsMillisSerializer ::class )
232+ val due : Instant
233+ ) : RegistrationState
234+
235+ @Serializable
236+ @SerialName(TYPE_ERROR )
237+ data class Error (
238+ @Serializable(with = InstantAsMillisSerializer ::class )
239+ val due : Instant ,
240+ val numRetried : Int ,
241+ ) : RegistrationState
242+
243+ @Serializable
244+ @SerialName(TYPE_PERMANENT_ERROR )
245+ data object PermanentError : RegistrationState
246+
247+ @Serializable
248+ @SerialName(TYPE_PENDING_UNREGISTER )
249+ data object PendingUnregister : RegistrationState
250+ }
251+
252+ /* *
253+ * The input required to perform a push registration.
254+ */
255+ @Serializable
256+ data class Input (
257+ val pushToken : String
258+ )
259+
260+ companion object {
261+ private const val STATE_TYPE_DISCRIMINATOR = " type"
262+
263+ private const val TYPE_PENDING_REGISTER = " PENDING_REGISTER"
264+ private const val TYPE_REGISTERED = " REGISTERED"
265+ private const val TYPE_ERROR = " ERROR"
266+ private const val TYPE_PERMANENT_ERROR = " PERMANENT_ERROR"
267+ private const val TYPE_PENDING_UNREGISTER = " PENDING_UNREGISTER"
268+
269+ fun createTableStatements () = arrayOf(
270+ """
271+ CREATE TABLE push_registration_state(
272+ account_id TEXT NOT NULL,
273+ input TEXT NOT NULL,
274+ state TEXT NOT NULL,
275+ state_type TEXT GENERATED ALWAYS AS (state->>'$.$STATE_TYPE_DISCRIMINATOR ') VIRTUAL,
276+ PRIMARY KEY (account_id, input)
277+ ) WITHOUT ROWID""" ,
278+ " CREATE INDEX idx_push_state_type ON push_registration_state(state_type)" ,
279+ " CREATE INDEX idx_push_due ON push_registration_state(CAST(state->>'$.due' AS INTEGER))"
280+ )
281+ }
282+ }
0 commit comments