diff --git a/changelog.d/20251030_125129_shane.obrien_OnConflict.md b/changelog.d/20251030_125129_shane.obrien_OnConflict.md new file mode 100644 index 00000000..eaca44d1 --- /dev/null +++ b/changelog.d/20251030_125129_shane.obrien_OnConflict.md @@ -0,0 +1,9 @@ +### Added + +- Added new `Conflict` and `Index` types. `Conflict` represents a [`conflict_target`](https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT) in an `ON CONFLICT`. It can be either a named constraint (`ON CONSTRAINT`) or a an `Index`. +- Added `Index`. `Index` is a description of a unique index which PostgreSQL can use for *unique index inference*. This is an alternative to specifying an explicit named constraint in a `conflict_target`. + +### Changed + +- The `Upsert` type was changed. Previously it had the columns (`index`, `predicate`) of what is now the `Index` type baked into its record. It now instead has a single `conflict` column (of type `Conflict`, which can be either an `Index` or a named constraint). +- The `DoNothing` constructor of `OnConflict` was changed to also take an optional `Conflict` value. Even though `ON CONFLICT DO NOTHING` does not generally require a `conflict_target`, there are cases where it can be necessary, e.g., if you have table that has both deferrable and non-deferrable constraints. diff --git a/src/Rel8.hs b/src/Rel8.hs index 56cb1d29..6616d138 100644 --- a/src/Rel8.hs +++ b/src/Rel8.hs @@ -360,6 +360,8 @@ module Rel8 -- ** @INSERT@ , Insert(..) , OnConflict(..) + , Conflict (..) + , Index (..) , Upsert(..) , insert , unsafeDefault diff --git a/src/Rel8/Statement/Insert.hs b/src/Rel8/Statement/Insert.hs index 120fe721..3e6ef14f 100644 --- a/src/Rel8/Statement/Insert.hs +++ b/src/Rel8/Statement/Insert.hs @@ -50,7 +50,7 @@ data Insert a where , rows :: Query exprs -- ^ The rows to insert. This can be an arbitrary query — use -- 'Rel8.values' insert a static list of rows. - , onConflict :: OnConflict names + , onConflict :: OnConflict exprs -- ^ What to do if the inserted rows conflict with data already in the -- table. , returning :: Returning names a diff --git a/src/Rel8/Statement/OnConflict.hs b/src/Rel8/Statement/OnConflict.hs index 4731b6d5..b5d1132e 100644 --- a/src/Rel8/Statement/OnConflict.hs +++ b/src/Rel8/Statement/OnConflict.hs @@ -11,13 +11,14 @@ module Rel8.Statement.OnConflict ( OnConflict(..) + , Conflict (..) + , Index (..) , Upsert(..) , ppOnConflict ) where -- base -import Data.Foldable ( toList ) import Data.Kind ( Type ) import Prelude @@ -31,28 +32,32 @@ import Text.PrettyPrint ( Doc, (<+>), ($$), parens, text ) -- rel8 import Rel8.Expr ( Expr ) import Rel8.Expr.Opaleye (toPrimExpr) -import Rel8.Schema.Name ( Name, Selects, ppColumn ) +import Rel8.Schema.Escape (escape) +import Rel8.Schema.Name ( Selects ) +import Rel8.Schema.HTable (hfoldMap) import Rel8.Schema.Table ( TableSchema(..) ) import Rel8.Statement.Set ( ppSet ) import Rel8.Statement.Where ( ppWhere ) import Rel8.Table ( Table, toColumns ) -import Rel8.Table.Cols ( Cols( Cols ) ) -import Rel8.Table.Name ( showNames ) import Rel8.Table.Opaleye (attributes, view) -import Rel8.Table.Projection ( Projecting, Projection, apply ) -- | 'OnConflict' represents the @ON CONFLICT@ clause of an @INSERT@ -- statement. This specifies what ought to happen when one or more of the -- rows proposed for insertion conflict with an existing row in the table. type OnConflict :: Type -> Type -data OnConflict names +data OnConflict exprs = Abort -- ^ Abort the transaction if there are conflicting rows (Postgres' default) - | DoNothing - -- ^ @ON CONFLICT DO NOTHING@ - | DoUpdate (Upsert names) - -- ^ @ON CONFLICT DO UPDATE@ + | DoNothing (Maybe (Conflict exprs)) + -- ^ @ON CONFLICT DO NOTHING@, or @ON CONFLICT (...) DO NOTHING@ if an + -- explicit conflict target is supplied. Specifying a conflict target is + -- essential when your table has has deferrable constraints — @ON + -- CONFLICT@ can't work on deferrable constraints, so it's necessary + -- to explicitly name one of its non-deferrable constraints in order to + -- use @ON CONFLICT@. + | DoUpdate (Upsert exprs) + -- ^ @ON CONFLICT (...) DO UPDATE ...@ -- | The @ON CONFLICT (...) DO UPDATE@ clause of an @INSERT@ statement, also @@ -69,35 +74,75 @@ data OnConflict names -- are specified by listing the columns that comprise them along with an -- optional predicate in the case of partial indexes. type Upsert :: Type -> Type -data Upsert names where - Upsert :: (Selects names exprs, Projecting names index, excluded ~ exprs) => - { index :: Projection names index - -- ^ The set of columns comprising the @UNIQUE@ index that forms our - -- conflict target, projected from the set of columns for the whole - -- table - , predicate :: Maybe (exprs -> Expr Bool) - -- ^ An optional predicate used to specify a - -- [partial index](https://www.postgresql.org/docs/current/indexes-partial.html). +data Upsert exprs where + Upsert :: excluded ~ exprs => + { conflict :: Conflict exprs + -- ^ The conflict target to supply to @DO UPDATE@. , set :: excluded -> exprs -> exprs -- ^ How to update each selected row. , updateWhere :: excluded -> exprs -> Expr Bool -- ^ Which rows to select for update. } - -> Upsert names + -> Upsert exprs + + +-- | Represents what PostgreSQL calls a +-- [@conflict_target@](https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT) +-- in an @ON CONFLICT@ clause of an @INSERT@ statement. +type Conflict :: Type -> Type +data Conflict exprs + = OnConstraint String + -- ^ Use a specific named constraint for the conflict target. This + -- corresponds the the syntax @ON CONFLICT constraint@ in PostgreSQL. + | OnIndex (Index exprs) + -- ^ Have PostgreSQL perform what it calls _unique index inference_ by + -- giving it a description of the target index. + + +-- | A description of the target unique index — its columns (and/or +-- expressions) and, in the case of partial indexes, a predicate. +type Index :: Type -> Type +data Index exprs where + Index :: Table Expr index => + { columns :: exprs -> index + -- ^ The set of columns and/or expressions comprising the @UNIQUE@ index + , predicate :: Maybe (exprs -> Expr Bool) + -- ^ An optional predicate used to specify a + -- [partial index](https://www.postgresql.org/docs/current/indexes-partial.html). + } + -> Index exprs -ppOnConflict :: TableSchema names -> OnConflict names -> Doc -ppOnConflict schema = \case +ppOnConflict :: Selects names exprs => TableSchema names -> OnConflict exprs -> Doc +ppOnConflict schema@TableSchema {columns} = \case Abort -> mempty - DoNothing -> text "ON CONFLICT DO NOTHING" - DoUpdate upsert -> ppUpsert schema upsert + DoNothing conflict -> text "ON CONFLICT" <+> foldMap (ppConflict row) conflict <+> text "DO NOTHING" + DoUpdate upsert -> ppUpsert schema row upsert + where + row = view columns + +ppConflict :: exprs -> Conflict exprs -> Doc +ppConflict row = \case + OnConstraint name -> "ON CONSTRAINT" <+> escape name + OnIndex index -> ppIndex row index -ppUpsert :: TableSchema names -> Upsert names -> Doc -ppUpsert schema@TableSchema {columns} Upsert {..} = - text "ON CONFLICT" <+> - ppIndex columns index <+> foldMap (ppPredicate columns) predicate <+> - text "DO UPDATE" $$ + +ppIndex :: exprs -> Index exprs -> Doc +ppIndex row Index {columns, predicate} = + parens (Opaleye.commaH id exprs) <> + foldMap (ppPredicate . ($ row)) predicate + where + exprs = hfoldMap (pure . parens . ppExpr) $ toColumns $ columns row + + +ppPredicate :: Expr Bool -> Doc +ppPredicate condition = text "WHERE" <+> ppExpr condition + + +ppUpsert :: Selects names exprs => TableSchema names -> exprs -> Upsert exprs -> Doc +ppUpsert schema@TableSchema {columns} row Upsert {..} = + text "ON CONFLICT" <+> ppConflict row conflict <+> "DO UPDATE" $$ ppSet schema (set excluded) $$ ppWhere schema (updateWhere excluded) where @@ -107,16 +152,5 @@ ppUpsert schema@TableSchema {columns} Upsert {..} = } -ppIndex :: (Table Name names, Projecting names index) - => names -> Projection names index -> Doc -ppIndex columns index = - parens $ Opaleye.commaV ppColumn $ toList $ - showNames $ Cols $ apply index $ toColumns columns - - -ppPredicate :: Selects names exprs - => names -> (exprs -> Expr Bool) -> Doc -ppPredicate schema where_ = text "WHERE" <+> ppExpr condition - where - ppExpr = Opaleye.ppSqlExpr . Opaleye.sqlExpr . toPrimExpr - condition = where_ (view schema) +ppExpr :: Expr a -> Doc +ppExpr = Opaleye.ppSqlExpr . Opaleye.sqlExpr . toPrimExpr diff --git a/src/Rel8/Table/Verify.hs b/src/Rel8/Table/Verify.hs index b1ceacdc..306c3b65 100644 --- a/src/Rel8/Table/Verify.hs +++ b/src/Rel8/Table/Verify.hs @@ -1,35 +1,34 @@ - {-# language BlockArguments #-} -{-# language LambdaCase #-} -{-# language RecordWildCards #-} -{-# language RankNTypes #-} -{-# language DuplicateRecordFields #-} -{-# language DerivingStrategies #-} -{-# language OverloadedRecordDot #-} -{-# language TypeApplications #-} -{-# language NamedFieldPuns #-} -{-# language ScopedTypeVariables #-} -{-# language StandaloneDeriving #-} {-# language DeriveAnyClass #-} +{-# language DeriveGeneric #-} +{-# language DerivingStrategies #-} +{-# language DuplicateRecordFields #-} {-# language FlexibleContexts #-} {-# language FlexibleInstances #-} -{-# language DeriveGeneric #-} +{-# language GADTs #-} {-# language GeneralizedNewtypeDeriving #-} +{-# language LambdaCase #-} +{-# language NamedFieldPuns #-} +{-# language OverloadedRecordDot #-} {-# language OverloadedStrings #-} -{-# language GADTs #-} +{-# language RankNTypes #-} +{-# language RecordWildCards #-} +{-# language ScopedTypeVariables #-} +{-# language StandaloneDeriving #-} +{-# language TypeApplications #-} +{-# options_ghc -Wno-partial-fields #-} module Rel8.Table.Verify - ( getSchemaErrors - , SomeTableSchema(..) - , showCreateTable - , checkedShowCreateTable - ) where + ( getSchemaErrors + , SomeTableSchema(..) + , showCreateTable + , checkedShowCreateTable + ) +where -- base -import Control.Monad import Data.Bits (shiftR, (.&.)) -import Data.Either (lefts) -import Data.Function +import Data.Function ((&)) import Data.Functor ((<&>)) import Data.Functor.Const import Data.Functor.Contravariant ( (>$<) ) @@ -48,32 +47,45 @@ import qualified Prelude as P import qualified Data.Map as M -- hasql -import Hasql.Connection import qualified Hasql.Statement as HS -- rel8 -import Rel8 -- not importing this seems to cause a type error??? import Rel8.Column ( Column ) import Rel8.Column.List ( HList ) import Rel8.Expr ( Expr ) +import Rel8.Expr.Eq ((==.)) +import Rel8.Expr.Ord ((>.)) +import Rel8.Expr.Order (asc) import Rel8.Generic.Rel8able (GFromExprs, Rel8able) import Rel8.Query ( Query ) +import Rel8.Query.Each (each) +import Rel8.Query.Filter (filter) +import Rel8.Query.List (many) +import Rel8.Query.Order (orderBy) import Rel8.Schema.HTable import Rel8.Schema.Name ( Name(Name) ) import Rel8.Schema.Null hiding (nullable) -import qualified Rel8.Schema.Null as Null -import qualified Rel8.Statement.Run as RSR -import Rel8.Schema.Table ( TableSchema(..) ) -import Rel8.Schema.Spec -import Rel8.Schema.Result ( Result ) import Rel8.Schema.QualifiedName ( QualifiedName(..) ) -import Rel8.Table ( Columns ) +import Rel8.Schema.Result ( Result ) +import Rel8.Schema.Spec (Spec (Spec)) +import qualified Rel8.Schema.Spec +import Rel8.Schema.Table ( TableSchema(..) ) +import Rel8.Statement.Run (run1) +import Rel8.Statement.Select (select) +import Rel8.Table (Columns, toColumns) import Rel8.Table.List ( ListTable ) -import Rel8.Table.Serialize ( ToExprs ) +import Rel8.Table.Name (namesFromLabelsWith) +import Rel8.Table.Rel8able () +import Rel8.Table.Serialize (ToExprs, lit) import Rel8.Type ( DBType(..) ) import Rel8.Type.Eq ( DBEq ) +import Rel8.Type.Information (parseTypeInformation) +import qualified Rel8.Type.Information import Rel8.Type.Name ( TypeName(..) ) +-- semialign +import Data.Semialign (align) + -- these import Data.These @@ -338,7 +350,7 @@ showCreateTable_helper name typeMap = "CREATE TABLE " <> show name <> " (" ++ "\n);" where go :: (String, TypeInfo) -> String - go (name, typeInfo) = "\n " ++ show name ++ " " ++ showTypeInfo typeInfo + go (name', typeInfo) = "\n " ++ show name' ++ " " ++ showTypeInfo typeInfo -- |@'showCreateTable'@ shows an example CREATE TABLE statement for the table. @@ -378,17 +390,9 @@ checkTypeEquality env db hs sameMods = db.typeName.modifiers == hs.typeName.modifiers sameDims = db.typeName.arrayDepth == hs.typeName.arrayDepth - sameName = equalName db.typeName.name hs.typeName.name - toName :: TypeInfo -> String toName typeInfo = case typeInfo.typeName.name of - QualifiedName name _ -> L.dropWhile (=='_') name - -equalName :: QualifiedName -> QualifiedName -> Bool -equalName (QualifiedName a (Just b)) (QualifiedName a' (Just b')) - = L.dropWhile (=='_') a == L.dropWhile (=='_') a' && b == b' -equalName (QualifiedName a _) (QualifiedName a' _) - = dropWhile (=='_') a == dropWhile (=='_') a' + QualifiedName name _ -> L.dropWhile (== '_') name -- check types for a single table compareTypes @@ -430,7 +434,7 @@ compareTypes env attrMap typeMap = fmap (uncurry go) $ M.assocs (disjointUnion a (T.unpack attr.typ.typname) (Just $ T.unpack attr.namespace.nspname) , modifiers = toModifier - (T.dropWhile (=='_') attr.typ.typname) + (T.dropWhile (== '_') attr.typ.typname) attr.attribute.atttypmod , arrayDepth = fromIntegral attr.attribute.attndims } @@ -444,14 +448,10 @@ compareTypes env attrMap typeMap = fmap (uncurry go) $ M.assocs (disjointUnion a toModifier _ _ = [] disjointUnion :: Ord k => M.Map k a -> M.Map k b -> M.Map k (These a b) - disjointUnion a b = M.unionWith go (fmap This a) (fmap That b) - where - go :: These a b -> These a b -> These a b - go (This a) (That b) = These a b - go _ _ = undefined + disjointUnion = align --- |@pShowTable@ is a helper function which takes a grid of text and prints it +-- |@pShowTable@ i's a helper f'unction which takes a grid of text and prints' it' -- as a table, with padding so that cells are lined in columns, and a bordered -- header for the first row pShowTable :: [[Text]] -> Text @@ -464,7 +464,7 @@ pShowTable xs where addHeaderBorder :: [Text] -> [Text] addHeaderBorder [] = [] - addHeaderBorder (x : xs) = x : T.replicate (T.length x) "-" : xs + addHeaderBorder (a : as) = a : T.replicate (T.length a) "-" : as xs' :: [[Text]] xs' = L.transpose xs @@ -489,8 +489,8 @@ pShowErrors = T.intercalate "\n\n" . fmap go [ "Table " , T.pack (show name) , " has multiple columns with the same name. This is an error with the Haskell code generating an impossible schema, rather than an error in your current setup of the database itself. Using 'namesFromLabels' can ensure each column has unique names, which is the easiest way to prevent this, but may require changing names in your database to match the new generated names." - , pShowTable (["DB name", "Haskell label"] : (M.assocs duplicates <&> \(name, typs) -> - [ T.pack name + , pShowTable (["DB name", "Haskell label"] : (M.assocs duplicates <&> \(name', typs) -> + [ T.pack name' , T.intercalate " " $ fmap (\typ -> T.intercalate "/" $ fmap T.pack typ.label) $ NonEmpty.toList typs ])) ] @@ -531,8 +531,8 @@ showTypeInfo typeInfo = concat ] where name = case typeInfo.typeName.name of - QualifiedName a Nothing -> show (dropWhile (=='_') a) - QualifiedName a (Just b) -> show b <> "." <> show (dropWhile (=='_') a) + QualifiedName a Nothing -> show (dropWhile (== '_') a) + QualifiedName a (Just b) -> show b <> "." <> show (dropWhile (== '_') a) modifiers :: [String] modifiers = typeInfo.typeName.modifiers diff --git a/tests/Main.hs b/tests/Main.hs index e7384715..8ba5cd58 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -277,7 +277,7 @@ testShowCreateTable getTestDatabase = testGroup "CREATE TABLE" statement () $ Rel8.run_ $ Rel8.insert Rel8.Insert { into = tableSchema , rows = Rel8.values $ map Rel8.lit rows - , onConflict = Rel8.DoNothing + , onConflict = Rel8.DoNothing Nothing , returning = Rel8.NoReturning } statement () $ Rel8.run $ Rel8.select do @@ -335,7 +335,7 @@ testSelectTestTable = databasePropertyTest "Can SELECT TestTable" \transaction - statement () $ Rel8.run_ $ Rel8.insert Rel8.Insert { into = testTableSchema , rows = Rel8.values $ map Rel8.lit rows - , onConflict = Rel8.DoNothing + , onConflict = Rel8.DoNothing Nothing , returning = Rel8.NoReturning } @@ -945,7 +945,7 @@ testUpdate = databasePropertyTest "Can UPDATE TestTable" \transaction -> do statement () $ Rel8.run_ $ Rel8.insert Rel8.Insert { into = testTableSchema , rows = Rel8.values $ map Rel8.lit $ Map.keys rows - , onConflict = Rel8.DoNothing + , onConflict = Rel8.DoNothing Nothing , returning = Rel8.NoReturning } @@ -989,7 +989,7 @@ testDelete = databasePropertyTest "Can DELETE TestTable" \transaction -> do statement () $ Rel8.run_ $ Rel8.insert Rel8.Insert { into = testTableSchema , rows = Rel8.values $ map Rel8.lit rows - , onConflict = Rel8.DoNothing + , onConflict = Rel8.DoNothing Nothing , returning = Rel8.NoReturning } @@ -1029,7 +1029,7 @@ testWithStatement genTestDatabase = inserted <- Rel8.insert $ Rel8.Insert { into = testTableSchema , rows = values - , onConflict = Rel8.DoNothing + , onConflict = Rel8.DoNothing Nothing , returning = Rel8.Returning id } @@ -1047,7 +1047,7 @@ testWithStatement genTestDatabase = Rel8.insert $ Rel8.Insert { into = testTableSchema , rows = Rel8.values $ map Rel8.lit rows - , onConflict = Rel8.DoNothing + , onConflict = Rel8.DoNothing Nothing , returning = Rel8.NoReturning } @@ -1063,7 +1063,7 @@ testWithStatement genTestDatabase = Rel8.insert $ Rel8.Insert { into = testTableSchema , rows = Rel8.values $ map Rel8.lit rows - , onConflict = Rel8.DoNothing + , onConflict = Rel8.DoNothing Nothing , returning = Rel8.Returning id } @@ -1123,7 +1123,7 @@ testUpsert = databasePropertyTest "Can UPSERT UniqueTable" \transaction -> do statement () $ Rel8.run_ $ Rel8.insert Rel8.Insert { into = uniqueTableSchema , rows = Rel8.values $ Rel8.lit <$> as - , onConflict = Rel8.DoNothing + , onConflict = Rel8.DoNothing Nothing , returning = Rel8.NoReturning } @@ -1131,8 +1131,12 @@ testUpsert = databasePropertyTest "Can UPSERT UniqueTable" \transaction -> do { into = uniqueTableSchema , rows = Rel8.values $ Rel8.lit <$> bs , onConflict = Rel8.DoUpdate Rel8.Upsert - { index = uniqueTableKey - , predicate = Nothing + { conflict = + Rel8.OnIndex + Rel8.Index + { columns = uniqueTableKey + , predicate = Nothing + } , set = \UniqueTable {uniqueTableValue} old -> old {uniqueTableValue} , updateWhere = \_ _ -> Rel8.true } diff --git a/tests/Rel8/Generic/Rel8able/Test.hs b/tests/Rel8/Generic/Rel8able/Test.hs index c5780150..601c9215 100644 --- a/tests/Rel8/Generic/Rel8able/Test.hs +++ b/tests/Rel8/Generic/Rel8able/Test.hs @@ -24,7 +24,6 @@ where -- aeson import Data.Aeson ( Value(..) ) -import qualified Data.Aeson as Aeson import qualified Data.Aeson.KeyMap as Aeson -- base @@ -53,7 +52,29 @@ import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range -- rel8 -import Rel8 +import Rel8 ( + Column, + DBType, + Expr, + HADT, + HEither, + HKD, + HList, + HMaybe, + HNonEmpty, + HThese, + KRel8able, + Lift, + Name, + QualifiedName, + Rel8able, + Result, + TableSchema (TableSchema), + ToExprs, + namesFromLabels, + namesFromLabelsWith, + ) +import qualified Rel8 -- scientific import Data.Scientific ( Scientific, fromFloatDigits ) @@ -79,7 +100,6 @@ import Data.UUID ( UUID ) import qualified Data.UUID as UUID -- vector -import Data.Vector ( Vector ) import qualified Data.Vector as Vector @@ -477,7 +497,7 @@ genTableType = do int64 <- Gen.int64 range float <- Gen.float linearFrac double <- Gen.double linearFrac - scientific <- fromFloatDigits <$> Gen.realFloat linearFrac + scientific <- fromFloatDigits @Double <$> Gen.realFloat linearFrac utctime <- UTCTime <$> (toEnum <$> Gen.integral range) <*> fmap secondsToDiffTime (Gen.integral range) day <- toEnum <$> Gen.integral range localtime <- LocalTime <$> (toEnum <$> Gen.integral range) <*> timeOfDay @@ -494,7 +514,7 @@ genTableType = do [ Object <$> Aeson.fromMapText <$> Map.fromList <$> Gen.list range (liftA2 (,) (Gen.text range Gen.alpha) (pure Null)) , Array <$> Vector.fromList <$> Gen.list range (pure Null) , String <$> Gen.text range Gen.alpha - , Number <$> fromFloatDigits <$> Gen.realFloat linearFrac + , Number <$> fromFloatDigits @Double <$> Gen.realFloat linearFrac , Bool <$> Gen.bool , pure Null ]