diff --git a/payjoin-cli/src/app/config.rs b/payjoin-cli/src/app/config.rs index f0cd22a0b..f671f3de1 100644 --- a/payjoin-cli/src/app/config.rs +++ b/payjoin-cli/src/app/config.rs @@ -338,6 +338,10 @@ fn handle_subcommands(config: Builder, cli: &Cli) -> Result Ok(config), #[cfg(feature = "v2")] Commands::Fallback { .. } => Ok(config), + #[cfg(feature = "v2")] + Commands::Cancel { .. } => Ok(config), + #[cfg(feature = "v2")] + Commands::CancelWithoutBroadcast { .. } => Ok(config), } } diff --git a/payjoin-cli/src/app/mod.rs b/payjoin-cli/src/app/mod.rs index 48499bebe..0b1f8c246 100644 --- a/payjoin-cli/src/app/mod.rs +++ b/payjoin-cli/src/app/mod.rs @@ -32,6 +32,10 @@ pub trait App: Send + Sync { async fn history(&self) -> Result<()>; #[cfg(feature = "v2")] async fn fallback_sender(&self, session_id: SessionId) -> Result<()>; + #[cfg(feature = "v2")] + async fn cancel_receiver(&self, session_id: SessionId) -> Result<()>; + #[cfg(feature = "v2")] + async fn cancel_receiver_without_broadcast(&self, session_id: SessionId) -> Result<()>; fn create_original_psbt( &self, diff --git a/payjoin-cli/src/app/v1.rs b/payjoin-cli/src/app/v1.rs index 5d98ddc26..cc14c08c3 100644 --- a/payjoin-cli/src/app/v1.rs +++ b/payjoin-cli/src/app/v1.rs @@ -132,6 +132,19 @@ impl AppTrait for App { async fn fallback_sender(&self, _session_id: crate::db::v2::SessionId) -> Result<()> { anyhow::bail!("fallback is only supported for v2 (BIP77) sessions") } + + #[cfg(feature = "v2")] + async fn cancel_receiver(&self, _session_id: crate::db::v2::SessionId) -> Result<()> { + anyhow::bail!("receiver cancellation is only supported for v2 (BIP77) sessions") + } + + #[cfg(feature = "v2")] + async fn cancel_receiver_without_broadcast( + &self, + _session_id: crate::db::v2::SessionId, + ) -> Result<()> { + anyhow::bail!("receiver cancellation is only supported for v2 (BIP77) sessions") + } } impl App { diff --git a/payjoin-cli/src/app/v2/mod.rs b/payjoin-cli/src/app/v2/mod.rs index 82fb87dcc..62231a939 100644 --- a/payjoin-cli/src/app/v2/mod.rs +++ b/payjoin-cli/src/app/v2/mod.rs @@ -1,13 +1,14 @@ use std::fmt; +use std::io::{self, Write}; use std::sync::{Arc, Mutex}; use anyhow::{anyhow, Context, Result}; use payjoin::bitcoin::consensus::encode::serialize_hex; -use payjoin::bitcoin::{Amount, FeeRate}; +use payjoin::bitcoin::{Amount, FeeRate, Transaction}; use payjoin::persist::{OptionalTransitionOutcome, SessionPersister}; use payjoin::receive::v2::{ replay_event_log as replay_receiver_event_log, HasReplyableError, Initialized, - MaybeInputsOwned, MaybeInputsSeen, Monitor, OutputsUnknown, PayjoinProposal, + MaybeInputsOwned, MaybeInputsSeen, Monitor, OutputsUnknown, PayjoinProposal, PendingFallback, ProvisionalProposal, ReceiveSession, Receiver, ReceiverBuilder, SessionOutcome as ReceiverSessionOutcome, UncheckedOriginalPayload, WantsFeeRange, WantsInputs, WantsOutputs, @@ -77,6 +78,7 @@ impl StatusText for ReceiveSession { ReceiveSession::HasReplyableError(_) => "Session failure, waiting to post error response", ReceiveSession::Monitor(_) => "Monitoring payjoin proposal", + ReceiveSession::PendingFallback(_) => "Original transaction awaiting fallback decision", ReceiveSession::Closed(session_outcome) => match session_outcome { ReceiverSessionOutcome::Failure => "Session failure", ReceiverSessionOutcome::Success(_) => "Session success, Payjoin proposal was broadcasted", @@ -109,6 +111,12 @@ impl Role { } } +#[derive(Clone, Copy)] +enum FallbackHandling { + Prompt, + CloseWithoutBroadcast, +} + struct SessionHistoryRow { session_id: SessionId, role: Role, @@ -315,11 +323,22 @@ impl AppTrait for App { let self_clone = self.clone(); let recv_persister = ReceiverPersister::from_id(self.db.clone(), session_id.clone()); match replay_receiver_event_log(&recv_persister) { - Ok((receiver_state, _)) => { - tasks.push(tokio::spawn(async move { - self_clone.process_receiver_session(receiver_state, &recv_persister).await - })); + Ok((ReceiveSession::PendingFallback(pending_fallback), _)) => { + if let Err(e) = self.complete_pending_fallback( + pending_fallback, + &recv_persister, + FallbackHandling::Prompt, + ) { + tracing::error!( + "An error {:?} occurred while handling receiver session {}", + e, + session_id + ); + } } + Ok((receiver_state, _)) => tasks.push(tokio::spawn(async move { + self_clone.process_receiver_session(receiver_state, &recv_persister).await + })), Err(e) => { tracing::error!("An error {:?} occurred while replaying receiver session", e); Self::close_failed_session(&recv_persister, &session_id, "receiver"); @@ -509,6 +528,15 @@ impl AppTrait for App { } Ok(()) } + + async fn cancel_receiver(&self, session_id: SessionId) -> Result<()> { + self.cancel_receiver_with_handling(session_id, FallbackHandling::Prompt).await + } + + async fn cancel_receiver_without_broadcast(&self, session_id: SessionId) -> Result<()> { + self.cancel_receiver_with_handling(session_id, FallbackHandling::CloseWithoutBroadcast) + .await + } } impl App { @@ -523,6 +551,95 @@ impl App { } } + async fn cancel_receiver_with_handling( + &self, + session_id: SessionId, + handling: FallbackHandling, + ) -> Result<()> { + let persister = ReceiverPersister::from_id(self.db.clone(), session_id.clone()); + let (session, _) = replay_receiver_event_log(&persister)?; + + macro_rules! cancel_to_pending_fallback { + ($state:expr) => {{ + let pending_fallback = $state.cancel().save(&persister)?; + self.complete_pending_fallback(pending_fallback, &persister, handling) + }}; + } + + match session { + ReceiveSession::Initialized(state) => { + state.cancel().save(&persister)?; + println!("Receiver session {session_id} cancelled."); + Ok(()) + } + ReceiveSession::UncheckedOriginalPayload(state) => { + state.cancel().save(&persister)?; + println!("Receiver session {session_id} cancelled."); + Ok(()) + } + ReceiveSession::MaybeInputsOwned(state) => cancel_to_pending_fallback!(state), + ReceiveSession::MaybeInputsSeen(state) => cancel_to_pending_fallback!(state), + ReceiveSession::OutputsUnknown(state) => cancel_to_pending_fallback!(state), + ReceiveSession::WantsOutputs(state) => cancel_to_pending_fallback!(state), + ReceiveSession::WantsInputs(state) => cancel_to_pending_fallback!(state), + ReceiveSession::WantsFeeRange(state) => cancel_to_pending_fallback!(state), + ReceiveSession::ProvisionalProposal(state) => cancel_to_pending_fallback!(state), + ReceiveSession::PayjoinProposal(state) => cancel_to_pending_fallback!(state), + ReceiveSession::Monitor(state) => cancel_to_pending_fallback!(state), + ReceiveSession::HasReplyableError(state) => match state.cancel().save(&persister)? { + Some(pending_fallback) => + self.complete_pending_fallback(pending_fallback, &persister, handling), + None => { + println!("Receiver session {session_id} cancelled."); + Ok(()) + } + }, + ReceiveSession::PendingFallback(pending_fallback) => + self.complete_pending_fallback(pending_fallback, &persister, handling), + ReceiveSession::Closed(outcome) => { + println!("Receiver session {session_id} is already closed: {outcome:?}"); + Ok(()) + } + } + } + + fn complete_pending_fallback( + &self, + pending_fallback: Receiver, + persister: &ReceiverPersister, + handling: FallbackHandling, + ) -> Result<()> { + let should_broadcast = match handling { + FallbackHandling::Prompt => + self.prompt_broadcast_fallback(pending_fallback.fallback_tx())?, + FallbackHandling::CloseWithoutBroadcast => false, + }; + + if should_broadcast { + let txid = self.wallet().broadcast_tx(pending_fallback.fallback_tx())?; + println!("Broadcasted fallback transaction txid: {txid}"); + } else { + println!("Closing receiver session without broadcasting the fallback transaction."); + } + + pending_fallback.close().save(persister)?; + Ok(()) + } + + fn prompt_broadcast_fallback(&self, fallback_tx: &Transaction) -> Result { + println!( + "Original transaction is pending fallback handling. TXID: {}", + fallback_tx.compute_txid() + ); + print!("Broadcast the original transaction before closing? [Y/n]: "); + io::stdout().flush()?; + + let mut answer = String::new(); + io::stdin().read_line(&mut answer)?; + let answer = answer.trim().to_ascii_lowercase(); + Ok(!matches!(answer.as_str(), "n" | "no" | "c" | "close")) + } + async fn process_sender_session( &self, session: SendSession, @@ -656,6 +773,14 @@ impl App { self.handle_error(error, persister).await, ReceiveSession::Monitor(proposal) => self.monitor_payjoin_proposal(proposal, persister).await, + ReceiveSession::PendingFallback(pending_fallback) => { + self.complete_pending_fallback( + pending_fallback, + persister, + FallbackHandling::Prompt, + )?; + Ok(()) + } ReceiveSession::Closed(_) => return Err(anyhow!("Session closed")), } }; @@ -807,7 +932,23 @@ impl App { .map_err(|e| anyhow!("v2 req extraction failed {}", e))?; let res = self.post_request(req).await?; let payjoin_psbt = proposal.psbt().clone(); - let session = proposal.process_response(&res.bytes().await?, ohttp_ctx).save(persister)?; + let session = + match proposal.process_response(&res.bytes().await?, ohttp_ctx).save(persister) { + Ok(session) => session, + Err(e) => { + let message = e.to_string(); + if let Some(pending_fallback) = e.error_state() { + println!("Payjoin proposal post failed: {message}"); + self.complete_pending_fallback( + pending_fallback, + persister, + FallbackHandling::Prompt, + )?; + return Ok(()); + } + return Err(anyhow!("Failed to process payjoin proposal response: {message}")); + } + }; println!( "Response successful. Watch mempool for successful Payjoin. TXID: {}", payjoin_psbt.extract_tx_unchecked_fee_rate().compute_txid() @@ -901,11 +1042,15 @@ impl App { Err(e) => return Err(anyhow!("Failed to get error response bytes: {}", e)), }; - if let Err(e) = session.process_error_response(&err_bytes, err_ctx).save(persister) { - return Err(anyhow!("Failed to process error response: {}", e)); + match session.process_error_response(&err_bytes, err_ctx).save(persister) { + Ok(Some(pending_fallback)) => self.complete_pending_fallback( + pending_fallback, + persister, + FallbackHandling::Prompt, + ), + Ok(None) => Ok(()), + Err(e) => Err(anyhow!("Failed to process error response: {}", e)), } - - Ok(()) } async fn post_request(&self, req: payjoin::Request) -> Result { diff --git a/payjoin-cli/src/cli/mod.rs b/payjoin-cli/src/cli/mod.rs index 7cef2551b..17dc70919 100644 --- a/payjoin-cli/src/cli/mod.rs +++ b/payjoin-cli/src/cli/mod.rs @@ -139,6 +139,21 @@ pub enum Commands { #[arg(required = true)] session_id: i64, }, + #[cfg(feature = "v2")] + /// Cancel a receiver session and prompt for fallback handling + Cancel { + /// The receiver session ID to cancel + #[arg(required = true)] + session_id: i64, + }, + #[cfg(feature = "v2")] + /// Cancel a receiver session without broadcasting the fallback transaction + #[command(name = "cancel-without-broadcast")] + CancelWithoutBroadcast { + /// The receiver session ID to cancel + #[arg(required = true)] + session_id: i64, + }, } pub fn parse_amount_in_sat(s: &str) -> Result { diff --git a/payjoin-cli/src/main.rs b/payjoin-cli/src/main.rs index 6b2b038f4..697ac6c52 100644 --- a/payjoin-cli/src/main.rs +++ b/payjoin-cli/src/main.rs @@ -82,6 +82,14 @@ async fn main() -> Result<()> { Commands::Fallback { session_id } => { app.fallback_sender(SessionId(*session_id)).await?; } + #[cfg(feature = "v2")] + Commands::Cancel { session_id } => { + app.cancel_receiver(SessionId(*session_id)).await?; + } + #[cfg(feature = "v2")] + Commands::CancelWithoutBroadcast { session_id } => { + app.cancel_receiver_without_broadcast(SessionId(*session_id)).await?; + } }; Ok(()) diff --git a/payjoin-ffi/javascript/wasm-manifest-patch.toml b/payjoin-ffi/javascript/wasm-manifest-patch.toml index 3717d63dd..fbc89c3c6 100644 --- a/payjoin-ffi/javascript/wasm-manifest-patch.toml +++ b/payjoin-ffi/javascript/wasm-manifest-patch.toml @@ -20,3 +20,6 @@ features = ["wasm-unstable-single-threaded"] payjoin = { path = "../../../../payjoin" } payjoin-mailroom = { path = "../../../../payjoin-mailroom" } payjoin-test-utils = { path = "../../../../payjoin-test-utils" } + +[patch."https://github.com/payjoin/rust-payjoin.git"] +payjoin = { path = "../../../../payjoin" } diff --git a/payjoin-ffi/src/receive/mod.rs b/payjoin-ffi/src/receive/mod.rs index c0b2a1079..75e6b5aad 100644 --- a/payjoin-ffi/src/receive/mod.rs +++ b/payjoin-ffi/src/receive/mod.rs @@ -65,87 +65,156 @@ macro_rules! impl_save_for_transition { }; } -/// A terminal transition produced by cancelling a receiver session. +enum ReceiverCancelTransition { + Terminal(payjoin::persist::TerminalTransition), + PendingFallback( + payjoin::persist::NextStateTransition< + payjoin::receive::v2::SessionEvent, + payjoin::receive::v2::Receiver, + >, + ), + MaybePendingFallback( + payjoin::persist::MaybeTerminalTransition< + payjoin::receive::v2::SessionEvent, + payjoin::receive::v2::Receiver, + >, + ), +} + +/// A transition produced by cancelling a receiver session. #[derive(uniffi::Object)] pub struct CancelTransition { - transition: RwLock< - Option< - payjoin::persist::TerminalTransition< - payjoin::receive::v2::SessionEvent, - Option, - >, - >, - >, + transition: RwLock>, } #[uniffi::export] impl CancelTransition { - /// Persist the cancellation and return the fallback transaction if available. - /// - /// The fallback transaction is the consensus-encoded raw transaction bytes, - /// or `None` if the session was cancelled before the sender's original - /// proposal was received. + /// Persist the cancellation and return pending fallback handling if needed. pub fn save( &self, persister: Arc, - ) -> Result>, ReceiverPersistedError> { + ) -> Result>, ReceiverPersistedError> { let adapter = CallbackPersisterAdapter::new(persister); let mut inner = self.transition.write().expect("Lock should not be poisoned"); let value = inner.take().expect("Already saved or moved"); - let fallback = value - .save(&adapter) - .map_err(|e| ReceiverPersistedError::from(ImplementationError::new(e)))?; - Ok(fallback.map(|tx| payjoin::bitcoin::consensus::serialize(&tx))) + match value { + ReceiverCancelTransition::Terminal(transition) => { + transition + .save(&adapter) + .map_err(|e| ReceiverPersistedError::from(ImplementationError::new(e)))?; + Ok(None) + } + ReceiverCancelTransition::PendingFallback(transition) => { + let pending_fallback = transition + .save(&adapter) + .map_err(|e| ReceiverPersistedError::from(ImplementationError::new(e)))?; + Ok(Some(Arc::new(pending_fallback.into()))) + } + ReceiverCancelTransition::MaybePendingFallback(transition) => { + let pending_fallback = transition + .save(&adapter) + .map_err(|e| ReceiverPersistedError::from(ImplementationError::new(e)))?; + Ok(pending_fallback.map(|pending_fallback| Arc::new(pending_fallback.into()))) + } + } } pub async fn save_async( &self, persister: Arc, - ) -> Result>, ReceiverPersistedError> { + ) -> Result>, ReceiverPersistedError> { let adapter = AsyncCallbackPersisterAdapter::new(persister); let value = { let mut inner = self.transition.write().expect("Lock should not be poisoned"); inner.take().expect("Already saved or moved") }; - let fallback = value - .save_async(&adapter) - .await - .map_err(|e| ReceiverPersistedError::from(ImplementationError::new(e)))?; - Ok(fallback.map(|tx| payjoin::bitcoin::consensus::serialize(&tx))) + match value { + ReceiverCancelTransition::Terminal(transition) => { + transition + .save_async(&adapter) + .await + .map_err(|e| ReceiverPersistedError::from(ImplementationError::new(e)))?; + Ok(None) + } + ReceiverCancelTransition::PendingFallback(transition) => { + let pending_fallback = transition + .save_async(&adapter) + .await + .map_err(|e| ReceiverPersistedError::from(ImplementationError::new(e)))?; + Ok(Some(Arc::new(pending_fallback.into()))) + } + ReceiverCancelTransition::MaybePendingFallback(transition) => { + let pending_fallback = transition + .save_async(&adapter) + .await + .map_err(|e| ReceiverPersistedError::from(ImplementationError::new(e)))?; + Ok(pending_fallback.map(|pending_fallback| Arc::new(pending_fallback.into()))) + } + } } } -macro_rules! impl_cancel_for_receiver { +macro_rules! impl_terminal_cancel_for_receiver { ($ty:ident) => { #[uniffi::export] impl $ty { /// Cancel the Payjoin session immediately. - /// - /// Returns a [`CancelTransition`] that, once persisted, yields the fallback - /// transaction when applicable. The fallback transaction is the sender's original - /// transaction that should be broadcast to complete the payment without Payjoin. - /// - /// This is a terminal transition — the session cannot be used after cancellation. pub fn cancel(&self) -> CancelTransition { let transition = self.0.clone().cancel(); - CancelTransition { transition: RwLock::new(Some(transition)) } + CancelTransition { + transition: RwLock::new(Some(ReceiverCancelTransition::Terminal(transition))), + } + } + } + }; +} + +macro_rules! impl_pending_fallback_cancel_for_receiver { + ($ty:ident) => { + #[uniffi::export] + impl $ty { + /// Cancel the Payjoin session and return pending fallback handling. + pub fn cancel(&self) -> CancelTransition { + let transition = self.0.clone().cancel(); + CancelTransition { + transition: RwLock::new(Some(ReceiverCancelTransition::PendingFallback( + transition, + ))), + } } } }; } -impl_cancel_for_receiver!(Initialized); -impl_cancel_for_receiver!(UncheckedOriginalPayload); -impl_cancel_for_receiver!(MaybeInputsOwned); -impl_cancel_for_receiver!(MaybeInputsSeen); -impl_cancel_for_receiver!(OutputsUnknown); -impl_cancel_for_receiver!(WantsOutputs); -impl_cancel_for_receiver!(WantsInputs); -impl_cancel_for_receiver!(WantsFeeRange); -impl_cancel_for_receiver!(ProvisionalProposal); -impl_cancel_for_receiver!(PayjoinProposal); -impl_cancel_for_receiver!(HasReplyableError); -impl_cancel_for_receiver!(Monitor); +macro_rules! impl_maybe_pending_fallback_cancel_for_receiver { + ($ty:ident) => { + #[uniffi::export] + impl $ty { + /// Cancel the Payjoin session. + pub fn cancel(&self) -> CancelTransition { + let transition = self.0.clone().cancel(); + CancelTransition { + transition: RwLock::new(Some(ReceiverCancelTransition::MaybePendingFallback( + transition, + ))), + } + } + } + }; +} + +impl_terminal_cancel_for_receiver!(Initialized); +impl_terminal_cancel_for_receiver!(UncheckedOriginalPayload); +impl_pending_fallback_cancel_for_receiver!(MaybeInputsOwned); +impl_pending_fallback_cancel_for_receiver!(MaybeInputsSeen); +impl_pending_fallback_cancel_for_receiver!(OutputsUnknown); +impl_pending_fallback_cancel_for_receiver!(WantsOutputs); +impl_pending_fallback_cancel_for_receiver!(WantsInputs); +impl_pending_fallback_cancel_for_receiver!(WantsFeeRange); +impl_pending_fallback_cancel_for_receiver!(ProvisionalProposal); +impl_pending_fallback_cancel_for_receiver!(PayjoinProposal); +impl_maybe_pending_fallback_cancel_for_receiver!(HasReplyableError); +impl_pending_fallback_cancel_for_receiver!(Monitor); #[derive(Debug, Clone, uniffi::Object)] pub struct ReceiverSessionEvent(payjoin::receive::v2::SessionEvent); @@ -198,6 +267,7 @@ pub enum ReceiveSession { PayjoinProposal { inner: Arc }, HasReplyableError { inner: Arc }, Monitor { inner: Arc }, + PendingFallback { inner: Arc }, Closed { inner: Arc }, } @@ -228,6 +298,8 @@ impl From for ReceiveSession { ReceiveSession::HasReplyableError(inner) => Self::HasReplyableError { inner: Arc::new(inner.into()) }, ReceiveSession::Monitor(inner) => Self::Monitor { inner: Arc::new(inner.into()) }, + ReceiveSession::PendingFallback(inner) => + Self::PendingFallback { inner: Arc::new(inner.into()) }, ReceiveSession::Closed(session_outcome) => Self::Closed { inner: Arc::new(session_outcome.into()) }, } @@ -1208,6 +1280,7 @@ pub struct PayjoinProposalTransition( payjoin::receive::v2::SessionEvent, payjoin::receive::v2::Receiver, payjoin::receive::ProtocolError, + payjoin::receive::v2::Receiver, >, >, >, @@ -1287,13 +1360,14 @@ impl From, payjoin::receive::ProtocolError, >, >, @@ -1306,28 +1380,29 @@ impl HasReplyableErrorTransition { pub fn save( &self, persister: Arc, - ) -> Result<(), ReceiverPersistedError> { + ) -> Result>, ReceiverPersistedError> { let adapter = CallbackPersisterAdapter::new(persister); let mut inner = self.0.write().expect("Lock should not be poisoned"); let value = inner.take().expect("Already saved or moved"); - value.save(&adapter).map_err(ReceiverPersistedError::from)?; - Ok(()) + let pending_fallback = value.save(&adapter).map_err(ReceiverPersistedError::from)?; + Ok(pending_fallback.map(|pending_fallback| Arc::new(pending_fallback.into()))) } pub async fn save_async( &self, persister: Arc, - ) -> Result<(), ReceiverPersistedError> { + ) -> Result>, ReceiverPersistedError> { let adapter = AsyncCallbackPersisterAdapter::new(persister); let value = { let mut inner = self.0.write().expect("Lock should not be poisoned"); inner.take().expect("Already saved or moved") }; - value.save_async(&adapter).await.map_err(ReceiverPersistedError::from)?; - Ok(()) + let pending_fallback = + value.save_async(&adapter).await.map_err(ReceiverPersistedError::from)?; + Ok(pending_fallback.map(|pending_fallback| Arc::new(pending_fallback.into()))) } } @@ -1432,6 +1507,81 @@ impl Monitor { } } +#[derive(uniffi::Object)] +#[allow(clippy::type_complexity)] +pub struct PendingFallbackTransition( + Arc< + RwLock< + Option>, + >, + >, +); + +#[uniffi::export] +impl PendingFallbackTransition { + pub fn save( + &self, + persister: Arc, + ) -> Result<(), ReceiverPersistedError> { + let adapter = CallbackPersisterAdapter::new(persister); + let mut inner = self.0.write().expect("Lock should not be poisoned"); + + let value = inner.take().expect("Already saved or moved"); + + value + .save(&adapter) + .map_err(|e| ReceiverPersistedError::from(ImplementationError::new(e)))?; + Ok(()) + } + + pub async fn save_async( + &self, + persister: Arc, + ) -> Result<(), ReceiverPersistedError> { + let adapter = AsyncCallbackPersisterAdapter::new(persister); + let value = { + let mut inner = self.0.write().expect("Lock should not be poisoned"); + inner.take().expect("Already saved or moved") + }; + + value + .save_async(&adapter) + .await + .map_err(|e| ReceiverPersistedError::from(ImplementationError::new(e)))?; + Ok(()) + } +} + +#[derive(Clone, uniffi::Object)] +pub struct PendingFallback( + Arc>, +); + +impl From> + for PendingFallback +{ + fn from(value: payjoin::receive::v2::Receiver) -> Self { + Self(Arc::new(value)) + } +} + +impl From + for payjoin::receive::v2::Receiver +{ + fn from(value: PendingFallback) -> Self { value.0.as_ref().clone() } +} + +#[uniffi::export] +impl PendingFallback { + pub fn fallback_tx(&self) -> Vec { + payjoin::bitcoin::consensus::encode::serialize(self.0.fallback_tx()) + } + + pub fn close(&self) -> PendingFallbackTransition { + PendingFallbackTransition(Arc::new(RwLock::new(Some(self.0.as_ref().clone().close())))) + } +} + /// Session persister that should save and load events as JSON strings. #[uniffi::export(with_foreign)] pub trait JsonReceiverSessionPersister: Send + Sync { diff --git a/payjoin/src/core/persist.rs b/payjoin/src/core/persist.rs index e665d86db..fe950bc84 100644 --- a/payjoin/src/core/persist.rs +++ b/payjoin/src/core/persist.rs @@ -361,102 +361,161 @@ where } } -/// A transition that can result in the completion of a state machine or a transient error -/// Fatal errors cannot occur in this transition. -pub struct MaybeSuccessTransition( - Result, Rejection>, -); +/// A transition that always results in a state transition. +pub struct NextStateTransition(AcceptNextState); -impl MaybeSuccessTransition -where - Err: std::error::Error, -{ - pub(crate) fn success(event: Event, success_value: SuccessValue) -> Self { - MaybeSuccessTransition(Ok(AcceptNextState(event, success_value))) +impl NextStateTransition { + pub(crate) fn success(event: Event, next_state: NextState) -> Self { + NextStateTransition(AcceptNextState(event, next_state)) } - pub(crate) fn transient(error: Err) -> Self { - MaybeSuccessTransition(Err(Rejection::transient(error))) + pub(crate) fn deconstruct(self) -> (PersistActions, NextState) { + let AcceptNextState(event, next_state) = self.0; + (PersistActions::Save(event), next_state) } - pub(crate) fn fatal(event: Event, error: Err) -> Self { - MaybeSuccessTransition(Err(Rejection::fatal(event, error))) + pub fn save

(self, persister: &P) -> Result + where + P: SessionPersister, + { + let (actions, next_state) = self.deconstruct(); + actions.execute(persister)?; + Ok(next_state) } - pub(crate) fn deconstruct( - self, - ) -> (PersistActions, Result>) { + pub async fn save_async

(self, persister: &P) -> Result + where + P: AsyncSessionPersister, + NextState: Send, + Event: Send, + { + let (actions, next_state) = self.deconstruct(); + actions.execute_async(persister).await?; + Ok(next_state) + } +} + +/// A transition that either advances to a live state or terminates the session. +/// +/// No error path exists. Both outcomes are successful from the protocol's point +/// of view. The choice is determined by the source typestate's internal data, +/// not by the caller. +pub struct MaybeTerminalTransition(MaybeTerminalOutcome); + +impl MaybeTerminalTransition { + pub(crate) fn advance(event: Event, next_state: NextState) -> Self { + Self(MaybeTerminalOutcome::Advance(AcceptNextState(event, next_state))) + } + + pub(crate) fn terminate(event: Event) -> Self { Self(MaybeTerminalOutcome::Terminate(event)) } + + pub(crate) fn deconstruct(self) -> (PersistActions, Option) { match self.0 { - Ok(AcceptNextState(event, success_value)) => - (PersistActions::SaveAndClose(event), Ok(success_value)), - Err(Rejection::Transient(RejectTransient(error))) => - (PersistActions::NoOp, Err(ApiError::Transient(error))), - Err(Rejection::Fatal(RejectFatal(event, error))) => - (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), - Err(Rejection::ReplyableError(RejectReplyableError(event, _, error))) => - (PersistActions::Save(event), Err(ApiError::Fatal(error))), + MaybeTerminalOutcome::Advance(AcceptNextState(event, next_state)) => + (PersistActions::Save(event), Some(next_state)), + MaybeTerminalOutcome::Terminate(event) => (PersistActions::SaveAndClose(event), None), } } - pub fn save

( - self, - persister: &P, - ) -> Result> + pub fn save

(self, persister: &P) -> Result, P::InternalStorageError> where P: SessionPersister, { - let (actions, outcome) = self.deconstruct(); - actions.execute(persister).map_err(InternalPersistedError::Storage)?; - Ok(outcome.map_err(InternalPersistedError::Api)?) + let (actions, next_state) = self.deconstruct(); + actions.execute(persister)?; + Ok(next_state) } pub async fn save_async

( self, persister: &P, - ) -> Result> + ) -> Result, P::InternalStorageError> where P: AsyncSessionPersister, - Err: Send, - SuccessValue: Send, + NextState: Send, Event: Send, { - let (actions, outcome) = self.deconstruct(); - actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; - Ok(outcome.map_err(InternalPersistedError::Api)?) + let (actions, next_state) = self.deconstruct(); + actions.execute_async(persister).await?; + Ok(next_state) } } -/// A transition that always results in a state transition. -pub struct NextStateTransition(AcceptNextState); +/// A transition that can either advance, terminate, or fail transiently. +/// +/// Fatal outcomes still persist an event. When the fatal outcome advances, the +/// saved event keeps the session live for replay while the caller receives the +/// fatal protocol error. +pub struct MaybeTerminalSuccessTransition( + MaybeTerminalSuccessOutcome, +); -impl NextStateTransition { - pub(crate) fn success(event: Event, next_state: NextState) -> Self { - NextStateTransition(AcceptNextState(event, next_state)) +impl MaybeTerminalSuccessTransition +where + Err: std::error::Error, +{ + pub(crate) fn advance(event: Event, next_state: NextState) -> Self { + Self(MaybeTerminalSuccessOutcome::Advance(AcceptNextState(event, next_state))) } - pub(crate) fn deconstruct(self) -> (PersistActions, NextState) { - let AcceptNextState(event, next_state) = self.0; - (PersistActions::Save(event), next_state) + pub(crate) fn terminate(event: Event) -> Self { + Self(MaybeTerminalSuccessOutcome::Terminate(event)) } - pub fn save

(self, persister: &P) -> Result + pub(crate) fn fatal_advance(event: Event, next_state: NextState, error: Err) -> Self { + Self(MaybeTerminalSuccessOutcome::FatalAdvance(event, next_state, error)) + } + + pub(crate) fn fatal_terminate(event: Event, error: Err) -> Self { + Self(MaybeTerminalSuccessOutcome::FatalTerminate(event, error)) + } + + pub(crate) fn transient(error: Err) -> Self { + Self(MaybeTerminalSuccessOutcome::Transient(error)) + } + + pub(crate) fn deconstruct( + self, + ) -> (PersistActions, Result, ApiError>) { + match self.0 { + MaybeTerminalSuccessOutcome::Advance(AcceptNextState(event, next_state)) => + (PersistActions::Save(event), Ok(Some(next_state))), + MaybeTerminalSuccessOutcome::Terminate(event) => + (PersistActions::SaveAndClose(event), Ok(None)), + MaybeTerminalSuccessOutcome::FatalAdvance(event, _next_state, error) => + (PersistActions::Save(event), Err(ApiError::Fatal(error))), + MaybeTerminalSuccessOutcome::FatalTerminate(event, error) => + (PersistActions::SaveAndClose(event), Err(ApiError::Fatal(error))), + MaybeTerminalSuccessOutcome::Transient(error) => + (PersistActions::NoOp, Err(ApiError::Transient(error))), + } + } + + pub fn save

( + self, + persister: &P, + ) -> Result, PersistedError> where P: SessionPersister, { - let (actions, next_state) = self.deconstruct(); - actions.execute(persister)?; - Ok(next_state) + let (actions, outcome) = self.deconstruct(); + actions.execute(persister).map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } - pub async fn save_async

(self, persister: &P) -> Result + pub async fn save_async

( + self, + persister: &P, + ) -> Result, PersistedError> where P: AsyncSessionPersister, + Err: Send, NextState: Send, Event: Send, { - let (actions, next_state) = self.deconstruct(); - actions.execute_async(persister).await?; - Ok(next_state) + let (actions, outcome) = self.deconstruct(); + actions.execute_async(persister).await.map_err(InternalPersistedError::Storage)?; + Ok(outcome.map_err(InternalPersistedError::Api)?) } } @@ -576,6 +635,19 @@ where /// Wrapper that marks the progression of a state machine pub struct AcceptNextState(Event, NextState); +enum MaybeTerminalOutcome { + Advance(AcceptNextState), + Terminate(Event), +} + +enum MaybeTerminalSuccessOutcome { + Advance(AcceptNextState), + Terminate(Event), + FatalAdvance(Event, NextState, Err), + FatalTerminate(Event, Err), + Transient(Err), +} + /// Wrapper that represents either a successful state transition or indicates no state change occurred pub enum AcceptOptionalTransition { /// A state transition that was successful and returned session event to be persisted @@ -1103,44 +1175,119 @@ mod tests { } #[tokio::test] - async fn test_maybe_success_transition() { + async fn test_maybe_terminal_transition() { let event = InMemoryTestEvent("foo".to_string()); - let error_event = InMemoryTestEvent("error event".to_string()); + let close_event = InMemoryTestEvent("close".to_string()); + let next_state = "Next state".to_string(); let test_cases = vec![ TestCase { make_transition: Box::new({ let event = event.clone(); - move || MaybeSuccessTransition::success(event.clone(), ()) + let next_state = next_state.clone(); + move || MaybeTerminalTransition::advance(event.clone(), next_state.clone()) }), expected_result: ExpectedResult { events: vec![event.clone()], + is_closed: false, + error: None, + success: Some(Some(next_state.clone())), + }, + }, + TestCase { + make_transition: Box::new({ + let close_event = close_event.clone(); + move || { + MaybeTerminalTransition::<_, InMemoryTestState>::terminate( + close_event.clone(), + ) + } + }), + expected_result: ExpectedResult { + events: vec![close_event.clone()], is_closed: true, error: None, - success: Some(()), + success: Some(None), }, }, + ]; + + run_test_cases!(test_cases); + } + + #[tokio::test] + async fn test_maybe_terminal_success_transition() { + let event = InMemoryTestEvent("foo".to_string()); + let close_event = InMemoryTestEvent("close".to_string()); + let fatal_event = InMemoryTestEvent("fatal".to_string()); + let fatal_close_event = InMemoryTestEvent("fatal close".to_string()); + let next_state = "Next state".to_string(); + + let test_cases = vec![ TestCase { - make_transition: Box::new(|| { - MaybeSuccessTransition::transient(InMemoryTestError {}) + make_transition: Box::new({ + let event = event.clone(); + let next_state = next_state.clone(); + move || { + MaybeTerminalSuccessTransition::advance(event.clone(), next_state.clone()) + } }), expected_result: ExpectedResult { - events: vec![], + events: vec![event.clone()], + is_closed: false, + error: None, + success: Some(Some(next_state.clone())), + }, + }, + TestCase { + make_transition: Box::new({ + let close_event = close_event.clone(); + move || { + MaybeTerminalSuccessTransition::<_, InMemoryTestState, InMemoryTestError>::terminate( + close_event.clone(), + ) + } + }), + expected_result: ExpectedResult { + events: vec![close_event.clone()], + is_closed: true, + error: None, + success: Some(None), + }, + }, + TestCase { + make_transition: Box::new({ + let fatal_event = fatal_event.clone(); + let next_state = next_state.clone(); + move || { + MaybeTerminalSuccessTransition::fatal_advance( + fatal_event.clone(), + next_state.clone(), + InMemoryTestError {}, + ) + } + }), + expected_result: ExpectedResult { + events: vec![fatal_event.clone()], is_closed: false, error: Some( - InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) - .into(), + InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), ), success: None, }, }, TestCase { make_transition: Box::new({ - let error_event = error_event.clone(); - move || MaybeSuccessTransition::fatal(error_event.clone(), InMemoryTestError {}) + let fatal_close_event = fatal_close_event.clone(); + move || { + MaybeTerminalSuccessTransition::<_, InMemoryTestState, InMemoryTestError>::fatal_terminate( + fatal_close_event.clone(), + InMemoryTestError {}, + ) + } }), expected_result: ExpectedResult { - events: vec![error_event.clone()], + events: vec![fatal_close_event.clone()], is_closed: true, error: Some( InternalPersistedError::Api(ApiError::Fatal(InMemoryTestError {})).into(), @@ -1148,6 +1295,24 @@ mod tests { success: None, }, }, + TestCase { + make_transition: Box::new(|| { + MaybeTerminalSuccessTransition::< + InMemoryTestEvent, + InMemoryTestState, + InMemoryTestError, + >::transient(InMemoryTestError {}) + }), + expected_result: ExpectedResult { + events: vec![], + is_closed: false, + error: Some( + InternalPersistedError::Api(ApiError::Transient(InMemoryTestError {})) + .into(), + ), + success: None, + }, + }, ]; run_test_cases!(test_cases); diff --git a/payjoin/src/core/receive/v2/mod.rs b/payjoin/src/core/receive/v2/mod.rs index 4b57c6abd..6ca938862 100644 --- a/payjoin/src/core/receive/v2/mod.rs +++ b/payjoin/src/core/receive/v2/mod.rs @@ -42,6 +42,7 @@ pub use session::{ #[cfg(target_arch = "wasm32")] use web_time::Duration; +use self::sealed::FallbackTx; use super::error::{Error, InputContributionError}; use super::{ common, InternalPayloadError, JsonReply, OutputSubstitutionError, ProtocolError, SelectionError, @@ -55,7 +56,8 @@ use crate::ohttp::{ use crate::output_substitution::OutputSubstitution; use crate::persist::{ MaybeFatalOrSuccessTransition, MaybeFatalTransition, MaybeFatalTransitionWithNoResults, - MaybeSuccessTransition, MaybeTransientTransition, NextStateTransition, TerminalTransition, + MaybeTerminalSuccessTransition, MaybeTerminalTransition, MaybeTransientTransition, + NextStateTransition, TerminalTransition, }; use crate::receive::{parse_payload, InputPair, OriginalPayload, PsbtContext}; use crate::time::Time; @@ -145,6 +147,7 @@ pub enum ReceiveSession { PayjoinProposal(Receiver), HasReplyableError(Receiver), Monitor(Receiver), + PendingFallback(Receiver), Closed(SessionOutcome), } @@ -198,29 +201,65 @@ impl ReceiveSession { (ReceiveSession::PayjoinProposal(state), SessionEvent::PostedPayjoinProposal()) => Ok(state.apply_payjoin_posted()), + (session, SessionEvent::Cancelled) => + try_pending_fallback(session, PendingFallbackCause::Cancelled).map_err(|session| { + InternalReplayError::InvalidEvent( + Box::new(SessionEvent::Cancelled), + Some(session), + ) + .into() + }), + + (session, SessionEvent::ProtocolFailed) => + try_pending_fallback(session, PendingFallbackCause::ProtocolFailed).map_err( + |session| { + InternalReplayError::InvalidEvent( + Box::new(SessionEvent::ProtocolFailed), + Some(session), + ) + .into() + }, + ), + (_, SessionEvent::Closed(session_outcome)) => Ok(ReceiveSession::Closed(session_outcome)), - (session, SessionEvent::GotReplyableError(error)) => + (session, SessionEvent::GotReplyableError(error)) => { + let (session_context, fallback_tx) = match session { + ReceiveSession::Initialized(r) => (r.session_context, None), + ReceiveSession::UncheckedOriginalPayload(r) => (r.session_context, None), + ReceiveSession::MaybeInputsOwned(r) => + (r.session_context, Some(r.state.fallback_tx())), + ReceiveSession::MaybeInputsSeen(r) => + (r.session_context, Some(r.state.fallback_tx())), + ReceiveSession::OutputsUnknown(r) => + (r.session_context, Some(r.state.fallback_tx())), + ReceiveSession::WantsOutputs(r) => + (r.session_context, Some(r.state.fallback_tx())), + ReceiveSession::WantsInputs(r) => + (r.session_context, Some(r.state.fallback_tx())), + ReceiveSession::WantsFeeRange(r) => + (r.session_context, Some(r.state.fallback_tx())), + ReceiveSession::ProvisionalProposal(r) => + (r.session_context, Some(r.state.fallback_tx())), + ReceiveSession::PayjoinProposal(r) => + (r.session_context, Some(r.state.fallback_tx())), + ReceiveSession::HasReplyableError(r) => + (r.session_context, r.state.fallback_tx.clone()), + ReceiveSession::Monitor(r) => (r.session_context, Some(r.state.fallback_tx())), + ReceiveSession::PendingFallback(r) => { + let fallback_tx = r.fallback_tx().clone(); + (r.session_context, Some(fallback_tx)) + } + ReceiveSession::Closed(session_outcome) => + return Ok(ReceiveSession::Closed(session_outcome)), + }; + Ok(ReceiveSession::HasReplyableError(Receiver { - state: HasReplyableError { error_reply: error.clone() }, - session_context: match session { - ReceiveSession::Initialized(r) => r.session_context, - ReceiveSession::UncheckedOriginalPayload(r) => r.session_context, - ReceiveSession::MaybeInputsOwned(r) => r.session_context, - ReceiveSession::MaybeInputsSeen(r) => r.session_context, - ReceiveSession::OutputsUnknown(r) => r.session_context, - ReceiveSession::WantsOutputs(r) => r.session_context, - ReceiveSession::WantsInputs(r) => r.session_context, - ReceiveSession::WantsFeeRange(r) => r.session_context, - ReceiveSession::ProvisionalProposal(r) => r.session_context, - ReceiveSession::PayjoinProposal(r) => r.session_context, - ReceiveSession::HasReplyableError(r) => r.session_context, - ReceiveSession::Monitor(r) => r.session_context, - ReceiveSession::Closed(session_outcome) => - return Ok(ReceiveSession::Closed(session_outcome)), - }, - })), + state: HasReplyableError { error_reply: error, fallback_tx }, + session_context, + })) + } (current_state, event) => Err(InternalReplayError::InvalidEvent( Box::new(event), @@ -231,72 +270,125 @@ impl ReceiveSession { } } -mod sealed { - pub trait State { - fn fallback_tx(&self) -> Option { None } +fn pending_fallback_from( + r: Receiver, + cause: PendingFallbackCause, +) -> ReceiveSession { + let fallback_tx = r.state.fallback_tx(); + ReceiveSession::PendingFallback(Receiver { + state: PendingFallback { fallback_tx, cause }, + session_context: r.session_context, + }) +} + +fn pending_fallback_from_replyable_error( + r: Receiver, + cause: PendingFallbackCause, +) -> Result> { + let Receiver { state: HasReplyableError { error_reply, fallback_tx }, session_context } = r; + match fallback_tx { + Some(fallback_tx) => Ok(ReceiveSession::PendingFallback(Receiver { + state: PendingFallback { fallback_tx, cause }, + session_context, + })), + None => Err(Box::new(ReceiveSession::HasReplyableError(Receiver { + state: HasReplyableError { error_reply, fallback_tx: None }, + session_context, + }))), } +} +fn try_pending_fallback( + session: ReceiveSession, + cause: PendingFallbackCause, +) -> Result> { + match session { + ReceiveSession::MaybeInputsOwned(receiver) => Ok(pending_fallback_from(receiver, cause)), + ReceiveSession::MaybeInputsSeen(receiver) => Ok(pending_fallback_from(receiver, cause)), + ReceiveSession::OutputsUnknown(receiver) => Ok(pending_fallback_from(receiver, cause)), + ReceiveSession::WantsOutputs(receiver) => Ok(pending_fallback_from(receiver, cause)), + ReceiveSession::WantsInputs(receiver) => Ok(pending_fallback_from(receiver, cause)), + ReceiveSession::WantsFeeRange(receiver) => Ok(pending_fallback_from(receiver, cause)), + ReceiveSession::ProvisionalProposal(receiver) => Ok(pending_fallback_from(receiver, cause)), + ReceiveSession::PayjoinProposal(receiver) => Ok(pending_fallback_from(receiver, cause)), + ReceiveSession::HasReplyableError(receiver) => + pending_fallback_from_replyable_error(receiver, cause), + ReceiveSession::Monitor(receiver) => Ok(pending_fallback_from(receiver, cause)), + session => Err(Box::new(session)), + } +} + +mod sealed { + pub trait State {} impl State for super::Initialized {} + impl State for super::UncheckedOriginalPayload {} + impl State for super::MaybeInputsOwned {} + impl State for super::MaybeInputsSeen {} + impl State for super::OutputsUnknown {} + impl State for super::WantsOutputs {} + impl State for super::WantsInputs {} + impl State for super::WantsFeeRange {} + impl State for super::ProvisionalProposal {} + impl State for super::PayjoinProposal {} + impl State for super::HasReplyableError {} + impl State for super::Monitor {} + impl State for super::PendingFallback {} - impl State for super::UncheckedOriginalPayload { - fn fallback_tx(&self) -> Option { - Some(self.original.psbt.clone().extract_tx_unchecked_fee_rate()) - } + pub trait FallbackTx: State { + fn fallback_tx(&self) -> bitcoin::Transaction; } - impl State for super::MaybeInputsOwned { - fn fallback_tx(&self) -> Option { - Some(self.original.psbt.clone().extract_tx_unchecked_fee_rate()) + impl FallbackTx for super::MaybeInputsOwned { + fn fallback_tx(&self) -> bitcoin::Transaction { + self.original.psbt.clone().extract_tx_unchecked_fee_rate() } } - impl State for super::MaybeInputsSeen { - fn fallback_tx(&self) -> Option { - Some(self.original.psbt.clone().extract_tx_unchecked_fee_rate()) + impl FallbackTx for super::MaybeInputsSeen { + fn fallback_tx(&self) -> bitcoin::Transaction { + self.original.psbt.clone().extract_tx_unchecked_fee_rate() } } - impl State for super::OutputsUnknown { - fn fallback_tx(&self) -> Option { - Some(self.original.psbt.clone().extract_tx_unchecked_fee_rate()) + impl FallbackTx for super::OutputsUnknown { + fn fallback_tx(&self) -> bitcoin::Transaction { + self.original.psbt.clone().extract_tx_unchecked_fee_rate() } } - impl State for super::WantsOutputs { - fn fallback_tx(&self) -> Option { - Some(self.inner.original_psbt.clone().extract_tx_unchecked_fee_rate()) + impl FallbackTx for super::WantsOutputs { + fn fallback_tx(&self) -> bitcoin::Transaction { + self.inner.original_psbt.clone().extract_tx_unchecked_fee_rate() } } - impl State for super::WantsInputs { - fn fallback_tx(&self) -> Option { - Some(self.inner.original_psbt.clone().extract_tx_unchecked_fee_rate()) + impl FallbackTx for super::WantsInputs { + fn fallback_tx(&self) -> bitcoin::Transaction { + self.inner.original_psbt.clone().extract_tx_unchecked_fee_rate() } } - impl State for super::WantsFeeRange { - fn fallback_tx(&self) -> Option { - Some(self.inner.original_psbt.clone().extract_tx_unchecked_fee_rate()) + impl FallbackTx for super::WantsFeeRange { + fn fallback_tx(&self) -> bitcoin::Transaction { + self.inner.original_psbt.clone().extract_tx_unchecked_fee_rate() } } - impl State for super::ProvisionalProposal { - fn fallback_tx(&self) -> Option { - Some(self.psbt_context.original_psbt.clone().extract_tx_unchecked_fee_rate()) + impl FallbackTx for super::ProvisionalProposal { + fn fallback_tx(&self) -> bitcoin::Transaction { + self.psbt_context.original_psbt.clone().extract_tx_unchecked_fee_rate() } } - impl State for super::PayjoinProposal { - fn fallback_tx(&self) -> Option { - Some(self.psbt_context.original_psbt.clone().extract_tx_unchecked_fee_rate()) + impl FallbackTx for super::PayjoinProposal { + fn fallback_tx(&self) -> bitcoin::Transaction { + self.psbt_context.original_psbt.clone().extract_tx_unchecked_fee_rate() } } - impl State for super::HasReplyableError {} - - impl State for super::Monitor { - fn fallback_tx(&self) -> Option { - Some(self.psbt_context.original_psbt.clone().extract_tx_unchecked_fee_rate()) + impl FallbackTx for super::Monitor { + fn fallback_tx(&self) -> bitcoin::Transaction { + self.psbt_context.original_psbt.clone().extract_tx_unchecked_fee_rate() } } } @@ -310,6 +402,14 @@ pub trait State: sealed::State {} impl State for S {} +/// Marker trait for receiver protocol states that hold a verified broadcastable +/// fallback transaction. +/// +/// This trait is sealed to prevent external implementations. +pub trait HasFallbackTx: sealed::FallbackTx {} + +impl HasFallbackTx for T {} + /// A higher-level receiver construct which will be taken through different states through the /// protocol workflow. /// @@ -340,17 +440,48 @@ impl core::ops::DerefMut for Receiver { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.state } } -impl Receiver { - /// Cancel the Payjoin session immediately. - /// - /// Returns a [`TerminalTransition`] that, once persisted, yields the fallback - /// transaction when applicable. The fallback transaction is the sender's original - /// transaction that should be broadcast to complete the payment without Payjoin. - /// - /// This is a terminal transition — the session cannot be used after cancellation. - pub fn cancel(self) -> TerminalTransition> { - let fallback = self.state.fallback_tx(); - TerminalTransition::new(SessionEvent::Closed(SessionOutcome::Cancel), fallback) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PendingFallback { + fallback_tx: bitcoin::Transaction, + cause: PendingFallbackCause, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PendingFallbackCause { + Cancelled, + ProtocolFailed, +} + +impl Receiver { + pub fn fallback_tx(&self) -> &bitcoin::Transaction { &self.state.fallback_tx } + + pub fn close(self) -> TerminalTransition { + let outcome = match self.state.cause { + PendingFallbackCause::Cancelled => SessionOutcome::Cancel, + PendingFallbackCause::ProtocolFailed => SessionOutcome::Failure, + }; + TerminalTransition::new(SessionEvent::Closed(outcome), ()) + } +} + +impl Receiver { + /// Cancel the Payjoin session and surface the fallback transaction. + pub fn cancel(self) -> NextStateTransition> { + let fallback_tx = self.state.fallback_tx(); + NextStateTransition::success( + SessionEvent::Cancelled, + Receiver { + state: PendingFallback { fallback_tx, cause: PendingFallbackCause::Cancelled }, + session_context: self.session_context, + }, + ) + } +} + +impl Receiver { + /// Cancel before any fallback transaction exists. + pub fn cancel(self) -> TerminalTransition { + TerminalTransition::new(SessionEvent::Closed(SessionOutcome::Cancel), ()) } } @@ -601,6 +732,13 @@ pub struct UncheckedOriginalPayload { pub(crate) original: OriginalPayload, } +impl Receiver { + /// Cancel before broadcast suitability has been checked. + pub fn cancel(self) -> TerminalTransition { + TerminalTransition::new(SessionEvent::Closed(SessionOutcome::Cancel), ()) + } +} + /// The original PSBT and the optional parameters received from the sender. /// /// This is the first typestate after the retrieval of the sender's original proposal in @@ -653,7 +791,7 @@ impl Receiver { Err(e) => MaybeFatalTransition::replyable_error( SessionEvent::GotReplyableError((&e).into()), Receiver { - state: HasReplyableError { error_reply: (&e).into() }, + state: HasReplyableError { error_reply: (&e).into(), fallback_tx: None }, session_context: self.session_context, }, e, @@ -730,7 +868,10 @@ impl Receiver { return MaybeFatalTransition::replyable_error( SessionEvent::GotReplyableError((&e).into()), Receiver { - state: HasReplyableError { error_reply: (&e).into() }, + state: HasReplyableError { + error_reply: (&e).into(), + fallback_tx: Some(self.state.fallback_tx()), + }, session_context: self.session_context, }, e, @@ -792,7 +933,10 @@ impl Receiver { return MaybeFatalTransition::replyable_error( SessionEvent::GotReplyableError((&e).into()), Receiver { - state: HasReplyableError { error_reply: (&e).into() }, + state: HasReplyableError { + error_reply: (&e).into(), + fallback_tx: Some(self.state.fallback_tx()), + }, session_context: self.session_context, }, e, @@ -849,6 +993,7 @@ impl Receiver { Error, Receiver, > { + let fallback_tx = Some(self.state.fallback_tx()); let inner = match self.state.original.identify_receiver_outputs(is_receiver_output) { Ok(inner) => inner, Err(e) => match e { @@ -859,7 +1004,7 @@ impl Receiver { return MaybeFatalTransition::replyable_error( SessionEvent::GotReplyableError((&e).into()), Receiver { - state: HasReplyableError { error_reply: (&e).into() }, + state: HasReplyableError { error_reply: (&e).into(), fallback_tx }, session_context: self.session_context, }, e, @@ -1203,7 +1348,12 @@ impl Receiver { self, res: &[u8], ohttp_context: ohttp::ClientResponse, - ) -> MaybeFatalTransition, ProtocolError> { + ) -> MaybeFatalTransition< + SessionEvent, + Receiver, + ProtocolError, + Receiver, + > { match process_post_res(res, ohttp_context) { Ok(_) => MaybeFatalTransition::success( SessionEvent::PostedPayjoinProposal(), @@ -1214,8 +1364,15 @@ impl Receiver { ), Err(e) => if e.is_fatal() { - MaybeFatalTransition::fatal( - SessionEvent::Closed(SessionOutcome::Failure), + MaybeFatalTransition::replyable_error( + SessionEvent::ProtocolFailed, + Receiver { + state: PendingFallback { + fallback_tx: self.state.fallback_tx(), + cause: PendingFallbackCause::ProtocolFailed, + }, + session_context: self.session_context.clone(), + }, ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()), ) } else { @@ -1237,9 +1394,26 @@ impl Receiver { #[derive(Debug, Clone, PartialEq)] pub struct HasReplyableError { error_reply: JsonReply, + fallback_tx: Option, } impl Receiver { + /// Cancel without sending the error response. + pub fn cancel(self) -> MaybeTerminalTransition> { + let Receiver { state: HasReplyableError { fallback_tx, .. }, session_context } = self; + match fallback_tx { + Some(fallback_tx) => MaybeTerminalTransition::advance( + SessionEvent::Cancelled, + Receiver { + state: PendingFallback { fallback_tx, cause: PendingFallbackCause::Cancelled }, + session_context, + }, + ), + None => + MaybeTerminalTransition::terminate(SessionEvent::Closed(SessionOutcome::Cancel)), + } + } + /// Construct an OHTTP Encapsulated HTTP POST request to return /// a Receiver Error Response pub fn create_error_request( @@ -1275,28 +1449,57 @@ impl Receiver { } /// Process an OHTTP Encapsulated HTTP POST Error response - /// to ensure it has been posted properly + /// to ensure it has been posted properly. + /// + /// This uses [`MaybeTerminalSuccessTransition`] because the posted error is + /// successfully handled by either surfacing a fallback obligation or closing + /// the session. Fatal directory errors follow the same fallback or close + /// split while returning the fatal error to the caller. pub fn process_error_response( &self, res: &[u8], ohttp_context: ohttp::ClientResponse, - ) -> MaybeSuccessTransition { + ) -> MaybeTerminalSuccessTransition, ProtocolError> + { match process_post_res(res, ohttp_context) { - Ok(_) => - MaybeSuccessTransition::success(SessionEvent::Closed(SessionOutcome::Failure), ()), + Ok(_) => match self.pending_fallback_after_protocol_failure() { + Some(pending_fallback) => MaybeTerminalSuccessTransition::advance( + SessionEvent::ProtocolFailed, + pending_fallback, + ), + None => MaybeTerminalSuccessTransition::terminate(SessionEvent::Closed( + SessionOutcome::Failure, + )), + }, Err(e) => if e.is_fatal() { - MaybeSuccessTransition::fatal( - SessionEvent::Closed(SessionOutcome::Failure), - ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()), - ) + let error = + ProtocolError::V2(InternalSessionError::DirectoryResponse(e).into()); + match self.pending_fallback_after_protocol_failure() { + Some(pending_fallback) => MaybeTerminalSuccessTransition::fatal_advance( + SessionEvent::ProtocolFailed, + pending_fallback, + error, + ), + None => MaybeTerminalSuccessTransition::fatal_terminate( + SessionEvent::Closed(SessionOutcome::Failure), + error, + ), + } } else { - MaybeSuccessTransition::transient(ProtocolError::V2( + MaybeTerminalSuccessTransition::transient(ProtocolError::V2( InternalSessionError::DirectoryResponse(e).into(), )) }, } } + + fn pending_fallback_after_protocol_failure(&self) -> Option> { + self.state.fallback_tx.clone().map(|fallback_tx| Receiver { + state: PendingFallback { fallback_tx, cause: PendingFallbackCause::ProtocolFailed }, + session_context: self.session_context.clone(), + }) + } } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -1439,7 +1642,7 @@ pub mod test { use super::*; use crate::output_substitution::OutputSubstitution; use crate::persist::{ - InMemoryPersister, OptionalTransitionOutcome, RejectTransient, Rejection, + InMemoryPersister, OptionalTransitionOutcome, RejectTransient, Rejection, SessionPersister, }; use crate::receive::optional_parameters::Params; use crate::receive::v2; @@ -1491,6 +1694,46 @@ pub mod test { JsonReply::from(&res) } + pub(crate) fn mock_fallback_tx() -> bitcoin::Transaction { + PARSED_ORIGINAL_PSBT.clone().extract_tx_unchecked_fee_rate() + } + + fn receiver(state: S) -> Receiver { + Receiver { state, session_context: SHARED_CONTEXT.clone() } + } + + fn assert_events( + persister: &InMemoryPersister, + expected_events: &[SessionEvent], + expected_closed: bool, + ) { + let inner = persister.inner.read().expect("Shouldn't be poisoned"); + assert_eq!(&*inner.events, expected_events); + assert_eq!(inner.is_closed, expected_closed); + } + + fn ohttp_response_for(req_body: &[u8], status: http::StatusCode) -> Vec { + let server = ohttp::Server::new(SHARED_CONTEXT.ohttp_keys.0.clone()) + .expect("test OHTTP server should be valid"); + let (_, probe_response) = server.decapsulate(req_body).expect("request should decapsulate"); + let response_overhead = + probe_response.encapsulate(&[]).expect("probe should encrypt").len(); + + let (_, server_response) = + server.decapsulate(req_body).expect("request should decapsulate again"); + let mut bhttp_response = + vec![0u8; crate::directory::ENCAPSULATED_MESSAGE_BYTES - response_overhead]; + bhttp::Message::response( + bhttp::StatusCode::try_from(status.as_u16()).expect("status should be valid"), + ) + .write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_response.as_mut_slice()) + .expect("BHTTP response should encode"); + let encrypted = + server_response.encapsulate(&bhttp_response).expect("response should encrypt"); + assert_eq!(encrypted.len(), crate::directory::ENCAPSULATED_MESSAGE_BYTES); + encrypted + } + #[test] fn test_monitor_typestate() -> Result<(), BoxError> { let psbt_ctx = PsbtContext { @@ -1765,7 +2008,10 @@ pub mod test { assert_eq!(mock_err.to_json(), expected_json); let receiver = Receiver { - state: HasReplyableError { error_reply: mock_err.clone() }, + state: HasReplyableError { + error_reply: mock_err.clone(), + fallback_tx: Some(mock_fallback_tx()), + }, session_context: SHARED_CONTEXT.clone(), }; @@ -1779,7 +2025,10 @@ pub mod test { let now = crate::time::Time::now(); let context = SessionContext { expiration: now, ..SHARED_CONTEXT.clone() }; let receiver = Receiver { - state: HasReplyableError { error_reply: mock_err() }, + state: HasReplyableError { + error_reply: mock_err(), + fallback_tx: Some(mock_fallback_tx()), + }, session_context: context.clone(), }; @@ -1795,6 +2044,123 @@ pub mod test { Ok(()) } + #[test] + fn process_error_response_success_with_fallback_enters_pending_fallback() -> Result<(), BoxError> + { + let expected_tx = mock_fallback_tx(); + let receiver = receiver(HasReplyableError { + error_reply: mock_err(), + fallback_tx: Some(expected_tx.clone()), + }); + let (req, ctx) = receiver.create_error_request(EXAMPLE_URL)?; + let response = ohttp_response_for(&req.body, http::StatusCode::OK); + let persister = InMemoryPersister::::default(); + + let pending_fallback = receiver + .process_error_response(&response, ctx) + .save(&persister)? + .expect("pending fallback should be returned"); + + assert_eq!(pending_fallback.fallback_tx(), &expected_tx); + assert_events(&persister, &[SessionEvent::ProtocolFailed], false); + Ok(()) + } + + #[test] + fn process_error_response_success_without_fallback_closes_session() -> Result<(), BoxError> { + let receiver = receiver(HasReplyableError { error_reply: mock_err(), fallback_tx: None }); + let (req, ctx) = receiver.create_error_request(EXAMPLE_URL)?; + let response = ohttp_response_for(&req.body, http::StatusCode::OK); + let persister = InMemoryPersister::::default(); + + let pending_fallback = receiver.process_error_response(&response, ctx).save(&persister)?; + + assert!(pending_fallback.is_none()); + assert_events(&persister, &[SessionEvent::Closed(SessionOutcome::Failure)], true); + Ok(()) + } + + #[test] + fn process_error_response_fatal_with_fallback_enters_pending_fallback() -> Result<(), BoxError> + { + let receiver = receiver(HasReplyableError { + error_reply: mock_err(), + fallback_tx: Some(mock_fallback_tx()), + }); + let (req, ctx) = receiver.create_error_request(EXAMPLE_URL)?; + let response = ohttp_response_for(&req.body, http::StatusCode::BAD_REQUEST); + let persister = InMemoryPersister::::default(); + + let err = receiver + .process_error_response(&response, ctx) + .save(&persister) + .expect_err("fatal response should error"); + + assert!(err.api_error_ref().is_some()); + assert_events(&persister, &[SessionEvent::ProtocolFailed], false); + Ok(()) + } + + #[test] + fn process_error_response_fatal_without_fallback_closes_session() -> Result<(), BoxError> { + let receiver = receiver(HasReplyableError { error_reply: mock_err(), fallback_tx: None }); + let (req, ctx) = receiver.create_error_request(EXAMPLE_URL)?; + let response = ohttp_response_for(&req.body, http::StatusCode::BAD_REQUEST); + let persister = InMemoryPersister::::default(); + + let err = receiver + .process_error_response(&response, ctx) + .save(&persister) + .expect_err("fatal response should error"); + + assert!(err.api_error_ref().is_some()); + assert_events(&persister, &[SessionEvent::Closed(SessionOutcome::Failure)], true); + Ok(()) + } + + #[test] + fn process_error_response_transient_leaves_session_open() -> Result<(), BoxError> { + let receiver = receiver(HasReplyableError { + error_reply: mock_err(), + fallback_tx: Some(mock_fallback_tx()), + }); + let (req, ctx) = receiver.create_error_request(EXAMPLE_URL)?; + let response = ohttp_response_for(&req.body, http::StatusCode::INTERNAL_SERVER_ERROR); + let persister = InMemoryPersister::::default(); + + let err = receiver + .process_error_response(&response, ctx) + .save(&persister) + .expect_err("transient response should error"); + + assert!(err.api_error_ref().is_some()); + assert_events(&persister, &[], false); + Ok(()) + } + + #[test] + fn payjoin_proposal_fatal_response_enters_pending_fallback() -> Result<(), BoxError> { + let expected_tx = mock_fallback_tx(); + let psbt_context = PsbtContext { + original_psbt: PARSED_ORIGINAL_PSBT.clone(), + payjoin_psbt: PARSED_PAYJOIN_PROPOSAL.clone(), + }; + let proposal = receiver(PayjoinProposal { psbt_context }); + let (req, ctx) = proposal.create_post_request(EXAMPLE_URL)?; + let response = ohttp_response_for(&req.body, http::StatusCode::BAD_REQUEST); + let persister = InMemoryPersister::::default(); + + let err = proposal + .process_response(&response, ctx) + .save(&persister) + .expect_err("fatal response should error"); + let pending_fallback = err.error_state().expect("pending fallback should be carried"); + + assert_eq!(pending_fallback.fallback_tx(), &expected_tx); + assert_events(&persister, &[SessionEvent::ProtocolFailed], false); + Ok(()) + } + #[test] fn default_max_fee_rate() { let persister = InMemoryPersister::default(); @@ -1897,52 +2263,241 @@ pub mod test { } #[test] - fn cancel_returns_expected_fallback() { - macro_rules! do_cancel_test { - ($state:expr, $expected:expr) => {{ - let persister = InMemoryPersister::::default(); - let fallback = Receiver { state: $state, session_context: SHARED_CONTEXT.clone() } - .cancel() - .save(&persister) - .expect("save should succeed"); - assert_eq!(fallback, $expected, "cancel from {}", stringify!($state)); - }}; - } + fn cancel_initialized_closes_session() { + let persister = InMemoryPersister::::default(); + receiver(Initialized {}).cancel().save(&persister).expect("save should succeed"); + + assert_events(&persister, &[SessionEvent::Closed(SessionOutcome::Cancel)], true); + } + #[test] + fn cancel_unchecked_original_payload_closes_session() { + let original = + OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params: Params::default() }; + let persister = InMemoryPersister::::default(); + receiver(UncheckedOriginalPayload { original }) + .cancel() + .save(&persister) + .expect("save should succeed"); + + assert_events(&persister, &[SessionEvent::Closed(SessionOutcome::Cancel)], true); + } + + #[test] + fn cancel_has_fallback_enters_pending_fallback() { let original = OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params: Params::default() }; let expected_tx = PARSED_ORIGINAL_PSBT.clone().extract_tx_unchecked_fee_rate(); - let psbt_ctx = PsbtContext { - original_psbt: PARSED_ORIGINAL_PSBT.clone(), - payjoin_psbt: PARSED_PAYJOIN_PROPOSAL.clone(), - }; - let wants_outputs = common::WantsOutputs::new(original.clone(), vec![0]); - let wants_inputs = wants_outputs.clone().commit_outputs(); - let wants_fee_range = wants_inputs.clone().commit_inputs(); - - // States without a fallback transaction - do_cancel_test!(Initialized {}, None); - do_cancel_test!(HasReplyableError { error_reply: mock_err() }, None); - - // States with a fallback transaction - do_cancel_test!( - UncheckedOriginalPayload { original: original.clone() }, - Some(expected_tx.clone()) - ); - do_cancel_test!(MaybeInputsOwned { original: original.clone() }, Some(expected_tx.clone())); - do_cancel_test!(MaybeInputsSeen { original: original.clone() }, Some(expected_tx.clone())); - do_cancel_test!(OutputsUnknown { original }, Some(expected_tx.clone())); - do_cancel_test!(WantsOutputs { inner: wants_outputs }, Some(expected_tx.clone())); - do_cancel_test!(WantsInputs { inner: wants_inputs }, Some(expected_tx.clone())); - do_cancel_test!(WantsFeeRange { inner: wants_fee_range }, Some(expected_tx.clone())); - do_cancel_test!( - ProvisionalProposal { psbt_context: psbt_ctx.clone() }, - Some(expected_tx.clone()) - ); - do_cancel_test!( - PayjoinProposal { psbt_context: psbt_ctx.clone() }, - Some(expected_tx.clone()) - ); - do_cancel_test!(Monitor { psbt_context: psbt_ctx }, Some(expected_tx)); + let persister = InMemoryPersister::::default(); + let pending_fallback = receiver(MaybeInputsOwned { original }) + .cancel() + .save(&persister) + .expect("save should succeed"); + + assert_eq!(pending_fallback.fallback_tx(), &expected_tx); + assert_events(&persister, &[SessionEvent::Cancelled], false); + } + + #[test] + fn cancel_replyable_error_with_fallback_enters_pending_fallback() { + let expected_tx = mock_fallback_tx(); + let persister = InMemoryPersister::::default(); + let pending_fallback = receiver(HasReplyableError { + error_reply: mock_err(), + fallback_tx: Some(expected_tx.clone()), + }) + .cancel() + .save(&persister) + .expect("save should succeed") + .expect("pending fallback should be returned"); + + assert_eq!(pending_fallback.fallback_tx(), &expected_tx); + assert_events(&persister, &[SessionEvent::Cancelled], false); + } + + #[test] + fn cancel_replyable_error_without_fallback_closes_session() { + let persister = InMemoryPersister::::default(); + let pending_fallback = + receiver(HasReplyableError { error_reply: mock_err(), fallback_tx: None }) + .cancel() + .save(&persister) + .expect("save should succeed"); + + assert!(pending_fallback.is_none()); + assert_events(&persister, &[SessionEvent::Closed(SessionOutcome::Cancel)], true); + } + + #[test] + fn replaying_cancel_event_sequences_reaches_expected_states() { + let original = + OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params: Params::default() }; + let expected_tx = PARSED_ORIGINAL_PSBT.clone().extract_tx_unchecked_fee_rate(); + let replyable_error = mock_err(); + + let test_cases = vec![ + ( + vec![ + SessionEvent::Created(SHARED_CONTEXT.clone()), + SessionEvent::Closed(SessionOutcome::Cancel), + ], + ReceiveSession::Closed(SessionOutcome::Cancel), + ), + ( + vec![ + SessionEvent::Created(SHARED_CONTEXT.clone()), + SessionEvent::RetrievedOriginalPayload { + original: original.clone(), + reply_key: None, + }, + SessionEvent::Closed(SessionOutcome::Cancel), + ], + ReceiveSession::Closed(SessionOutcome::Cancel), + ), + ( + vec![ + SessionEvent::Created(SHARED_CONTEXT.clone()), + SessionEvent::RetrievedOriginalPayload { + original: original.clone(), + reply_key: None, + }, + SessionEvent::CheckedBroadcastSuitability(), + SessionEvent::Cancelled, + ], + ReceiveSession::PendingFallback(Receiver { + state: PendingFallback { + fallback_tx: expected_tx.clone(), + cause: PendingFallbackCause::Cancelled, + }, + session_context: SHARED_CONTEXT.clone(), + }), + ), + ( + vec![ + SessionEvent::Created(SHARED_CONTEXT.clone()), + SessionEvent::RetrievedOriginalPayload { + original: original.clone(), + reply_key: None, + }, + SessionEvent::CheckedBroadcastSuitability(), + SessionEvent::GotReplyableError(replyable_error.clone()), + SessionEvent::Cancelled, + ], + ReceiveSession::PendingFallback(Receiver { + state: PendingFallback { + fallback_tx: expected_tx, + cause: PendingFallbackCause::Cancelled, + }, + session_context: SHARED_CONTEXT.clone(), + }), + ), + ( + vec![ + SessionEvent::Created(SHARED_CONTEXT.clone()), + SessionEvent::RetrievedOriginalPayload { original, reply_key: None }, + SessionEvent::GotReplyableError(replyable_error), + SessionEvent::Closed(SessionOutcome::Cancel), + ], + ReceiveSession::Closed(SessionOutcome::Cancel), + ), + ]; + + for (events, expected_state) in test_cases { + let persister = InMemoryPersister::::default(); + for event in events { + persister.save_event(event).expect("save should succeed"); + } + let (state, _) = replay_event_log(&persister).expect("replay should succeed"); + assert_eq!(state, expected_state); + } + } + + #[test] + fn replaying_replyable_error_from_unchecked_captures_no_fallback() { + let state = unchecked_proposal_v2_from_test_vector(); + let error = mock_err(); + let session = ReceiveSession::UncheckedOriginalPayload(Receiver { + state, + session_context: SHARED_CONTEXT.clone(), + }); + + let replayed = session + .process_event(SessionEvent::GotReplyableError(error.clone())) + .expect("replyable error should replay"); + + match replayed { + ReceiveSession::HasReplyableError(receiver) => { + assert_eq!(receiver.state.error_reply, error); + assert_eq!(receiver.state.fallback_tx, None); + } + other => panic!("Expected HasReplyableError, got {other:?}"), + } + } + + #[test] + fn replaying_replyable_error_from_initialized_captures_no_fallback() { + let error = mock_err(); + let session = ReceiveSession::Initialized(Receiver { + state: Initialized {}, + session_context: SHARED_CONTEXT.clone(), + }); + + let replayed = session + .process_event(SessionEvent::GotReplyableError(error.clone())) + .expect("replyable error should replay"); + + match replayed { + ReceiveSession::HasReplyableError(receiver) => { + assert_eq!(receiver.state.error_reply, error); + assert_eq!(receiver.state.fallback_tx, None); + } + other => panic!("Expected HasReplyableError, got {other:?}"), + } + } + + #[test] + fn replaying_replyable_error_from_replyable_error_carries_some_fallback() { + let expected_fallback = mock_fallback_tx(); + let error = mock_err(); + let session = ReceiveSession::HasReplyableError(Receiver { + state: HasReplyableError { + error_reply: mock_err(), + fallback_tx: Some(expected_fallback.clone()), + }, + session_context: SHARED_CONTEXT.clone(), + }); + + let replayed = session + .process_event(SessionEvent::GotReplyableError(error.clone())) + .expect("replyable error should replay"); + + match replayed { + ReceiveSession::HasReplyableError(receiver) => { + assert_eq!(receiver.state.error_reply, error); + assert_eq!(receiver.state.fallback_tx, Some(expected_fallback)); + } + other => panic!("Expected HasReplyableError, got {other:?}"), + } + } + + #[test] + fn replaying_replyable_error_from_replyable_error_carries_no_fallback() { + let error = mock_err(); + let session = ReceiveSession::HasReplyableError(Receiver { + state: HasReplyableError { error_reply: mock_err(), fallback_tx: None }, + session_context: SHARED_CONTEXT.clone(), + }); + + let replayed = session + .process_event(SessionEvent::GotReplyableError(error.clone())) + .expect("replyable error should replay"); + + match replayed { + ReceiveSession::HasReplyableError(receiver) => { + assert_eq!(receiver.state.error_reply, error); + assert_eq!(receiver.state.fallback_tx, None); + } + other => panic!("Expected HasReplyableError, got {other:?}"), + } } } diff --git a/payjoin/src/core/receive/v2/session.rs b/payjoin/src/core/receive/v2/session.rs index 7907453a8..6addb628e 100644 --- a/payjoin/src/core/receive/v2/session.rs +++ b/payjoin/src/core/receive/v2/session.rs @@ -166,6 +166,8 @@ impl SessionHistory { SessionOutcome::Failure | SessionOutcome::Cancel => SessionStatus::Failed, SessionOutcome::FallbackBroadcasted => SessionStatus::FallbackBroadcasted, }, + Some(SessionEvent::Cancelled | SessionEvent::ProtocolFailed) => + SessionStatus::PendingFallback, _ => SessionStatus::Active, } } @@ -180,6 +182,7 @@ pub enum SessionStatus { Failed, Completed, FallbackBroadcasted, + PendingFallback, } /// Represents a piece of information that the receiver has obtained from the session @@ -198,6 +201,8 @@ pub enum SessionEvent { FinalizedProposal(bitcoin::Psbt), GotReplyableError(JsonReply), PostedPayjoinProposal(), + Cancelled, + ProtocolFailed, Closed(SessionOutcome), } @@ -229,7 +234,8 @@ mod tests { use crate::receive::tests::original_from_test_vector; use crate::receive::v2::test::{mock_err, SHARED_CONTEXT}; use crate::receive::v2::{ - Initialized, MaybeInputsOwned, ProvisionalProposal, Receiver, UncheckedOriginalPayload, + Initialized, MaybeInputsOwned, PendingFallback, PendingFallbackCause, ProvisionalProposal, + Receiver, UncheckedOriginalPayload, }; use crate::receive::{InternalPayloadError, PayloadError}; @@ -299,6 +305,8 @@ mod tests { SessionEvent::AppliedFeeRange(provisional_proposal.state.psbt_context.clone()), SessionEvent::FinalizedProposal(payjoin_proposal.psbt().clone()), SessionEvent::GotReplyableError(mock_err()), + SessionEvent::Cancelled, + SessionEvent::ProtocolFailed, ]; for event in test_cases { @@ -523,6 +531,72 @@ mod tests { run_session_history_test_async(&test).await; } + #[tokio::test] + async fn replaying_cancelled_session_enters_pending_fallback() { + let session_context = SHARED_CONTEXT.clone(); + let original = original_from_test_vector(); + let reply_key = Some(crate::HpkeKeyPair::gen_keypair().1); + let expected_fallback = original.psbt.clone().extract_tx_unchecked_fee_rate(); + + let test = SessionHistoryTest { + events: vec![ + SessionEvent::Created(session_context.clone()), + SessionEvent::RetrievedOriginalPayload { + original: original.clone(), + reply_key: reply_key.clone(), + }, + SessionEvent::CheckedBroadcastSuitability(), + SessionEvent::Cancelled, + ], + expected_session_history: SessionHistoryExpectedOutcome { + fallback_tx: Some(expected_fallback.clone()), + expected_status: SessionStatus::PendingFallback, + }, + expected_receiver_state: ReceiveSession::PendingFallback(Receiver { + state: PendingFallback { + fallback_tx: expected_fallback, + cause: PendingFallbackCause::Cancelled, + }, + session_context: SessionContext { reply_key, ..session_context }, + }), + }; + run_session_history_test(&test); + run_session_history_test_async(&test).await; + } + + #[tokio::test] + async fn replaying_protocol_failed_session_enters_pending_fallback() { + let session_context = SHARED_CONTEXT.clone(); + let original = original_from_test_vector(); + let reply_key = Some(crate::HpkeKeyPair::gen_keypair().1); + let expected_fallback = original.psbt.clone().extract_tx_unchecked_fee_rate(); + + let test = SessionHistoryTest { + events: vec![ + SessionEvent::Created(session_context.clone()), + SessionEvent::RetrievedOriginalPayload { + original: original.clone(), + reply_key: reply_key.clone(), + }, + SessionEvent::CheckedBroadcastSuitability(), + SessionEvent::ProtocolFailed, + ], + expected_session_history: SessionHistoryExpectedOutcome { + fallback_tx: Some(expected_fallback.clone()), + expected_status: SessionStatus::PendingFallback, + }, + expected_receiver_state: ReceiveSession::PendingFallback(Receiver { + state: PendingFallback { + fallback_tx: expected_fallback, + cause: PendingFallbackCause::ProtocolFailed, + }, + session_context: SessionContext { reply_key, ..session_context }, + }), + }; + run_session_history_test(&test); + run_session_history_test_async(&test).await; + } + #[tokio::test] async fn test_contributed_inputs() { let persister = InMemoryPersister::::default(); diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index a72f44ed6..2696abf87 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -456,7 +456,18 @@ mod integration { let err_bytes = err_response.bytes().await?; has_error.process_error_response(&err_bytes, err_ctx).save(&persister)?; - // Ensure the session is closed properly + // The error response was sent successfully and the source state + // carried a fallback transaction, so the session is now waiting + // on the wallet to acknowledge the fallback obligation. + let (session, session_history) = replay_receiver_event_log(&persister)?; + assert_eq!(session_history.status(), SessionStatus::PendingFallback); + let pending = match session { + ReceiveSession::PendingFallback(r) => r, + _ => panic!("Expected PendingFallback"), + }; + pending.close().save(&persister)?; + + // After the wallet closes, the session terminates with a Failed status. let (_, session_history) = replay_receiver_event_log(&persister)?; assert_eq!(session_history.status(), SessionStatus::Failed);