diff --git a/src/embit/psbt.py b/src/embit/psbt.py index 54497c2..ce62ca2 100644 --- a/src/embit/psbt.py +++ b/src/embit/psbt.py @@ -145,6 +145,14 @@ def __init__(self, unknown: dict = {}, vin=None, compress=CompressMode.KEEP_ALL) self.taproot_sigs = OrderedDict() self.taproot_scripts = OrderedDict() + # Silent Payments BIP-375 fields + self.sp_ecdh_shares = ( + OrderedDict() + ) # key: scan_key (bytes), value: share (bytes) + self.sp_dleq_proofs = ( + OrderedDict() + ) # key: scan_key (bytes), value: proof (bytes) + self.final_scriptsig = None self.final_scriptwitness = None self.parse_unknowns() @@ -168,6 +176,9 @@ def clear_metadata(self, compress=CompressMode.CLEAR_ALL): self.taproot_internal_key = None self.taproot_merkle_root = None self.taproot_scripts = OrderedDict() + # Clear SP fields + self.sp_ecdh_shares = OrderedDict() + self.sp_dleq_proofs = OrderedDict() def update(self, other): self.txid = other.txid or self.txid @@ -189,6 +200,9 @@ def update(self, other): self.taproot_merkle_root = other.taproot_merkle_root or self.taproot_merkle_root self.taproot_sigs.update(other.taproot_sigs) self.taproot_scripts.update(other.taproot_scripts) + # Update SP fields + self.sp_ecdh_shares.update(other.sp_ecdh_shares) + self.sp_dleq_proofs.update(other.sp_dleq_proofs) self.final_scriptsig = other.final_scriptsig or self.final_scriptsig self.final_scriptwitness = other.final_scriptwitness or self.final_scriptwitness @@ -388,6 +402,28 @@ def read_value(self, stream, k): elif k[0] == 0x18: self.taproot_merkle_root = v + # PSBT_IN_SP_ECDH_SHARE (BIP-375) + elif k[0] == 0x1D: + if len(k) != 34: # 1 byte type + 33 byte scan key + raise PSBTError("Invalid key length for PSBT_IN_SP_ECDH_SHARE") + scan_key = k[1:] + if scan_key in self.sp_ecdh_shares: + raise PSBTError("Duplicated PSBT_IN_SP_ECDH_SHARE for scan key") + if len(v) != 33: + raise PSBTError("Invalid value length for PSBT_IN_SP_ECDH_SHARE") + self.sp_ecdh_shares[scan_key] = v + + # PSBT_IN_SP_DLEQ (BIP-375) + elif k[0] == 0x1E: + if len(k) != 34: # 1 byte type + 33 byte scan key + raise PSBTError("Invalid key length for PSBT_IN_SP_DLEQ") + scan_key = k[1:] + if scan_key in self.sp_dleq_proofs: + raise PSBTError("Duplicated PSBT_IN_SP_DLEQ for scan key") + if len(v) != 64: + raise PSBTError("Invalid value length for PSBT_IN_SP_DLEQ") + self.sp_dleq_proofs[scan_key] = v + else: if k in self.unknown: raise PSBTError("Duplicated key") @@ -434,6 +470,16 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int: r += ser_string(stream, b"\x10") r += ser_string(stream, self.sequence.to_bytes(4, "little")) + # PSBT_IN_SP_ECDH_SHARE (BIP-375) + for scan_key in self.sp_ecdh_shares: + r += ser_string(stream, b"\x1d" + scan_key) + r += ser_string(stream, self.sp_ecdh_shares[scan_key]) + + # PSBT_IN_SP_DLEQ (BIP-375) + for scan_key in self.sp_dleq_proofs: + r += ser_string(stream, b"\x1e" + scan_key) + r += ser_string(stream, self.sp_dleq_proofs[scan_key]) + # PSBT_IN_TAP_SCRIPT_SIG for pub, leaf in self.taproot_sigs: r += ser_string(stream, b"\x14" + pub.xonly() + leaf) @@ -489,6 +535,9 @@ def __init__(self, unknown: dict = {}, vout=None, compress=CompressMode.KEEP_ALL self.bip32_derivations = OrderedDict() self.taproot_bip32_derivations = OrderedDict() self.taproot_internal_key = None + # Silent Payments BIP-375 fields + self.sp_v0_info = None # tuple(scan_key_bytes, spend_key_bytes) + self.sp_v0_label = None # int(32 bit) self.parse_unknowns() def clear_metadata(self, compress=CompressMode.CLEAR_ALL): @@ -501,6 +550,9 @@ def clear_metadata(self, compress=CompressMode.CLEAR_ALL): self.bip32_derivations = OrderedDict() self.taproot_bip32_derivations = OrderedDict() self.taproot_internal_key = None + # Clear SP fields + self.sp_v0_info = None + self.sp_v0_label = None def update(self, other): self.value = other.value if other.value is not None else self.value @@ -511,9 +563,23 @@ def update(self, other): self.bip32_derivations.update(other.bip32_derivations) self.taproot_bip32_derivations.update(other.taproot_bip32_derivations) self.taproot_internal_key = other.taproot_internal_key + # Update SP fields + self.sp_v0_info = other.sp_v0_info or self.sp_v0_info + self.sp_v0_label = ( + other.sp_v0_label if other.sp_v0_label is not None else self.sp_v0_label + ) @property def vout(self): + # If script_pubkey is not set (because it's an SP output not yet computed), + # we cannot construct a valid TransactionOutput yet. + if self.script_pubkey is None and self.sp_v0_info is not None: + # Or raise an error, or return a placeholder? + # Returning None might be safest for now. + return None + if self.value is None or self.script_pubkey is None: + # If value or script_pubkey is missing (and not SP), it's incomplete + return None return TransactionOutput(self.value, self.script_pubkey) def read_value(self, stream, k): @@ -568,6 +634,26 @@ def read_value(self, stream, k): der = DerivationPath.read_from(b) self.taproot_bip32_derivations[pub] = (leaf_hashes, der) + # PSBT_OUT_SP_V0_INFO (BIP-375) + elif k == b"\x09": + if len(k) != 1: + raise PSBTError("Invalid key length for PSBT_OUT_SP_V0_INFO") + if self.sp_v0_info is not None: + raise PSBTError("Duplicated PSBT_OUT_SP_V0_INFO") + if len(v) != 66: # 33 byte scan key + 33 byte spend key + raise PSBTError("Invalid value length for PSBT_OUT_SP_V0_INFO") + self.sp_v0_info = (v[:33], v[33:]) + + # PSBT_OUT_SP_V0_LABEL (BIP-375) + elif k == b"\x0a": + if len(k) != 1: + raise PSBTError("Invalid key length for PSBT_OUT_SP_V0_LABEL") + if self.sp_v0_label is not None: + raise PSBTError("Duplicated PSBT_OUT_SP_V0_LABEL") + if len(v) != 4: + raise PSBTError("Invalid value length for PSBT_OUT_SP_V0_LABEL") + self.sp_v0_label = int.from_bytes(v, "little") + else: if k in self.unknown: raise PSBTError("Duplicated key") @@ -609,6 +695,17 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int: + derivation.serialize(), ) + # PSBT_OUT_SP_V0_INFO (BIP-375) + if self.sp_v0_info is not None: + r += ser_string(stream, b"\x09") + scan_key, spend_key = self.sp_v0_info + r += ser_string(stream, scan_key + spend_key) + + # PSBT_OUT_SP_V0_LABEL (BIP-375) + if self.sp_v0_label is not None: + r += ser_string(stream, b"\x0a") + r += ser_string(stream, self.sp_v0_label.to_bytes(4, "little")) + # unknown for key in self.unknown: r += ser_string(stream, key) @@ -638,6 +735,13 @@ def __init__(self, tx=None, unknown={}, version=None): self.unknown = unknown self.xpubs = OrderedDict() + # Silent Payments BIP-375 fields + self.sp_global_ecdh_shares = ( + OrderedDict() + ) # key: scan_key (bytes), value: share (bytes) + self.sp_global_dleq_proofs = ( + OrderedDict() + ) # key: scan_key (bytes), value: proof (bytes) self.parse_unknowns() def parse_tx(self, tx): @@ -691,6 +795,13 @@ def fee(self): def write_to(self, stream) -> int: # magic bytes r = stream.write(self.MAGIC) + # PSBTv0 uses global tx field, PSBTv2 uses per-input/output fields + # BIP-375 builds on PSBTv2, so we assume version 2 or higher for SP fields + if self.version != 2 and ( + self.sp_global_ecdh_shares or self.sp_global_dleq_proofs + ): + raise PSBTError("Silent Payment global fields require PSBT version 2") + if self.version != 2: # unsigned tx flag r += stream.write(b"\x01\x00") @@ -702,6 +813,16 @@ def write_to(self, stream) -> int: r += ser_string(stream, b"\x01" + xpub.serialize()) r += ser_string(stream, self.xpubs[xpub].serialize()) + # PSBT_GLOBAL_SP_ECDH_SHARE (BIP-375) + for scan_key in self.sp_global_ecdh_shares: + r += ser_string(stream, b"\x07" + scan_key) + r += ser_string(stream, self.sp_global_ecdh_shares[scan_key]) + + # PSBT_GLOBAL_SP_DLEQ (BIP-375) + for scan_key in self.sp_global_dleq_proofs: + r += ser_string(stream, b"\x08" + scan_key) + r += ser_string(stream, self.sp_global_dleq_proofs[scan_key]) + if self.version == 2: if self.tx_version is not None: r += ser_string(stream, b"\x02") @@ -780,6 +901,13 @@ def read_from(cls, stream, compress=CompressMode.KEEP_ALL): elif key == b"\xfb": version = int.from_bytes(value, "little") else: + # Handle potential global SP fields before creating the PSBT object + # Although parse_unknowns will handle them later, checking here avoids issues + # if they appear before the version field in a v0 PSBT. + if version != 2 and key[0] in {0x07, 0x08}: + raise PSBTError( + f"Silent Payment global field {key.hex()} found in non-v2 PSBT" + ) if key in unknown: raise PSBTError("Duplicated key") unknown[key] = value @@ -824,6 +952,36 @@ def parse_unknowns(self): for _ in range(compact.from_bytes(self.unknown.pop(k))) ] + # PSBT_GLOBAL_SP_ECDH_SHARE (BIP-375) + elif k[0] == 0x07: + if self.version != 2: + raise PSBTError("PSBT_GLOBAL_SP_ECDH_SHARE requires PSBTv2") + if len(k) != 34: # 1 byte type + 33 byte scan key + raise PSBTError("Invalid key length for PSBT_GLOBAL_SP_ECDH_SHARE") + scan_key = k[1:] + if scan_key in self.sp_global_ecdh_shares: + raise PSBTError("Duplicated PSBT_GLOBAL_SP_ECDH_SHARE for scan key") + value = self.unknown.pop(k) + if len(value) != 33: + raise PSBTError( + "Invalid value length for PSBT_GLOBAL_SP_ECDH_SHARE" + ) + self.sp_global_ecdh_shares[scan_key] = value + + # PSBT_GLOBAL_SP_DLEQ (BIP-375) + elif k[0] == 0x08: + if self.version != 2: + raise PSBTError("PSBT_GLOBAL_SP_DLEQ requires PSBTv2") + if len(k) != 34: # 1 byte type + 33 byte scan key + raise PSBTError("Invalid key length for PSBT_GLOBAL_SP_DLEQ") + scan_key = k[1:] + if scan_key in self.sp_global_dleq_proofs: + raise PSBTError("Duplicated PSBT_GLOBAL_SP_DLEQ for scan key") + value = self.unknown.pop(k) + if len(value) != 64: + raise PSBTError("Invalid value length for PSBT_GLOBAL_SP_DLEQ") + self.sp_global_dleq_proofs[scan_key] = value + def sighash(self, i, sighash=SIGHASH.ALL, **kwargs): inp = self.inputs[i] @@ -926,6 +1084,24 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int: If you want to sign with sighashes provided in the PSBT - set sighash=None. """ counter = 0 # sigs counter + + # Check for Silent Payment outputs + sp_outputs_present = any(out.sp_v0_info is not None for out in self.outputs) + + # Check for incompatible inputs if SP outputs are present (BIP-375) + if sp_outputs_present: + for i, inp in enumerate(self.inputs): + script_pubkey = inp.script_pubkey + if ( + script_pubkey + and script_pubkey.is_segwit() + and script_pubkey.version > 1 + ): + raise PSBTError( + f"Input {i} uses Segwit v{script_pubkey.version}, incompatible with Silent Payment outputs (BIP-375)" + ) + + # Check if root is a descriptor # check if it's a descriptor, and sign with all private keys in this descriptor if hasattr(root, "keys"): for k in root.keys: @@ -962,18 +1138,29 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int: # check which sighash to use inp_sighash = inp.sighash_type if inp_sighash is None: - inp_sighash = required_sighash or SIGHASH.DEFAULT - if not inp.is_taproot and inp_sighash == SIGHASH.DEFAULT: - inp_sighash = SIGHASH.ALL + # If SP outputs are present, default MUST be SIGHASH_ALL (BIP-375) + if sp_outputs_present: + inp_sighash = SIGHASH.ALL + else: + inp_sighash = required_sighash or SIGHASH.DEFAULT + if not inp.is_taproot and inp_sighash == SIGHASH.DEFAULT: + inp_sighash = SIGHASH.ALL + elif sp_outputs_present and inp_sighash != SIGHASH.ALL: + # If SP outputs present, only SIGHASH_ALL is allowed (BIP-375) + raise PSBTError( + f"Input {i} has sighash {inp_sighash}, but Silent Payments require SIGHASH_ALL (0x{SIGHASH.ALL:02x})" + ) # if input sighash is set and is different from required sighash # we don't sign this input # except DEFAULT is functionally the same as ALL if required_sighash is not None and inp_sighash != required_sighash: - if inp_sighash not in { - SIGHASH.DEFAULT, - SIGHASH.ALL, - } or required_sighash not in {SIGHASH.DEFAULT, SIGHASH.ALL}: + # Allow DEFAULT and ALL to be interchangeable if not SP + is_interchangeable = not sp_outputs_present and { + inp_sighash, + required_sighash, + }.issubset({SIGHASH.DEFAULT, SIGHASH.ALL}) + if not is_interchangeable: continue # get all possible derivations with matching fingerprint