diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 374e2a3..95cc6b6 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -2,6 +2,12 @@ use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{parse_macro_input, Data, DeriveInput, Fields}; +// TODO: wrap functions in a trait + +// TODO: add sync function (set fields of self by querying db) + +// TODO: doc comments + #[proc_macro_derive(Table)] pub fn derive_table(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input as DeriveInput); @@ -59,19 +65,13 @@ pub fn derive_table(input: TokenStream) -> TokenStream { panic!("Structs annotated with `Table` require a primary key field `id: Option`."); } + let create_table_sql = format!( "CREATE TABLE IF NOT EXISTS {} (id INTEGER PRIMARY KEY AUTOINCREMENT, {});", table_name, column_names.join(", ") ); - let insert_sql = format!( - "INSERT INTO {} (id, {}) VALUES ({});", - table_name, - column_names.join(", "), - vec!["?"; field_names.len()].join(", ") - ); - let create_table_fn = quote! { pub fn create_table(conn: &rusqlite::Connection) -> rusqlite::Result<()> where #(#to_sql_trait_bounds),* @@ -81,11 +81,18 @@ pub fn derive_table(input: TokenStream) -> TokenStream { } }; + + let insert_sql = format!( + "INSERT INTO {} (id, {}) VALUES ({});", + table_name, + column_names.join(", "), + vec!["?"; field_names.len()].join(", ") + ); + let insert_fn = quote! { pub fn insert(&mut self, conn: &rusqlite::Connection) -> rusqlite::Result where #(#to_sql_trait_bounds),* { - println!(#insert_sql); conn.execute(#insert_sql, rusqlite::params![#(#field_accessors),*])?; let id = conn.last_insert_rowid(); self.id = Some(id); @@ -93,30 +100,66 @@ pub fn derive_table(input: TokenStream) -> TokenStream { } }; + + let upsert_fn = quote! { + pub fn upsert(&mut self, conn: &rusqlite::Connection) -> rusqlite::Result + where #(#to_sql_trait_bounds),* + { + match self.id { + None => self.insert(conn), + Some(id) => { + if !self.update(conn)? { + return self.insert(conn); + } + Ok(id) + }, + } + } + }; + + + let update_sql = format!( + "UPDATE OR IGNORE {} SET ({}) = ({}) WHERE id = ?1", + table_name, + field_names.iter().map(|id| id.to_string()).collect::>().join(", "), + (1..=field_names.len()).map(|i| format!("?{}", i)).collect::>().join(", "), + ); + + let update_fn = quote! { + pub fn update(&mut self, conn: &rusqlite::Connection) -> rusqlite::Result + where #(#to_sql_trait_bounds),* + { + if self.id.is_none() { + return Ok(false); + } + let updated_count = conn.execute(#update_sql, rusqlite::params![#(#field_accessors),*])?; + Ok(updated_count > 0) + } + }; + + let get_by_id_fn = quote! { pub fn get_by_id(conn: &rusqlite::Connection, id: i64) -> rusqlite::Result> where Self: Sized, #(#from_sql_trait_bounds),* { - let mut stmt = conn.prepare(&format!( - "SELECT * FROM {} WHERE id = ?", - #table_name - ))?; + let mut stmt = conn.prepare(&format!("SELECT * FROM {} WHERE id = ?", #table_name))?; let mut rows = stmt.query(rusqlite::params![id])?; if let Some(row) = rows.next()? { - Ok(Some(Self { - #(#field_getters),* - })) + Ok(Some(Self { #(#field_getters),* })) } else { Ok(None) } } }; + let delete_fn = quote! { - pub fn delete(&mut self, conn: &rusqlite::Connection) -> rusqlite::Result { + pub fn delete(&mut self, conn: &rusqlite::Connection) -> rusqlite::Result + where #(#to_sql_trait_bounds),* + { if self.id.is_none() { return Ok(false); } @@ -139,16 +182,29 @@ pub fn derive_table(input: TokenStream) -> TokenStream { } }; + + let drop_table_fn = quote! { + pub fn drop_table(conn: &rusqlite::Connection) -> rusqlite::Result<()> { + conn.execute(&format!("DROP TABLE {}", #table_name), [])?; + Ok(()) + } + }; + + let expanded = quote! { impl #struct_name { #create_table_fn #insert_fn + #upsert_fn + #update_fn #get_by_id_fn #delete_fn #delete_by_id_fn + #drop_table_fn } }; - // dbg!(expanded.to_string()); + // if you want to see the generated code: + // println!("{}", expanded.to_string()); TokenStream::from(expanded) } diff --git a/src/test/macros.rs b/src/test/macros.rs index 1fa15fd..44724eb 100644 --- a/src/test/macros.rs +++ b/src/test/macros.rs @@ -43,6 +43,7 @@ fn test_table_derive_macro() { } } } + let mut larry = Person { id: None, name: String::from("larry"), @@ -55,17 +56,39 @@ fn test_table_derive_macro() { Person::create_table(&conn).unwrap(); larry.insert(&conn).unwrap(); - let larry_id = larry.id.unwrap(); + let larry_id = larry.id.expect("After (mutable) insertion, id should not be None"); - let larry_copy = Person::get_by_id(&conn, larry_id).unwrap(); - assert_eq!(larry_copy, Some(larry.clone())); + larry.age += 1; + let updated_something = larry.update(&conn).expect("Updating should work"); + assert!(updated_something, "Should have updated a row"); - let deleted_something = larry.delete(&conn).unwrap(); + let larry_copy = Person::get_by_id(&conn, larry_id).expect("Querying a row should work"); + assert_eq!(larry_copy, Some(larry.clone()), "Retrieving inserted row should give an identical row"); + + let deleted_something = larry.delete(&conn).expect("Deletion should work"); // also works: `Person::delete_by_id(&conn, larry_id).unwrap();` - assert!(deleted_something); + assert!(deleted_something, "Should have deleted something"); - let deleted_larry = Person::get_by_id(&conn, larry_id).unwrap(); - assert_eq!(deleted_larry, None); + let deleted_larry = Person::get_by_id(&conn, larry_id).expect("Querying a deleted row should return Ok(None), not Err(_)"); + assert_eq!(deleted_larry, None, "Received row that should have been deleted"); + + let id = larry.upsert(&conn).expect("Upsertion (insert) should work"); + let larry_id = larry.id.expect("After (mutable) upsertion, id should not be None"); + + let larry_copy = Person::get_by_id(&conn, larry_id).expect("Querying a row should work"); + assert_eq!(id, larry_id, "Upsert should return correct id"); + assert_eq!(larry_copy, Some(larry.clone()), "Retrieving upserted row should give an identical row"); + + larry.age += 1; + let id = larry.upsert(&conn).expect("Upsertion (update) should work"); + let larry_id = larry.id.expect("After (mutable) upsertion, id should not be None"); + assert_eq!(id, larry_id, "Upsert should return correct id"); + + let larry_copy = Person::get_by_id(&conn, larry_id).expect("Querying a row should work"); + assert_eq!(larry_copy, Some(larry.clone()), "Retrieving upserted row should give an identical row"); + + Person::drop_table(&conn).expect("Dropping table should work"); + Person::drop_table(&conn).expect_err("Dropping previously dropped table should err"); }