@@ -4,7 +4,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
44use std:: sync:: Arc ;
55
66use parking_lot:: { Mutex , RwLock } ;
7- use rusqlite:: { DatabaseName , ErrorCode , OpenFlags , StatementStatus } ;
7+ use rusqlite:: { DatabaseName , ErrorCode , OpenFlags , StatementStatus , TransactionState } ;
88use sqld_libsql_bindings:: wal_hook:: { TransparentMethods , WalMethodsHook } ;
99use tokio:: sync:: { watch, Notify } ;
1010use tokio:: time:: { Duration , Instant } ;
@@ -144,7 +144,6 @@ where
144144 }
145145}
146146
147- #[ derive( Clone ) ]
148147pub struct LibSqlConnection < W : WalHook > {
149148 inner : Arc < Mutex < Connection < W > > > ,
150149}
@@ -160,6 +159,12 @@ impl<W: WalHook> std::fmt::Debug for LibSqlConnection<W> {
160159 }
161160}
162161
162+ impl < W : WalHook > Clone for LibSqlConnection < W > {
163+ fn clone ( & self ) -> Self {
164+ Self { inner : self . inner . clone ( ) }
165+ }
166+ }
167+
163168pub fn open_conn < W > (
164169 path : & Path ,
165170 wal_methods : & ' static WalMethodsHook < W > ,
@@ -219,6 +224,15 @@ where
219224 inner : Arc :: new ( Mutex :: new ( conn) ) ,
220225 } )
221226 }
227+
228+ pub fn txn_status ( & self ) -> crate :: Result < TxnStatus > {
229+ Ok ( self
230+ . inner
231+ . lock ( )
232+ . conn
233+ . transaction_state ( Some ( DatabaseName :: Main ) ) ?
234+ . into ( ) )
235+ }
222236}
223237
224238struct Connection < W : WalHook = TransparentMethods > {
@@ -351,6 +365,16 @@ unsafe extern "C" fn busy_handler<W: WalHook>(state: *mut c_void, _retries: c_in
351365 } )
352366}
353367
368+ impl From < TransactionState > for TxnStatus {
369+ fn from ( value : TransactionState ) -> Self {
370+ use TransactionState as Tx ;
371+ match value {
372+ Tx :: None => TxnStatus :: Init ,
373+ Tx :: Read | Tx :: Write => TxnStatus :: Txn ,
374+ _ => unreachable ! ( ) ,
375+ }
376+ }
377+ }
354378impl < W : WalHook > Connection < W > {
355379 fn new (
356380 path : & Path ,
@@ -405,7 +429,7 @@ impl<W: WalHook> Connection<W> {
405429 this : Arc < Mutex < Self > > ,
406430 pgm : Program ,
407431 mut builder : B ,
408- ) -> Result < ( B , TxnStatus ) > {
432+ ) -> Result < B > {
409433 use rusqlite:: TransactionState as Tx ;
410434
411435 let state = this. lock ( ) . state . clone ( ) ;
@@ -469,23 +493,18 @@ impl<W: WalHook> Connection<W> {
469493 results. push ( res) ;
470494 }
471495
472- let status = if matches ! (
473- this. lock( )
474- . conn
475- . transaction_state( Some ( DatabaseName :: Main ) ) ?,
476- Tx :: Read | Tx :: Write
477- ) {
478- TxnStatus :: Txn
479- } else {
480- TxnStatus :: Init
481- } ;
496+ let status = this
497+ . lock ( )
498+ . conn
499+ . transaction_state ( Some ( DatabaseName :: Main ) ) ?
500+ . into ( ) ;
482501
483502 builder. finish (
484503 * this. lock ( ) . current_frame_no_receiver . borrow_and_update ( ) ,
485504 status,
486505 ) ?;
487506
488- Ok ( ( builder, status ) )
507+ Ok ( builder)
489508 }
490509
491510 fn execute_step (
@@ -736,7 +755,7 @@ where
736755 auth : Authenticated ,
737756 builder : B ,
738757 _replication_index : Option < FrameNo > ,
739- ) -> Result < ( B , TxnStatus ) > {
758+ ) -> Result < B > {
740759 check_program_auth ( auth, & pgm) ?;
741760 let conn = self . inner . clone ( ) ;
742761 tokio:: task:: spawn_blocking ( move || Connection :: run ( conn, pgm, builder) )
@@ -828,7 +847,7 @@ mod test {
828847 fn test_libsql_conn_builder_driver ( ) {
829848 test_driver ( 1000 , |b| {
830849 let conn = setup_test_conn ( ) ;
831- Connection :: run ( conn, Program :: seq ( & [ "select * from test" ] ) , b) . map ( |x| x . 0 )
850+ Connection :: run ( conn, Program :: seq ( & [ "select * from test" ] ) , b)
832851 } )
833852 }
834853
@@ -852,23 +871,23 @@ mod test {
852871
853872 tokio:: time:: pause ( ) ;
854873 let conn = make_conn. make_connection ( ) . await . unwrap ( ) ;
855- let ( _builder, state ) = Connection :: run (
874+ let _builder = Connection :: run (
856875 conn. inner . clone ( ) ,
857876 Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
858877 TestBuilder :: default ( ) ,
859878 )
860879 . unwrap ( ) ;
861- assert_eq ! ( state , TxnStatus :: Txn ) ;
880+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Txn ) ;
862881
863882 tokio:: time:: advance ( TXN_TIMEOUT * 2 ) . await ;
864883
865- let ( builder, state ) = Connection :: run (
884+ let builder = Connection :: run (
866885 conn. inner . clone ( ) ,
867886 Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
868887 TestBuilder :: default ( ) ,
869888 )
870889 . unwrap ( ) ;
871- assert_eq ! ( state , TxnStatus :: Init ) ;
890+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Init ) ;
872891 assert ! ( matches!( builder. into_ret( ) [ 0 ] , Err ( Error :: LibSqlTxTimeout ) ) ) ;
873892 }
874893
@@ -896,13 +915,13 @@ mod test {
896915 for _ in 0 ..10 {
897916 let conn = make_conn. make_connection ( ) . await . unwrap ( ) ;
898917 set. spawn_blocking ( move || {
899- let ( builder, state ) = Connection :: run (
900- conn. inner ,
918+ let builder = Connection :: run (
919+ conn. inner . clone ( ) ,
901920 Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
902921 TestBuilder :: default ( ) ,
903922 )
904923 . unwrap ( ) ;
905- assert_eq ! ( state , TxnStatus :: Txn ) ;
924+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Txn ) ;
906925 assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
907926 } ) ;
908927 }
@@ -937,15 +956,15 @@ mod test {
937956
938957 let conn1 = make_conn. make_connection ( ) . await . unwrap ( ) ;
939958 tokio:: task:: spawn_blocking ( {
940- let conn = conn1. inner . clone ( ) ;
959+ let conn = conn1. clone ( ) ;
941960 move || {
942- let ( builder, state ) = Connection :: run (
943- conn,
961+ let builder = Connection :: run (
962+ conn. inner . clone ( ) ,
944963 Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
945964 TestBuilder :: default ( ) ,
946965 )
947966 . unwrap ( ) ;
948- assert_eq ! ( state , TxnStatus :: Txn ) ;
967+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Txn ) ;
949968 assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
950969 }
951970 } )
@@ -954,16 +973,16 @@ mod test {
954973
955974 let conn2 = make_conn. make_connection ( ) . await . unwrap ( ) ;
956975 let handle = tokio:: task:: spawn_blocking ( {
957- let conn = conn2. inner . clone ( ) ;
976+ let conn = conn2. clone ( ) ;
958977 move || {
959978 let before = Instant :: now ( ) ;
960- let ( builder, state ) = Connection :: run (
961- conn,
979+ let builder = Connection :: run (
980+ conn. inner . clone ( ) ,
962981 Program :: seq ( & [ "BEGIN IMMEDIATE" ] ) ,
963982 TestBuilder :: default ( ) ,
964983 )
965984 . unwrap ( ) ;
966- assert_eq ! ( state , TxnStatus :: Txn ) ;
985+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Txn ) ;
967986 assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
968987 before. elapsed ( )
969988 }
@@ -973,12 +992,12 @@ mod test {
973992 tokio:: time:: sleep ( wait_time) . await ;
974993
975994 tokio:: task:: spawn_blocking ( {
976- let conn = conn1. inner . clone ( ) ;
995+ let conn = conn1. clone ( ) ;
977996 move || {
978- let ( builder, state ) =
979- Connection :: run ( conn, Program :: seq ( & [ "COMMIT" ] ) , TestBuilder :: default ( ) )
997+ let builder =
998+ Connection :: run ( conn. inner . clone ( ) , Program :: seq ( & [ "COMMIT" ] ) , TestBuilder :: default ( ) )
980999 . unwrap ( ) ;
981- assert_eq ! ( state , TxnStatus :: Init ) ;
1000+ assert_eq ! ( conn . txn_status ( ) . unwrap ( ) , TxnStatus :: Init ) ;
9821001 assert ! ( builder. into_ret( ) [ 0 ] . is_ok( ) ) ;
9831002 }
9841003 } )
0 commit comments