Skip to content

Commit aa47ce8

Browse files
Optimise push registration (#1627)
1 parent 3965e4d commit aa47ce8

File tree

10 files changed

+686
-312
lines changed

10 files changed

+686
-312
lines changed

app/src/firebaseCommon/kotlin/org/thoughtcrime/securesms/notifications/FirebaseTokenFetcher.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ class FirebaseTokenFetcher @Inject constructor(): TokenFetcher {
1212
override val token = MutableStateFlow<String?>(null)
1313

1414
init {
15+
fetchToken()
16+
}
17+
18+
private fun fetchToken() {
1519
FirebaseMessaging.getInstance()
1620
.token
1721
.addOnSuccessListener(this::onNewToken)
@@ -23,5 +27,6 @@ class FirebaseTokenFetcher @Inject constructor(): TokenFetcher {
2327

2428
override suspend fun resetToken() {
2529
FirebaseMessaging.getInstance().deleteToken().await()
30+
fetchToken()
2631
}
2732
}
Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
package org.thoughtcrime.securesms.database
2+
3+
import android.content.Context
4+
import androidx.sqlite.db.transaction
5+
import dagger.hilt.android.qualifiers.ApplicationContext
6+
import kotlinx.serialization.ExperimentalSerializationApi
7+
import kotlinx.serialization.SerialName
8+
import kotlinx.serialization.Serializable
9+
import kotlinx.serialization.json.Json
10+
import kotlinx.serialization.json.JsonClassDiscriminator
11+
import org.session.libsession.utilities.serializable.InstantAsMillisSerializer
12+
import org.thoughtcrime.securesms.database.helpers.SQLCipherOpenHelper
13+
import org.thoughtcrime.securesms.util.asSequence
14+
import java.time.Instant
15+
import javax.inject.Inject
16+
import javax.inject.Provider
17+
import javax.inject.Singleton
18+
19+
@Singleton
20+
class PushRegistrationDatabase @Inject constructor(
21+
@ApplicationContext context: Context,
22+
helper: Provider<SQLCipherOpenHelper>,
23+
private val json: Json,
24+
) : Database(context, helper) {
25+
26+
27+
@Serializable
28+
data class Registration(val accountId: String, val input: Input)
29+
30+
/**
31+
* Ensure that the provided registrations exist in the database. If input changes for an existing
32+
* registration, reset its state to NONE. Any registrations not in the provided list will be
33+
* marked as pending unregistration.
34+
*
35+
* @return The number of database rows that were changed.
36+
*/
37+
fun ensureRegistrations(registrations: Collection<Registration>): Int {
38+
val registrationsAsText = json.encodeToString(registrations)
39+
40+
// It's important to specify the base RegistrationState so that the discriminator is correct
41+
val pendingRegisterAsText = json.encodeToString<RegistrationState>(RegistrationState.PendingRegister)
42+
val pendingUnregisterAsText = json.encodeToString<RegistrationState>(RegistrationState.PendingUnregister)
43+
44+
return writableDatabase.transaction {
45+
var numChanges = 0
46+
47+
if (registrations.isNotEmpty()) {
48+
// Insert the provided registrations with PendingRegister state
49+
// If they already exist with a PendingUnregister state, flip them back to PendingRegister,
50+
// otherwise keep their existing state.
51+
compileStatement(
52+
"""
53+
INSERT INTO push_registration_state (account_id, input, state)
54+
SELECT
55+
value->>'$.accountId',
56+
value->>'$.input',
57+
:pending_register_state
58+
FROM json_each(:registrations)
59+
WHERE TRUE
60+
ON CONFLICT DO UPDATE
61+
SET state = :pending_register_state
62+
WHERE state_type = '$TYPE_PENDING_UNREGISTER'
63+
"""
64+
).use { stmt ->
65+
stmt.bindString(1, pendingRegisterAsText)
66+
stmt.bindString(2, registrationsAsText)
67+
numChanges += stmt.executeUpdateDelete()
68+
}
69+
}
70+
71+
// Mark all other registrations that are registered or error as PendingUnregister to be cleaned up
72+
compileStatement("""
73+
UPDATE push_registration_state
74+
SET state = ?
75+
WHERE (account_id, input) NOT IN (SELECT value->>'$.accountId', value->>'$.input' FROM json_each(?))
76+
AND state_type IN ('$TYPE_REGISTERED', '$TYPE_ERROR')
77+
""").use { stmt ->
78+
stmt.bindString(1, pendingUnregisterAsText)
79+
stmt.bindString(2, registrationsAsText)
80+
numChanges += stmt.executeUpdateDelete()
81+
}
82+
83+
// Delete no longer desired registrations that didn't start, or ended up in a permanent error
84+
// Note: the changes here do not count towards numChanges since they won't affect
85+
// the scheduling
86+
compileStatement("""
87+
DELETE FROM push_registration_state
88+
WHERE state_type IN ('$TYPE_PERMANENT_ERROR', '$TYPE_PENDING_REGISTER')
89+
AND (account_id, input) NOT IN (SELECT value->>'$.accountId', value->>'$.input' FROM json_each(?))
90+
""").use { stmt ->
91+
stmt.bindString(1, registrationsAsText)
92+
stmt.execute()
93+
}
94+
95+
numChanges
96+
}
97+
}
98+
99+
fun updateRegistrations(registrationWithStates: Collection<RegistrationWithState>) {
100+
writableDatabase.compileStatement(
101+
"""
102+
UPDATE push_registration_state
103+
SET state = ?
104+
WHERE account_id = ? AND input = ?
105+
"""
106+
).use { stmt ->
107+
for (r in registrationWithStates) {
108+
stmt.clearBindings()
109+
stmt.bindString(1, json.encodeToString(r.state))
110+
stmt.bindString(2, r.accountId)
111+
stmt.bindString(3, json.encodeToString(r.input))
112+
stmt.execute()
113+
}
114+
}
115+
}
116+
117+
data class PendingRegistrationWork(
118+
val register: List<RegistrationWithState>,
119+
val unregister: List<RegistrationWithState>,
120+
)
121+
122+
fun getPendingRegistrationWork(now: Instant = Instant.now(), limit: Int): PendingRegistrationWork {
123+
// This query needs to consider two type of data:
124+
// - Registrations that need to be registered (due REGISTER, due ERROR or NONE)
125+
// - Registrations that need to be unregistered (PENDING_UNREGISTER)
126+
// The query does not directly map to these two groups, so we partition the results in code.
127+
return readableDatabase.rawQuery(
128+
"""
129+
SELECT account_id, input, state, CAST(state->>'$.due' AS INTEGER) AS due_time
130+
FROM push_registration_state
131+
WHERE state_type IN ('$TYPE_ERROR', '$TYPE_REGISTERED')
132+
AND CAST(state->>'$.due' AS INTEGER) <= ?
133+
134+
UNION ALL
135+
136+
SELECT account_id, input, state, 0 AS due_time
137+
FROM push_registration_state
138+
WHERE state_type IN ('$TYPE_PENDING_REGISTER', '$TYPE_PENDING_UNREGISTER')
139+
140+
ORDER BY due_time ASC
141+
LIMIT ?
142+
""", now.toEpochMilli(), limit
143+
).use { cursor ->
144+
val (unregister, register) = cursor.asSequence()
145+
.map {
146+
RegistrationWithState(
147+
accountId = cursor.getString(0),
148+
input = json.decodeFromString(cursor.getString(1)),
149+
state = json.decodeFromString(cursor.getString(2)),
150+
)
151+
}
152+
.partition { it.state is RegistrationState.PendingUnregister }
153+
154+
PendingRegistrationWork(register, unregister)
155+
}
156+
}
157+
158+
fun removeRegistrations(registrations: Collection<Registration>) {
159+
if (registrations.isEmpty()) return
160+
161+
writableDatabase.rawExecSQL(
162+
"""
163+
DELETE FROM push_registration_state
164+
WHERE (account_id, input) IN (SELECT value->>'$.accountId', value->>'$.input' FROM json_each(?))
165+
""", json.encodeToString(registrations)
166+
)
167+
}
168+
169+
/**
170+
* Get the next due time among all registrations. Null if there are no pending registrations.
171+
*
172+
* Note that the due time can be in the past or now, meaning there are som registrations
173+
* that must be processed immediately.
174+
*/
175+
fun getNextProcessTime(now: Instant = Instant.now()): Instant? {
176+
// The NONE state means we should process immediately, so we'll look them up first
177+
readableDatabase.rawQuery(
178+
"""
179+
SELECT 1 FROM push_registration_state
180+
WHERE state_type IN ('$TYPE_PENDING_REGISTER', '$TYPE_PENDING_UNREGISTER')
181+
"""
182+
).use { cursor ->
183+
if (cursor.moveToNext()) {
184+
return now
185+
}
186+
}
187+
188+
// Otherwise, find the minimum due time among ERROR and REGISTERED states
189+
readableDatabase.rawQuery(
190+
"""
191+
SELECT MIN(CAST(state->>'$.due' AS INTEGER))
192+
FROM push_registration_state
193+
WHERE state_type IN ('$TYPE_ERROR', '$TYPE_REGISTERED')
194+
""",
195+
).use { cursor ->
196+
if (cursor.moveToFirst()) {
197+
val dueMillis = cursor.getLong(0)
198+
if (!cursor.isNull(0)) {
199+
return Instant.ofEpochMilli(dueMillis)
200+
}
201+
}
202+
}
203+
204+
return null
205+
}
206+
207+
@Serializable
208+
data class RegistrationWithState(
209+
val accountId: String,
210+
val input: Input,
211+
val state: RegistrationState
212+
)
213+
214+
/**
215+
* The registration state that is saved in the db.
216+
*
217+
* Please note that any changes to this class must consider the backward compatibility
218+
* to the existing data in the database.
219+
*/
220+
@OptIn(ExperimentalSerializationApi::class)
221+
@Serializable
222+
@JsonClassDiscriminator(STATE_TYPE_DISCRIMINATOR)
223+
sealed interface RegistrationState {
224+
@Serializable
225+
@SerialName(TYPE_PENDING_REGISTER)
226+
data object PendingRegister : RegistrationState
227+
228+
@Serializable
229+
@SerialName(TYPE_REGISTERED)
230+
data class Registered(
231+
@Serializable(with = InstantAsMillisSerializer::class)
232+
val due: Instant
233+
) : RegistrationState
234+
235+
@Serializable
236+
@SerialName(TYPE_ERROR)
237+
data class Error(
238+
@Serializable(with = InstantAsMillisSerializer::class)
239+
val due: Instant,
240+
val numRetried: Int,
241+
) : RegistrationState
242+
243+
@Serializable
244+
@SerialName(TYPE_PERMANENT_ERROR)
245+
data object PermanentError : RegistrationState
246+
247+
@Serializable
248+
@SerialName(TYPE_PENDING_UNREGISTER)
249+
data object PendingUnregister : RegistrationState
250+
}
251+
252+
/**
253+
* The input required to perform a push registration.
254+
*/
255+
@Serializable
256+
data class Input(
257+
val pushToken: String
258+
)
259+
260+
companion object {
261+
private const val STATE_TYPE_DISCRIMINATOR = "type"
262+
263+
private const val TYPE_PENDING_REGISTER = "PENDING_REGISTER"
264+
private const val TYPE_REGISTERED = "REGISTERED"
265+
private const val TYPE_ERROR = "ERROR"
266+
private const val TYPE_PERMANENT_ERROR = "PERMANENT_ERROR"
267+
private const val TYPE_PENDING_UNREGISTER = "PENDING_UNREGISTER"
268+
269+
fun createTableStatements() = arrayOf(
270+
"""
271+
CREATE TABLE push_registration_state(
272+
account_id TEXT NOT NULL,
273+
input TEXT NOT NULL,
274+
state TEXT NOT NULL,
275+
state_type TEXT GENERATED ALWAYS AS (state->>'$.$STATE_TYPE_DISCRIMINATOR') VIRTUAL,
276+
PRIMARY KEY (account_id, input)
277+
) WITHOUT ROWID""",
278+
"CREATE INDEX idx_push_state_type ON push_registration_state(state_type)",
279+
"CREATE INDEX idx_push_due ON push_registration_state(CAST(state->>'$.due' AS INTEGER))"
280+
)
281+
}
282+
}

app/src/main/java/org/thoughtcrime/securesms/database/helpers/SQLCipherOpenHelper.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.thoughtcrime.securesms.database.MmsDatabase;
3232
import org.thoughtcrime.securesms.database.MmsSmsDatabase;
3333
import org.thoughtcrime.securesms.database.PushDatabase;
34+
import org.thoughtcrime.securesms.database.PushRegistrationDatabase;
3435
import org.thoughtcrime.securesms.database.ReactionDatabase;
3536
import org.thoughtcrime.securesms.database.RecipientDatabase;
3637
import org.thoughtcrime.securesms.database.RecipientSettingsDatabase;
@@ -100,6 +101,7 @@ public class SQLCipherOpenHelper extends SQLiteOpenHelper {
100101
private static final int lokiV52 = 73;
101102
private static final int lokiV53 = 74;
102103
private static final int lokiV54 = 75;
104+
private static final int lokiV55 = 76;
103105

104106
// Loki - onUpgrade(...) must be updated to use Loki version numbers if Signal makes any database changes
105107
private static final int DATABASE_VERSION = lokiV54;
@@ -262,6 +264,8 @@ public void onCreate(SQLiteDatabase db) {
262264

263265
db.execSQL(SmsDatabase.ADD_LAST_MESSAGE_INDEX);
264266
db.execSQL(MmsDatabase.ADD_LAST_MESSAGE_INDEX);
267+
268+
executeStatements(db, PushRegistrationDatabase.Companion.createTableStatements());
265269
}
266270

267271
@Override
@@ -596,6 +600,10 @@ public void onUpgrade(SQLiteDatabase db, int oldVersion, int newVersion) {
596600
db.execSQL(MmsDatabase.ADD_LAST_MESSAGE_INDEX);
597601
}
598602

603+
if (oldVersion < lokiV55) {
604+
executeStatements(db, PushRegistrationDatabase.Companion.createTableStatements());
605+
}
606+
599607
db.setTransactionSuccessful();
600608
} finally {
601609
db.endTransaction();

app/src/main/java/org/thoughtcrime/securesms/debugmenu/DebugMenu.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,13 @@ fun DebugMenu(
479479
}
480480
)
481481

482+
SlimOutlineButton(
483+
modifier = Modifier.fillMaxWidth(),
484+
text = "Reset Push Token",
485+
) {
486+
sendCommand(DebugMenuViewModel.Commands.ResetPushToken)
487+
}
488+
482489
SlimOutlineButton(
483490
modifier = Modifier.fillMaxWidth(),
484491
text = "Clear All Trusted Downloads",

app/src/main/java/org/thoughtcrime/securesms/debugmenu/DebugMenuViewModel.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import network.loki.messenger.libsession_util.util.BlindKeyAPI
2424
import org.session.libsession.database.StorageProtocol
2525
import org.session.libsession.messaging.file_server.FileServerApi
2626
import org.session.libsession.messaging.groups.LegacyGroupDeprecationManager
27+
import org.session.libsession.messaging.notifications.TokenFetcher
2728
import org.session.libsession.messaging.sending_receiving.attachments.AttachmentState
2829
import org.session.libsession.utilities.Address
2930
import org.session.libsession.utilities.Address.Companion.toAddress
@@ -59,6 +60,7 @@ class DebugMenuViewModel @Inject constructor(
5960
private val attachmentDatabase: AttachmentDatabase,
6061
private val conversationRepository: ConversationRepository,
6162
private val databaseInspector: DatabaseInspector,
63+
private val tokenFetcher: TokenFetcher,
6264
subscriptionManagers: Set<@JvmSuppressWildcards SubscriptionManager>,
6365
) : ViewModel() {
6466
private val TAG = "DebugMenu"
@@ -318,6 +320,12 @@ class DebugMenuViewModel @Inject constructor(
318320
_uiState.update { it.copy(debugAvatarReupload = newValue) }
319321
textSecurePreferences.debugAvatarReupload = newValue
320322
}
323+
324+
is Commands.ResetPushToken -> {
325+
viewModelScope.launch {
326+
tokenFetcher.resetToken()
327+
}
328+
}
321329
}
322330
}
323331

@@ -469,5 +477,6 @@ class DebugMenuViewModel @Inject constructor(
469477
data class PurchaseDebugPlan(val plan: DebugProPlan) : Commands()
470478
data object ToggleDeterministicAttachmentUpload : Commands()
471479
data object ToggleDebugAvatarReupload : Commands()
480+
data object ResetPushToken : Commands()
472481
}
473482
}

0 commit comments

Comments
 (0)