Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions watchtower-plugin/src/convert.rs
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why rust fmt is not properly formatting this, but here you have a proper formatted version:

diff --git a/watchtower-plugin/src/convert.rs b/watchtower-plugin/src/convert.rs
index d49889c..b9475de 100644
--- a/watchtower-plugin/src/convert.rs
+++ b/watchtower-plugin/src/convert.rs
@@ -289,20 +289,22 @@ impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
         match value {
             serde_json::Value::Array(a) => {
                 let param_count = a.len();
-                if param_count == 2{
-                    Err(GetRegistrationReceiptError::InvalidFormat(("Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string()))
+                if param_count == 2 {
+                    Err(GetRegistrationReceiptError::InvalidFormat(
+                        "Both ends of boundary (subscription_start and subscription_expiry) are required".to_string()
+                    ))
                 } else if param_count != 1 && param_count != 3 {
                     Err(GetRegistrationReceiptError::InvalidFormat(format!(
                         "Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}"
                     )))
-                } else{
+                } else {
                     let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
                         TowerId::from_str(s).map_err(|_| {
-                        GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned())
+                            GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned())
                         })
                     } else {
                         Err(GetRegistrationReceiptError::InvalidId(
-                        "tower_id must be a hex encoded string".to_owned(),
+                            "tower_id must be a hex encoded string".to_owned(),
                         ))
                     }?;

@@ -311,22 +313,24 @@ impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
                             (Some(start as u32), Some(expire as u32))
                         } else {
                             return Err(GetRegistrationReceiptError::InvalidFormat(
-                                    "Subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(),
-                                    ));
+                                "subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(),
+                            ));
                         }
                     } else if a.get(1).is_some() || a.get(2).is_some() {
                         return Err(GetRegistrationReceiptError::InvalidFormat(
-                                "Subscription_start and subscription_expiry must be provided together as positive integers".to_owned(),
-                                ));
+                             "subscription_start and subscription_expiry must be provided together as positive integers".to_owned(),
+                        ));
                     } else {
                         (None, None)
                     };

-                    Ok(Self {
-                    tower_id,
-                    subscription_start,
-                    subscription_expiry,
-                    })
+                    Ok(
+                        Self {
+                            tower_id,
+                            subscription_start,
+                            subscription_expiry,
+                        }
+                    )
                 }
             },
             serde_json::Value::Object(mut m) => {

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the formatted code. Both rustfmt and cargo fmt are not formatting this code. Did it work for you or you had to format manually ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I had to do it manually, I don't know why rust fmt seems not to properly handle big files with multiple levels of identation :(

Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,117 @@ impl TryFrom<serde_json::Value> for GetAppointmentParams {
}
}

// Errors related to `getregistrationreceipt` command
#[derive(Debug)]
pub enum GetRegistrationReceiptError {
InvalidId(String),
InvalidFormat(String),
}

impl std::fmt::Display for GetRegistrationReceiptError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
GetRegistrationReceiptError::InvalidId(x) => write!(f, "{x}"),
GetRegistrationReceiptError::InvalidFormat(x) => write!(f, "{x}"),
}
}
}
Comment on lines +262 to +275
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no distinction in usage between these two different error variants. I think we can combine them into one.
We might also be able to get rid of them and use an Err instead.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I looked how other methods had this implemented and followed them 😅. Which way do you think would be better, keeping as it is, merging into one error or removing this entirely ? I will make changes accordingly.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I noticed other param parsers are doing the same. I think we can make all of them to use Err(String)||Err(&'static str).

Which way do you think would be better, keeping as it is, merging into one error or removing this entirely ? I will make changes accordingly.

Try to remove it entirely, then we can do the same for other param errors in a follow up if it made sense.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err seems to be working fine, just had to tweak the code where parsing was done. I think I would wait for decision on issue of Vec or Option<vec before final submission, rest changes are done.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why this is here is to make it properly testable. If you look at the other params, there are test suits to make sure the proper error is returned.

I guess we should add them too for GetRegistrationReceipt now

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, got it. I will add tests for GetRegistrationReceipt params there. And in case of #199 (comment), I think it would be better as Some(Vec<>) just because its common way in rust-teos code.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And in case of #199 (comment), I think it would be better as Some(Vec<>) just because its common way in rust-teos code.

This would only make sense if #199 (comment) is carried out. I am fine with both ways, but the None variant has to carry a meaning and not just be there so the interface look similar to other methods.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed


// Parameters related to the `getregistrationreceipt` command
#[derive(Debug)]
pub struct GetRegistrationReceiptParams {
Comment thread
sr-gi marked this conversation as resolved.
pub tower_id: TowerId,
pub subscription_start: Option<u32>,
pub subscription_expiry: Option<u32>,
}

impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
type Error = GetRegistrationReceiptError;

fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
match value {
serde_json::Value::Array(a) => {
let param_count = a.len();
if param_count == 2 {
Err(GetRegistrationReceiptError::InvalidFormat((
"Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string()
))
} else if param_count != 1 && param_count != 3 {
Err(GetRegistrationReceiptError::InvalidFormat(format!(
"Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}"
)))
} else {
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
TowerId::from_str(s).map_err(|_| {
GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned())
})
} else {
Err(GetRegistrationReceiptError::InvalidId(
"tower_id must be a hex encoded string".to_owned(),
))
}?;

let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1), a.get(2)){
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Insert a space after (a.get(1), a.get(2)) before {.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this been fixed?

let start = start.as_i64().ok_or_else(|| {
GetRegistrationReceiptError::InvalidFormat(
"Subscription_start must be a positive integer".to_owned(),
)
})?;
Comment on lines +312 to +316
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want it a positive integer, shouldn't we use as_u64?


let expire = expire.as_i64().ok_or_else(|| {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we name this variable expiry?

GetRegistrationReceiptError::InvalidFormat(
"Subscription_expire must be a positive integer".to_owned(),
)
})?;

if start >= 0 && expire > start {
(Some(start as u32), Some(expire as u32))
} else {
return Err(GetRegistrationReceiptError::InvalidFormat(
"subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't they be the same value? If, for instance, I know what specific block I registered in.

));
}
} else {
(None, None)
};

Ok(
Self {
tower_id,
subscription_start,
subscription_expiry,
}
)
}
},
serde_json::Value::Object(mut m) => {
let allowed_keys = ["tower_id", "subscription_start", "subscription_expiry"];
let param_count = m.len();

if m.is_empty() || param_count > allowed_keys.len() {
Err(GetRegistrationReceiptError::InvalidFormat(format!("Unexpected request format. The request needs 1-3 parameters. Received: {param_count}")))
} else if !m.contains_key(allowed_keys[0]){
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a space here as well, before the {.

Err(GetRegistrationReceiptError::InvalidId(format!("{} is mandatory", allowed_keys[0])))
} else if !m.iter().all(|(k, _)| allowed_keys.contains(&k.as_str())) {
Err(GetRegistrationReceiptError::InvalidFormat("Invalid named parameter found in request".to_owned()))
} else {
let mut params = Vec::with_capacity(allowed_keys.len());
for k in allowed_keys {
if let Some(v) = m.remove(k) {
params.push(v);
}
}

GetRegistrationReceiptParams::try_from(json!(params))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a line break here

}
},
_ => Err(GetRegistrationReceiptError::InvalidFormat(format!(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional parameters are usually expressed in square brackets, so this should read something like:

Unexpected request format. Expected: tower_id [subscription_start] [subscription_expire]. Received: '{value}'

"Unexpected request format. Expected: tower_id [subscription_start] [subscription_expire]. Received: '{value}'"
))),
}
}
}

/// Data associated with a commitment revocation. Represents the data sent by CoreLN through the `commitment_revocation` hook.
#[derive(Debug, Serialize, Deserialize)]
pub struct CommitmentRevocation {
Expand Down
91 changes: 60 additions & 31 deletions watchtower-plugin/src/dbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::iter::FromIterator;
use std::path::PathBuf;
use std::str::FromStr;

use rusqlite::{params, Connection, Error as SqliteError};
use rusqlite::{params, Connection, Error as SqliteError, ToSql};

use bitcoin::secp256k1::SecretKey;

Expand Down Expand Up @@ -209,36 +209,43 @@ impl DBM {
Some(tower)
}

/// Loads the latest registration receipt for a given tower.
///
/// Loads the registration receipt(s) for a given tower in the given subscription range.
/// If no range is given, then loads the latest receipt
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a period after this docstring.

/// Latests is determined by the one with the `subscription_expiry` further into the future.
pub fn load_registration_receipt(
&self,
tower_id: TowerId,
user_id: UserId,
) -> Option<RegistrationReceipt> {
let mut stmt = self
.connection
.prepare(
"SELECT available_slots, subscription_start, subscription_expiry, signature
FROM registration_receipts
WHERE tower_id = ?1 AND subscription_expiry = (SELECT MAX(subscription_expiry)
FROM registration_receipts
WHERE tower_id = ?1)",
)
.unwrap();
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Option<Vec<RegistrationReceipt>> {
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string();

let tower_id_encoded = tower_id.to_vec();
let mut params: Vec<&dyn ToSql> = vec![&tower_id_encoded];

if subscription_expiry.is_none() {
query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a ; at the end of this line.

} else {
query.push_str(" AND subscription_start>=?2 AND subscription_expiry <=?3");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix the spacing around the >= & <= (keep one space on each side).

params.push(&subscription_start);
params.push(&subscription_expiry)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this one as well.

}
let mut stmt = self.connection.prepare(&query).unwrap();

stmt.query_row([tower_id.to_vec()], |row| {
let slots: u32 = row.get(0).unwrap();
let start: u32 = row.get(1).unwrap();
let expiry: u32 = row.get(2).unwrap();
let signature: String = row.get(3).unwrap();
stmt.query_map(params.as_slice(), |row| {
let slots: u32 = row.get(0)?;
let start: u32 = row.get(1)?;
let expiry: u32 = row.get(2)?;
let signature: String = row.get(3)?;

Ok(RegistrationReceipt::with_signature(
user_id, slots, start, expiry, signature,
))
})
.ok()
.unwrap()
.map(|r| r.ok())
.collect()
Comment on lines +246 to +248
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't really get my head around how collect() produces Option<Vec< here 😕,
but aside from that, I see that the failure that's being accounted for in map(|r| r.ok()) is for the de-serialization of the data from the DB (slots, start, ...) which we assume shouldn't fail anyways.

I think we can unwrap while de-serializing directly since this shouldn't actually fail.
and substitute map(|r| r.ok()) with map(|r| r.unwrap()). This will make this method return a Vec< instead.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was previous approach, but then sr-gi suggested to keep it Option<Vec<
ref

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my point here was being able to distinguish between when no data was found because the tower was not even there, and when the range didn't produce any output. However, we are checking that afterward in main.rs by checking if the tower was part of the state. So this may indeed be simplified.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my point here was being able to distinguish between when no data was found because the tower was not even there, and when the range didn't produce any output.

I would be for implementing this kind of logic, but the one implemented in here doesn't actually distinguish between the two cases. It always returns Some(a_possibly_zero_len_vec).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that is my point, that the suggestion was for that, but we really don't do that here, but always return Some(...) and check afterward.

I'm happy with either one or the other.

}

/// Removes a tower record from the database.
Expand Down Expand Up @@ -725,34 +732,49 @@ mod tests {
let tower_id = get_random_user_id();
let net_addr = "talaia.watch";
let receipt = get_random_registration_receipt();
let subscription_start = Some(receipt.subscription_start());
let subscription_expiry = Some(receipt.subscription_expiry());

// Check the receipt was stored
dbm.store_tower_record(tower_id, net_addr, &receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, receipt.user_id())
.unwrap(),
dbm.load_registration_receipt(
tower_id,
receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap()[0],
receipt
);

// Add another receipt for the same tower with a higher expiry and check this last one is loaded
// Add another receipt for the same tower with a higher expiry and check that output gives vector of both receipts
let middle_receipt = get_registration_receipt_from_previous(&receipt);
let latest_receipt = get_registration_receipt_from_previous(&middle_receipt);

let latest_subscription_expiry = Some(latest_receipt.subscription_expiry());

dbm.store_tower_record(tower_id, net_addr, &latest_receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, latest_receipt.user_id())
.unwrap(),
latest_receipt
dbm.load_registration_receipt(
tower_id,
latest_receipt.user_id(),
subscription_start,
latest_subscription_expiry
)
.unwrap(),
vec![receipt, latest_receipt.clone()]
);

// Add a final one with a lower expiry and check the last is still loaded
// Add a final one with a lower expiry and check if the lastest receipt is loaded when boundry
// params are not passed
dbm.store_tower_record(tower_id, net_addr, &middle_receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, latest_receipt.user_id())
.unwrap(),
dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None)
.unwrap()[0],
latest_receipt
);
}
Expand All @@ -765,13 +787,20 @@ mod tests {
let tower_id = get_random_user_id();
let net_addr = "talaia.watch";
let receipt = get_random_registration_receipt();
let subscription_start = Some(receipt.subscription_start());
let subscription_expiry = Some(receipt.subscription_expiry());

// Store it once
dbm.store_tower_record(tower_id, net_addr, &receipt)
.unwrap();
assert_eq!(
dbm.load_registration_receipt(tower_id, receipt.user_id())
.unwrap(),
dbm.load_registration_receipt(
tower_id,
receipt.user_id(),
subscription_start,
subscription_expiry
)
.unwrap()[0],
receipt
);

Expand Down
31 changes: 22 additions & 9 deletions watchtower-plugin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use teos_common::protos as common_msgs;
use teos_common::TowerId;
use teos_common::{cryptography, errors};

use watchtower_plugin::convert::{CommitmentRevocation, GetAppointmentParams, RegisterParams};
use watchtower_plugin::convert::{
CommitmentRevocation, GetAppointmentParams, GetRegistrationReceiptParams, RegisterParams,
};
use watchtower_plugin::net::http::{
self, get_request, post_request, process_post_response, AddAppointmentError, ApiResponse,
RequestError,
Expand Down Expand Up @@ -127,22 +129,33 @@ async fn register(
Ok(json!(receipt))
}

/// Gets the latest registration receipt from the client to a given tower (if it exists).
///
/// Gets the registration receipt(s) from the client to a given tower (if it exists) in the given
/// range. If no range is given, then gets the latest registration receipt.
/// This is pulled from the database
async fn get_registration_receipt(
plugin: Plugin<Arc<Mutex<WTClient>>>,
v: serde_json::Value,
) -> Result<serde_json::Value, Error> {
let tower_id = TowerId::try_from(v).map_err(|x| anyhow!(x))?;
let params = GetRegistrationReceiptParams::try_from(v).map_err(|x| anyhow!(x))?;
let tower_id = params.tower_id;
let subscription_start = params.subscription_start;
let subscription_expiry = params.subscription_expiry;
let state = plugin.state().lock().unwrap();

if let Some(response) = state.get_registration_receipt(tower_id) {
Ok(json!(response))
let response =
state.get_registration_receipt(tower_id, subscription_start, subscription_expiry);
if response.clone().unwrap().is_empty() {
if state.towers.contains_key(&tower_id) {
Err(anyhow!(
"No registration receipt found for {tower_id} on the given range"
))
} else {
Err(anyhow!(
"Cannot find {tower_id} within the known towers. Have you registered?"
))
}
Comment on lines +147 to +156
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really grasp this part. why is there an .unwrap()?!
response is an optional vector, so Option::None & an empty vector signal different meanings? right?

I guess Option::None is when the tower isn't found and an empty vector is when the registration range doesn't contain any receipts?
If so, then the .clone().unwrap() part will panic with no useful information to the user.

Otherwise, we don't need the None variant anymore.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that is actually what i think I was suggesting for, however, checking DBM::load_registration_receipt I don't think we ever produce a None return.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be good to test this out by creating a test that queries a tower that is not part of the DB and see whether the method is returning an empty vec or None

} else {
Err(anyhow!(
"Cannot find {tower_id} within the known towers. Have you registered?"
))
Ok(json!(response))
}
}

Expand Down
17 changes: 14 additions & 3 deletions watchtower-plugin/src/wt_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,20 @@ impl WTClient {
Ok(())
}

/// Gets the latest registration receipt of a given tower.
pub fn get_registration_receipt(&self, tower_id: TowerId) -> Option<RegistrationReceipt> {
self.dbm.load_registration_receipt(tower_id, self.user_id)
/// Gets the registration receipt(s) of a given tower in the given range.
/// If no range is given then gets the latest registration receipt
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end this docstring with a period.

pub fn get_registration_receipt(
&self,
tower_id: TowerId,
subscription_start: Option<u32>,
subscription_expiry: Option<u32>,
) -> Option<Vec<RegistrationReceipt>> {
self.dbm.load_registration_receipt(
tower_id,
self.user_id,
subscription_start,
subscription_expiry,
)
}

/// Loads a tower record from the database.
Expand Down