@@ -2,6 +2,7 @@ package indexworker
22
33import (
44 "context"
5+ "database/sql"
56 "errors"
67 "fmt"
78 "log"
@@ -10,6 +11,7 @@ import (
1011 "time"
1112
1213 "github.com/gobuffalo/pop/v6"
14+ pkgerrors "github.com/pkg/errors"
1315 "github.com/sirupsen/logrus"
1416 "github.com/supabase/auth/internal/conf"
1517)
@@ -93,10 +95,17 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo
9395 }
9496 }()
9597
96- // Look up which schema the pg_trgm extension is installed in
97- trgmSchema , err := getTrgmExtensionSchema (db )
98+ // Ensure either auth_trgm or pg_trgm extension is installed
99+ extName , err := ensureTrgmExtension (db , config . DB . Namespace , le )
98100 if err != nil {
99- le .Errorf ("Failed to find pg_trgm extension schema: %+v" , err )
101+ le .Errorf ("Failed to ensure trgm extension is available: %+v" , err )
102+ return err
103+ }
104+
105+ // Look up which schema the trgm extension is installed in
106+ trgmSchema , err := getTrgmExtensionSchema (db , extName )
107+ if err != nil {
108+ le .Errorf ("Failed to find %s extension schema: %+v" , extName , err )
100109 return ErrExtensionNotFound
101110 }
102111
@@ -170,23 +179,113 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo
170179 return nil
171180}
172181
173- // getTrgmExtensionSchema looks up which schema the pg_trgm extension is installed in
174- func getTrgmExtensionSchema (db * pop.Connection ) (string , error ) {
182+ // getTrgmExtensionSchema looks up which schema the specified trgm extension is installed in
183+ func getTrgmExtensionSchema (db * pop.Connection , extName string ) (string , error ) {
175184 var schema string
176185 query := `
177186 SELECT extnamespace::regnamespace::text AS schema_name
178187 FROM pg_extension
179- WHERE extname = 'pg_trgm'
188+ WHERE extname = $1
180189 LIMIT 1
181190 `
182191
183- if err := db .RawQuery (query ).First (& schema ); err != nil {
184- return "" , fmt .Errorf ("failed to find pg_trgm extension schema: %w" , err )
192+ if err := db .RawQuery (query , extName ).First (& schema ); err != nil {
193+ return "" , fmt .Errorf ("failed to find %s extension schema: %w" , extName , err )
185194 }
186195
187196 return schema , nil
188197}
189198
199+ // extensionStatus represents the status of an extension from pg_available_extensions
200+ type extensionStatus struct {
201+ Available bool
202+ Installed bool
203+ }
204+
205+ // getExtensionStatus checks if an extension is available and/or installed
206+ func getExtensionStatus (db * pop.Connection , extName string ) (extensionStatus , error ) {
207+ var result struct {
208+ Name * string `db:"name"`
209+ InstalledVersion * string `db:"installed_version"`
210+ }
211+
212+ query := `
213+ SELECT name, installed_version
214+ FROM pg_available_extensions
215+ WHERE name = $1
216+ `
217+
218+ if err := db .RawQuery (query , extName ).First (& result ); err != nil {
219+ // If no rows returned, extension is not available
220+ if pkgerrors .Cause (err ) == sql .ErrNoRows {
221+ return extensionStatus {Available : false , Installed : false }, nil
222+ }
223+ return extensionStatus {}, fmt .Errorf ("failed to check extension status for %s: %w" , extName , err )
224+ }
225+
226+ return extensionStatus {
227+ Available : result .Name != nil ,
228+ Installed : result .InstalledVersion != nil ,
229+ }, nil
230+ }
231+
232+ // installExtension installs the specified extension in the provided schema
233+ func installExtension (db * pop.Connection , extName string , schema string ) error {
234+ query := fmt .Sprintf ("CREATE EXTENSION IF NOT EXISTS %s SCHEMA %s" , extName , schema )
235+ if err := db .RawQuery (query ).Exec (); err != nil {
236+ return fmt .Errorf ("failed to install extension %s in schema %s: %w" , extName , schema , err )
237+ }
238+ return nil
239+ }
240+
241+ // ensureTrgmExtension ensures that either auth_trgm or pg_trgm extension is installed
242+ // It prefers auth_trgm if available, otherwise falls back to pg_trgm
243+ // Returns the name of the installed extension
244+ func ensureTrgmExtension (db * pop.Connection , authSchema string , le * logrus.Entry ) (string , error ) {
245+ authTrgmStatus , err := getExtensionStatus (db , "auth_trgm" )
246+ if err != nil {
247+ return "" , fmt .Errorf ("failed to check auth_trgm extension status: %w" , err )
248+ }
249+
250+ if authTrgmStatus .Available {
251+ if ! authTrgmStatus .Installed {
252+ le .Infof ("auth_trgm extension is available but not installed. Installing..." )
253+ if err := installExtension (db , "auth_trgm" , authSchema ); err != nil {
254+ le .Errorf ("Failed to install auth_trgm extension: %v" , err )
255+ return "" , fmt .Errorf ("auth_trgm extension is available but failed to install: %w" , err )
256+ }
257+ le .Infof ("Successfully installed auth_trgm extension" )
258+ } else {
259+ le .Infof ("auth_trgm extension is already installed" )
260+ }
261+ return "auth_trgm" , nil
262+ }
263+
264+ le .Infof ("auth_trgm extension is not available, checking pg_trgm..." )
265+
266+ pgTrgmStatus , err := getExtensionStatus (db , "pg_trgm" )
267+ if err != nil {
268+ return "" , fmt .Errorf ("failed to check pg_trgm extension status: %w" , err )
269+ }
270+
271+ if ! pgTrgmStatus .Available {
272+ return "" , fmt .Errorf ("neither auth_trgm nor pg_trgm extensions are available" )
273+ }
274+
275+ if ! pgTrgmStatus .Installed {
276+ le .Infof ("pg_trgm extension is available but not installed. Installing..." )
277+ if err := installExtension (db , "pg_trgm" , "pg_catalog" ); err != nil {
278+ le .Errorf ("Failed to install pg_trgm extension: %v" , err )
279+ return "" , fmt .Errorf ("pg_trgm extension is available but failed to install: %w" , err )
280+ }
281+ le .Infof ("Successfully installed pg_trgm extension" )
282+ } else {
283+ le .Infof ("pg_trgm extension is already installed" )
284+ }
285+
286+ return "pg_trgm" , nil
287+ }
288+
190289// getUsersIndexes returns the list of indexes to create on the users table
191290func getUsersIndexes (namespace , trgmSchema string ) []struct {
192291 name string
0 commit comments