Skip to content

Commit cb56e85

Browse files
committed
Fetch session from conn
1 parent 70f015b commit cb56e85

File tree

4 files changed

+73
-82
lines changed

4 files changed

+73
-82
lines changed

lib/phoenix/endpoint.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ defmodule Phoenix.Endpoint do
643643
end
644644

645645
@doc false
646-
defmacro __before_compile__(%{module: module}) do
646+
defmacro __before_compile__(_env) do
647647
quote do
648648
defoverridable call: 2
649649

lib/phoenix/endpoint/supervisor.ex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ defmodule Phoenix.Endpoint.Supervisor do
8484
config_children(mod, secret_conf, default_conf) ++
8585
warmup_children(mod) ++
8686
pubsub_children(mod, conf) ++
87-
socket_children(mod, conf, :child_spec) ++
87+
socket_children(mod, :child_spec) ++
8888
server_children(mod, conf, server?) ++
89-
socket_children(mod, conf, :drainer_spec) ++
89+
socket_children(mod, :drainer_spec) ++
9090
watcher_children(mod, conf, server?)
9191

9292
Supervisor.init(children, strategy: :one_for_one)
@@ -118,7 +118,7 @@ defmodule Phoenix.Endpoint.Supervisor do
118118
end
119119
end
120120

121-
defp socket_children(endpoint, conf, fun) do
121+
defp socket_children(endpoint, fun) do
122122
for {socket, opts} <- endpoint.__start_sockets__(),
123123
# TODO is this the correct place for this?
124124
# Needs to know transport specific config

lib/phoenix/socket/transport.ex

Lines changed: 11 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -290,39 +290,23 @@ defmodule Phoenix.Socket.Transport do
290290
:user_agent,
291291
:x_headers,
292292
:sec_websocket_headers,
293-
:auth_token
293+
:auth_token,
294+
:session
294295
] ->
295296
key
296297

297-
{:session, session} ->
298-
{:session, init_session(session)}
299-
300298
{_, _} = pair ->
301299
pair
302300

303301
other ->
304302
raise ArgumentError,
305-
":connect_info keys are expected to be one of :peer_data, :trace_context_headers, :x_headers, :user_agent, :sec_websocket_headers, :uri, or {:session, config}, " <>
303+
":connect_info keys are expected to be one of :peer_data, :trace_context_headers, :x_headers, :user_agent, :sec_websocket_headers, :uri, or :session, " <>
306304
"optionally followed by custom keyword pairs, got: #{inspect(other)}"
307305
end)
308306

309307
[connect_info: connect_info] ++ config
310308
end
311309

312-
# The original session_config is returned in addition to init value so we can
313-
# access special config like :csrf_token_key downstream.
314-
defp init_session(session_config) when is_list(session_config) do
315-
key = Keyword.fetch!(session_config, :key)
316-
store = Plug.Session.Store.get(Keyword.fetch!(session_config, :store))
317-
init = store.init(Keyword.drop(session_config, [:store, :key]))
318-
csrf_token_key = Keyword.get(session_config, :csrf_token_key, "_csrf_token")
319-
{key, store, {csrf_token_key, init}}
320-
end
321-
322-
defp init_session({_, _, _} = mfa) do
323-
{:mfa, mfa}
324-
end
325-
326310
@doc """
327311
Runs the code reloader if enabled.
328312
"""
@@ -496,7 +480,7 @@ defmodule Phoenix.Socket.Transport do
496480
497481
The CSRF check can be disabled by setting the `:check_csrf` option to `false`.
498482
"""
499-
def connect_info(conn, endpoint, keys, opts \\ []) do
483+
def connect_info(conn, _endpoint, keys, opts \\ []) do
500484
for key <- keys, into: %{} do
501485
case key do
502486
:peer_data ->
@@ -517,8 +501,8 @@ defmodule Phoenix.Socket.Transport do
517501
:sec_websocket_headers ->
518502
{:sec_websocket_headers, fetch_headers(conn, "sec-websocket-")}
519503

520-
{:session, session} ->
521-
{:session, connect_session(conn, endpoint, session, opts)}
504+
:session ->
505+
{:session, connect_session(conn, opts)}
522506

523507
:auth_token ->
524508
{:auth_token, conn.private[:phoenix_transport_auth_token]}
@@ -529,28 +513,15 @@ defmodule Phoenix.Socket.Transport do
529513
end
530514
end
531515

532-
defp connect_session(conn, endpoint, {key, store, {csrf_token_key, init}}, opts) do
533-
conn = Plug.Conn.fetch_cookies(conn)
516+
defp connect_session(conn, opts) do
517+
session = Plug.Conn.get_session(conn)
534518
check_csrf = Keyword.get(opts, :check_csrf, true)
519+
csrf_token_key = Keyword.get(opts, :csrf_token_key, "_csrf_token")
535520

536-
with cookie when is_binary(cookie) <- conn.cookies[key],
537-
conn = put_in(conn.secret_key_base, endpoint.config(:secret_key_base)),
538-
{_, session} <- store.get(conn, cookie, init),
539-
true <- not check_csrf or csrf_token_valid?(conn, session, csrf_token_key) do
521+
if not check_csrf or csrf_token_valid?(conn, session, csrf_token_key) do
540522
session
541523
else
542-
_ -> nil
543-
end
544-
end
545-
546-
defp connect_session(conn, endpoint, {:mfa, {module, function, args}}, opts) do
547-
case apply(module, function, args) do
548-
session_config when is_list(session_config) ->
549-
connect_session(conn, endpoint, init_session(session_config), opts)
550-
551-
other ->
552-
raise ArgumentError,
553-
"the MFA given to `session_config` must return a keyword list, got: #{inspect(other)}"
524+
nil
554525
end
555526
end
556527

test/phoenix/integration/websocket_channels_test.exs

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -162,52 +162,72 @@ defmodule Phoenix.Integration.WebSocketChannelsTest do
162162
def handle_error(conn, :rate_limit), do: Plug.Conn.send_resp(conn, 429, "Too many requests")
163163
end
164164

165-
defmodule Endpoint do
166-
use Phoenix.Endpoint, otp_app: :phoenix
165+
defmodule SetSession do
166+
import Plug.Conn
167167

168-
@session_config store: :cookie,
169-
key: "_hello_key",
170-
signing_salt: "change_me"
168+
def init(opts), do: opts
171169

172-
socket "/ws", UserSocket,
173-
websocket: [
174-
check_origin: ["//example.com"],
175-
timeout: 200,
176-
error_handler: {UserSocket, :handle_error, []}
177-
]
170+
def call(conn, _) do
171+
conn
172+
|> put_session(:from_session, "123")
173+
|> send_resp(200, Plug.CSRFProtection.get_csrf_token())
174+
end
175+
end
178176

179-
socket "/ws/admin", UserSocket,
180-
websocket: [
181-
check_origin: ["//example.com"],
182-
timeout: 200
183-
]
177+
defmodule Router do
178+
use Phoenix.Router
179+
import Phoenix.Socket.Router
180+
181+
get "/", SetSession, []
182+
183+
scope "/ws" do
184+
socket "/", UserSocket,
185+
websocket: [
186+
check_origin: ["//example.com"],
187+
timeout: 200,
188+
error_handler: {UserSocket, :handle_error, []}
189+
]
190+
191+
socket "/admin", UserSocket,
192+
websocket: [
193+
check_origin: ["//example.com"],
194+
timeout: 200
195+
]
184196

185-
socket "/ws/connect_info", UserSocketConnectInfo,
186-
websocket: [
187-
check_origin: ["//example.com"],
188-
timeout: 200,
189-
connect_info: [
190-
:trace_context_headers,
191-
:x_headers,
192-
:peer_data,
193-
:uri,
194-
:user_agent,
195-
:sec_websocket_headers,
196-
session: @session_config,
197-
signing_salt: "salt",
197+
socket "/connect_info", UserSocketConnectInfo,
198+
websocket: [
199+
check_origin: ["//example.com"],
200+
timeout: 200,
201+
connect_info: [
202+
:trace_context_headers,
203+
:x_headers,
204+
:peer_data,
205+
:uri,
206+
:user_agent,
207+
:sec_websocket_headers,
208+
:session,
209+
signing_salt: "salt"
210+
]
198211
]
212+
end
213+
end
214+
215+
defmodule Endpoint do
216+
use Phoenix.Endpoint,
217+
otp_app: :phoenix,
218+
start_sockets: [
219+
{UserSocket, []},
220+
{UserSocketConnectInfo, []}
199221
]
200222

201-
plug Plug.Session, @session_config
223+
plug Plug.Session,
224+
store: :cookie,
225+
key: "_hello_key",
226+
signing_salt: "change_me"
227+
202228
plug :fetch_session
203229
plug Plug.CSRFProtection
204-
plug :put_session
205-
206-
defp put_session(conn, _) do
207-
conn
208-
|> put_session(:from_session, "123")
209-
|> send_resp(200, Plug.CSRFProtection.get_csrf_token())
210-
end
230+
plug Router
211231
end
212232

213233
setup %{adapter: adapter} do
@@ -540,7 +560,7 @@ defmodule Phoenix.Integration.WebSocketChannelsTest do
540560
{:ok, _} = WebsocketClient.connect(self(), @vsn_path, @serializer)
541561
end)
542562

543-
assert log == ""
563+
refute log =~ "UserSocket"
544564
end
545565

546566
test "logs and filter params on join and handle_in" do

0 commit comments

Comments
 (0)