diff --git a/src/gen_tcp_server.erl b/src/gen_tcp_server.erl index 64684df..c0e2ce8 100644 --- a/src/gen_tcp_server.erl +++ b/src/gen_tcp_server.erl @@ -47,6 +47,8 @@ addr => any }). -define(DEFAULT_SOCKET_OPTIONS, #{}). +%% Smaller chunks work better with lwIP's limited buffers +-define(MAX_SEND_CHUNK, 1460). %% TCP MSS - fits in single packet without fragmentation %% %% API @@ -130,16 +132,29 @@ handle_info({tcp, Socket, Packet}, State) -> case Handler:handle_receive(Socket, Packet, HandlerState) of {reply, ResponsePacket, ResponseState} -> ?TRACE("Sending reply to endpoint ~p", [socket:peername(Socket)]), - try_send(Socket, ResponsePacket), - {noreply, State#state{handler_state=ResponseState}}; + case try_send(Socket, ResponsePacket) of + ok -> + {noreply, State#state{handler_state=ResponseState}}; + {error, closed} -> + ?TRACE("Connection closed during send, cleaning up", []), + {noreply, State#state{handler_state=ResponseState}}; + {error, _Reason} -> + try_close(Socket), + {noreply, State#state{handler_state=ResponseState}} + end; {noreply, ResponseState} -> ?TRACE("no reply", []), {noreply, State#state{handler_state=ResponseState}}; {close, ResponsePacket} -> ?TRACE("Sending reply to endpoint ~p and closing socket: ~p", [socket:peername(Socket), Socket]), - try_send(Socket, ResponsePacket), - % timer:sleep(500), - try_close(Socket), + case try_send(Socket, ResponsePacket) of + ok -> + try_close(Socket); + {error, closed} -> + ok; %% Already closed, nothing to do + {error, _Reason} -> + try_close(Socket) + end, {noreply, State}; close -> ?TRACE("Closing socket ~p", [Socket]), @@ -168,20 +183,11 @@ try_send(Socket, Packet) when is_binary(Packet) -> "Trying to send binary packet data to socket ~p. Packet (or len): ~p", [ Socket, case byte_size(Packet) < 32 of true -> Packet; _ -> byte_size(Packet) end ]), - case socket:send(Socket, Packet) of - ok -> - ?TRACE("sent.", []), - ok; - {ok, Rest} -> - ?TRACE("sent. remaining: ~p", [Rest]), - try_send(Socket, Rest); - Error -> - io:format("Send failed due to error ~p~n", [Error]) - end; -try_send(Socket, Char) when is_integer(Char) -> - %% TODO handle unicode - ?TRACE("Sending char ~p as ~p", [Char, <>]), - try_send(Socket, <>); + try_send_binary(Socket, Packet); +try_send(Socket, Byte) when is_integer(Byte) -> + %% Handles bytes (0-255) in iolists. Unicode must be pre-encoded to UTF-8. + ?TRACE("Sending byte ~p as ~p", [Byte, <>]), + try_send(Socket, <>); try_send(Socket, List) when is_list(List) -> case is_string(List) of true -> @@ -193,8 +199,46 @@ try_send(Socket, List) when is_list(List) -> try_send_iolist(_Socket, []) -> ok; try_send_iolist(Socket, [H | T]) -> - try_send(Socket, H), - try_send_iolist(Socket, T). + case try_send(Socket, H) of + ok -> + try_send_iolist(Socket, T); + {error, _Reason} = Error -> + Error + end. + +try_send_binary(_Socket, <<>>) -> + ok; +try_send_binary(Socket, Packet) when is_binary(Packet) -> + TotalSize = byte_size(Packet), + ChunkSize = erlang:min(TotalSize, ?MAX_SEND_CHUNK), + <> = Packet, + case socket:send(Socket, Chunk) of + ok -> + %% Give the scheduler a chance to run and let TCP drain + maybe_yield(Rest), + try_send_binary(Socket, Rest); + {ok, Remaining} -> + %% Partial send - combine remaining with rest and retry + try_send_binary(Socket, <>); + {error, closed} -> + %% Only log if we actually had more data to send + case byte_size(Rest) of + 0 -> ok; %% Sent everything, client just closed after - that's fine + _ -> io:format("Connection closed mid-transfer (~p/~p bytes sent)~n", + [ChunkSize, TotalSize]) + end, + {error, closed}; + {error, Reason} -> + io:format("Send error: ~p (chunk: ~p, total: ~p)~n", + [Reason, ChunkSize, TotalSize]), + {error, Reason} + end. + +%% Lightweight yield using receive timeout - works in AtomVM +maybe_yield(<<>>) -> + ok; +maybe_yield(_) -> + receive after 0 -> ok end. is_string([]) -> true; @@ -216,7 +260,6 @@ try_close(Socket) -> set_socket_options(Socket, SocketOptions) -> maps:fold( fun(Option, Value, Accum) -> - erlang:display({setopt, Socket, Option, Value}), ok = socket:setopt(Socket, Option, Value), Accum end, @@ -232,8 +275,10 @@ accept(ControllingProcess, ListenSocket) -> ?TRACE("Accepted connection from ~p", [socket:peername(Connection)]), spawn(fun() -> accept(ControllingProcess, ListenSocket) end), loop(ControllingProcess, Connection); - _Error -> - ?TRACE("Error accepting connection: ~p", [Error]) + Error -> + ?TRACE("Error accepting connection: ~p", [Error]), + timer:sleep(100), + accept(ControllingProcess, ListenSocket) end. diff --git a/src/httpd.erl b/src/httpd.erl index 1b1aa0f..4c93f2c 100644 --- a/src/httpd.erl +++ b/src/httpd.erl @@ -17,7 +17,7 @@ -module(httpd). --export([start/2, start/3, start_link/2, start_link/3, stop/1]). +-export([start/2, start/3, start/4, start_link/2, start_link/3, start_link/4, stop/1]). -behaviour(gen_tcp_server). -export([init/1, handle_receive/3, handle_tcp_closed/2]). @@ -66,7 +66,8 @@ -record(state, { config, pending_request_map = #{}, - ws_socket_map = #{} + ws_socket_map = #{}, + pending_buffer_map = #{} }). %% @@ -75,19 +76,27 @@ -spec start(Port :: portnum(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. start(Port, Config) -> - start(any, Port, Config). + start(any, Port, #{}, Config). -spec start(Address :: address(), Port :: portnum(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. start(Address, Port, Config) -> - gen_tcp_server:start(#{addr => Address, port => Port}, ?MODULE, Config). + start(Address, Port, #{}, Config). + +-spec start(Address :: address(), Port :: portnum(), SocketOptions :: map(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. +start(Address, Port, SocketOptions, Config) -> + gen_tcp_server:start(#{addr => Address, port => Port}, SocketOptions, ?MODULE, Config). -spec start_link(Port :: portnum(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. start_link(Port, Config) -> - start_link(any, Port, Config). + start_link(any, Port, #{}, Config). -spec start_link(Address :: address(), Port :: portnum(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. start_link(Address, Port, Config) -> - gen_tcp_server:start_link(#{addr => Address, port => Port}, ?MODULE, Config). + start_link(Address, Port, #{}, Config). + +-spec start_link(Address :: address(), Port :: portnum(), SocketOptions :: map(), Config :: config()) -> {ok, HTTPD :: pid()} | {error, Reason :: term()}. +start_link(Address, Port, SocketOptions, Config) -> + gen_tcp_server:start_link(#{addr => Address, port => Port}, SocketOptions, ?MODULE, Config). stop(Httpd) -> gen_tcp_server:stop(Httpd). @@ -122,59 +131,76 @@ handle_receive(Socket, Packet, State) -> %% @private handle_http_request(Socket, Packet, State) -> - case maps:get(Socket, State#state.pending_request_map, undefined) of + PendingRequestMap = State#state.pending_request_map, + BufferMap = State#state.pending_buffer_map, + PendingBuffer = maps:get(Socket, BufferMap, <<>>), + AccumulatedPacket = <>, + case maps:get(Socket, PendingRequestMap, undefined) of undefined -> - HttpRequest = parse_http_request(binary_to_list(Packet)), - % ?TRACE("HttpRequest: ~p~n", [HttpRequest]), - #{ - method := Method, - headers := Headers - } = HttpRequest, - case get_protocol(Method, Headers) of - http -> - case init_handler(HttpRequest, State) of - {ok, {Handler, HandlerState, PathSuffix, HandlerConfig}} -> - NewHttpRequest = HttpRequest#{ - handler => Handler, - handler_state => HandlerState, - path_suffix => PathSuffix, - handler_config => HandlerConfig, - socket => Socket - }, - handle_request_state(Socket, NewHttpRequest, State); - Error -> - {close, create_error(?INTERNAL_SERVER_ERROR, Error)} - end; - ws -> - ?TRACE("Protocol is ws", []), - Config = State#state.config, - Path = maps:get(path, HttpRequest), - case get_handler(Path, Config) of - {ok, PathSuffix, EntryConfig} -> - WsHandler = maps:get(handler, EntryConfig), - ?TRACE("Got handler ~p", [WsHandler]), - HandlerConfig = maps:get(handler_config, EntryConfig, #{}), - case WsHandler:start(Socket, PathSuffix, HandlerConfig) of - {ok, WebSocket} -> - ?TRACE("Started web socket handler: ~p", [WebSocket]), - NewWebSocketMap = maps:put(Socket, WebSocket, State#state.ws_socket_map), - NewState = State#state{ws_socket_map = NewWebSocketMap}, - ReplyToken = get_reply_token(maps:get(headers, HttpRequest)), - ReplyHeaders = #{"Upgrade" => "websocket", "Connection" => "Upgrade", "Sec-WebSocket-Accept" => ReplyToken}, - Reply = create_reply(?SWITCHING_PROTOCOLS, ReplyHeaders, <<"">>), - ?TRACE("Sending web socket upgrade reply: ~p", [Reply]), - {reply, Reply, NewState}; + case maybe_parse_http_request(AccumulatedPacket) of + {more, IncompletePacket} -> + NewBufferMap = BufferMap#{Socket => IncompletePacket}, + {noreply, State#state{pending_buffer_map = NewBufferMap}}; + {ok, HttpRequest} -> + CleanBufferMap = maps:remove(Socket, BufferMap), + CleanState = State#state{pending_buffer_map = CleanBufferMap}, + % ?TRACE("HttpRequest: ~p~n", [HttpRequest]), + #{ + method := Method, + headers := Headers + } = HttpRequest, + case get_protocol(Method, Headers) of + http -> + case init_handler(HttpRequest, CleanState) of + {ok, {Handler, HandlerState, PathSuffix, HandlerConfig}} -> + NewHttpRequest = HttpRequest#{ + handler => Handler, + handler_state => HandlerState, + path_suffix => PathSuffix, + handler_config => HandlerConfig, + socket => Socket + }, + handle_request_state(Socket, NewHttpRequest, CleanState); Error -> - ?TRACE("Web socket error: ~p", [Error]), - {close, create_error(?INTERNAL_SERVER_ERROR, {web_socket_error, Error})} + {close, create_error(?INTERNAL_SERVER_ERROR, Error)} end; - Error -> - Error - end + ws -> + ?TRACE("Protocol is ws", []), + Config = CleanState#state.config, + Path = maps:get(path, HttpRequest), + case get_handler(Path, Config) of + {ok, PathSuffix, EntryConfig} -> + WsHandler = maps:get(handler, EntryConfig), + ?TRACE("Got handler ~p", [WsHandler]), + HandlerConfig = maps:get(handler_config, EntryConfig, #{}), + case WsHandler:start(Socket, PathSuffix, HandlerConfig) of + {ok, WebSocket} -> + ?TRACE("Started web socket handler: ~p", [WebSocket]), + NewWebSocketMap = maps:put(Socket, WebSocket, CleanState#state.ws_socket_map), + NewState = CleanState#state{ws_socket_map = NewWebSocketMap}, + ReplyToken = get_reply_token(maps:get(headers, HttpRequest)), + ReplyHeaders = #{"Upgrade" => "websocket", "Connection" => "Upgrade", "Sec-WebSocket-Accept" => ReplyToken}, + Reply = create_reply(?SWITCHING_PROTOCOLS, ReplyHeaders, <<"">>), + ?TRACE("Sending web socket upgrade reply: ~p", [Reply]), + {reply, Reply, NewState}; + Error -> + ?TRACE("Web socket error: ~p", [Error]), + {close, create_error(?INTERNAL_SERVER_ERROR, {web_socket_error, Error})} + end; + Error -> + {close, create_error(?INTERNAL_SERVER_ERROR, {web_socket_error, Error})} + end + end; + {error, Reason} -> + {close, create_error(?BAD_REQUEST, Reason)} end; PendingHttpRequest -> ?TRACE("Packetlen: ~p", [erlang:byte_size(Packet)]), - handle_request_state(Socket, PendingHttpRequest#{body := Packet}, State) + ExistingBody = maps:get(body, PendingHttpRequest, <<>>), + NewBody = <>, + CleanBufferMap = maps:remove(Socket, BufferMap), + CleanState = State#state{pending_buffer_map = CleanBufferMap}, + handle_request_state(Socket, PendingHttpRequest#{body := NewBody}, CleanState) end. %% @private @@ -213,7 +239,7 @@ handle_request_state(Socket, HttpRequest, State) -> {reply, Reply, State#state{pending_request_map = NewPendingRequestMap}}; wait_for_body -> NewPendingRequestMap = PendingRequestMap#{Socket => HttpRequest}, - call_http_req_handler(Socket, HttpRequest, State#state{pending_request_map = NewPendingRequestMap}) + {noreply, State#state{pending_request_map = NewPendingRequestMap}} end. %% @private @@ -284,19 +310,25 @@ call_http_req_handler(Socket, HttpRequest, State) -> update_state(Socket, HttpRequest, HandlerState, State) -> NewHttpRequest = HttpRequest#{handler_state := HandlerState}, PendingRequestMap = State#state.pending_request_map, - NewPendingRequestMap = PendingRequestMap#{Socket := NewHttpRequest}, + NewPendingRequestMap = PendingRequestMap#{Socket => NewHttpRequest}, State#state{pending_request_map = NewPendingRequestMap}. %% @hidden handle_tcp_closed(Socket, State) -> - case maps:get(Socket, State#state.ws_socket_map, undefined) of + NewPendingRequestMap = maps:remove(Socket, State#state.pending_request_map), + NewPendingBufferMap = maps:remove(Socket, State#state.pending_buffer_map), + CleanState = State#state{ + pending_request_map = NewPendingRequestMap, + pending_buffer_map = NewPendingBufferMap + }, + case maps:get(Socket, CleanState#state.ws_socket_map, undefined) of undefined -> - State; + CleanState; WebSocket -> ok = httpd_ws_handler:stop(WebSocket), - NewWebSocketMap = maps:remove(Socket, State#state.ws_socket_map), - State#state{ws_socket_map = NewWebSocketMap} + NewWebSocketMap = maps:remove(Socket, CleanState#state.ws_socket_map), + CleanState#state{ws_socket_map = NewWebSocketMap} end. %% @@ -324,6 +356,29 @@ parse_http_request(Packet) -> } ). +maybe_parse_http_request(Packet) when is_binary(Packet) -> + case find_header_delimiter(Packet) of + nomatch -> + {more, Packet}; + {_Pos, _Len} -> + try + {ok, parse_http_request(binary_to_list(Packet))} + catch + throw:Reason -> + {error, Reason}; + error:Reason -> + {error, Reason} + end + end. + +find_header_delimiter(Packet) -> + case binary:match(Packet, <<"\r\n\r\n">>) of + nomatch -> + binary:match(Packet, <<"\n\n">>); + Match -> + Match + end. + %% @private parse_heading([$\s|Rest], start, Tmp, Accum) -> parse_heading(Rest, start, Tmp, Accum); @@ -521,15 +576,33 @@ create_error(StatusCode, Error) -> create_reply(StatusCode, ContentType, Reply) when is_list(ContentType) orelse is_binary(ContentType) -> create_reply(StatusCode, #{"Content-Type" => ContentType}, Reply); create_reply(StatusCode, Headers, Reply) when is_map(Headers) -> + ReplyLen = erlang:iolist_size(Reply), + HeadersWithLen = ensure_content_length(Headers, ReplyLen), [ <<"HTTP/1.1 ">>, erlang:integer_to_binary(StatusCode), <<" ">>, moniker(StatusCode), <<"\r\n">>, io_lib:format("Server: atomvm-~s\r\n", [get_version_str(erlang:system_info(atomvm_version))]), - to_headers_list(Headers), + to_headers_list(HeadersWithLen), <<"\r\n">>, Reply ]. +%% @private +ensure_content_length(Headers, ReplyLen) -> + LenBin = erlang:integer_to_binary(ReplyLen), + CleanHeaders = remove_content_length_header(Headers), + CleanHeaders#{<<"Content-Length">> => LenBin}. + +%% @private +remove_content_length_header(Headers) -> + KeysToRemove = [ + "Content-Length", + <<"Content-Length">>, + "content-length", + <<"content-length">> + ], + lists:foldl(fun(Key, Acc) -> maps:remove(Key, Acc) end, Headers, KeysToRemove). + %% @private maybe_binary_to_string(Bin) when is_binary(Bin) -> erlang:binary_to_list(Bin); diff --git a/src/httpd_ws_handler.erl b/src/httpd_ws_handler.erl index a33248f..68a7475 100644 --- a/src/httpd_ws_handler.erl +++ b/src/httpd_ws_handler.erl @@ -80,7 +80,8 @@ send(WebSocket, Packet) -> -record(state, { socket, handler_module, - handler_state + handler_state, + frame_buffer = <<>> %% Buffer for incomplete WebSocket frames }). %% @hidden @@ -103,28 +104,36 @@ handle_cast({message, Packet}, State) -> #state{ socket = Socket, handler_module = HandlerModule, - handler_state = HandlerState + handler_state = HandlerState, + frame_buffer = Buffer } = State, ?TRACE("WebSocket received packet ~p", [Packet]), - case parse_frame(Packet) of - {ok, PayloadData} -> - ?TRACE("HandlerModule ~p; PayloadData ~p", [HandlerModule, PayloadData]), + + %% Accumulate packet data into buffer + NewBuffer = <>, + + case parse_frame(NewBuffer) of + {ok, PayloadData, Remaining} -> + ?TRACE("HandlerModule ~p; PayloadData ~p; Remaining ~p bytes", [HandlerModule, PayloadData, byte_size(Remaining)]), case HandlerModule:handle_ws_message(PayloadData, HandlerState) of {reply, Reply, NewHandlerState} -> ?TRACE("Handled WS payload. NewHandlerState: ~p", [NewHandlerState]), do_send(Socket, Reply, text), - {noreply, State#state{handler_state = NewHandlerState}}; + {noreply, State#state{handler_state = NewHandlerState, frame_buffer = Remaining}}; {noreply, NewHandlerState} -> ?TRACE("Handled WS payload. NewHandlerState: ~p", [NewHandlerState]), - {noreply, State#state{handler_state = NewHandlerState}}; + {noreply, State#state{handler_state = NewHandlerState, frame_buffer = Remaining}}; HandleModleError -> ?TRACE("HandleModleError: ~p", [HandleModleError]), socket:close(Socket), {stop, HandleModleError, State} end; + incomplete -> + ?TRACE("Incomplete frame, buffering ~p bytes", [byte_size(NewBuffer)]), + {noreply, State#state{frame_buffer = NewBuffer}}; empty_payload -> ?TRACE("Empty payload.", []), - {noreply, State}; + {noreply, State#state{frame_buffer = <<>>}}; ParseFrameError -> ?TRACE("ParseFrameError: ~p", [ParseFrameError]), socket:close(Socket), @@ -153,49 +162,50 @@ terminate(_Reason, _State) -> %% @private parse_frame(<<0,0,0,0,0,0,0,0,0,0>>) -> empty_payload; +parse_frame(Packet) when byte_size(Packet) < 2 -> + incomplete; parse_frame(Packet) -> try <<_FinOpcode:8, MaskLen:8, Rest/binary>> = Packet, - % Fin = (FinOpcode band 16#80) bsr 7, - % Opcode = FinOpcode band 16#0F, Mask = (MaskLen band 16#80) bsr 7, PayloadLen = MaskLen band 16#7F, - % <> = Packet, - % ?TRACE("FinOpcode: ~p, Fin: ~p, Opcode: ~p, MaskLen: ~p, Mask: ~p, PayloadLen: ~p, Rest: ~p", [FinOpcode, Fin, Opcode, MaskLen, Mask, PayloadLen, Rest]), - ?TRACE("Fin: ~p, Opcode: ~p, Mask: ~p, PayloadLen: ~p, Rest: ~p", [Fin, Opcode, Mask, PayloadLen, Rest]), - case PayloadLen of - 0 -> - {ok, <<"">>}; + + %% Calculate how many bytes we need for the complete frame + {ActualPayloadLen, HeaderSize} = case PayloadLen of 126 -> - case Mask of - 1 -> - <> = Rest, - <> = Rest2, - ?TRACE("MaskingKey: ~p, MaskedPayload: ~p", [MaskingKey, MaskedPayload]), - {ok, unmask(MaskingKey, MaskedPayload)}; - _ -> - <> = Rest, - {ok, <>} + case byte_size(Rest) >= 2 of + true -> + <> = Rest, + {Len, 2}; + false -> + {need_more, 2} end; 127 -> - case Mask of - 1 -> - <> = Rest, - <> = Rest2, - ?TRACE("MaskingKey: ~p, MaskedPayload: ~p", [MaskingKey, MaskedPayload]), - {ok, unmask(MaskingKey, MaskedPayload)}; - _ -> - <> = Rest, - {ok, <>} + case byte_size(Rest) >= 8 of + true -> + <> = Rest, + {Len, 8}; + false -> + {need_more, 8} end; + Len -> + {Len, 0} + end, + + case ActualPayloadLen of + need_more -> + incomplete; _ -> - case Mask of - 1 -> - <> = Rest, - ?TRACE("MaskingKey: ~p, MaskedPayload: ~p", [MaskingKey, MaskedPayload]), - {ok, unmask(MaskingKey, MaskedPayload)}; - _ -> - {ok, Rest} + MaskSize = case Mask of 1 -> 4; _ -> 0 end, + RequiredBytes = HeaderSize + MaskSize + ActualPayloadLen, + + case byte_size(Rest) >= RequiredBytes of + true -> + %% We have enough data to parse the complete frame + parse_complete_frame(PayloadLen, Mask, Rest); + false -> + ?TRACE("Incomplete frame: have ~p bytes, need ~p", [byte_size(Packet), 2 + RequiredBytes]), + incomplete end end catch @@ -204,6 +214,33 @@ parse_frame(Packet) -> {error, Error} end. +%% @private +parse_complete_frame(PayloadLen, Mask, Rest) -> + case PayloadLen of + 0 -> + {ok, <<"">>, Rest}; + 126 -> + <> = Rest, + extract_payload(Mask, MediumPayloadLen, Rest2); + 127 -> + <> = Rest, + extract_payload(Mask, LargePayloadLen, Rest2); + _ -> + extract_payload(Mask, PayloadLen, Rest) + end. + +%% @private +extract_payload(Mask, PayloadLen, Data) -> + case Mask of + 1 -> + <> = Data, + ?TRACE("MaskingKey: ~p, MaskedPayload length: ~p", [MaskingKey, byte_size(MaskedPayload)]), + {ok, unmask(MaskingKey, MaskedPayload), Remaining}; + _ -> + <> = Data, + {ok, Payload, Remaining} + end. + %% @private unmask(MaskingKey, MaskedPayload) -> unmask(MaskingKey, MaskedPayload, 0, []).