diesel/connection/
transaction_manager.rsuse backend::UsesAnsiSavepointSyntax;
use connection::{Connection, SimpleConnection};
use result::{DatabaseErrorKind, Error, QueryResult};
pub trait TransactionManager<Conn: Connection> {
fn begin_transaction(&self, conn: &Conn) -> QueryResult<()>;
fn rollback_transaction(&self, conn: &Conn) -> QueryResult<()>;
fn commit_transaction(&self, conn: &Conn) -> QueryResult<()>;
fn get_transaction_depth(&self) -> u32;
}
use std::cell::Cell;
#[allow(missing_debug_implementations)]
#[derive(Default)]
pub struct AnsiTransactionManager {
transaction_depth: Cell<i32>,
}
impl AnsiTransactionManager {
pub fn new() -> Self {
AnsiTransactionManager::default()
}
fn change_transaction_depth(&self, by: i32, query: QueryResult<()>) -> QueryResult<()> {
if query.is_ok() {
self.transaction_depth
.set(self.transaction_depth.get() + by)
}
query
}
pub fn begin_transaction_sql<Conn>(&self, conn: &Conn, sql: &str) -> QueryResult<()>
where
Conn: SimpleConnection,
{
use result::Error::AlreadyInTransaction;
if self.transaction_depth.get() == 0 {
self.change_transaction_depth(1, conn.batch_execute(sql))
} else {
Err(AlreadyInTransaction)
}
}
}
impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
where
Conn: Connection,
Conn::Backend: UsesAnsiSavepointSyntax,
{
fn begin_transaction(&self, conn: &Conn) -> QueryResult<()> {
let transaction_depth = self.transaction_depth.get();
self.change_transaction_depth(
1,
if transaction_depth == 0 {
conn.batch_execute("BEGIN")
} else {
conn.batch_execute(&format!("SAVEPOINT diesel_savepoint_{}", transaction_depth))
},
)
}
fn rollback_transaction(&self, conn: &Conn) -> QueryResult<()> {
let transaction_depth = self.transaction_depth.get();
self.change_transaction_depth(
-1,
if transaction_depth == 1 {
conn.batch_execute("ROLLBACK")
} else {
conn.batch_execute(&format!(
"ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
transaction_depth - 1
))
},
)
}
fn commit_transaction(&self, conn: &Conn) -> QueryResult<()> {
let transaction_depth = self.transaction_depth.get();
if transaction_depth <= 1 {
match conn.batch_execute("COMMIT") {
e @ Err(Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) => {
self.change_transaction_depth(-1, conn.batch_execute("ROLLBACK"))?;
e
}
result => self.change_transaction_depth(-1, result),
}
} else {
self.change_transaction_depth(
-1,
conn.batch_execute(&format!(
"RELEASE SAVEPOINT diesel_savepoint_{}",
transaction_depth - 1
)),
)
}
}
fn get_transaction_depth(&self) -> u32 {
self.transaction_depth.get() as u32
}
}
#[cfg(test)]
mod test {
#[cfg(feature = "postgres")]
macro_rules! matches {
($expression:expr, $( $pattern:pat )|+ $( if $guard: expr )?) => {
match $expression {
$( $pattern )|+ $( if $guard )? => true,
_ => false
}
}
}
#[test]
#[cfg(feature = "postgres")]
fn transaction_depth_is_tracked_properly_on_commit_failure() {
use crate::result::DatabaseErrorKind::SerializationFailure;
use crate::result::Error::DatabaseError;
use crate::*;
use std::sync::{Arc, Barrier};
use std::thread;
table! {
#[sql_name = "transaction_depth_is_tracked_properly_on_commit_failure"]
serialization_example {
id -> Serial,
class -> Integer,
}
}
let conn = crate::test_helpers::pg_connection_no_transaction();
sql_query("DROP TABLE IF EXISTS transaction_depth_is_tracked_properly_on_commit_failure;")
.execute(&conn)
.unwrap();
sql_query(
r#"
CREATE TABLE transaction_depth_is_tracked_properly_on_commit_failure (
id SERIAL PRIMARY KEY,
class INTEGER NOT NULL
)
"#,
)
.execute(&conn)
.unwrap();
insert_into(serialization_example::table)
.values(&vec![
serialization_example::class.eq(1),
serialization_example::class.eq(2),
])
.execute(&conn)
.unwrap();
let barrier = Arc::new(Barrier::new(2));
let threads = (1..3)
.map(|i| {
let barrier = barrier.clone();
thread::spawn(move || {
use crate::connection::transaction_manager::AnsiTransactionManager;
use crate::connection::transaction_manager::TransactionManager;
let conn = crate::test_helpers::pg_connection_no_transaction();
assert_eq!(0, <AnsiTransactionManager as TransactionManager<PgConnection>>::get_transaction_depth(&conn.transaction_manager));
let result =
conn.build_transaction().serializable().run(|| {
assert_eq!(1, <AnsiTransactionManager as TransactionManager<PgConnection>>::get_transaction_depth(&conn.transaction_manager));
let _ = serialization_example::table
.filter(serialization_example::class.eq(i))
.count()
.execute(&conn)?;
barrier.wait();
let other_i = if i == 1 { 2 } else { 1 };
insert_into(serialization_example::table)
.values(serialization_example::class.eq(other_i))
.execute(&conn)
});
assert_eq!(0, <AnsiTransactionManager as TransactionManager<PgConnection>>::get_transaction_depth(&conn.transaction_manager));
result
})
})
.collect::<Vec<_>>();
let mut results = threads
.into_iter()
.map(|t| t.join().unwrap())
.collect::<Vec<_>>();
results.sort_by_key(|r| r.is_err());
assert!(matches!(results[0], Ok(_)));
assert!(matches!(
results[1],
Err(DatabaseError(SerializationFailure, _))
));
}
}