Skip to content

Commit acfb7b0

Browse files
dozyioachingbrain
andauthored
feat: async middleware (#3204)
Adds middleware handlers for protocol streams. They are invoked for incoming and outgoing streams and allow access to the stream and connection before the handler (incoming) or caller (outgoing) receive them. This way middleware can wrap streams in transforms, or deny access, or something else. ```ts libp2p.use('/my/protocol/1.0.0', (stream, connection, next) => { const originalSource = stream.source // increment all byte values in the stream by one stream.source = (async function * () { for await (const buf of originalSource) { buf = buf.map(val => val + 1) yield buf } })() // pass the stream on to the next middleware next(stream, connection) }) ``` --------- Co-authored-by: achingbrain <[email protected]>
1 parent 6332556 commit acfb7b0

File tree

2 files changed

+123
-41
lines changed

2 files changed

+123
-41
lines changed

packages/libp2p/src/connection.ts

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { CONNECTION_CLOSE_TIMEOUT, PROTOCOL_NEGOTIATION_TIMEOUT } from './connec
66
import { isDirect } from './connection-manager/utils.ts'
77
import { MuxerUnavailableError } from './errors.ts'
88
import { DEFAULT_MAX_INBOUND_STREAMS, DEFAULT_MAX_OUTBOUND_STREAMS } from './registrar.ts'
9-
import type { AbortOptions, Logger, MessageStreamDirection, Connection as ConnectionInterface, Stream, NewStreamOptions, PeerId, ConnectionLimits, StreamMuxer, Metrics, PeerStore, MultiaddrConnection, MessageStreamEvents, MultiaddrConnectionTimeline, ConnectionStatus, MessageStream } from '@libp2p/interface'
9+
import type { AbortOptions, Logger, MessageStreamDirection, Connection as ConnectionInterface, Stream, NewStreamOptions, PeerId, ConnectionLimits, StreamMuxer, Metrics, PeerStore, MultiaddrConnection, MessageStreamEvents, MultiaddrConnectionTimeline, ConnectionStatus, MessageStream, StreamMiddleware } from '@libp2p/interface'
1010
import type { Registrar } from '@libp2p/interface-internal'
1111
import type { Multiaddr } from '@multiformats/multiaddr'
1212

@@ -126,7 +126,7 @@ export class Connection extends TypedEventEmitter<MessageStreamEvents> implement
126126
}
127127

128128
this.log.trace('starting new stream for protocols %s', protocols)
129-
let muxedStream = await this.muxer.createStream({
129+
const muxedStream = await this.muxer.createStream({
130130
...options,
131131

132132
// most underlying transports only support negotiating a single protocol
@@ -179,23 +179,7 @@ export class Connection extends TypedEventEmitter<MessageStreamEvents> implement
179179

180180
const middleware = this.components.registrar.getMiddleware(muxedStream.protocol)
181181

182-
middleware.push((stream, connection, next) => {
183-
next(stream, connection)
184-
})
185-
186-
let i = 0
187-
let connection: ConnectionInterface = this
188-
189-
while (i < middleware.length) {
190-
// eslint-disable-next-line no-loop-func
191-
middleware[i](muxedStream, connection, (s, c) => {
192-
muxedStream = s
193-
connection = c
194-
i++
195-
})
196-
}
197-
198-
return muxedStream
182+
return await this.runMiddlewareChain(muxedStream, this, middleware)
199183
} catch (err: any) {
200184
if (muxedStream.status === 'open') {
201185
muxedStream.abort(err)
@@ -208,7 +192,7 @@ export class Connection extends TypedEventEmitter<MessageStreamEvents> implement
208192
}
209193

210194
private async onIncomingStream (evt: CustomEvent<Stream>): Promise<void> {
211-
let muxedStream = evt.detail
195+
const muxedStream = evt.detail
212196

213197
const signal = AbortSignal.timeout(this.inboundStreamProtocolNegotiationTimeout)
214198
setMaxListeners(Infinity, signal)
@@ -260,20 +244,40 @@ export class Connection extends TypedEventEmitter<MessageStreamEvents> implement
260244
next(stream, connection)
261245
})
262246

263-
let connection: ConnectionInterface = this
264-
265-
for (const m of middleware) {
266-
// eslint-disable-next-line no-loop-func
267-
await m(muxedStream, connection, (s, c) => {
268-
muxedStream = s
269-
connection = c
270-
})
271-
}
247+
await this.runMiddlewareChain(muxedStream, this, middleware)
272248
} catch (err: any) {
273249
muxedStream.abort(err)
274250
}
275251
}
276252

253+
private async runMiddlewareChain (stream: Stream, connection: ConnectionInterface, middleware: StreamMiddleware[]): Promise<Stream> {
254+
for (let i = 0; i < middleware.length; i++) {
255+
const mw = middleware[i]
256+
stream.log.trace('running middleware', i, mw)
257+
258+
// eslint-disable-next-line no-loop-func
259+
await new Promise<void>((resolve, reject) => {
260+
try {
261+
const result = mw(stream, connection, (s, c) => {
262+
stream = s
263+
connection = c
264+
resolve()
265+
})
266+
267+
if (result instanceof Promise) {
268+
result.catch(reject)
269+
}
270+
} catch (err) {
271+
reject(err)
272+
}
273+
})
274+
275+
stream.log.trace('ran middleware', i, mw)
276+
}
277+
278+
return stream
279+
}
280+
277281
/**
278282
* Close the connection
279283
*/

packages/libp2p/test/connection/index.spec.ts

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { StreamCloseEvent } from '@libp2p/interface'
2+
import { defaultLogger } from '@libp2p/logger'
23
import { peerIdFromString } from '@libp2p/peer-id'
34
import { echoStream, streamPair, echo, multiaddrConnectionPair, mockMuxer } from '@libp2p/utils'
45
import { multiaddr } from '@multiformats/multiaddr'
@@ -361,6 +362,7 @@ describe('connection', () => {
361362
}
362363

363364
const incomingStream = stubInterface<Stream>({
365+
log: defaultLogger().forComponent('stream'),
364366
protocol: streamProtocol
365367
})
366368

@@ -371,24 +373,100 @@ describe('connection', () => {
371373
onIncomingStream(new CustomEvent('stream', {
372374
detail: incomingStream
373375
}))
374-
/*
376+
377+
// incoming stream is opened asynchronously
378+
await delay(100)
379+
380+
expect(middleware1.called).to.be.true()
381+
expect(middleware2.called).to.be.true()
382+
})
383+
384+
it('should not call outbound middleware if previous middleware errors', async () => {
385+
const streamProtocol = '/test/protocol'
386+
const err = new Error('boom')
387+
388+
const middleware1 = Sinon.stub().callsFake((stream, connection, next) => {
389+
throw err
390+
})
391+
const middleware2 = Sinon.stub().callsFake((stream, connection, next) => {
392+
next(stream, connection)
393+
})
394+
395+
const middleware = [
396+
middleware1,
397+
middleware2
398+
]
399+
400+
registrar.getMiddleware.withArgs(streamProtocol).returns(middleware)
401+
registrar.getHandler.withArgs(streamProtocol).returns({
402+
handler: () => {},
403+
options: {}
404+
})
405+
406+
const connection = createConnection(components, init)
407+
408+
await expect(connection.newStream(streamProtocol))
409+
.to.eventually.be.rejectedWith(err)
410+
411+
expect(middleware1.called).to.be.true()
412+
expect(middleware2.called).to.be.false()
413+
})
414+
415+
it('should not call inbound middleware if previous middleware errors', async () => {
416+
const streamProtocol = '/test/protocol'
417+
418+
const middleware1 = Sinon.stub().callsFake((stream, connection, next) => {
419+
throw new Error('boom')
420+
})
421+
const middleware2 = Sinon.stub().callsFake((stream, connection, next) => {
422+
next(stream, connection)
423+
})
424+
425+
const middleware = [
426+
middleware1,
427+
middleware2
428+
]
429+
430+
registrar.getMiddleware.withArgs(streamProtocol).returns(middleware)
431+
registrar.getHandler.withArgs(streamProtocol).returns({
432+
handler: () => {},
433+
options: {}
434+
})
435+
436+
const muxer = stubInterface<StreamMuxer>({
437+
streams: []
438+
})
439+
440+
createConnection(components, {
441+
...init,
442+
muxer
443+
})
444+
445+
expect(muxer.addEventListener.getCall(0).args[0]).to.equal('stream')
446+
const onIncomingStream = muxer.addEventListener.getCall(0).args[1]
447+
448+
if (onIncomingStream == null) {
449+
throw new Error('No incoming stream handler registered')
450+
}
451+
375452
const incomingStream = stubInterface<Stream>({
376-
id: 'stream-id',
377-
log: logger('test-stream'),
378-
direction: 'outbound',
379-
sink: async (source) => drain(source),
380-
source: map((async function * () {
381-
yield '/multistream/1.0.0\n'
382-
yield `${streamProtocol}\n`
383-
})(), str => encode.single(uint8ArrayFromString(str)))
453+
log: defaultLogger().forComponent('stream'),
454+
protocol: streamProtocol
384455
})
385-
*/
386-
// onIncomingStream?.(incomingStream)
456+
457+
if (typeof onIncomingStream !== 'function') {
458+
throw new Error('Stream handler was not function')
459+
}
460+
461+
onIncomingStream(new CustomEvent('stream', {
462+
detail: incomingStream
463+
}))
387464

388465
// incoming stream is opened asynchronously
389466
await delay(100)
390467

391468
expect(middleware1.called).to.be.true()
392-
expect(middleware2.called).to.be.true()
469+
expect(middleware2.called).to.be.false()
470+
expect(incomingStream).to.have.nested.property('abort.called', true)
393471
})
394472
})

0 commit comments

Comments
 (0)