Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 194 additions & 7 deletions src/embit/psbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ def __init__(self, unknown: dict = {}, vin=None, compress=CompressMode.KEEP_ALL)
self.taproot_sigs = OrderedDict()
self.taproot_scripts = OrderedDict()

# Silent Payments BIP-375 fields
self.sp_ecdh_shares = (
OrderedDict()
) # key: scan_key (bytes), value: share (bytes)
self.sp_dleq_proofs = (
OrderedDict()
) # key: scan_key (bytes), value: proof (bytes)

self.final_scriptsig = None
self.final_scriptwitness = None
self.parse_unknowns()
Expand All @@ -168,6 +176,9 @@ def clear_metadata(self, compress=CompressMode.CLEAR_ALL):
self.taproot_internal_key = None
self.taproot_merkle_root = None
self.taproot_scripts = OrderedDict()
# Clear SP fields
self.sp_ecdh_shares = OrderedDict()
self.sp_dleq_proofs = OrderedDict()

def update(self, other):
self.txid = other.txid or self.txid
Expand All @@ -189,6 +200,9 @@ def update(self, other):
self.taproot_merkle_root = other.taproot_merkle_root or self.taproot_merkle_root
self.taproot_sigs.update(other.taproot_sigs)
self.taproot_scripts.update(other.taproot_scripts)
# Update SP fields
self.sp_ecdh_shares.update(other.sp_ecdh_shares)
self.sp_dleq_proofs.update(other.sp_dleq_proofs)
self.final_scriptsig = other.final_scriptsig or self.final_scriptsig
self.final_scriptwitness = other.final_scriptwitness or self.final_scriptwitness

Expand Down Expand Up @@ -388,6 +402,28 @@ def read_value(self, stream, k):
elif k[0] == 0x18:
self.taproot_merkle_root = v

# PSBT_IN_SP_ECDH_SHARE (BIP-375)
elif k[0] == 0x1D:
if len(k) != 34: # 1 byte type + 33 byte scan key
raise PSBTError("Invalid key length for PSBT_IN_SP_ECDH_SHARE")
scan_key = k[1:]
if scan_key in self.sp_ecdh_shares:
raise PSBTError("Duplicated PSBT_IN_SP_ECDH_SHARE for scan key")
if len(v) != 33:
raise PSBTError("Invalid value length for PSBT_IN_SP_ECDH_SHARE")
self.sp_ecdh_shares[scan_key] = v

# PSBT_IN_SP_DLEQ (BIP-375)
elif k[0] == 0x1E:
if len(k) != 34: # 1 byte type + 33 byte scan key
raise PSBTError("Invalid key length for PSBT_IN_SP_DLEQ")
scan_key = k[1:]
if scan_key in self.sp_dleq_proofs:
raise PSBTError("Duplicated PSBT_IN_SP_DLEQ for scan key")
if len(v) != 64:
raise PSBTError("Invalid value length for PSBT_IN_SP_DLEQ")
self.sp_dleq_proofs[scan_key] = v

else:
if k in self.unknown:
raise PSBTError("Duplicated key")
Expand Down Expand Up @@ -434,6 +470,16 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int:
r += ser_string(stream, b"\x10")
r += ser_string(stream, self.sequence.to_bytes(4, "little"))

# PSBT_IN_SP_ECDH_SHARE (BIP-375)
for scan_key in self.sp_ecdh_shares:
r += ser_string(stream, b"\x1d" + scan_key)
r += ser_string(stream, self.sp_ecdh_shares[scan_key])

# PSBT_IN_SP_DLEQ (BIP-375)
for scan_key in self.sp_dleq_proofs:
r += ser_string(stream, b"\x1e" + scan_key)
r += ser_string(stream, self.sp_dleq_proofs[scan_key])

# PSBT_IN_TAP_SCRIPT_SIG
for pub, leaf in self.taproot_sigs:
r += ser_string(stream, b"\x14" + pub.xonly() + leaf)
Expand Down Expand Up @@ -489,6 +535,9 @@ def __init__(self, unknown: dict = {}, vout=None, compress=CompressMode.KEEP_ALL
self.bip32_derivations = OrderedDict()
self.taproot_bip32_derivations = OrderedDict()
self.taproot_internal_key = None
# Silent Payments BIP-375 fields
self.sp_v0_info = None # tuple(scan_key_bytes, spend_key_bytes)
self.sp_v0_label = None # int(32 bit)
self.parse_unknowns()

def clear_metadata(self, compress=CompressMode.CLEAR_ALL):
Expand All @@ -501,6 +550,9 @@ def clear_metadata(self, compress=CompressMode.CLEAR_ALL):
self.bip32_derivations = OrderedDict()
self.taproot_bip32_derivations = OrderedDict()
self.taproot_internal_key = None
# Clear SP fields
self.sp_v0_info = None
self.sp_v0_label = None

def update(self, other):
self.value = other.value if other.value is not None else self.value
Expand All @@ -511,9 +563,23 @@ def update(self, other):
self.bip32_derivations.update(other.bip32_derivations)
self.taproot_bip32_derivations.update(other.taproot_bip32_derivations)
self.taproot_internal_key = other.taproot_internal_key
# Update SP fields
self.sp_v0_info = other.sp_v0_info or self.sp_v0_info
self.sp_v0_label = (
other.sp_v0_label if other.sp_v0_label is not None else self.sp_v0_label
)

@property
def vout(self):
# If script_pubkey is not set (because it's an SP output not yet computed),
# we cannot construct a valid TransactionOutput yet.
if self.script_pubkey is None and self.sp_v0_info is not None:
# Or raise an error, or return a placeholder?
# Returning None might be safest for now.
return None
if self.value is None or self.script_pubkey is None:
# If value or script_pubkey is missing (and not SP), it's incomplete
return None
return TransactionOutput(self.value, self.script_pubkey)

def read_value(self, stream, k):
Expand Down Expand Up @@ -568,6 +634,26 @@ def read_value(self, stream, k):
der = DerivationPath.read_from(b)
self.taproot_bip32_derivations[pub] = (leaf_hashes, der)

# PSBT_OUT_SP_V0_INFO (BIP-375)
elif k == b"\x09":
if len(k) != 1:
raise PSBTError("Invalid key length for PSBT_OUT_SP_V0_INFO")
if self.sp_v0_info is not None:
raise PSBTError("Duplicated PSBT_OUT_SP_V0_INFO")
if len(v) != 66: # 33 byte scan key + 33 byte spend key
raise PSBTError("Invalid value length for PSBT_OUT_SP_V0_INFO")
self.sp_v0_info = (v[:33], v[33:])

# PSBT_OUT_SP_V0_LABEL (BIP-375)
elif k == b"\x0a":
if len(k) != 1:
raise PSBTError("Invalid key length for PSBT_OUT_SP_V0_LABEL")
if self.sp_v0_label is not None:
raise PSBTError("Duplicated PSBT_OUT_SP_V0_LABEL")
if len(v) != 4:
raise PSBTError("Invalid value length for PSBT_OUT_SP_V0_LABEL")
self.sp_v0_label = int.from_bytes(v, "little")

else:
if k in self.unknown:
raise PSBTError("Duplicated key")
Expand Down Expand Up @@ -609,6 +695,17 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int:
+ derivation.serialize(),
)

# PSBT_OUT_SP_V0_INFO (BIP-375)
if self.sp_v0_info is not None:
r += ser_string(stream, b"\x09")
scan_key, spend_key = self.sp_v0_info
r += ser_string(stream, scan_key + spend_key)

# PSBT_OUT_SP_V0_LABEL (BIP-375)
if self.sp_v0_label is not None:
r += ser_string(stream, b"\x0a")
r += ser_string(stream, self.sp_v0_label.to_bytes(4, "little"))

# unknown
for key in self.unknown:
r += ser_string(stream, key)
Expand Down Expand Up @@ -638,6 +735,13 @@ def __init__(self, tx=None, unknown={}, version=None):

self.unknown = unknown
self.xpubs = OrderedDict()
# Silent Payments BIP-375 fields
self.sp_global_ecdh_shares = (
OrderedDict()
) # key: scan_key (bytes), value: share (bytes)
self.sp_global_dleq_proofs = (
OrderedDict()
) # key: scan_key (bytes), value: proof (bytes)
self.parse_unknowns()

def parse_tx(self, tx):
Expand Down Expand Up @@ -691,6 +795,13 @@ def fee(self):
def write_to(self, stream) -> int:
# magic bytes
r = stream.write(self.MAGIC)
# PSBTv0 uses global tx field, PSBTv2 uses per-input/output fields
# BIP-375 builds on PSBTv2, so we assume version 2 or higher for SP fields
if self.version != 2 and (
self.sp_global_ecdh_shares or self.sp_global_dleq_proofs
):
raise PSBTError("Silent Payment global fields require PSBT version 2")

if self.version != 2:
# unsigned tx flag
r += stream.write(b"\x01\x00")
Expand All @@ -702,6 +813,16 @@ def write_to(self, stream) -> int:
r += ser_string(stream, b"\x01" + xpub.serialize())
r += ser_string(stream, self.xpubs[xpub].serialize())

# PSBT_GLOBAL_SP_ECDH_SHARE (BIP-375)
for scan_key in self.sp_global_ecdh_shares:
r += ser_string(stream, b"\x07" + scan_key)
r += ser_string(stream, self.sp_global_ecdh_shares[scan_key])

# PSBT_GLOBAL_SP_DLEQ (BIP-375)
for scan_key in self.sp_global_dleq_proofs:
r += ser_string(stream, b"\x08" + scan_key)
r += ser_string(stream, self.sp_global_dleq_proofs[scan_key])

if self.version == 2:
if self.tx_version is not None:
r += ser_string(stream, b"\x02")
Expand Down Expand Up @@ -780,6 +901,13 @@ def read_from(cls, stream, compress=CompressMode.KEEP_ALL):
elif key == b"\xfb":
version = int.from_bytes(value, "little")
else:
# Handle potential global SP fields before creating the PSBT object
# Although parse_unknowns will handle them later, checking here avoids issues
# if they appear before the version field in a v0 PSBT.
if version != 2 and key[0] in {0x07, 0x08}:
raise PSBTError(
f"Silent Payment global field {key.hex()} found in non-v2 PSBT"
)
if key in unknown:
raise PSBTError("Duplicated key")
unknown[key] = value
Expand Down Expand Up @@ -824,6 +952,36 @@ def parse_unknowns(self):
for _ in range(compact.from_bytes(self.unknown.pop(k)))
]

# PSBT_GLOBAL_SP_ECDH_SHARE (BIP-375)
elif k[0] == 0x07:
if self.version != 2:
raise PSBTError("PSBT_GLOBAL_SP_ECDH_SHARE requires PSBTv2")
if len(k) != 34: # 1 byte type + 33 byte scan key
raise PSBTError("Invalid key length for PSBT_GLOBAL_SP_ECDH_SHARE")
scan_key = k[1:]
if scan_key in self.sp_global_ecdh_shares:
raise PSBTError("Duplicated PSBT_GLOBAL_SP_ECDH_SHARE for scan key")
value = self.unknown.pop(k)
if len(value) != 33:
raise PSBTError(
"Invalid value length for PSBT_GLOBAL_SP_ECDH_SHARE"
)
self.sp_global_ecdh_shares[scan_key] = value

# PSBT_GLOBAL_SP_DLEQ (BIP-375)
elif k[0] == 0x08:
if self.version != 2:
raise PSBTError("PSBT_GLOBAL_SP_DLEQ requires PSBTv2")
if len(k) != 34: # 1 byte type + 33 byte scan key
raise PSBTError("Invalid key length for PSBT_GLOBAL_SP_DLEQ")
scan_key = k[1:]
if scan_key in self.sp_global_dleq_proofs:
raise PSBTError("Duplicated PSBT_GLOBAL_SP_DLEQ for scan key")
value = self.unknown.pop(k)
if len(value) != 64:
raise PSBTError("Invalid value length for PSBT_GLOBAL_SP_DLEQ")
self.sp_global_dleq_proofs[scan_key] = value

def sighash(self, i, sighash=SIGHASH.ALL, **kwargs):
inp = self.inputs[i]

Expand Down Expand Up @@ -926,6 +1084,24 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int:
If you want to sign with sighashes provided in the PSBT - set sighash=None.
"""
counter = 0 # sigs counter

# Check for Silent Payment outputs
sp_outputs_present = any(out.sp_v0_info is not None for out in self.outputs)

# Check for incompatible inputs if SP outputs are present (BIP-375)
if sp_outputs_present:
for i, inp in enumerate(self.inputs):
script_pubkey = inp.script_pubkey
if (
script_pubkey
and script_pubkey.is_segwit()
and script_pubkey.version > 1
):
raise PSBTError(
f"Input {i} uses Segwit v{script_pubkey.version}, incompatible with Silent Payment outputs (BIP-375)"
)

# Check if root is a descriptor
# check if it's a descriptor, and sign with all private keys in this descriptor
if hasattr(root, "keys"):
for k in root.keys:
Expand Down Expand Up @@ -962,18 +1138,29 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int:
# check which sighash to use
inp_sighash = inp.sighash_type
if inp_sighash is None:
inp_sighash = required_sighash or SIGHASH.DEFAULT
if not inp.is_taproot and inp_sighash == SIGHASH.DEFAULT:
inp_sighash = SIGHASH.ALL
# If SP outputs are present, default MUST be SIGHASH_ALL (BIP-375)
if sp_outputs_present:
inp_sighash = SIGHASH.ALL
else:
inp_sighash = required_sighash or SIGHASH.DEFAULT
if not inp.is_taproot and inp_sighash == SIGHASH.DEFAULT:
inp_sighash = SIGHASH.ALL
elif sp_outputs_present and inp_sighash != SIGHASH.ALL:
# If SP outputs present, only SIGHASH_ALL is allowed (BIP-375)
raise PSBTError(
f"Input {i} has sighash {inp_sighash}, but Silent Payments require SIGHASH_ALL (0x{SIGHASH.ALL:02x})"
)

# if input sighash is set and is different from required sighash
# we don't sign this input
# except DEFAULT is functionally the same as ALL
if required_sighash is not None and inp_sighash != required_sighash:
if inp_sighash not in {
SIGHASH.DEFAULT,
SIGHASH.ALL,
} or required_sighash not in {SIGHASH.DEFAULT, SIGHASH.ALL}:
# Allow DEFAULT and ALL to be interchangeable if not SP
is_interchangeable = not sp_outputs_present and {
inp_sighash,
required_sighash,
}.issubset({SIGHASH.DEFAULT, SIGHASH.ALL})
if not is_interchangeable:
continue

# get all possible derivations with matching fingerprint
Expand Down