Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/core/router/router-execution-context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ export class RouterExecutionContext {
}
const isSseHandler = !!this.reflectSse(callback);
if (isSseHandler) {
return <
return async <
TResult extends Observable<unknown> = any,
TResponse extends HeaderStream = any,
TRequest extends IncomingMessage = any,
Expand All @@ -443,7 +443,7 @@ export class RouterExecutionContext {
res: TResponse,
req: TRequest,
) => {
this.responseController.sse(
await this.responseController.sse(
result,
(res as any).raw || res,
(req as any).raw || req,
Expand Down
20 changes: 15 additions & 5 deletions packages/core/router/router-response-controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ export class RouterResponseController {
this.applicationRef.status(response, statusCode);
}

public sse<
public async sse<
TInput extends Observable<unknown> = any,
TResponse extends WritableHeaderStream = any,
TRequest extends IncomingMessage = any,
>(
result: TInput,
result: TInput | Promise<TInput>,
response: TResponse,
request: TRequest,
options?: { additionalHeaders: AdditionalHeaders },
Expand All @@ -114,12 +114,22 @@ export class RouterResponseController {
return;
}

this.assertObservable(result);
const observableResult = await Promise.resolve(result);

this.assertObservable(observableResult);

const stream = new SseStream(request);
stream.pipe(response, options);

const subscription = result
// Extract custom status code from response if it was set
const customStatusCode = (response as any).statusCode;
const pipeOptions =
customStatusCode && customStatusCode !== 200
? { ...options, statusCode: customStatusCode }
: options;

stream.pipe(response, pipeOptions);

const subscription = observableResult
.pipe(
map((message): MessageEvent => {
if (isObject(message)) {
Expand Down
4 changes: 3 additions & 1 deletion packages/core/router/sse-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ export class SseStream extends Transform {
destination: T,
options?: {
additionalHeaders?: AdditionalHeaders;
statusCode?: number;
end?: boolean;
},
): T {
if (destination.writeHead) {
destination.writeHead(200, {
const statusCode = options?.statusCode ?? 200;
destination.writeHead(statusCode, {
...options?.additionalHeaders,
// See https://github.com/dunglas/mercure/blob/master/hub/subscribe.go#L124-L130
'Content-Type': 'text/event-stream',
Expand Down
72 changes: 71 additions & 1 deletion packages/core/test/router/router-response-controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ describe('RouterResponseController', () => {
it('should accept only observables', async () => {
const result = Promise.resolve('test');
try {
routerResponseController.sse(
await routerResponseController.sse(
result as unknown as any,
{} as unknown as ServerResponse,
{} as unknown as IncomingMessage,
Expand All @@ -275,6 +275,76 @@ describe('RouterResponseController', () => {
}
});

it('should accept Promise<Observable>', async () => {
class Sink extends Writable {
private readonly chunks: string[] = [];

_write(
chunk: any,
encoding: string,
callback: (error?: Error | null) => void,
): void {
this.chunks.push(chunk);
callback();
}

get content() {
return this.chunks.join('');
}
}

const written = (stream: Writable) =>
new Promise((resolve, reject) =>
stream.on('error', reject).on('finish', resolve),
);

const result = Promise.resolve(of('test'));
const response = new Sink();
const request = new PassThrough();
await routerResponseController.sse(
result,
response as unknown as ServerResponse,
request as unknown as IncomingMessage,
);
request.destroy();
await written(response);
expect(response.content).to.eql(
`
id: 1
data: test

`,
);
});

it('should use custom status code from response', async () => {
class SinkWithStatusCode extends Writable {
statusCode = 404;
writeHead = sinon.spy();
flushHeaders = sinon.spy();

_write(
chunk: any,
encoding: string,
callback: (error?: Error | null) => void,
): void {
callback();
}
}

const result = of('test');
const response = new SinkWithStatusCode();
const request = new PassThrough();
await routerResponseController.sse(
result,
response as unknown as ServerResponse,
request as unknown as IncomingMessage,
);

expect(response.writeHead.firstCall.args[0]).to.equal(404);
request.destroy();
});

it('should write string', async () => {
class Sink extends Writable {
private readonly chunks: string[] = [];
Expand Down
28 changes: 28 additions & 0 deletions packages/core/test/router/sse-stream.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,34 @@ data: hello
});
});

it('sets custom status code when provided', callback => {
const sse = new SseStream();
const sink = new Sink(
(status: number, headers: string | OutgoingHttpHeaders) => {
expect(status).to.equal(404);
callback();
return sink;
},
);

sse.pipe(sink, {
statusCode: 404,
});
});

it('defaults to 200 status code when not provided', callback => {
const sse = new SseStream();
const sink = new Sink(
(status: number, headers: string | OutgoingHttpHeaders) => {
expect(status).to.equal(200);
callback();
return sink;
},
);

sse.pipe(sink);
});

it('allows an eventsource to connect', callback => {
let sse: SseStream;
const server = createServer((req, res) => {
Expand Down
Loading