Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions lectures/Discriminative Classification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ using Random, Plots, LaTeXStrings
using MarkdownLiteral: @mdx

# ╔═╡ 616e84d7-063d-4d9d-99e4-56aecf3c7ee4
using Distributions, ExponentialFamily, LinearAlgebra, LogExpFunctions, StatsFuns, BayesBase, Optim
using Distributions, ExponentialFamily, LinearAlgebra, LogExpFunctions, StatsFuns, BayesBase, Optim, StableRNGs

# ╔═╡ 25eefb10-d294-11ef-0734-2daf18636e8e
title("Discriminative Classification")
Expand Down Expand Up @@ -74,7 +74,7 @@ Our task will be the same as in the preceding class on (generative) classificati

# ╔═╡ 4ceede48-a4d5-446b-bb34-26cec4af357a
begin
N_bond = @bindname N Slider(9:200; default=120, show_value=true)
N_bond = @bindname N Slider(8:200; default=120, show_value=true)
end

# ╔═╡ 7e7cab21-09ab-4d06-9716-ab7864b229ab
Expand Down Expand Up @@ -452,7 +452,7 @@ Note that we get a full predictive posterior distribution over the assignment of

# ╔═╡ 98ef7093-f8ed-4a44-a153-0a64ab483f65
md"""
## Implementation Issues
## Implementation
"""

# ╔═╡ 0045e569-dc3c-4998-86da-9d96f599c599
Expand Down Expand Up @@ -680,24 +680,24 @@ md"""

# ╔═╡ fcec3c3a-8b0b-4dfd-b010-66abbf330069
function generate_dataset(N::Int64)
Random.seed!(1234)
rng = StableRNG(984289)
# Generate dataset {(x1,y1),...,(xN,yN)}
# x is a 2d feature vector [x1;x2]
# y ∈ {false,true} is a binary class label
# p(x|y) is multi-modal (mixture of uniform and Gaussian distributions)
# srand(123)
X = Matrix{Float64}(undef,2,N); y = Vector{Bool}(undef,N)
for n=1:N
if (y[n]=(rand()>0.6)) # p(y=true) = 0.6
if (y[n]=(rand(rng)>0.6)) # p(y=true) = 0.6
# Sample class 1 conditional distribution
if rand()<0.5
X[:,n] = [6.0; 0.5] .* rand(2) .+ [3.0; 6.0]
if rand(rng)<0.5
X[:,n] = [6.0; 0.5] .* rand(rng, 2) .+ [3.0; 6.0]
else
X[:,n] = sqrt(0.5) * randn(2) .+ [5.5, 0.0]
X[:,n] = sqrt(0.5) * randn(rng, 2) .+ [5.5, 0.0]
end
else
# Sample class 2 conditional distribution
X[:,n] = randn(2) .+ [1., 4.]
X[:,n] = randn(rng, 2) .+ [1., 4.]
end
end

Expand Down Expand Up @@ -843,6 +843,7 @@ MarkdownLiteral = "736d6165-7244-6769-4267-6b50796e6954"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
Expand All @@ -854,17 +855,18 @@ LaTeXStrings = "~1.4.0"
LogExpFunctions = "~0.3.29"
MarkdownLiteral = "~0.1.2"
Optim = "~1.13.2"
Plots = "~1.40.17"
Plots = "~1.40.20"
StableRNGs = "~1.0.4"
StatsFuns = "~1.5.0"
"""

# ╔═╡ 00000000-0000-0000-0000-000000000002
PLUTO_MANIFEST_TOML_CONTENTS = """
# This file is machine-generated - editing it directly is not advised

julia_version = "1.12.1"
julia_version = "1.12.2"
manifest_format = "2.0"
project_hash = "6c3dd68af2e626ed33e0c26a2dd1c4e83dc100b4"
project_hash = "88b4b8c9292999e875bc97e1e761ba61613bdfda"

[[deps.ADTypes]]
git-tree-sha1 = "27cecae79e5cc9935255f90c53bb831cc3c870d7"
Expand Down Expand Up @@ -1237,7 +1239,7 @@ version = "0.7.16"
[[deps.Downloads]]
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
version = "1.6.0"
version = "1.7.0"

[[deps.EnumX]]
git-tree-sha1 = "bddad79635af6aec424f53ed8aad5d7555dc6f00"
Expand Down Expand Up @@ -1562,7 +1564,7 @@ version = "0.6.4"
[[deps.LibCURL_jll]]
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "OpenSSL_jll", "Zlib_jll", "nghttp2_jll"]
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
version = "8.11.1+1"
version = "8.15.0+0"

[[deps.LibGit2]]
deps = ["LibGit2_jll", "NetworkOptions", "Printf", "SHA"]
Expand Down Expand Up @@ -1749,7 +1751,7 @@ version = "1.5.0"
[[deps.OpenSSL_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
version = "3.5.1+0"
version = "3.5.4+0"

[[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl"]
Expand Down Expand Up @@ -2047,9 +2049,9 @@ version = "2.6.1"

[[deps.StableRNGs]]
deps = ["Random"]
git-tree-sha1 = "95af145932c2ed859b63329952ce8d633719f091"
git-tree-sha1 = "4f96c596b8c8258cc7d3b19797854d368f243ddc"
uuid = "860ef19b-820b-49d6-a774-d7a799459cd3"
version = "1.0.3"
version = "1.0.4"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
Expand Down Expand Up @@ -2453,9 +2455,9 @@ uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
version = "1.64.0+1"

[[deps.p7zip_jll]]
deps = ["Artifacts", "Libdl"]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
version = "17.5.0+2"
version = "17.7.0+0"

[[deps.x264_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
Expand Down Expand Up @@ -2545,12 +2547,12 @@ version = "1.9.2+0"
# ╠═a759653c-0da4-40b7-9e9e-1e3d2e4df4ea
# ╠═ad196ae6-c65e-4aaa-b0cc-bd72daa41952
# ╠═616e84d7-063d-4d9d-99e4-56aecf3c7ee4
# ╠═fcec3c3a-8b0b-4dfd-b010-66abbf330069
# ╟─fcec3c3a-8b0b-4dfd-b010-66abbf330069
# ╟─a65ca01a-0e9a-42cb-b1d7-648102a77eb5
# ╟─b8790891-1546-48e0-9f96-e09eade31c12
# ╟─b48f8800-473d-48e4-ab78-eb07653db7a5
# ╠═1bfac9c5-e5cf-4a70-b077-11bb00cb1482
# ╠═fd908bf5-71a1-4ae8-8416-cc1fdf084dcb
# ╟─1bfac9c5-e5cf-4a70-b077-11bb00cb1482
# ╟─fd908bf5-71a1-4ae8-8416-cc1fdf084dcb
# ╟─cf829697-6283-4d2f-b0dd-bbfbd689a145
# ╠═5e4bb719-ea9b-4a30-8800-5d753f405fd1
# ╠═6b56ec96-4b9d-4281-bd63-061df324867f
Expand Down
Loading
Loading