Skip to content

Commit 05daa43

Browse files
authored
feat(indexworker): use auth_trgm extension if available (#2263)
* Installs the trgm extension if it's not already installed * Prefers `auth_trgm` index and fallsback to `pg_trgm` if not available * Create index statements are namespaced with the schema * Failing to install the trgm extension aborts index creation
1 parent 4be12b3 commit 05daa43

File tree

2 files changed

+124
-15
lines changed

2 files changed

+124
-15
lines changed

internal/indexworker/indexworker.go

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package indexworker
22

33
import (
44
"context"
5+
"database/sql"
56
"errors"
67
"fmt"
78
"log"
@@ -10,6 +11,7 @@ import (
1011
"time"
1112

1213
"github.com/gobuffalo/pop/v6"
14+
pkgerrors "github.com/pkg/errors"
1315
"github.com/sirupsen/logrus"
1416
"github.com/supabase/auth/internal/conf"
1517
)
@@ -93,10 +95,17 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo
9395
}
9496
}()
9597

96-
// Look up which schema the pg_trgm extension is installed in
97-
trgmSchema, err := getTrgmExtensionSchema(db)
98+
// Ensure either auth_trgm or pg_trgm extension is installed
99+
extName, err := ensureTrgmExtension(db, config.DB.Namespace, le)
98100
if err != nil {
99-
le.Errorf("Failed to find pg_trgm extension schema: %+v", err)
101+
le.Errorf("Failed to ensure trgm extension is available: %+v", err)
102+
return err
103+
}
104+
105+
// Look up which schema the trgm extension is installed in
106+
trgmSchema, err := getTrgmExtensionSchema(db, extName)
107+
if err != nil {
108+
le.Errorf("Failed to find %s extension schema: %+v", extName, err)
100109
return ErrExtensionNotFound
101110
}
102111

@@ -170,23 +179,113 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo
170179
return nil
171180
}
172181

173-
// getTrgmExtensionSchema looks up which schema the pg_trgm extension is installed in
174-
func getTrgmExtensionSchema(db *pop.Connection) (string, error) {
182+
// getTrgmExtensionSchema looks up which schema the specified trgm extension is installed in
183+
func getTrgmExtensionSchema(db *pop.Connection, extName string) (string, error) {
175184
var schema string
176185
query := `
177186
SELECT extnamespace::regnamespace::text AS schema_name
178187
FROM pg_extension
179-
WHERE extname = 'pg_trgm'
188+
WHERE extname = $1
180189
LIMIT 1
181190
`
182191

183-
if err := db.RawQuery(query).First(&schema); err != nil {
184-
return "", fmt.Errorf("failed to find pg_trgm extension schema: %w", err)
192+
if err := db.RawQuery(query, extName).First(&schema); err != nil {
193+
return "", fmt.Errorf("failed to find %s extension schema: %w", extName, err)
185194
}
186195

187196
return schema, nil
188197
}
189198

199+
// extensionStatus represents the status of an extension from pg_available_extensions
200+
type extensionStatus struct {
201+
Available bool
202+
Installed bool
203+
}
204+
205+
// getExtensionStatus checks if an extension is available and/or installed
206+
func getExtensionStatus(db *pop.Connection, extName string) (extensionStatus, error) {
207+
var result struct {
208+
Name *string `db:"name"`
209+
InstalledVersion *string `db:"installed_version"`
210+
}
211+
212+
query := `
213+
SELECT name, installed_version
214+
FROM pg_available_extensions
215+
WHERE name = $1
216+
`
217+
218+
if err := db.RawQuery(query, extName).First(&result); err != nil {
219+
// If no rows returned, extension is not available
220+
if pkgerrors.Cause(err) == sql.ErrNoRows {
221+
return extensionStatus{Available: false, Installed: false}, nil
222+
}
223+
return extensionStatus{}, fmt.Errorf("failed to check extension status for %s: %w", extName, err)
224+
}
225+
226+
return extensionStatus{
227+
Available: result.Name != nil,
228+
Installed: result.InstalledVersion != nil,
229+
}, nil
230+
}
231+
232+
// installExtension installs the specified extension in the provided schema
233+
func installExtension(db *pop.Connection, extName string, schema string) error {
234+
query := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s SCHEMA %s", extName, schema)
235+
if err := db.RawQuery(query).Exec(); err != nil {
236+
return fmt.Errorf("failed to install extension %s in schema %s: %w", extName, schema, err)
237+
}
238+
return nil
239+
}
240+
241+
// ensureTrgmExtension ensures that either auth_trgm or pg_trgm extension is installed
242+
// It prefers auth_trgm if available, otherwise falls back to pg_trgm
243+
// Returns the name of the installed extension
244+
func ensureTrgmExtension(db *pop.Connection, authSchema string, le *logrus.Entry) (string, error) {
245+
authTrgmStatus, err := getExtensionStatus(db, "auth_trgm")
246+
if err != nil {
247+
return "", fmt.Errorf("failed to check auth_trgm extension status: %w", err)
248+
}
249+
250+
if authTrgmStatus.Available {
251+
if !authTrgmStatus.Installed {
252+
le.Infof("auth_trgm extension is available but not installed. Installing...")
253+
if err := installExtension(db, "auth_trgm", authSchema); err != nil {
254+
le.Errorf("Failed to install auth_trgm extension: %v", err)
255+
return "", fmt.Errorf("auth_trgm extension is available but failed to install: %w", err)
256+
}
257+
le.Infof("Successfully installed auth_trgm extension")
258+
} else {
259+
le.Infof("auth_trgm extension is already installed")
260+
}
261+
return "auth_trgm", nil
262+
}
263+
264+
le.Infof("auth_trgm extension is not available, checking pg_trgm...")
265+
266+
pgTrgmStatus, err := getExtensionStatus(db, "pg_trgm")
267+
if err != nil {
268+
return "", fmt.Errorf("failed to check pg_trgm extension status: %w", err)
269+
}
270+
271+
if !pgTrgmStatus.Available {
272+
return "", fmt.Errorf("neither auth_trgm nor pg_trgm extensions are available")
273+
}
274+
275+
if !pgTrgmStatus.Installed {
276+
le.Infof("pg_trgm extension is available but not installed. Installing...")
277+
if err := installExtension(db, "pg_trgm", "pg_catalog"); err != nil {
278+
le.Errorf("Failed to install pg_trgm extension: %v", err)
279+
return "", fmt.Errorf("pg_trgm extension is available but failed to install: %w", err)
280+
}
281+
le.Infof("Successfully installed pg_trgm extension")
282+
} else {
283+
le.Infof("pg_trgm extension is already installed")
284+
}
285+
286+
return "pg_trgm", nil
287+
}
288+
190289
// getUsersIndexes returns the list of indexes to create on the users table
191290
func getUsersIndexes(namespace, trgmSchema string) []struct {
192291
name string

internal/indexworker/indexworker_test.go

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -393,12 +393,12 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithInvalidIndexes() {
393393
ts.logger.Infof("Successfully recovered from %d invalid indexes", len(indexesToInvalidate))
394394
}
395395

396-
// TestCreateIndexesWithoutTrgmExtension tests that CreateIndexes fails when pg_trgm extension doesn't exist
397-
// and that no indexes are created when this prerequisite check fails.
396+
// TestCreateIndexesWithoutTrgmExtension tests that CreateIndexes installs pg_trgm extension
397+
// when it's available but not installed, and then successfully creates indexes.
398398
func (ts *IndexWorkerTestSuite) TestCreateIndexesWithoutTrgmExtension() {
399399
ctx := context.Background()
400400

401-
// Drop the pg_trgm extension to simulate it not being available
401+
// Drop the pg_trgm extension to simulate it not being installed
402402
dropExtQuery := "DROP EXTENSION IF EXISTS pg_trgm CASCADE"
403403
err := ts.db.RawQuery(dropExtQuery).Exec()
404404
require.NoError(ts.T(), err, "Should be able to drop pg_trgm extension")
@@ -416,14 +416,24 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithoutTrgmExtension() {
416416
require.NoError(ts.T(), err)
417417
assert.Empty(ts.T(), existingIndexes, "No indexes should exist initially")
418418

419-
// Try to create indexes without pg_trgm extension
419+
// Run CreateIndexes - it should install the pg_trgm extension and create indexes
420420
err = CreateIndexes(ctx, ts.config, ts.logger)
421-
assert.Error(ts.T(), err, "CreateIndexes should fail when pg_trgm extension doesn't exist")
422-
assert.ErrorIs(ts.T(), err, ErrExtensionNotFound)
421+
require.NoError(ts.T(), err, "CreateIndexes should succeed by installing the pg_trgm extension")
423422

423+
// Verify that pg_trgm is now installed
424+
err = ts.db.RawQuery(checkExtQuery).First(&extensionExists)
425+
require.NoError(ts.T(), err)
426+
assert.True(ts.T(), extensionExists, "pg_trgm extension should have been installed")
427+
428+
// Verify all indexes were created successfully
424429
existingIndexes, err = getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes))
425430
require.NoError(ts.T(), err)
426-
assert.Empty(ts.T(), existingIndexes, "No indexes should have been created when pg_trgm is missing")
431+
assert.Equal(ts.T(), len(indexes), len(existingIndexes), "All indexes should have been created")
432+
433+
for _, idx := range existingIndexes {
434+
assert.True(ts.T(), idx.IsValid, "Index %s should be valid", idx.IndexName)
435+
assert.True(ts.T(), idx.IsReady, "Index %s should be ready", idx.IndexName)
436+
}
427437

428438
// Restore pg_trgm extension for other tests
429439
createExtQuery := "CREATE EXTENSION IF NOT EXISTS pg_trgm"

0 commit comments

Comments
 (0)