Skip to content

Commit b18cc28

Browse files
committed
Add high-level helpers for using Musig2 with Taproot
When using Musig2 for a taproot key path, we can provide simpler helper functions to collaboratively build a shared signature for the spending transaction. This hides all of the low-level details of how the musig2 algorithm works, by exposing a subset of what can be done that is sufficient for spending taproot inputs.
1 parent 76d89a4 commit b18cc28

File tree

3 files changed

+88
-71
lines changed

3 files changed

+88
-71
lines changed

src/commonMain/kotlin/fr/acinq/bitcoin/ByteVector.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package fr.acinq.bitcoin
1818

1919
import fr.acinq.secp256k1.Hex
2020
import kotlin.experimental.or
21+
import kotlin.experimental.xor
2122
import kotlin.jvm.JvmField
2223
import kotlin.jvm.JvmStatic
2324

@@ -151,6 +152,15 @@ public class ByteVector32(bytes: ByteArray, offset: Int) : ByteVector(bytes, off
151152

152153
@JvmStatic
153154
public fun fromValidHex(input: String): ByteVector32 = ByteVector32(input)
155+
156+
@JvmStatic
157+
public fun xor(a: ByteVector32, b: ByteVector32): ByteVector32 {
158+
val result = ByteArray(32)
159+
for (i in 0..31) {
160+
result[i] = a[i].xor(b[i])
161+
}
162+
return result.byteVector32()
163+
}
154164
}
155165
}
156166

src/commonMain/kotlin/fr/acinq/bitcoin/musig2/Musig2.kt

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@ import fr.acinq.bitcoin.*
44
import fr.acinq.bitcoin.crypto.Pack
55
import fr.acinq.secp256k1.Hex
66
import fr.acinq.secp256k1.Secp256k1
7-
import kotlin.experimental.xor
87
import kotlin.jvm.JvmStatic
98

10-
119
/**
1210
* Key Aggregation Context
1311
* Holds a public key aggregate that can optionally be tweaked
@@ -16,6 +14,9 @@ import kotlin.jvm.JvmStatic
1614
* @param tacc tweak accumulator
1715
*/
1816
public data class KeyAggCtx(val Q: PublicKey, val gacc: Boolean, val tacc: ByteVector32) {
17+
public constructor(Q: PublicKey) : this(Q, true, ByteVector32.Zeroes)
18+
public constructor(pubkeys: List<PublicKey>) : this(Musig2.keyAgg(pubkeys))
19+
1920
public fun tweak(tweak: ByteVector32, isXonly: Boolean): KeyAggCtx {
2021
require(tweak == ByteVector32.Zeroes || PrivateKey(tweak).isValid()) { "invalid tweak" }
2122
return if (isXonly && !Q.isEven()) {
@@ -30,15 +31,37 @@ public data class KeyAggCtx(val Q: PublicKey, val gacc: Boolean, val tacc: ByteV
3031

3132
public object Musig2 {
3233
@JvmStatic
33-
public fun keyAgg(pubkeys: List<PublicKey>): KeyAggCtx {
34+
public fun keyAgg(pubkeys: List<PublicKey>): PublicKey {
3435
val pk2 = getSecondKey(pubkeys)
3536
val a = pubkeys.map { keyAggCoeffInternal(pubkeys, it, pk2) }
36-
val Q = pubkeys.zip(a).map { it.first.times(PrivateKey(it.second)) }.reduce { p1, p2 -> p1 + p2 }
37-
return KeyAggCtx(Q, true, ByteVector32.Zeroes)
37+
return pubkeys.zip(a).map { it.first.times(PrivateKey(it.second)) }.reduce { p1, p2 -> p1 + p2 }
3838
}
3939

4040
@JvmStatic
4141
public fun keySort(pubkeys: List<PublicKey>): List<PublicKey> = pubkeys.sortedWith { a, b -> LexicographicalOrdering.compare(a, b) }
42+
43+
private fun taprootSessionCtx(tx: Transaction, inputIndex: Int, inputs: List<TxOut>, pubkeys: List<PublicKey>, publicNonces: List<IndividualNonce>, scriptTree: ScriptTree?): SessionCtx {
44+
val aggregatedNonce = IndividualNonce.aggregate(publicNonces)
45+
val aggregatedKey = keyAgg(pubkeys).xOnly()
46+
val tweak = when (scriptTree) {
47+
null -> Pair(aggregatedKey.tweak(Crypto.TaprootTweak.NoScriptTweak), true)
48+
else -> Pair(aggregatedKey.tweak(Crypto.TaprootTweak.ScriptTweak(scriptTree)), true)
49+
}
50+
val txHash = Transaction.hashForSigningTaprootKeyPath(tx, inputIndex, inputs, SigHash.SIGHASH_DEFAULT)
51+
return SessionCtx(aggregatedNonce, pubkeys, listOf(tweak), txHash)
52+
}
53+
54+
@JvmStatic
55+
public fun signTaprootInput(privateKey: PrivateKey, tx: Transaction, inputIndex: Int, inputs: List<TxOut>, pubkeys: List<PublicKey>, secretNonce: SecretNonce, publicNonces: List<IndividualNonce>, scriptTree: ScriptTree?): ByteVector32? {
56+
val ctx = taprootSessionCtx(tx, inputIndex, inputs, pubkeys, publicNonces, scriptTree)
57+
return ctx.sign(secretNonce, privateKey)
58+
}
59+
60+
@JvmStatic
61+
public fun aggregateTaprootSignatures(partialSigs: List<ByteVector32>, tx: Transaction, inputIndex: Int, inputs: List<TxOut>, pubkeys: List<PublicKey>, publicNonces: List<IndividualNonce>, scriptTree: ScriptTree?): ByteVector64? {
62+
val ctx = taprootSessionCtx(tx, inputIndex, inputs, pubkeys, publicNonces, scriptTree)
63+
return ctx.partialSigAgg(partialSigs)
64+
}
4265
}
4366

4467
/**
@@ -70,17 +93,8 @@ public data class SecretNonce(val data: ByteVector) {
7093
*/
7194
@JvmStatic
7295
public fun generate(sk: PrivateKey?, pk: PublicKey, aggpk: XonlyPublicKey?, msg: ByteArray?, extraInput: ByteArray?, randprime: ByteVector32): SecretNonce {
73-
74-
fun xor(a: ByteVector32, b: ByteVector32): ByteVector32 {
75-
val result = ByteArray(32)
76-
for (i in 0..31) {
77-
result[i] = a[i].xor(b[i])
78-
}
79-
return result.byteVector32()
80-
}
81-
8296
val rand = if (sk != null) {
83-
xor(sk.value, Crypto.taggedHash(randprime.toByteArray(), "MuSig/aux"))
97+
ByteVector32.xor(sk.value, Crypto.taggedHash(randprime.toByteArray(), "MuSig/aux"))
8498
} else {
8599
randprime
86100
}
@@ -102,6 +116,11 @@ public data class SecretNonce(val data: ByteVector) {
102116
val secnonce = SecretNonce(PrivateKey(k1).value + PrivateKey(k2).value + pk.value)
103117
return secnonce
104118
}
119+
120+
@JvmStatic
121+
public fun generate(sk: PrivateKey, aggregatedKey: XonlyPublicKey, rand: ByteVector32): SecretNonce {
122+
return generate(sk, sk.publicKey(), aggregatedKey, null, null, rand)
123+
}
105124
}
106125
}
107126

@@ -192,7 +211,6 @@ internal fun add(a: PublicKey?, b: PublicKey?): PublicKey? = when {
192211
else -> a + b
193212
}
194213

195-
196214
internal fun mul(a: PublicKey?, b: PrivateKey): PublicKey? = a?.times(b)
197215

198216
/**
@@ -204,7 +222,7 @@ internal fun mul(a: PublicKey?, b: PrivateKey): PublicKey? = a?.times(b)
204222
*/
205223
public data class SessionCtx(val aggnonce: AggregatedNonce, val pubkeys: List<PublicKey>, val tweaks: List<Pair<ByteVector32, Boolean>>, val message: ByteVector) {
206224
private fun build(): SessionValues {
207-
val keyAggCtx0 = Musig2.keyAgg(pubkeys)
225+
val keyAggCtx0 = KeyAggCtx(pubkeys)
208226
val keyAggCtx = tweaks.fold(keyAggCtx0) { ctx, tweak -> ctx.tweak(tweak.first, tweak.second) }
209227
val (Q, gacc, tacc) = keyAggCtx
210228
val b = PrivateKey(Crypto.taggedHash((aggnonce.toByteArray().byteVector() + Q.xOnly().value + message).toByteArray(), "MuSig/noncecoef"))
@@ -221,7 +239,7 @@ public data class SessionCtx(val aggnonce: AggregatedNonce, val pubkeys: List<Pu
221239
/**
222240
* @param secnonce secret nonce
223241
* @param sk private key
224-
* @return a Musig2 partial signature, or null if the nonce does not match the private key or the partial signature cannot be verified
242+
* @return a Musig2 partial signature, or null if the nonce does not match the private key or the partial signature cannot be verified
225243
*/
226244
public fun sign(secnonce: SecretNonce, sk: PrivateKey): ByteVector32? = runCatching {
227245
val (Q, gacc, _, b, R, e) = build()

src/commonTest/kotlin/fr/acinq/bitcoin/Musig2TestsCommon.kt

Lines changed: 42 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ class Musig2TestsCommon {
2929
val keyIndices = it.jsonObject["key_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
3030
val expected = XonlyPublicKey(ByteVector32.fromValidHex(it.jsonObject["expected"]!!.jsonPrimitive.content))
3131
val ctx = Musig2.keyAgg(keyIndices.map { pubkeys[it] })
32-
assertEquals(expected, ctx.Q.xOnly())
32+
assertEquals(expected, ctx.xOnly())
3333
}
3434
tests.jsonObject["error_test_cases"]!!.jsonArray.forEach {
3535
val keyIndices = it.jsonObject["key_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
3636
val tweakIndices = it.jsonObject["tweak_indices"]!!.jsonArray.map { it.jsonPrimitive.int }
3737
val isXonly = it.jsonObject["is_xonly"]!!.jsonArray.map { it.jsonPrimitive.boolean }
3838
assertFails {
39-
var ctx = Musig2.keyAgg(keyIndices.map { pubkeys[it] })
39+
var ctx = KeyAggCtx(keyIndices.map { pubkeys[it] })
4040
tweakIndices.zip(isXonly).forEach { ctx = ctx.tweak(tweaks[it.first], it.second) }
4141
}
4242
}
@@ -267,7 +267,7 @@ class Musig2TestsCommon {
267267
}
268268

269269
// aggregate public keys
270-
val aggpub = Musig2.keyAgg(pubkeys)
270+
val aggpub = KeyAggCtx(pubkeys)
271271
.tweak(plainTweak, false)
272272
.tweak(xonlyTweak, true)
273273

@@ -277,51 +277,45 @@ class Musig2TestsCommon {
277277

278278
@Test
279279
fun `use musig2 to replace multisig 2-of-2`() {
280+
val random = Random.Default
280281
val alicePrivKey = PrivateKey(ByteArray(32) { 1 })
281282
val alicePubKey = alicePrivKey.publicKey()
282283
val bobPrivKey = PrivateKey(ByteArray(32) { 2 })
283284
val bobPubKey = bobPrivKey.publicKey()
284285

285-
// Alice and Bob exchange public keys and agree on a common aggregated key
286-
val internalPubKey = Musig2.keyAgg(listOf(alicePubKey, bobPubKey)).Q.xOnly()
287-
// we use the standard BIP86 tweak
288-
val commonPubKey = internalPubKey.outputKey(Crypto.TaprootTweak.NoScriptTweak).first
289-
290-
// this tx sends to a standard p2tr(commonPubKey) script
291-
val tx = Transaction(2, listOf(), listOf(TxOut(Satoshi(10000), Script.pay2tr(commonPubKey))), 0)
286+
// Alice and Bob exchange public keys and agree on a common aggregated key.
287+
val aggregatedKey = Musig2.keyAgg(listOf(alicePubKey, bobPubKey)).xOnly()
288+
// This tx sends to a taproot script that doesn't contain any script path.
289+
val tx = Transaction(2, listOf(), listOf(TxOut(Satoshi(10000), Script.pay2tr(aggregatedKey, scripts = null))), 0)
292290

293291
// this is how Alice and Bob would spend that tx
294292
val spendingTx = Transaction(2, listOf(TxIn(OutPoint(tx, 0), sequence = 0)), listOf(TxOut(Satoshi(10000), Script.pay2wpkh(alicePubKey))), 0)
295-
296-
val commonSig = run {
297-
val random = Random.Default
298-
val aliceNonce = SecretNonce.generate(alicePrivKey, alicePubKey, commonPubKey, null, null, random.nextBytes(32).byteVector32())
299-
val bobNonce = SecretNonce.generate(bobPrivKey, bobPubKey, commonPubKey, null, null, random.nextBytes(32).byteVector32())
300-
301-
val aggnonce = IndividualNonce.aggregate(listOf(aliceNonce.publicNonce(), bobNonce.publicNonce()))
302-
val msg = Transaction.hashForSigningTaprootKeyPath(spendingTx, 0, listOf(tx.txOut[0]), SigHash.SIGHASH_DEFAULT)
303-
304-
// we use the same ctx for Alice and Bob, they both know all the public keys that are used here
305-
val ctx = SessionCtx(
306-
aggnonce,
307-
listOf(alicePubKey, bobPubKey),
308-
listOf(Pair(internalPubKey.tweak(Crypto.TaprootTweak.NoScriptTweak), true)),
309-
msg
310-
)
311-
val aliceSig = ctx.sign(aliceNonce, alicePrivKey)!!
312-
val bobSig = ctx.sign(bobNonce, bobPrivKey)!!
313-
ctx.partialSigAgg(listOf(aliceSig, bobSig))!!
293+
val sig = run {
294+
// The first step of a musig2 signing session is to exchange nonces.
295+
// If participants are disconnected before the end of the signing session, they must start again with fresh nonces.
296+
val aliceNonce = SecretNonce.generate(alicePrivKey, aggregatedKey, random.nextBytes(32).byteVector32())
297+
val bobNonce = SecretNonce.generate(bobPrivKey, aggregatedKey, random.nextBytes(32).byteVector32())
298+
299+
// Once they have each other's public nonce, they can produce partial signatures.
300+
val publicNonces = listOf(aliceNonce.publicNonce(), bobNonce.publicNonce())
301+
val aliceSig = Musig2.signTaprootInput(alicePrivKey, spendingTx, 0, tx.txOut, listOf(alicePubKey, bobPubKey), aliceNonce, publicNonces, scriptTree = null)!!
302+
val bobSig = Musig2.signTaprootInput(bobPrivKey, spendingTx, 0, tx.txOut, listOf(alicePubKey, bobPubKey), bobNonce, publicNonces, scriptTree = null)!!
303+
304+
// Once they have each other's partial signature, they can aggregate them into a valid signature.
305+
Musig2.aggregateTaprootSignatures(listOf(aliceSig, bobSig), spendingTx, 0, tx.txOut, listOf(alicePubKey, bobPubKey), publicNonces, scriptTree = null)!!
314306
}
315307

316-
// this tx looks like any other tx that spends a p2tr output, with a single signature
317-
val signedSpendingTx = spendingTx.updateWitness(0, ScriptWitness(listOf(commonSig)))
308+
// This tx looks like any other tx that spends a p2tr output, with a single signature.
309+
val signedSpendingTx = spendingTx.updateWitness(0, Script.witnessKeyPathPay2tr(sig))
318310
Transaction.correctlySpends(signedSpendingTx, tx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
319311
}
320312

321313
@Test
322314
fun `swap-in-potentiam example with musig2 and taproot`() {
323315
val userPrivateKey = PrivateKey(ByteArray(32) { 1 })
316+
val userPublicKey = userPrivateKey.publicKey()
324317
val serverPrivateKey = PrivateKey(ByteArray(32) { 2 })
318+
val serverPublicKey = serverPrivateKey.publicKey()
325319
val userRefundPrivateKey = PrivateKey(ByteArray(32) { 3 })
326320
val refundDelay = 25920
327321

@@ -333,8 +327,8 @@ class Musig2TestsCommon {
333327
val scriptTree = ScriptTree.Leaf(0, redeemScript)
334328

335329
// the internal pubkey is the musig2 aggregation of the user's and server's public keys: it does not depend upon the user's refund's key
336-
val internalPubKey = Musig2.keyAgg(listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey())).Q.xOnly()
337-
val pubkeyScript = Script.pay2tr(internalPubKey, scriptTree)
330+
val aggregatedKey = Musig2.keyAgg(listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey())).xOnly()
331+
val pubkeyScript = Script.pay2tr(aggregatedKey, scriptTree)
338332

339333
val swapInTx = Transaction(
340334
version = 2,
@@ -348,27 +342,22 @@ class Musig2TestsCommon {
348342
val tx = Transaction(
349343
version = 2,
350344
txIn = listOf(TxIn(OutPoint(swapInTx, 0), sequence = TxIn.SEQUENCE_FINAL)),
351-
txOut = listOf(TxOut(Satoshi(10000), Script.pay2wpkh(userPrivateKey.publicKey()))),
345+
txOut = listOf(TxOut(Satoshi(10000), Script.pay2wpkh(userPublicKey))),
352346
lockTime = 0
353347
)
354-
// this is the beginning of an interactive musig2 signing session. if user and server are disconnected before they have exchanged partial
355-
// signatures they will have to start again with fresh nonces
356-
val userNonce = SecretNonce.generate(userPrivateKey, userPrivateKey.publicKey(), internalPubKey, null, null, random.nextBytes(32).byteVector32())
357-
val serverNonce = SecretNonce.generate(serverPrivateKey, serverPrivateKey.publicKey(), internalPubKey, null, null, random.nextBytes(32).byteVector32())
358-
359-
val txHash = Transaction.hashForSigningTaprootKeyPath(tx, 0, swapInTx.txOut, SigHash.SIGHASH_DEFAULT)
360-
val commonNonce = IndividualNonce.aggregate(listOf(userNonce.publicNonce(), serverNonce.publicNonce()))
361-
val ctx = SessionCtx(
362-
commonNonce,
363-
listOf(userPrivateKey.publicKey(), serverPrivateKey.publicKey()),
364-
listOf(Pair(internalPubKey.tweak(Crypto.TaprootTweak.ScriptTweak(scriptTree)), true)),
365-
txHash
366-
)
367-
368-
val userSig = ctx.sign(userNonce, userPrivateKey)!!
369-
val serverSig = ctx.sign(serverNonce, serverPrivateKey)!!
370-
val commonSig = ctx.partialSigAgg(listOf(userSig, serverSig))!!
371-
val signedTx = tx.updateWitness(0, Script.witnessKeyPathPay2tr(commonSig))
348+
// The first step of a musig2 signing session is to exchange nonces.
349+
// If participants are disconnected before the end of the signing session, they must start again with fresh nonces.
350+
val userNonce = SecretNonce.generate(userPrivateKey, aggregatedKey, random.nextBytes(32).byteVector32())
351+
val serverNonce = SecretNonce.generate(serverPrivateKey, aggregatedKey, random.nextBytes(32).byteVector32())
352+
353+
// Once they have each other's public nonce, they can produce partial signatures.
354+
val publicNonces = listOf(userNonce.publicNonce(), serverNonce.publicNonce())
355+
val userSig = Musig2.signTaprootInput(userPrivateKey, tx, 0, swapInTx.txOut, listOf(userPublicKey, serverPublicKey), userNonce, publicNonces, scriptTree)!!
356+
val serverSig = Musig2.signTaprootInput(serverPrivateKey, tx, 0, swapInTx.txOut, listOf(userPublicKey, serverPublicKey), serverNonce, publicNonces, scriptTree)!!
357+
358+
// Once they have each other's partial signature, they can aggregate them into a valid signature.
359+
val sig = Musig2.aggregateTaprootSignatures(listOf(userSig, serverSig), tx, 0, swapInTx.txOut, listOf(userPublicKey, serverPublicKey), publicNonces, scriptTree)!!
360+
val signedTx = tx.updateWitness(0, Script.witnessKeyPathPay2tr(sig))
372361
Transaction.correctlySpends(signedTx, swapInTx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
373362
}
374363

@@ -381,7 +370,7 @@ class Musig2TestsCommon {
381370
lockTime = 0
382371
)
383372
val sig = Crypto.signTaprootScriptPath(userRefundPrivateKey, tx, 0, swapInTx.txOut, SigHash.SIGHASH_DEFAULT, scriptTree.hash())
384-
val witness = Script.witnessScriptPathPay2tr(internalPubKey, scriptTree, ScriptWitness(listOf(sig)), scriptTree)
373+
val witness = Script.witnessScriptPathPay2tr(aggregatedKey, scriptTree, ScriptWitness(listOf(sig)), scriptTree)
385374
val signedTx = tx.updateWitness(0, witness)
386375
Transaction.correctlySpends(signedTx, swapInTx, ScriptFlags.STANDARD_SCRIPT_VERIFY_FLAGS)
387376
}

0 commit comments

Comments
 (0)