diesel/pg/connection/
mod.rs1mod cursor;
2pub mod raw;
3#[doc(hidden)]
4pub mod result;
5mod row;
6mod stmt;
7
8use std::ffi::CString;
9use std::os::raw as libc;
10
11use self::cursor::*;
12use self::raw::RawConnection;
13use self::result::PgResult;
14use self::stmt::Statement;
15use connection::*;
16use deserialize::{Queryable, QueryableByName};
17use pg::{Pg, PgMetadataLookup, TransactionBuilder};
18use query_builder::bind_collector::RawBytesBindCollector;
19use query_builder::*;
20use result::ConnectionError::CouldntSetupConfiguration;
21use result::*;
22use sql_types::HasSqlType;
23
24#[allow(missing_debug_implementations)]
28pub struct PgConnection {
29 raw_connection: RawConnection,
30 pub(crate) transaction_manager: AnsiTransactionManager,
31 statement_cache: StatementCache<Pg, Statement>,
32}
33
34unsafe impl Send for PgConnection {}
35
36impl SimpleConnection for PgConnection {
37 fn batch_execute(&self, query: &str) -> QueryResult<()> {
38 let query = CString::new(query)?;
39 let inner_result = unsafe { self.raw_connection.exec(query.as_ptr()) };
40 PgResult::new(inner_result?)?;
41 Ok(())
42 }
43}
44
45impl Connection for PgConnection {
46 type Backend = Pg;
47 type TransactionManager = AnsiTransactionManager;
48
49 fn establish(database_url: &str) -> ConnectionResult<PgConnection> {
50 RawConnection::establish(database_url).and_then(|raw_conn| {
51 let conn = PgConnection {
52 raw_connection: raw_conn,
53 transaction_manager: AnsiTransactionManager::new(),
54 statement_cache: StatementCache::new(),
55 };
56 conn.set_config_options()
57 .map_err(CouldntSetupConfiguration)?;
58 Ok(conn)
59 })
60 }
61
62 #[doc(hidden)]
63 fn execute(&self, query: &str) -> QueryResult<usize> {
64 self.execute_inner(query).map(|res| res.rows_affected())
65 }
66
67 #[doc(hidden)]
68 fn query_by_index<T, U>(&self, source: T) -> QueryResult<Vec<U>>
69 where
70 T: AsQuery,
71 T::Query: QueryFragment<Pg> + QueryId,
72 Pg: HasSqlType<T::SqlType>,
73 U: Queryable<T::SqlType, Pg>,
74 {
75 let (query, params) = self.prepare_query(&source.as_query())?;
76 query
77 .execute(&self.raw_connection, ¶ms)
78 .and_then(|r| Cursor::new(r).collect())
79 }
80
81 #[doc(hidden)]
82 fn query_by_name<T, U>(&self, source: &T) -> QueryResult<Vec<U>>
83 where
84 T: QueryFragment<Pg> + QueryId,
85 U: QueryableByName<Pg>,
86 {
87 let (query, params) = self.prepare_query(source)?;
88 query
89 .execute(&self.raw_connection, ¶ms)
90 .and_then(|r| NamedCursor::new(r).collect())
91 }
92
93 #[doc(hidden)]
94 fn execute_returning_count<T>(&self, source: &T) -> QueryResult<usize>
95 where
96 T: QueryFragment<Pg> + QueryId,
97 {
98 let (query, params) = self.prepare_query(source)?;
99 query
100 .execute(&self.raw_connection, ¶ms)
101 .map(|r| r.rows_affected())
102 }
103
104 #[doc(hidden)]
105 fn transaction_manager(&self) -> &Self::TransactionManager {
106 &self.transaction_manager
107 }
108}
109
110impl PgConnection {
111 pub fn build_transaction(&self) -> TransactionBuilder {
136 TransactionBuilder::new(self)
137 }
138
139 #[allow(clippy::type_complexity)]
140 fn prepare_query<T: QueryFragment<Pg> + QueryId>(
141 &self,
142 source: &T,
143 ) -> QueryResult<(MaybeCached<Statement>, Vec<Option<Vec<u8>>>)> {
144 let mut bind_collector = RawBytesBindCollector::<Pg>::new();
145 source.collect_binds(&mut bind_collector, PgMetadataLookup::new(self))?;
146 let binds = bind_collector.binds;
147 let metadata = bind_collector.metadata;
148
149 let cache_len = self.statement_cache.len();
150 let query = self
151 .statement_cache
152 .cached_statement(source, &metadata, |sql| {
153 let query_name = if source.is_safe_to_cache_prepared()? {
154 Some(format!("__diesel_stmt_{}", cache_len))
155 } else {
156 None
157 };
158 Statement::prepare(
159 &self.raw_connection,
160 sql,
161 query_name.as_ref().map(|s| &**s),
162 &metadata,
163 )
164 });
165
166 Ok((query?, binds))
167 }
168
169 fn execute_inner(&self, query: &str) -> QueryResult<PgResult> {
170 let query = Statement::prepare(&self.raw_connection, query, None, &[])?;
171 query.execute(&self.raw_connection, &Vec::new())
172 }
173
174 fn set_config_options(&self) -> QueryResult<()> {
175 self.execute("SET TIME ZONE 'UTC'")?;
176 self.execute("SET CLIENT_ENCODING TO 'UTF8'")?;
177 self.raw_connection
178 .set_notice_processor(noop_notice_processor);
179 Ok(())
180 }
181}
182
183extern "C" fn noop_notice_processor(_: *mut libc::c_void, _message: *const libc::c_char) {}
184
185#[cfg(test)]
186mod tests {
187 extern crate dotenv;
188
189 use self::dotenv::dotenv;
190 use std::env;
191
192 use super::*;
193 use dsl::sql;
194 use prelude::*;
195 use sql_types::{Integer, VarChar};
196
197 #[test]
198 fn prepared_statements_are_cached() {
199 let connection = connection();
200
201 let query = ::select(1.into_sql::<Integer>());
202
203 assert_eq!(Ok(1), query.get_result(&connection));
204 assert_eq!(Ok(1), query.get_result(&connection));
205 assert_eq!(1, connection.statement_cache.len());
206 }
207
208 #[test]
209 fn queries_with_identical_sql_but_different_types_are_cached_separately() {
210 let connection = connection();
211
212 let query = ::select(1.into_sql::<Integer>());
213 let query2 = ::select("hi".into_sql::<VarChar>());
214
215 assert_eq!(Ok(1), query.get_result(&connection));
216 assert_eq!(Ok("hi".to_string()), query2.get_result(&connection));
217 assert_eq!(2, connection.statement_cache.len());
218 }
219
220 #[test]
221 fn queries_with_identical_types_and_sql_but_different_bind_types_are_cached_separately() {
222 let connection = connection();
223
224 let query = ::select(1.into_sql::<Integer>()).into_boxed::<Pg>();
225 let query2 = ::select("hi".into_sql::<VarChar>()).into_boxed::<Pg>();
226
227 assert_eq!(0, connection.statement_cache.len());
228 assert_eq!(Ok(1), query.get_result(&connection));
229 assert_eq!(Ok("hi".to_string()), query2.get_result(&connection));
230 assert_eq!(2, connection.statement_cache.len());
231 }
232
233 #[test]
234 fn queries_with_identical_types_and_binds_but_different_sql_are_cached_separately() {
235 let connection = connection();
236
237 sql_function!(fn lower(x: VarChar) -> VarChar);
238 let hi = "HI".into_sql::<VarChar>();
239 let query = ::select(hi).into_boxed::<Pg>();
240 let query2 = ::select(lower(hi)).into_boxed::<Pg>();
241
242 assert_eq!(0, connection.statement_cache.len());
243 assert_eq!(Ok("HI".to_string()), query.get_result(&connection));
244 assert_eq!(Ok("hi".to_string()), query2.get_result(&connection));
245 assert_eq!(2, connection.statement_cache.len());
246 }
247
248 #[test]
249 fn queries_with_sql_literal_nodes_are_not_cached() {
250 let connection = connection();
251 let query = ::select(sql::<Integer>("1"));
252
253 assert_eq!(Ok(1), query.get_result(&connection));
254 assert_eq!(0, connection.statement_cache.len());
255 }
256
257 fn connection() -> PgConnection {
258 dotenv().ok();
259 let database_url = env::var("PG_DATABASE_URL")
260 .or_else(|_| env::var("DATABASE_URL"))
261 .expect("DATABASE_URL must be set in order to run tests");
262 PgConnection::establish(&database_url).unwrap()
263 }
264}