Skip to content

Commit c73180a

Browse files
committed
WIP: use phantom types to restrict roles
1 parent 695b4ca commit c73180a

File tree

7 files changed

+87
-56
lines changed

7 files changed

+87
-56
lines changed

src/main/scala/com/simple/jdub/Database.scala

+34-23
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,21 @@ import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
44

55
import java.io.FileInputStream
66
import java.security.KeyStore
7-
import java.util.{UUID, Properties}
7+
import java.util.{Properties, UUID}
88
import javax.sql.DataSource
9-
109
import com.codahale.metrics.MetricRegistry
1110
import com.codahale.metrics.health.HealthCheckRegistry
11+
import com.simple.jdub.Database.Primary
12+
import com.simple.jdub.Database.Replica
13+
14+
import scala.annotation.implicitNotFound
1215

1316
object Database {
1417

18+
sealed trait Role
19+
final class Primary extends Role
20+
final class Replica extends Role
21+
1522
/**
1623
* Create a pool of connections to the given database.
1724
*
@@ -20,7 +27,7 @@ object Database {
2027
* @param password the database password
2128
* @param sslSettings if present, uses the given SSL settings for a client-side SSL cert.
2229
*/
23-
def connect(url: String,
30+
def connect[R <: Role](url: String,
2431
username: String,
2532
password: String,
2633
name: Option[String] = None,
@@ -30,7 +37,7 @@ object Database {
3037
sslSettings: Option[SslSettings] = None,
3138
healthCheckRegistry: Option[HealthCheckRegistry] = None,
3239
metricRegistry: Option[MetricRegistry] = None,
33-
connectionInitSql: Option[String] = None): Database = {
40+
connectionInitSql: Option[String] = None): Database[R] = {
3441

3542
val properties = new Properties
3643

@@ -95,8 +102,8 @@ object Database {
95102
/**
96103
* A set of pooled connections to a database.
97104
*/
98-
class Database protected(val source: DataSource, metrics: Option[MetricRegistry])
99-
extends Queryable {
105+
class Database[R <: Database.Role] protected(val source: DataSource, metrics: Option[MetricRegistry])
106+
extends Queryable[R] {
100107

101108
private[jdub] def time[A](klass: java.lang.Class[_])(f: => A) = {
102109
metrics.fold(f) { registry =>
@@ -110,34 +117,38 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
110117
}
111118
}
112119

113-
val transactionProvider: TransactionProvider = new TransactionManager
120+
val transactionProvider: TransactionProvider[R] = new TransactionManager
121+
122+
def replica: Database[Replica] = new Database[Replica](source, metrics)
123+
124+
def primary: Database[Primary] = new Database[Primary](source, metrics)
114125

115126
/**
116127
* Opens a transaction which is committed after `f` is called. If `f` throws
117128
* an exception, the transaction is rolled back.
118129
*/
119-
def transaction[A](f: Transaction => A): A = transaction(true, f)
130+
def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(true, f)
120131

121132
/**
122133
* Opens a transaction which is committed after `f` is called. If `f` throws
123134
* an exception, the transaction is rolled back, but the exception is not
124135
* logged (since it is rethrown).
125136
*/
126-
def quietTransaction[A](f: Transaction => A): A = transaction(false, f)
137+
def quietTransaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(false, f)
127138

128-
def transaction[A](logError: Boolean, f: Transaction => A): A = transaction(false, false, f)
139+
def transaction[A](logError: Boolean, f: Transaction[R] => A)(implicit ev: R =:= Primary): A = transaction(false, false, f)
129140

130141
/**
131142
* Opens a transaction which is committed after `f` is called. If `f` throws
132143
* an exception, the transaction is rolled back.
133144
*/
134-
def transaction[A](logError: Boolean, forceNew: Boolean, f: Transaction => A): A = {
145+
def transaction[A](logError: Boolean, forceNew: Boolean, f: Transaction[R] => A)(implicit ev: R =:= Primary): A = {
135146
if (!forceNew && transactionProvider.transactionExists) {
136147
f(transactionProvider.currentTransaction)
137148
} else {
138149
val connection = source.getConnection
139150
connection.setAutoCommit(false)
140-
val txn = new Transaction(connection)
151+
val txn = new Transaction[R](connection)
141152
try {
142153
logger.debug("Starting transaction")
143154
val result = f(txn)
@@ -162,8 +173,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
162173
* thread within the scope of `f`. If `f` throws an exception the transaction
163174
* is rolled back. Logs exceptions thrown by `f` as errors.
164175
*/
165-
def transactionScope[A](f: => A): A = {
166-
transaction(logError = true, forceNew = false, (txn: Transaction) => {
176+
def transactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
177+
transaction(logError = true, forceNew = false, (txn: Transaction[R]) => {
167178
transactionProvider.begin(txn)
168179
try {
169180
f
@@ -181,8 +192,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
181192
* exception the transaction is rolled back. Logs exceptions thrown by
182193
* `f` as errors.
183194
*/
184-
def newTransactionScope[A](f: => A): A = {
185-
transaction(logError = true, forceNew = true, (txn: Transaction) => {
195+
def newTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
196+
transaction(logError = true, forceNew = true, (txn: Transaction[R]) => {
186197
transactionProvider.begin(txn)
187198
try {
188199
f
@@ -197,8 +208,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
197208
* thread within the scope of `f`. If `f` throws an exception the transaction
198209
* is rolled back. Will not log exceptions thrown by `f`.
199210
*/
200-
def quietTransactionScope[A](f: => A): A = {
201-
transaction(logError = false, forceNew = false, (txn: Transaction) => {
211+
def quietTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
212+
transaction(logError = false, forceNew = false, (txn: Transaction[R]) => {
202213
transactionProvider.begin(txn)
203214
try {
204215
f
@@ -216,8 +227,8 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
216227
* exception the transaction is rolled back. Will not log exceptions
217228
* thrown by `f`.
218229
*/
219-
def newQuietTransactionScope[A](f: => A): A = {
220-
transaction(logError = false, forceNew = true, (txn: Transaction) => {
230+
def newQuietTransactionScope[A](f: => A)(implicit ev: R =:= Primary): A = {
231+
transaction(logError = false, forceNew = true, (txn: Transaction[R]) => {
221232
transactionProvider.begin(txn)
222233
try {
223234
f
@@ -230,7 +241,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
230241
/**
231242
* The transaction currently scoped via transactionScope.
232243
*/
233-
def currentTransaction = {
244+
def currentTransaction(implicit ev: R =:= Primary) = {
234245
transactionProvider.currentTransaction
235246
}
236247

@@ -260,7 +271,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
260271
/**
261272
* Executes an update, insert, delete, or DDL statement.
262273
*/
263-
def execute(statement: Statement) = {
274+
def execute(statement: Statement)(implicit ev: R =:= Primary): Int = {
264275
if (transactionProvider.transactionExists) {
265276
transactionProvider.currentTransaction.execute(statement)
266277
} else {
@@ -278,7 +289,7 @@ class Database protected(val source: DataSource, metrics: Option[MetricRegistry]
278289
/**
279290
* Rollback any existing ambient transaction
280291
*/
281-
def rollback() {
292+
def rollback()(implicit ev: R =:= Primary) {
282293
transactionProvider.rollback
283294
}
284295

src/main/scala/com/simple/jdub/Queryable.scala

+10-8
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
package com.simple.jdub
77

8-
import java.sql.Connection
8+
import com.simple.jdub.Database.Primary
9+
import com.simple.jdub.Database.Role
910

11+
import java.sql.Connection
1012
import grizzled.slf4j.Logging
1113

12-
trait Queryable extends Logging {
14+
trait Queryable[R <: Role] extends Logging {
1315
import Utils._
1416

1517
/**
@@ -35,7 +37,7 @@ trait Queryable extends Logging {
3537
/**
3638
* Executes an update, insert, delete, or DDL statement.
3739
*/
38-
def execute(connection: Connection, statement: Statement): Int = {
40+
def execute(connection: Connection, statement: Statement)(implicit ev: R =:= Primary): Int = {
3941
logger.debug("%s with %s".format(statement.sql,
4042
statement.values.mkString("(", ", ", ")")))
4143
val stmt = connection.prepareStatement(prependComment(statement, statement.sql))
@@ -47,9 +49,9 @@ trait Queryable extends Logging {
4749
}
4850
}
4951

50-
def execute(statement: Statement): Int
52+
def execute(statement: Statement)(implicit ev: R =:= Primary): Int
5153
def apply[A](query: RawQuery[A]): A
52-
def transaction[A](f: Transaction => A): A
54+
def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A
5355

5456
/**
5557
* Performs a query and returns the results.
@@ -59,15 +61,15 @@ trait Queryable extends Logging {
5961
/**
6062
* Executes an update statement.
6163
*/
62-
def update(statement: Statement): Int = execute(statement)
64+
def update(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement)
6365

6466
/**
6567
* Executes an insert statement.
6668
*/
67-
def insert(statement: Statement): Int = execute(statement)
69+
def insert(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement)
6870

6971
/**
7072
* Executes a delete statement.
7173
*/
72-
def delete(statement: Statement): Int = execute(statement)
74+
def delete(statement: Statement)(implicit ev: R =:= Primary): Int = execute(statement)
7375
}

src/main/scala/com/simple/jdub/RawQuery.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.simple.jdub
22

3+
import com.simple.jdub.Database.Role
4+
35
import java.sql.ResultSet
46

57
trait RawQuery[A] extends SqlBase {
@@ -9,5 +11,5 @@ trait RawQuery[A] extends SqlBase {
911

1012
def handle(results: ResultSet): A
1113

12-
def apply(db: Database): A = db(this)
14+
def apply(db: Database[Role]): A = db(this)
1315
}

src/main/scala/com/simple/jdub/Transaction.scala

+12-9
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package com.simple.jdub
22

3+
import com.simple.jdub.Database.Primary
4+
import com.simple.jdub.Database.Role
5+
36
import java.sql.{Connection, Savepoint}
47
import scala.collection.mutable.ListBuffer
58

6-
class Transaction(val connection: Connection) extends Queryable {
9+
class Transaction[R <: Role](val connection: Connection) extends Queryable[R] {
710
private[this] var rolledback = false
811

912
/**
@@ -14,12 +17,12 @@ class Transaction(val connection: Connection) extends Queryable {
1417
/**
1518
* Executes an update, insert, delete, or DDL statement.
1619
*/
17-
def execute(statement: Statement) = execute(connection, statement)
20+
def execute(statement: Statement)(implicit ev: R =:= Primary) = execute(connection, statement)
1821

1922
/**
2023
* Roll back the transaction.
2124
*/
22-
def rollback() {
25+
def rollback()(implicit ev: R =:= Primary) {
2326
logger.debug("Rolling back transaction")
2427
connection.rollback()
2528
rolledback = true
@@ -29,36 +32,36 @@ class Transaction(val connection: Connection) extends Queryable {
2932
/**
3033
* Roll back the transaction to a savepoint.
3134
*/
32-
def rollback(savepoint: Savepoint) {
35+
def rollback(savepoint: Savepoint)(implicit ev: R =:= Primary) {
3336
logger.debug("Rolling back to savepoint")
3437
connection.rollback(savepoint)
3538
}
3639

3740
/**
3841
* Release a transaction from a savepoint.
3942
*/
40-
def release(savepoint: Savepoint) {
43+
def release(savepoint: Savepoint)(implicit ev: R =:= Primary) {
4144
logger.debug("Releasing savepoint")
4245
connection.releaseSavepoint(savepoint)
4346
}
4447

4548
/**
4649
* Set an unnamed savepoint.
4750
*/
48-
def savepoint(): Savepoint = {
51+
def savepoint()(implicit ev: R =:= Primary): Savepoint = {
4952
logger.debug("Setting unnamed savepoint")
5053
connection.setSavepoint()
5154
}
5255

5356
/**
5457
* Set a named savepoint.
5558
*/
56-
def savepoint(name: String): Savepoint = {
59+
def savepoint(name: String)(implicit ev: R =:= Primary): Savepoint = {
5760
logger.debug("Setting savepoint")
5861
connection.setSavepoint(name)
5962
}
6063

61-
private[jdub] def commit() {
64+
private[jdub] def commit()(implicit ev: R =:= Primary) {
6265
if (!rolledback) {
6366
logger.debug("Committing transaction")
6467
connection.commit()
@@ -72,7 +75,7 @@ class Transaction(val connection: Connection) extends Queryable {
7275
onClose.foreach(_())
7376
}
7477

75-
def transaction[A](f: Transaction => A): A = f(this)
78+
def transaction[A](f: Transaction[R] => A)(implicit ev: R =:= Primary): A = f(this)
7679

7780
var onCommit: ListBuffer[() => Unit] = ListBuffer.empty[() => Unit]
7881
var onClose: ListBuffer[() => Unit] = ListBuffer.empty[() => Unit]

src/main/scala/com/simple/jdub/TransactionManager.scala

+16-13
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,21 @@
55

66
package com.simple.jdub
77

8+
import com.simple.jdub.Database.Primary
9+
import com.simple.jdub.Database.Role
10+
811
import java.util.Stack
912

10-
trait TransactionProvider {
13+
trait TransactionProvider[R <: Role] {
1114
def transactionExists: Boolean
12-
def currentTransaction: Transaction
13-
def begin(transaction: Transaction): Unit
14-
def end(): Unit
15-
def rollback(): Unit
15+
def currentTransaction: Transaction[R]
16+
def begin(transaction: Transaction[R])(implicit ev: R =:= Primary): Unit
17+
def end()(implicit ev: R =:= Primary): Unit
18+
def rollback()(implicit ev: R =:= Primary): Unit
1619
}
1720

18-
class TransactionManager extends TransactionProvider {
19-
case class TransactionState(transactions: Stack[Transaction])
21+
class TransactionManager[R <: Role] extends TransactionProvider[R] {
22+
case class TransactionState(transactions: Stack[Transaction[R]])
2023

2124
private val localTransactionStorage = new ThreadLocal[Option[TransactionState]] {
2225
override def initialValue = None
@@ -26,7 +29,7 @@ class TransactionManager extends TransactionProvider {
2629
localTransactionStorage.get
2730
}
2831

29-
protected def ambientTransaction: Option[Transaction] = {
32+
protected def ambientTransaction: Option[Transaction[R]] = {
3033
ambientTransactionState.map(_.transactions.peek)
3134
}
3235

@@ -40,23 +43,23 @@ class TransactionManager extends TransactionProvider {
4043
ambientTransactionState.isDefined
4144
}
4245

43-
def currentTransaction: Transaction = {
46+
def currentTransaction: Transaction[R] = {
4447
ambientTransaction.getOrElse(
4548
throw new Exception("No transaction in current context")
4649
)
4750
}
4851

49-
def begin(transaction: Transaction): Unit = {
52+
def begin(transaction: Transaction[R])(implicit ev: R =:= Primary): Unit = {
5053
if (!transactionExists) {
51-
val stack = new Stack[Transaction]()
54+
val stack = new Stack[Transaction[R]]()
5255
stack.push(transaction)
5356
localTransactionStorage.set(Some(new TransactionState(stack)))
5457
} else {
5558
currentTransactionState.transactions.push(transaction)
5659
}
5760
}
5861

59-
def end(): Unit = {
62+
def end()(implicit ev: R =:= Primary): Unit = {
6063
if (!transactionExists) {
6164
throw new Exception("No transaction in current context")
6265
} else {
@@ -67,7 +70,7 @@ class TransactionManager extends TransactionProvider {
6770
}
6871
}
6972

70-
def rollback(): Unit = {
73+
def rollback()(implicit ev: R =:= Primary): Unit = {
7174
currentTransaction.rollback
7275
}
7376
}

0 commit comments

Comments
 (0)