diesel/pg/connection/
mod.rs

1mod cursor;
2pub mod raw;
3#[doc(hidden)]
4pub mod result;
5mod row;
6mod stmt;
7
8use std::ffi::CString;
9use std::os::raw as libc;
10
11use self::cursor::*;
12use self::raw::RawConnection;
13use self::result::PgResult;
14use self::stmt::Statement;
15use connection::*;
16use deserialize::{Queryable, QueryableByName};
17use pg::{Pg, PgMetadataLookup, TransactionBuilder};
18use query_builder::bind_collector::RawBytesBindCollector;
19use query_builder::*;
20use result::ConnectionError::CouldntSetupConfiguration;
21use result::*;
22use sql_types::HasSqlType;
23
24/// The connection string expected by `PgConnection::establish`
25/// should be a PostgreSQL connection string, as documented at
26/// <https://www.postgresql.org/docs/9.4/static/libpq-connect.html#LIBPQ-CONNSTRING>
27#[allow(missing_debug_implementations)]
28pub struct PgConnection {
29    raw_connection: RawConnection,
30    pub(crate) transaction_manager: AnsiTransactionManager,
31    statement_cache: StatementCache<Pg, Statement>,
32}
33
34unsafe impl Send for PgConnection {}
35
36impl SimpleConnection for PgConnection {
37    fn batch_execute(&self, query: &str) -> QueryResult<()> {
38        let query = CString::new(query)?;
39        let inner_result = unsafe { self.raw_connection.exec(query.as_ptr()) };
40        PgResult::new(inner_result?)?;
41        Ok(())
42    }
43}
44
45impl Connection for PgConnection {
46    type Backend = Pg;
47    type TransactionManager = AnsiTransactionManager;
48
49    fn establish(database_url: &str) -> ConnectionResult<PgConnection> {
50        RawConnection::establish(database_url).and_then(|raw_conn| {
51            let conn = PgConnection {
52                raw_connection: raw_conn,
53                transaction_manager: AnsiTransactionManager::new(),
54                statement_cache: StatementCache::new(),
55            };
56            conn.set_config_options()
57                .map_err(CouldntSetupConfiguration)?;
58            Ok(conn)
59        })
60    }
61
62    #[doc(hidden)]
63    fn execute(&self, query: &str) -> QueryResult<usize> {
64        self.execute_inner(query).map(|res| res.rows_affected())
65    }
66
67    #[doc(hidden)]
68    fn query_by_index<T, U>(&self, source: T) -> QueryResult<Vec<U>>
69    where
70        T: AsQuery,
71        T::Query: QueryFragment<Pg> + QueryId,
72        Pg: HasSqlType<T::SqlType>,
73        U: Queryable<T::SqlType, Pg>,
74    {
75        let (query, params) = self.prepare_query(&source.as_query())?;
76        query
77            .execute(&self.raw_connection, &params)
78            .and_then(|r| Cursor::new(r).collect())
79    }
80
81    #[doc(hidden)]
82    fn query_by_name<T, U>(&self, source: &T) -> QueryResult<Vec<U>>
83    where
84        T: QueryFragment<Pg> + QueryId,
85        U: QueryableByName<Pg>,
86    {
87        let (query, params) = self.prepare_query(source)?;
88        query
89            .execute(&self.raw_connection, &params)
90            .and_then(|r| NamedCursor::new(r).collect())
91    }
92
93    #[doc(hidden)]
94    fn execute_returning_count<T>(&self, source: &T) -> QueryResult<usize>
95    where
96        T: QueryFragment<Pg> + QueryId,
97    {
98        let (query, params) = self.prepare_query(source)?;
99        query
100            .execute(&self.raw_connection, &params)
101            .map(|r| r.rows_affected())
102    }
103
104    #[doc(hidden)]
105    fn transaction_manager(&self) -> &Self::TransactionManager {
106        &self.transaction_manager
107    }
108}
109
110impl PgConnection {
111    /// Build a transaction, specifying additional details such as isolation level
112    ///
113    /// See [`TransactionBuilder`] for more examples.
114    ///
115    /// [`TransactionBuilder`]: ../pg/struct.TransactionBuilder.html
116    ///
117    /// ```rust
118    /// # #[macro_use] extern crate diesel;
119    /// # include!("../../doctest_setup.rs");
120    /// #
121    /// # fn main() {
122    /// #     run_test().unwrap();
123    /// # }
124    /// #
125    /// # fn run_test() -> QueryResult<()> {
126    /// #     use schema::users::dsl::*;
127    /// #     let conn = connection_no_transaction();
128    /// conn.build_transaction()
129    ///     .read_only()
130    ///     .serializable()
131    ///     .deferrable()
132    ///     .run(|| Ok(()))
133    /// # }
134    /// ```
135    pub fn build_transaction(&self) -> TransactionBuilder {
136        TransactionBuilder::new(self)
137    }
138
139    #[allow(clippy::type_complexity)]
140    fn prepare_query<T: QueryFragment<Pg> + QueryId>(
141        &self,
142        source: &T,
143    ) -> QueryResult<(MaybeCached<Statement>, Vec<Option<Vec<u8>>>)> {
144        let mut bind_collector = RawBytesBindCollector::<Pg>::new();
145        source.collect_binds(&mut bind_collector, PgMetadataLookup::new(self))?;
146        let binds = bind_collector.binds;
147        let metadata = bind_collector.metadata;
148
149        let cache_len = self.statement_cache.len();
150        let query = self
151            .statement_cache
152            .cached_statement(source, &metadata, |sql| {
153                let query_name = if source.is_safe_to_cache_prepared()? {
154                    Some(format!("__diesel_stmt_{}", cache_len))
155                } else {
156                    None
157                };
158                Statement::prepare(
159                    &self.raw_connection,
160                    sql,
161                    query_name.as_ref().map(|s| &**s),
162                    &metadata,
163                )
164            });
165
166        Ok((query?, binds))
167    }
168
169    fn execute_inner(&self, query: &str) -> QueryResult<PgResult> {
170        let query = Statement::prepare(&self.raw_connection, query, None, &[])?;
171        query.execute(&self.raw_connection, &Vec::new())
172    }
173
174    fn set_config_options(&self) -> QueryResult<()> {
175        self.execute("SET TIME ZONE 'UTC'")?;
176        self.execute("SET CLIENT_ENCODING TO 'UTF8'")?;
177        self.raw_connection
178            .set_notice_processor(noop_notice_processor);
179        Ok(())
180    }
181}
182
183extern "C" fn noop_notice_processor(_: *mut libc::c_void, _message: *const libc::c_char) {}
184
185#[cfg(test)]
186mod tests {
187    extern crate dotenv;
188
189    use self::dotenv::dotenv;
190    use std::env;
191
192    use super::*;
193    use dsl::sql;
194    use prelude::*;
195    use sql_types::{Integer, VarChar};
196
197    #[test]
198    fn prepared_statements_are_cached() {
199        let connection = connection();
200
201        let query = ::select(1.into_sql::<Integer>());
202
203        assert_eq!(Ok(1), query.get_result(&connection));
204        assert_eq!(Ok(1), query.get_result(&connection));
205        assert_eq!(1, connection.statement_cache.len());
206    }
207
208    #[test]
209    fn queries_with_identical_sql_but_different_types_are_cached_separately() {
210        let connection = connection();
211
212        let query = ::select(1.into_sql::<Integer>());
213        let query2 = ::select("hi".into_sql::<VarChar>());
214
215        assert_eq!(Ok(1), query.get_result(&connection));
216        assert_eq!(Ok("hi".to_string()), query2.get_result(&connection));
217        assert_eq!(2, connection.statement_cache.len());
218    }
219
220    #[test]
221    fn queries_with_identical_types_and_sql_but_different_bind_types_are_cached_separately() {
222        let connection = connection();
223
224        let query = ::select(1.into_sql::<Integer>()).into_boxed::<Pg>();
225        let query2 = ::select("hi".into_sql::<VarChar>()).into_boxed::<Pg>();
226
227        assert_eq!(0, connection.statement_cache.len());
228        assert_eq!(Ok(1), query.get_result(&connection));
229        assert_eq!(Ok("hi".to_string()), query2.get_result(&connection));
230        assert_eq!(2, connection.statement_cache.len());
231    }
232
233    #[test]
234    fn queries_with_identical_types_and_binds_but_different_sql_are_cached_separately() {
235        let connection = connection();
236
237        sql_function!(fn lower(x: VarChar) -> VarChar);
238        let hi = "HI".into_sql::<VarChar>();
239        let query = ::select(hi).into_boxed::<Pg>();
240        let query2 = ::select(lower(hi)).into_boxed::<Pg>();
241
242        assert_eq!(0, connection.statement_cache.len());
243        assert_eq!(Ok("HI".to_string()), query.get_result(&connection));
244        assert_eq!(Ok("hi".to_string()), query2.get_result(&connection));
245        assert_eq!(2, connection.statement_cache.len());
246    }
247
248    #[test]
249    fn queries_with_sql_literal_nodes_are_not_cached() {
250        let connection = connection();
251        let query = ::select(sql::<Integer>("1"));
252
253        assert_eq!(Ok(1), query.get_result(&connection));
254        assert_eq!(0, connection.statement_cache.len());
255    }
256
257    fn connection() -> PgConnection {
258        dotenv().ok();
259        let database_url = env::var("PG_DATABASE_URL")
260            .or_else(|_| env::var("DATABASE_URL"))
261            .expect("DATABASE_URL must be set in order to run tests");
262        PgConnection::establish(&database_url).unwrap()
263    }
264}