Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions persistent-postgresql/Database/Persist/Postgresql/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ data AlterColumn
--
-- @since 2.17.1.0
data AlterTable
= AddUniqueConstraint ConstraintNameDB [FieldNameDB]
= AddUniqueConstraint ConstraintNameDB [FieldNameDB] [Attr]
| DropConstraint ConstraintNameDB
deriving (Show, Eq)

Expand Down Expand Up @@ -430,8 +430,8 @@ migrateStructured allDefs getter entity = do
createText newcols fdefs_ udspair =
(addTable newcols entity) : uniques ++ references ++ foreignsAlt
where
uniques = flip concatMap udspair $ \(uname, ucols) ->
[AlterTable name $ AddUniqueConstraint uname ucols]
uniques = flip concatMap udspair $ \(uname, ucols, uattrs) ->
[AlterTable name $ AddUniqueConstraint uname ucols uattrs]
references =
mapMaybe
( \Column{cName, cReference} ->
Expand Down Expand Up @@ -464,8 +464,8 @@ mockMigrateStructured allDefs entity = migrationText
createText newcols fdefs udspair =
(addTable newcols entity) : uniques ++ references ++ foreignsAlt
where
uniques = flip concatMap udspair $ \(uname, ucols) ->
[AlterTable name $ AddUniqueConstraint uname ucols]
uniques = flip concatMap udspair $ \(uname, ucols, uattrs) ->
[AlterTable name $ AddUniqueConstraint uname ucols uattrs]
references =
mapMaybe
( \Column{cName, cReference} ->
Expand Down Expand Up @@ -508,7 +508,7 @@ mayDefault def = case def of
getAlters
:: [EntityDef]
-> EntityDef
-> ([Column], [(ConstraintNameDB, [FieldNameDB])])
-> ([Column], [(ConstraintNameDB, [FieldNameDB], [Attr])])
-> ([Column], [(ConstraintNameDB, [FieldNameDB])])
-> ([AlterColumn], [AlterTable])
getAlters defs def (c1, u1) (c2, u2) =
Expand All @@ -523,15 +523,15 @@ getAlters defs def (c1, u1) (c2, u2) =
alters ++ getAltersC news old'

getAltersU
:: [(ConstraintNameDB, [FieldNameDB])]
:: [(ConstraintNameDB, [FieldNameDB], [Attr])]
-> [(ConstraintNameDB, [FieldNameDB])]
-> [AlterTable]
getAltersU [] old =
map DropConstraint $ filter (not . isManual) $ map fst old
getAltersU ((name, cols) : news) old =
getAltersU ((name, cols, attrs) : news) old =
case lookup name old of
Nothing ->
AddUniqueConstraint name cols : getAltersU news old
AddUniqueConstraint name cols attrs : getAltersU news old
Just ocols ->
let
old' = filter (\(x, _) -> x /= name) old
Expand All @@ -540,7 +540,7 @@ getAlters defs def (c1, u1) (c2, u2) =
then getAltersU news old'
else
DropConstraint name
: AddUniqueConstraint name cols
: AddUniqueConstraint name cols attrs
: getAltersU news old'

-- Don't drop constraints which were manually added.
Expand Down Expand Up @@ -632,8 +632,8 @@ safeToRemove def (FieldNameDB colName) =
_ ->
[]

udToPair :: UniqueDef -> (ConstraintNameDB, [FieldNameDB])
udToPair ud = (uniqueDBName ud, map snd $ NEL.toList $ uniqueFields ud)
udToPair :: UniqueDef -> (ConstraintNameDB, [FieldNameDB], [Attr])
udToPair ud = (uniqueDBName ud, map snd $ NEL.toList $ uniqueFields ud, uniqueAttrs ud)

-- | Get the references to be added to a table for the given column.
getAddReference
Expand Down Expand Up @@ -739,13 +739,15 @@ showAlterDb (AlterColumn t ac) =
showAlterDb (AlterTable t at) = (False, showAlterTable t at)

showAlterTable :: EntityNameDB -> AlterTable -> Text
showAlterTable table (AddUniqueConstraint cname cols) =
showAlterTable table (AddUniqueConstraint cname cols attrs) =
T.concat
[ "ALTER TABLE "
, escapeE table
, " ADD CONSTRAINT "
, escapeC cname
, " UNIQUE("
, " UNIQUE"
, if "!nullsNotDistinct" `elem` attrs then " NULLS NOT DISTINCT" else ""
, "("
, T.intercalate "," $ map escapeF cols
, ")"
]
Expand Down
1 change: 1 addition & 0 deletions persistent-postgresql/persistent-postgresql.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ test-suite test
ImplicitUuidSpec
JSONTest
MigrationReferenceSpec
NullsNotDistinctTest
PgInit
PgIntervalTest
UpsertWhere
Expand Down
274 changes: 274 additions & 0 deletions persistent-postgresql/test/NullsNotDistinctTest.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module NullsNotDistinctTest where

import Control.Exception (SomeException, try)
import Control.Monad (unless, void, when)
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Trans.Reader (ReaderT)
import Data.Text (Text)
import qualified Data.Text as T
import Database.Persist
import Database.Persist.Postgresql
import Database.Persist.Postgresql.Internal
import Database.Persist.TH
import qualified Test.Hspec as Hspec
import qualified Test.Hspec.Expectations.Lifted as Lifted

import PgInit

-- Test entities with and without NULLS NOT DISTINCT
share
[mkPersist sqlSettings, mkMigrate "nullsNotDistinctMigrate"]
[persistLowerCase|
-- Standard unique constraint (allows multiple NULLs)
StandardUnique
name Text
email Text Maybe
UniqueStandardEmail name email !force
deriving Eq Show

-- Unique constraint with NULLS NOT DISTINCT (PostgreSQL 15+)
-- This should prevent multiple NULLs
NullsNotDistinctUnique
name Text
email Text Maybe
UniqueNNDEmail name email !nullsNotDistinct
deriving Eq Show

-- Multiple nullable fields with NULLS NOT DISTINCT
MultiFieldNND
fieldA Text
fieldB Text Maybe
fieldC Int Maybe
UniqueMultiNND fieldA fieldB fieldC !nullsNotDistinct
deriving Eq Show
|]

-- Helper to check PostgreSQL version
getPostgresVersion :: (MonadIO m) => ReaderT SqlBackend m (Maybe Int)
getPostgresVersion = do
result <- rawSql "SELECT current_setting('server_version_num')::integer" []
case result of
[Single version] -> return $ Just version
_ -> return Nothing

isPostgres15OrHigher :: (MonadIO m) => ReaderT SqlBackend m Bool
isPostgres15OrHigher = do
mVersion <- getPostgresVersion
case mVersion of
Just version -> return $ version >= 150000 -- PostgreSQL 15.0
Nothing -> return False

cleanDB
:: (BaseBackend backend ~ SqlBackend, PersistQueryWrite backend, MonadIO m)
=> ReaderT backend m ()
cleanDB = do
deleteWhere ([] :: [Filter StandardUnique])
deleteWhere ([] :: [Filter NullsNotDistinctUnique])
deleteWhere ([] :: [Filter MultiFieldNND])

specs :: Spec
specs = describe "NULLS NOT DISTINCT support" $ do
let
runDb = runConnAssert

it "generates correct SQL for NULLS NOT DISTINCT constraint" $ do
let
alterWithNND =
AddUniqueConstraint
(ConstraintNameDB "unique_nnd_email")
[FieldNameDB "name", FieldNameDB "email"]
["!nullsNotDistinct"]

let
alterWithoutNND =
AddUniqueConstraint
(ConstraintNameDB "unique_standard_email")
[FieldNameDB "name", FieldNameDB "email"]
["!force"]

let
tableName = EntityNameDB "test_table"
let
sqlWithNND = showAlterTable tableName alterWithNND
let
sqlWithoutNND = showAlterTable tableName alterWithoutNND

sqlWithNND
`Hspec.shouldBe` "ALTER TABLE \"test_table\" ADD CONSTRAINT \"unique_nnd_email\" UNIQUE NULLS NOT DISTINCT(\"name\",\"email\")"

sqlWithoutNND
`Hspec.shouldBe` "ALTER TABLE \"test_table\" ADD CONSTRAINT \"unique_standard_email\" UNIQUE(\"name\",\"email\")"

describe "runtime behavior" $ do
it "standard unique allows multiple NULLs" $ do
runDb $ do
cleanDB

-- These should both succeed with standard unique
k1 <- insert $ StandardUnique "user1" Nothing
k2 <- insert $ StandardUnique "user2" Nothing

-- Verify both were inserted
count1 <- count [StandardUniqueName ==. "user1"]
count2 <- count [StandardUniqueName ==. "user2"]

liftIO $ do
count1 `Lifted.shouldBe` 1
count2 `Lifted.shouldBe` 1

it "standard unique prevents duplicate non-NULLs" $ do
( runDb $ do
cleanDB
_ <- insert $ StandardUnique "user1" (Just "[email protected]")
_ <- insert $ StandardUnique "user1" (Just "[email protected]")
return ()
)
`Hspec.shouldThrow` Hspec.anyException

it
"standard unique getBy returns Nothing for NULL values (backwards compatibility)"
$ do
runDb $ do
cleanDB

-- Insert a record with NULL email
_ <- insert $ StandardUnique "user1" Nothing

-- getBy with NULL should return Nothing (standard SQL behavior)
-- This ensures backwards compatibility - without !nullsNotDistinct,
-- getBy cannot find NULL values
result <- getBy $ UniqueStandardEmail "user1" Nothing

liftIO $ result `Lifted.shouldBe` Nothing

-- Verify that getBy still works for non-NULL values
k2 <- insert $ StandardUnique "user2" (Just "[email protected]")
result2 <- getBy $ UniqueStandardEmail "user2" (Just "[email protected]")

liftIO $ case result2 of
Just (Entity key _) -> key `Lifted.shouldBe` k2
Nothing -> Hspec.expectationFailure "getBy should find non-NULL values"

describe "PostgreSQL 15+ features" $ do
it "NULLS NOT DISTINCT prevents multiple NULLs (PostgreSQL 15+)" $ do
runDb $ do
supported <- isPostgres15OrHigher
when supported $ do
-- Run the migration to ensure constraint is created
void $ runMigrationSilent nullsNotDistinctMigrate
unless supported $
liftIO $
Hspec.pendingWith "Requires PostgreSQL 15 or higher"

-- Now test the constraint enforcement separately
( runDb $ do
cleanDB
void $ runMigrationSilent nullsNotDistinctMigrate
_ <- insert $ NullsNotDistinctUnique "user1" Nothing
-- Same name and email - this should violate the unique constraint
_ <- insert $ NullsNotDistinctUnique "user1" Nothing
return ()
)
`Hspec.shouldThrow` Hspec.anyException

it "NULLS NOT DISTINCT with multiple nullable fields (PostgreSQL 15+)" $ do
-- First test that different NULL patterns work
runDb $ do
supported <- isPostgres15OrHigher
if supported
then do
cleanDB

-- First record with NULLs
_ <- insert $ MultiFieldNND "test1" Nothing Nothing

-- Different NULL pattern should succeed
_ <- insert $ MultiFieldNND "test1" (Just "value") Nothing
_ <- insert $ MultiFieldNND "test1" Nothing (Just 42)

count' <- count ([] :: [Filter MultiFieldNND])
liftIO $ count' `Hspec.shouldBe` 3
else
liftIO $ Hspec.pendingWith "Requires PostgreSQL 15 or higher"

-- Test duplicate prevention with same NULL pattern
( runDb $ do
supported <- isPostgres15OrHigher
when supported $ do
cleanDB
_ <- insert $ MultiFieldNND "test1" Nothing Nothing
_ <- insert $ MultiFieldNND "test1" Nothing Nothing
return ()
)
`Hspec.shouldThrow` Hspec.anyException

it "getBy finds NULL values with NULLS NOT DISTINCT (PostgreSQL 15+)" $ do
runDb $ do
supported <- isPostgres15OrHigher
if supported
then do
cleanDB
void $ runMigrationSilent nullsNotDistinctMigrate

-- Insert with NULL
k1 <- insert $ NullsNotDistinctUnique "user1" Nothing

-- With our runtime detection, getBy now uses IS NOT DISTINCT FROM
-- for entities with !nullsNotDistinct, allowing it to find NULL values
result <- getBy $ UniqueNNDEmail "user1" Nothing

-- We expect getBy TO find the entity with NULLS NOT DISTINCT
liftIO $ case result of
Just (Entity key _) -> key `Hspec.shouldBe` k1
Nothing ->
Hspec.expectationFailure
"getBy should find NULL values when !nullsNotDistinct is set"
else
liftIO $ Hspec.pendingWith "Requires PostgreSQL 15 or higher"

it "migration generates correct constraints" $ do
runDb $ do
-- Run migration to create tables
void $ runMigrationSilent nullsNotDistinctMigrate

-- Check that constraints were created
-- This query checks PostgreSQL's information schema
constraints :: [(Single Text, Single Text)] <-
rawSql
"SELECT conname, pg_get_constraintdef(oid) \
\FROM pg_constraint \
\WHERE conrelid = 'nulls_not_distinct_unique'::regclass \
\ AND contype = 'u'"
[]

supported <- isPostgres15OrHigher
liftIO $ case constraints of
[] -> return () -- Tables might not exist yet
results -> do
-- Check if any constraint has NULLS NOT DISTINCT
let
hasNND =
any
( \(Single _, Single def) ->
"NULLS NOT DISTINCT" `T.isInfixOf` def
)
results

when supported $
hasNND `Hspec.shouldBe` True
Loading