diesel/connection/
transaction_manager.rs1use backend::UsesAnsiSavepointSyntax;
2use connection::{Connection, SimpleConnection};
3use result::{DatabaseErrorKind, Error, QueryResult};
4
5pub trait TransactionManager<Conn: Connection> {
10 fn begin_transaction(&self, conn: &Conn) -> QueryResult<()>;
16
17 fn rollback_transaction(&self, conn: &Conn) -> QueryResult<()>;
23
24 fn commit_transaction(&self, conn: &Conn) -> QueryResult<()>;
30
31 fn get_transaction_depth(&self) -> u32;
36}
37
38use std::cell::Cell;
39
40#[allow(missing_debug_implementations)]
43#[derive(Default)]
44pub struct AnsiTransactionManager {
45 transaction_depth: Cell<i32>,
46}
47
48impl AnsiTransactionManager {
49 pub fn new() -> Self {
51 AnsiTransactionManager::default()
52 }
53
54 fn change_transaction_depth(&self, by: i32, query: QueryResult<()>) -> QueryResult<()> {
55 if query.is_ok() {
56 self.transaction_depth
57 .set(self.transaction_depth.get() + by)
58 }
59 query
60 }
61
62 pub fn begin_transaction_sql<Conn>(&self, conn: &Conn, sql: &str) -> QueryResult<()>
68 where
69 Conn: SimpleConnection,
70 {
71 use result::Error::AlreadyInTransaction;
72
73 if self.transaction_depth.get() == 0 {
74 self.change_transaction_depth(1, conn.batch_execute(sql))
75 } else {
76 Err(AlreadyInTransaction)
77 }
78 }
79}
80
81impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
82where
83 Conn: Connection,
84 Conn::Backend: UsesAnsiSavepointSyntax,
85{
86 fn begin_transaction(&self, conn: &Conn) -> QueryResult<()> {
87 let transaction_depth = self.transaction_depth.get();
88 self.change_transaction_depth(
89 1,
90 if transaction_depth == 0 {
91 conn.batch_execute("BEGIN")
92 } else {
93 conn.batch_execute(&format!("SAVEPOINT diesel_savepoint_{}", transaction_depth))
94 },
95 )
96 }
97
98 fn rollback_transaction(&self, conn: &Conn) -> QueryResult<()> {
99 let transaction_depth = self.transaction_depth.get();
100 self.change_transaction_depth(
101 -1,
102 if transaction_depth == 1 {
103 conn.batch_execute("ROLLBACK")
104 } else {
105 conn.batch_execute(&format!(
106 "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
107 transaction_depth - 1
108 ))
109 },
110 )
111 }
112
113 fn commit_transaction(&self, conn: &Conn) -> QueryResult<()> {
118 let transaction_depth = self.transaction_depth.get();
119 if transaction_depth <= 1 {
120 match conn.batch_execute("COMMIT") {
121 e @ Err(Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) => {
125 self.change_transaction_depth(-1, conn.batch_execute("ROLLBACK"))?;
126 e
127 }
128 result => self.change_transaction_depth(-1, result),
129 }
130 } else {
131 self.change_transaction_depth(
132 -1,
133 conn.batch_execute(&format!(
134 "RELEASE SAVEPOINT diesel_savepoint_{}",
135 transaction_depth - 1
136 )),
137 )
138 }
139 }
140
141 fn get_transaction_depth(&self) -> u32 {
142 self.transaction_depth.get() as u32
143 }
144}
145
146#[cfg(test)]
147mod test {
148 #[cfg(feature = "postgres")]
149 macro_rules! matches {
150 ($expression:expr, $( $pattern:pat )|+ $( if $guard: expr )?) => {
151 match $expression {
152 $( $pattern )|+ $( if $guard )? => true,
153 _ => false
154 }
155 }
156 }
157
158 #[test]
159 #[cfg(feature = "postgres")]
160 fn transaction_depth_is_tracked_properly_on_commit_failure() {
161 use crate::result::DatabaseErrorKind::SerializationFailure;
162 use crate::result::Error::DatabaseError;
163 use crate::*;
164 use std::sync::{Arc, Barrier};
165 use std::thread;
166
167 table! {
168 #[sql_name = "transaction_depth_is_tracked_properly_on_commit_failure"]
169 serialization_example {
170 id -> Serial,
171 class -> Integer,
172 }
173 }
174
175 let conn = crate::test_helpers::pg_connection_no_transaction();
176
177 sql_query("DROP TABLE IF EXISTS transaction_depth_is_tracked_properly_on_commit_failure;")
178 .execute(&conn)
179 .unwrap();
180 sql_query(
181 r#"
182 CREATE TABLE transaction_depth_is_tracked_properly_on_commit_failure (
183 id SERIAL PRIMARY KEY,
184 class INTEGER NOT NULL
185 )
186 "#,
187 )
188 .execute(&conn)
189 .unwrap();
190
191 insert_into(serialization_example::table)
192 .values(&vec![
193 serialization_example::class.eq(1),
194 serialization_example::class.eq(2),
195 ])
196 .execute(&conn)
197 .unwrap();
198
199 let barrier = Arc::new(Barrier::new(2));
200 let threads = (1..3)
201 .map(|i| {
202 let barrier = barrier.clone();
203 thread::spawn(move || {
204 use crate::connection::transaction_manager::AnsiTransactionManager;
205 use crate::connection::transaction_manager::TransactionManager;
206 let conn = crate::test_helpers::pg_connection_no_transaction();
207 assert_eq!(0, <AnsiTransactionManager as TransactionManager<PgConnection>>::get_transaction_depth(&conn.transaction_manager));
208
209 let result =
210 conn.build_transaction().serializable().run(|| {
211 assert_eq!(1, <AnsiTransactionManager as TransactionManager<PgConnection>>::get_transaction_depth(&conn.transaction_manager));
212
213 let _ = serialization_example::table
214 .filter(serialization_example::class.eq(i))
215 .count()
216 .execute(&conn)?;
217
218 barrier.wait();
219
220 let other_i = if i == 1 { 2 } else { 1 };
221 insert_into(serialization_example::table)
222 .values(serialization_example::class.eq(other_i))
223 .execute(&conn)
224 });
225
226 assert_eq!(0, <AnsiTransactionManager as TransactionManager<PgConnection>>::get_transaction_depth(&conn.transaction_manager));
227 result
228 })
229 })
230 .collect::<Vec<_>>();
231
232 let mut results = threads
233 .into_iter()
234 .map(|t| t.join().unwrap())
235 .collect::<Vec<_>>();
236
237 results.sort_by_key(|r| r.is_err());
238
239 assert!(matches!(results[0], Ok(_)));
240 assert!(matches!(
241 results[1],
242 Err(DatabaseError(SerializationFailure, _))
243 ));
244 }
245}