diff --git a/lectures/Discriminative Classification.jl b/lectures/Discriminative Classification.jl index 005119f..28de3fd 100644 --- a/lectures/Discriminative Classification.jl +++ b/lectures/Discriminative Classification.jl @@ -34,7 +34,7 @@ using Random, Plots, LaTeXStrings using MarkdownLiteral: @mdx # ╔═╡ 616e84d7-063d-4d9d-99e4-56aecf3c7ee4 -using Distributions, ExponentialFamily, LinearAlgebra, LogExpFunctions, StatsFuns, BayesBase, Optim +using Distributions, ExponentialFamily, LinearAlgebra, LogExpFunctions, StatsFuns, BayesBase, Optim, StableRNGs # ╔═╡ 25eefb10-d294-11ef-0734-2daf18636e8e title("Discriminative Classification") @@ -74,7 +74,7 @@ Our task will be the same as in the preceding class on (generative) classificati # ╔═╡ 4ceede48-a4d5-446b-bb34-26cec4af357a begin - N_bond = @bindname N Slider(9:200; default=120, show_value=true) + N_bond = @bindname N Slider(8:200; default=120, show_value=true) end # ╔═╡ 7e7cab21-09ab-4d06-9716-ab7864b229ab @@ -452,7 +452,7 @@ Note that we get a full predictive posterior distribution over the assignment of # ╔═╡ 98ef7093-f8ed-4a44-a153-0a64ab483f65 md""" -## Implementation Issues +## Implementation """ # ╔═╡ 0045e569-dc3c-4998-86da-9d96f599c599 @@ -680,7 +680,7 @@ md""" # ╔═╡ fcec3c3a-8b0b-4dfd-b010-66abbf330069 function generate_dataset(N::Int64) - Random.seed!(1234) + rng = StableRNG(984289) # Generate dataset {(x1,y1),...,(xN,yN)} # x is a 2d feature vector [x1;x2] # y ∈ {false,true} is a binary class label @@ -688,16 +688,16 @@ function generate_dataset(N::Int64) # srand(123) X = Matrix{Float64}(undef,2,N); y = Vector{Bool}(undef,N) for n=1:N - if (y[n]=(rand()>0.6)) # p(y=true) = 0.6 + if (y[n]=(rand(rng)>0.6)) # p(y=true) = 0.6 # Sample class 1 conditional distribution - if rand()<0.5 - X[:,n] = [6.0; 0.5] .* rand(2) .+ [3.0; 6.0] + if rand(rng)<0.5 + X[:,n] = [6.0; 0.5] .* rand(rng, 2) .+ [3.0; 6.0] else - X[:,n] = sqrt(0.5) * randn(2) .+ [5.5, 0.0] + X[:,n] = sqrt(0.5) * randn(rng, 2) .+ [5.5, 0.0] end else # Sample class 2 conditional distribution - X[:,n] = randn(2) .+ [1., 4.] + X[:,n] = randn(rng, 2) .+ [1., 4.] end end @@ -843,6 +843,7 @@ MarkdownLiteral = "736d6165-7244-6769-4267-6b50796e6954" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] @@ -854,7 +855,8 @@ LaTeXStrings = "~1.4.0" LogExpFunctions = "~0.3.29" MarkdownLiteral = "~0.1.2" Optim = "~1.13.2" -Plots = "~1.40.17" +Plots = "~1.40.20" +StableRNGs = "~1.0.4" StatsFuns = "~1.5.0" """ @@ -862,9 +864,9 @@ StatsFuns = "~1.5.0" PLUTO_MANIFEST_TOML_CONTENTS = """ # This file is machine-generated - editing it directly is not advised -julia_version = "1.12.1" +julia_version = "1.12.2" manifest_format = "2.0" -project_hash = "6c3dd68af2e626ed33e0c26a2dd1c4e83dc100b4" +project_hash = "88b4b8c9292999e875bc97e1e761ba61613bdfda" [[deps.ADTypes]] git-tree-sha1 = "27cecae79e5cc9935255f90c53bb831cc3c870d7" @@ -1237,7 +1239,7 @@ version = "0.7.16" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" +version = "1.7.0" [[deps.EnumX]] git-tree-sha1 = "bddad79635af6aec424f53ed8aad5d7555dc6f00" @@ -1562,7 +1564,7 @@ version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.11.1+1" +version = "8.15.0+0" [[deps.LibGit2]] deps = ["LibGit2_jll", "NetworkOptions", "Printf", "SHA"] @@ -1749,7 +1751,7 @@ version = "1.5.0" [[deps.OpenSSL_jll]] deps = ["Artifacts", "Libdl"] uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.5.1+0" +version = "3.5.4+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl"] @@ -2047,9 +2049,9 @@ version = "2.6.1" [[deps.StableRNGs]] deps = ["Random"] -git-tree-sha1 = "95af145932c2ed859b63329952ce8d633719f091" +git-tree-sha1 = "4f96c596b8c8258cc7d3b19797854d368f243ddc" uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" -version = "1.0.3" +version = "1.0.4" [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] @@ -2453,9 +2455,9 @@ uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" version = "1.64.0+1" [[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.5.0+2" +version = "17.7.0+0" [[deps.x264_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -2545,12 +2547,12 @@ version = "1.9.2+0" # ╠═a759653c-0da4-40b7-9e9e-1e3d2e4df4ea # ╠═ad196ae6-c65e-4aaa-b0cc-bd72daa41952 # ╠═616e84d7-063d-4d9d-99e4-56aecf3c7ee4 -# ╠═fcec3c3a-8b0b-4dfd-b010-66abbf330069 +# ╟─fcec3c3a-8b0b-4dfd-b010-66abbf330069 # ╟─a65ca01a-0e9a-42cb-b1d7-648102a77eb5 # ╟─b8790891-1546-48e0-9f96-e09eade31c12 # ╟─b48f8800-473d-48e4-ab78-eb07653db7a5 -# ╠═1bfac9c5-e5cf-4a70-b077-11bb00cb1482 -# ╠═fd908bf5-71a1-4ae8-8416-cc1fdf084dcb +# ╟─1bfac9c5-e5cf-4a70-b077-11bb00cb1482 +# ╟─fd908bf5-71a1-4ae8-8416-cc1fdf084dcb # ╟─cf829697-6283-4d2f-b0dd-bbfbd689a145 # ╠═5e4bb719-ea9b-4a30-8800-5d753f405fd1 # ╠═6b56ec96-4b9d-4281-bd63-061df324867f diff --git a/mlss/Bayesian Machine Learning.jl b/mlss/Bayesian Machine Learning.jl index d97d50f..49f3a2e 100644 --- a/mlss/Bayesian Machine Learning.jl +++ b/mlss/Bayesian Machine Learning.jl @@ -766,6 +766,33 @@ N_bond # ╔═╡ 679ef9d1-cc1c-4fc1-bf82-caa967c196c2 example("Bayesian Logistic Regression (Classification)",header_level=2) +# ╔═╡ 2be6440b-4f94-4632-8f9b-0d20bf69ac39 +TODO("introduction") + +# ╔═╡ 13cca61b-ee30-4cbc-b267-d77b1f51be6c +begin + N2_bond = @bindname N2 Slider(8:200; default=120, show_value=true) +end + +# ╔═╡ 3f6cf81b-8d43-4c46-a3dc-c7517f53322c +md""" +### Generative classification +""" + +# ╔═╡ ed550951-58ad-4b29-baac-1cc6b8681b4b +TODO("Feel free to remove this if not needed") + +# ╔═╡ 27d3631b-cd58-4098-9ea8-b8bbdb7bb8c2 +N2_bond + +# ╔═╡ a893baee-217a-4dfe-9641-2d98cd769956 +md""" +### Discriminative classifcation +""" + +# ╔═╡ 4d38fa93-ddac-4e73-b2c5-f1d8c6fb9b38 +N2_bond + # ╔═╡ b5d0f64a-82bf-4fe3-a1c8-696d6ef29d11 md""" See [Bayesian Logistic Regression lecture](https://bmlip.github.io/course/lectures/Discriminative%20Classification.html#Challenge-Revisited:-Bayesian-Logistic-Regression). @@ -835,9 +862,6 @@ priors = [ Beta(8., 13.) ]; -# ╔═╡ d1d2bb84-7083-435a-9c19-4c02074143e3 - - # ╔═╡ 9c751f8e-f7ed-464f-b63c-41e318bbff2d precomputed_tosses = rand(StableRNG(234), secret_distribution, 500) @@ -1130,7 +1154,7 @@ D_sample_controls = PlutoUI.ExperimentalLayout.Div( """ -) +); # ╔═╡ 49879bbf-ab9a-4bf0-b174-0a5be6eb0005 D_sample_controls @@ -1144,6 +1168,153 @@ D_sample_controls # ╔═╡ bd0058fe-3b38-49f5-af3c-c1e7678dd431 D_sample_controls +# ╔═╡ 830e8d28-2b0a-48f2-829b-6254fa6de065 +md""" +### Discriminative classification +""" + +# ╔═╡ 6c6b7c68-2e5f-44f8-be0b-11777e46a767 +function generate_dataset(N::Int64) + rng = StableRNG(984289) + # Generate dataset {(x1,y1),...,(xN,yN)} + # x is a 2d feature vector [x1;x2] + # y ∈ {false,true} is a binary class label + # p(x|y) is multi-modal (mixture of uniform and Gaussian distributions) + # srand(123) + X = Matrix{Float64}(undef,2,N); y = Vector{Bool}(undef,N) + for n=1:N + if (y[n]=(rand(rng)>0.6)) # p(y=true) = 0.6 + # Sample class 1 conditional distribution + if rand(rng)<0.5 + X[:,n] = [6.0; 0.5] .* rand(rng, 2) .+ [3.0; 6.0] + else + X[:,n] = sqrt(0.5) * randn(rng, 2) .+ [5.5, 0.0] + end + else + # Sample class 2 conditional distribution + X[:,n] = randn(rng, 2) .+ [1., 4.] + end + end + + return (X, y) +end + +# ╔═╡ d5955286-15c0-4723-b418-da54f675c59e +X, y = generate_dataset(N2); # Generate data set, collect in matrix X and vector y + +# ╔═╡ fee6e410-3e62-42ae-80ba-8e774b7ceb1e +X_c1 = X[:,findall(.!y)]' # Split X based on class label + +# ╔═╡ b75a418f-471c-421e-9eca-0f4157716cea +X_c2 = X[:,findall(y)]' + +# ╔═╡ c1492ac3-4692-4bb9-8c51-ce3981af2aea +X_test = [3.75; 1.0]; # Features of 'new' data point + +# ╔═╡ b11a03b6-3091-4031-a59c-5ae5a4cace3f +function plot_dataset() + result = scatter(X_c1[:,1], X_c1[:,2],markersize=4, label=L"y=0", xlabel=L"x_1", ylabel=L"x_2", xlims=(-1.6, 9), ylims=(-2, 7)) + scatter!(X_c2[:,1], X_c2[:,2],markersize=4, label=L"y=1") + scatter!([X_test[1]], [X_test[2]], markersize=7, marker=:star, label=L"y=?") + plot!(legend=:bottomright) + return result +end + +# ╔═╡ ccbf6bd8-3b77-4a3b-9272-6e63495ddaf3 +plot_dataset() + +# ╔═╡ 4093dea9-2ad3-4384-a414-06a6726d4660 +let + d1 = fit_mle(MvNormal, X_c1') + d2 = fit_mle(MvNormal, X_c2') + + plot_dataset() + + xrange = range(-1.6, 9; length=20) + yrange = range(-2, 7; length=15) + + contour!( + xrange, yrange, + (x,y) -> pdf(d1, [x,y]); + opacity=.4, + color=:blues, + ) + + + contour!( + xrange, yrange, + (x,y) -> pdf(d2, [x,y]); + opacity=.4, + color=:red, + colorbar=nothing, + ) +end + +# ╔═╡ f2340f0e-a170-4386-a616-47edd748704d +""" +Computes the predictive posterior eq. B-4.152 using the given approximation to the sigmoid function. +""" +function predictive_posterior(x, weight_posterior) + λsq = π / 8 + wN = mean(weight_posterior) + μ = wN' * x + σ = x' * cov(weight_posterior) * x + query_point = μ / (sqrt(inv(λsq) + σ )) + return normcdf(0, 1, query_point) +end + +# ╔═╡ 6e6abac8-8800-4f33-8530-3a8f2c797259 +logσ(x) = -softplus(x) + +# ╔═╡ 36d658eb-2244-4290-adc6-84ca4327931e +function log_likelihood(w, X, y) + return sum(logσ.((2*y .- 1) .* (X' * w))) +end + +# ╔═╡ 598a32fa-99cc-4a4c-a622-6fb992916e6a +""" +This function computes the posterior distribution over regression weights using the Laplace Approximation. We use `logσ` as a numerically stable alternative to `logistic`, and we avoid matrix inversions by computing the precision matrix of the posterior distribution instead of the covariance. + +The math in this function corresponds to eq. B-4.143 +""" +function bayesian_discrimination_boundary(prior_w, X::Matrix, y::Vector{Bool}) + m_0 = mean(prior_w) + p_0 = precision(prior_w) + negative_unnormalized_posterior = w -> -log_likelihood(w, X, y) - logpdf(prior_w, w) + MAP_w = Optim.minimizer(optimize(negative_unnormalized_posterior, zeros(3))) + σ_n = logistic.((2y .- 1) .* (X' * MAP_w)) + inv_Σ = p_0 + for i in 1:length(y) + slice = view(X, :, i) + inv_Σ .+= σ_n[i] * (1.0 - σ_n[i]) .* (slice * slice') + end + + return MvNormalMeanPrecision(MAP_w, inv_Σ) +end + +# ╔═╡ 60095e6e-be1e-411c-8c89-753bc8604950 +let + X_ext = vcat(X, ones(1, length(y))) + + # Define a prior distribution over parameters, play with this to see the result change! + prior = MvNormalMeanCovariance(zeros(3), 100 .* diagm(ones(3))) + posterior = bayesian_discrimination_boundary(prior, X_ext, y) + + # Plot 50% boundary + θ = mean(posterior) + disc_boundary(x1) = -1 / θ[2] * (θ[1]*x1 + θ[3]) + plot_dataset() + plot!([-2., 10.], disc_boundary; label="Discr. boundary", linewidth=2) + + # Plot heatmap + xrange = range(-1.6, 9; length=50) + yrange = range(-2, 7; length=30) + heatmap!( + xrange, yrange, (x,y) -> predictive_posterior([x, y, 1], posterior); + alpha=0.5, color=:redblue, + ) +end + # ╔═╡ 00000000-0000-0000-0000-000000000001 PLUTO_PROJECT_TOML_CONTENTS = """ [deps] @@ -3089,6 +3260,16 @@ version = "1.9.2+0" # ╟─1211336b-5fb0-415e-92a8-6ba2b061cb43 # ╟─74640c85-8589-4121-8fdf-d71cb29532b8 # ╟─679ef9d1-cc1c-4fc1-bf82-caa967c196c2 +# ╟─2be6440b-4f94-4632-8f9b-0d20bf69ac39 +# ╟─13cca61b-ee30-4cbc-b267-d77b1f51be6c +# ╟─ccbf6bd8-3b77-4a3b-9272-6e63495ddaf3 +# ╟─3f6cf81b-8d43-4c46-a3dc-c7517f53322c +# ╟─ed550951-58ad-4b29-baac-1cc6b8681b4b +# ╟─27d3631b-cd58-4098-9ea8-b8bbdb7bb8c2 +# ╟─4093dea9-2ad3-4384-a414-06a6726d4660 +# ╟─a893baee-217a-4dfe-9641-2d98cd769956 +# ╟─4d38fa93-ddac-4e73-b2c5-f1d8c6fb9b38 +# ╟─60095e6e-be1e-411c-8c89-753bc8604950 # ╟─b5d0f64a-82bf-4fe3-a1c8-696d6ef29d11 # ╟─47842de0-d17e-460e-b3b7-b2e642569e25 # ╟─b273c8bc-3819-4f63-801a-acf0ee78ef1d @@ -3101,10 +3282,9 @@ version = "1.9.2+0" # ╠═3f8fd1c3-202e-45a6-ab03-5229863db297 # ╠═3987d441-b9c8-4bb1-8b2d-0cc78d78819e # ╟─7a764a14-a5df-4f76-8836-f0a571fc3519 -# ╠═c28b7130-f7fb-41ee-852e-9964b091d7fb +# ╟─c28b7130-f7fb-41ee-852e-9964b091d7fb # ╠═9da43d0f-e605-41b7-9bc6-db5be95bc87f # ╠═e47b6eb6-2bb3-4c2d-bda6-f1535f2f94c4 -# ╠═d1d2bb84-7083-435a-9c19-4c02074143e3 # ╠═9c751f8e-f7ed-464f-b63c-41e318bbff2d # ╠═3a903a4d-1fb0-4566-8151-9c86dfc40ceb # ╠═e99e7650-bb72-4576-8f2a-c3994533b644 @@ -3131,5 +3311,16 @@ version = "1.9.2+0" # ╠═26369851-1d00-4f48-9e64-6b576af61066 # ╠═280c69a5-b7a4-400f-a810-3b846ff27ec2 # ╠═0a81b382-b01b-459a-8955-9ec8640a57d1 +# ╟─830e8d28-2b0a-48f2-829b-6254fa6de065 +# ╟─6c6b7c68-2e5f-44f8-be0b-11777e46a767 +# ╠═d5955286-15c0-4723-b418-da54f675c59e +# ╠═fee6e410-3e62-42ae-80ba-8e774b7ceb1e +# ╠═b75a418f-471c-421e-9eca-0f4157716cea +# ╠═c1492ac3-4692-4bb9-8c51-ce3981af2aea +# ╟─b11a03b6-3091-4031-a59c-5ae5a4cace3f +# ╟─598a32fa-99cc-4a4c-a622-6fb992916e6a +# ╟─f2340f0e-a170-4386-a616-47edd748704d +# ╟─6e6abac8-8800-4f33-8530-3a8f2c797259 +# ╟─36d658eb-2244-4290-adc6-84ca4327931e # ╟─00000000-0000-0000-0000-000000000001 # ╟─00000000-0000-0000-0000-000000000002