From 8bcb35fd7ce586087ddc112c721a6f0965d575ca Mon Sep 17 00:00:00 2001 From: mxhagen Date: Wed, 11 Dec 2024 21:29:10 +0100 Subject: [PATCH] add from_sql_row, rename update_or_insert also fix duplicate trait bounds --- macros/src/lib.rs | 58 +++++++++++++++++++++++++++++++--------------- src/lib.rs | 2 ++ src/test/macros.rs | 27 ++++++++++++++------- 3 files changed, 60 insertions(+), 27 deletions(-) diff --git a/macros/src/lib.rs b/macros/src/lib.rs index df7622d..6b3f08b 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -1,11 +1,13 @@ +use std::collections::HashMap; use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{parse_macro_input, Data, DeriveInput, Fields}; -// TODO: wrap functions in a trait +// TODO: wrap functions in a trait? would probably use the other crate // TODO: doc comments + #[proc_macro_derive(Table)] pub fn derive_table(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -29,8 +31,8 @@ pub fn derive_table(input: TokenStream) -> TokenStream { let mut field_getters = Vec::new(); let mut field_accessors = Vec::new(); - let mut to_sql_trait_bounds = Vec::new(); - let mut from_sql_trait_bounds = Vec::new(); + let mut to_sql_trait_bounds = HashMap::new(); + let mut from_sql_trait_bounds = HashMap::new(); for field in fields.named.iter() { let field_name = field.ident.as_ref().unwrap(); @@ -40,8 +42,8 @@ pub fn derive_table(input: TokenStream) -> TokenStream { field_getters.push(quote!(#field_name: row.get(stringify!(#field_name))?)); field_accessors.push(quote!(self.#field_name)); - to_sql_trait_bounds.push(quote!(#field_type: rusqlite::types::ToSql)); - from_sql_trait_bounds.push(quote!(#field_type: rusqlite::types::FromSql)); + to_sql_trait_bounds.insert(stringify!(#field_type), quote!(#field_type: rusqlite::types::ToSql)); + from_sql_trait_bounds.insert(stringify!(#field_type), quote!(#field_type: rusqlite::types::FromSql)); if field_name == "id" { if let syn::Type::Path(type_path) = field_type { @@ -59,6 +61,9 @@ pub fn derive_table(input: TokenStream) -> TokenStream { } } + let to_sql_trait_bounds = to_sql_trait_bounds.values().collect::>(); + let from_sql_trait_bounds = from_sql_trait_bounds.values().collect::>(); + if !field_names.iter().map(|id| id.to_string()).any(|id| &id == "id") { panic!("Structs annotated with `Table` require a primary key field `id: Option`."); } @@ -92,6 +97,7 @@ pub fn derive_table(input: TokenStream) -> TokenStream { where #(#to_sql_trait_bounds),* { conn.execute(#insert_sql, rusqlite::params![#(#field_accessors),*])?; + // TODO: test this with manually set id. also test that this can't update!!! let id = conn.last_insert_rowid(); self.id = Some(id); Ok(id) @@ -99,8 +105,8 @@ pub fn derive_table(input: TokenStream) -> TokenStream { }; - let upsert_fn = quote! { - pub fn upsert(&mut self, conn: &rusqlite::Connection) -> rusqlite::Result + let update_or_insert_fn = quote! { + pub fn update_or_insert(&mut self, conn: &rusqlite::Connection) -> rusqlite::Result where #(#to_sql_trait_bounds),* { match self.id { @@ -128,6 +134,7 @@ pub fn derive_table(input: TokenStream) -> TokenStream { where #(#to_sql_trait_bounds),* { if self.id.is_none() { + // TODO: bad design, should probably fail instead return Ok(false); } let updated_count = conn.execute(#update_sql, rusqlite::params![#(#field_accessors),*])?; @@ -143,17 +150,31 @@ pub fn derive_table(input: TokenStream) -> TokenStream { if self.id.is_none() { return Ok(false); } - match #struct_name::get_by_id(conn, self.id.unwrap())? { - Some(person) => *self = person, - _ => return Ok(false), - }; - Ok(true) + match #struct_name::get_by_id(conn, self.id.unwrap()) { + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(false), + Ok(person) => { + *self = person; + Ok(true) + }, + Err(e) => Err(e), + } + } + }; + + + let from_sql_row_fn = quote! { + pub fn from_sql_row(row: &rusqlite::Row) -> rusqlite::Result + where + Self: Sized, + #(#from_sql_trait_bounds),* + { + Ok(Self { #(#field_getters),* }) } }; let get_by_id_fn = quote! { - pub fn get_by_id(conn: &rusqlite::Connection, id: i64) -> rusqlite::Result> + pub fn get_by_id(conn: &rusqlite::Connection, id: i64) -> rusqlite::Result where Self: Sized, #(#from_sql_trait_bounds),* @@ -162,18 +183,16 @@ pub fn derive_table(input: TokenStream) -> TokenStream { let mut rows = stmt.query(rusqlite::params![id])?; if let Some(row) = rows.next()? { - Ok(Some(Self { #(#field_getters),* })) + Self::from_sql_row(row) } else { - Ok(None) + Err(rusqlite::Error::QueryReturnedNoRows) } } }; let delete_fn = quote! { - pub fn delete(&mut self, conn: &rusqlite::Connection) -> rusqlite::Result - where #(#to_sql_trait_bounds),* - { + pub fn delete(&mut self, conn: &rusqlite::Connection) -> rusqlite::Result { if self.id.is_none() { return Ok(false); } @@ -209,9 +228,10 @@ pub fn derive_table(input: TokenStream) -> TokenStream { impl #struct_name { #create_table_fn #insert_fn - #upsert_fn + #update_or_insert_fn #update_fn #sync_fn + #from_sql_row_fn #get_by_id_fn #delete_fn #delete_by_id_fn diff --git a/src/lib.rs b/src/lib.rs index 97375f7..e0a8b15 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,2 +1,4 @@ pub use squail_macros::*; mod test; + +pub trait SquailTable {} diff --git a/src/test/macros.rs b/src/test/macros.rs index 7669ccc..4883c43 100644 --- a/src/test/macros.rs +++ b/src/test/macros.rs @@ -73,7 +73,7 @@ fn test_table_derive_macro() { let larry_copy = Person::get_by_id(&conn, larry_id).expect("Querying a row should work"); assert_eq!( larry_copy, - Some(larry.clone()), + larry.clone(), "Retrieving inserted row should give an identical row" ); @@ -81,14 +81,14 @@ fn test_table_derive_macro() { // also works: `Person::delete_by_id(&conn, larry_id).unwrap();` assert!(deleted_something, "Should have deleted something"); - let deleted_larry = Person::get_by_id(&conn, larry_id) - .expect("Querying a deleted row should return Ok(None), not Err(_)"); + let err = Person::get_by_id(&conn, larry_id) + .expect_err("Querying a deleted row should Err(QueryReturnedNoRows)"); assert_eq!( - deleted_larry, None, + err, rusqlite::Error::QueryReturnedNoRows, "Received row that should have been deleted" ); - let id = larry.upsert(&conn).expect("Upsertion (insert) should work"); + let id = larry.update_or_insert(&conn).expect("Upsertion (insert) should work"); let larry_id = larry .id .expect("After (mutable) upsertion, id should not be None"); @@ -97,12 +97,12 @@ fn test_table_derive_macro() { assert_eq!(id, larry_id, "Upsert should return correct id"); assert_eq!( larry_copy, - Some(larry.clone()), + larry.clone(), "Retrieving upserted row should give an identical row" ); larry.age += 1; - let id = larry.upsert(&conn).expect("Upsertion (update) should work"); + let id = larry.update_or_insert(&conn).expect("Upsertion (update) should work"); let larry_id = larry .id .expect("After (mutable) upsertion, id should not be None"); @@ -111,7 +111,7 @@ fn test_table_derive_macro() { let larry_copy = Person::get_by_id(&conn, larry_id).expect("Querying a row should work"); assert_eq!( larry_copy, - Some(larry.clone()), + larry.clone(), "Retrieving upserted row should give an identical row" ); @@ -138,6 +138,17 @@ fn test_table_derive_macro() { ) .expect("Explicit Sqlite statement (not a library test) failed"); assert!(!exists, "Deleted table should not exist anymore but does"); + + + /// Another example struct to assure this works more than once + #[derive(Table, Default)] + #[allow(unused)] + struct Car { + id: Option, + name: String, + brand: String, + year: i32, + } } // TODO: implement compile-error test(s) -- perhaps with `trybuild`?