From 5c5162b2a69a39e42de7565431b5242773f1353a Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 7 Mar 2024 14:38:40 -0500 Subject: [PATCH 1/5] Outline basic algorithm --- lib/polaris/updates.ex | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/lib/polaris/updates.ex b/lib/polaris/updates.ex index b259a4a..da1fda6 100644 --- a/lib/polaris/updates.ex +++ b/lib/polaris/updates.ex @@ -1003,6 +1003,48 @@ defmodule Polaris.Updates do Map.merge(params, state, fn _, s1, s2 -> merge_inner(s1, s2) end) end + @doc """ + Applies the GaLore algorithm to an optimizer for low-memory + training. + """ + def galore({parent_init_fn, parent_apply_fn}, opts \\ []) do + init_fn = fn params -> + # on initialization, project down so we initialize parent + # state with low-rank version + {projected, ortho_matrix} = apply_galore_projection_down(params) + parent_init_fn.(projected) + end + + apply_fn = fn updates, %{galore_matrix: ortho_matrix, state: state}, params -> + {projected, ortho_matrix} = apply_galore_projection_down(updates) + {scaled_updates, new_state} = parent_apply_fn.(projected) + updates = apply_galore_projection_up(scaled_updates, ortho_matrix) + {updates, new_state} + end + end + + defnp apply_galore_projection_down(params, opts \\ []) do + opts = keyword!(opts, scale: 1.0) + + ortho_matrix = deep_new(params, fn g -> + get_orthogonal_matrix(g) + end) + + projected = deep_merge(params, ortho_matrix, fn g, ortho -> + Nx.dot(Nx.transpose(ortho_matrix), g) + end) + + {projected, ortho_matrix} + end + + defnp apply_galore_projection_down(params, ortho_matrix, opts \\ []) do + opts = keyword!(opts, scale: 1.0) + + deep_merge(params, ortho_matrix, fn g, ortho -> + Nx.dot(ortho, g) + end) + end + ## Helpers defnp update_moment(x, moment, decay, order) do From b210e8039832216b5201a20ea24d847815d0c3fa Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 7 Mar 2024 14:46:46 -0500 Subject: [PATCH 2/5] Fix all warnings --- lib/polaris/updates.ex | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/lib/polaris/updates.ex b/lib/polaris/updates.ex index da1fda6..e63f7e2 100644 --- a/lib/polaris/updates.ex +++ b/lib/polaris/updates.ex @@ -1008,43 +1008,56 @@ defmodule Polaris.Updates do training. """ def galore({parent_init_fn, parent_apply_fn}, opts \\ []) do + opts = Keyword.validate!(opts, rank: 128, scale: 1.0) + init_fn = fn params -> # on initialization, project down so we initialize parent # state with low-rank version - {projected, ortho_matrix} = apply_galore_projection_down(params) + {projected, _ortho_matrix} = apply_galore_projection_down(params, opts) parent_init_fn.(projected) end - apply_fn = fn updates, %{galore_matrix: ortho_matrix, state: state}, params -> - {projected, ortho_matrix} = apply_galore_projection_down(updates) - {scaled_updates, new_state} = parent_apply_fn.(projected) - updates = apply_galore_projection_up(scaled_updates, ortho_matrix) + apply_fn = fn updates, state, params -> + {projected, ortho_matrix} = apply_galore_projection_down(updates, opts) + {scaled_updates, new_state} = parent_apply_fn.(projected, state, params) + updates = apply_galore_projection_up(scaled_updates, ortho_matrix, opts) {updates, new_state} end + + {init_fn, apply_fn} end defnp apply_galore_projection_down(params, opts \\ []) do - opts = keyword!(opts, scale: 1.0) + opts = keyword!(opts, rank: 128, scale: 1.0) - ortho_matrix = deep_new(params, fn g -> - get_orthogonal_matrix(g) - end) + ortho_matrix = + deep_new(params, fn g -> + get_orthogonal_matrix(g, rank: opts[:rank]) + end) - projected = deep_merge(params, ortho_matrix, fn g, ortho -> - Nx.dot(Nx.transpose(ortho_matrix), g) - end) + projected = + deep_merge(params, ortho_matrix, fn g, ortho -> + Nx.dot(g, Nx.transpose(ortho)) + end) {projected, ortho_matrix} end - defnp apply_galore_projection_down(params, ortho_matrix, opts \\ []) do + defnp apply_galore_projection_up(params, ortho_matrix, opts \\ []) do opts = keyword!(opts, scale: 1.0) deep_merge(params, ortho_matrix, fn g, ortho -> - Nx.dot(ortho, g) + opts[:scale] * Nx.dot(g, ortho) end) end + defnp get_orthogonal_matrix(g, opts \\ []) do + opts = keyword!(opts, rank: 128) + + {_u, _s, vh} = Nx.LinAlg.svd(g) + vh[[0..opts[:rank], ..]] + end + ## Helpers defnp update_moment(x, moment, decay, order) do From 03921775a632e7d59ff539eca70bbbe762e5d8bf Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 7 Mar 2024 15:11:04 -0500 Subject: [PATCH 3/5] Fixees --- lib/polaris/updates.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/polaris/updates.ex b/lib/polaris/updates.ex index e63f7e2..5443257 100644 --- a/lib/polaris/updates.ex +++ b/lib/polaris/updates.ex @@ -1044,7 +1044,7 @@ defmodule Polaris.Updates do end defnp apply_galore_projection_up(params, ortho_matrix, opts \\ []) do - opts = keyword!(opts, scale: 1.0) + opts = keyword!(opts, rank: 128, scale: 1.0) deep_merge(params, ortho_matrix, fn g, ortho -> opts[:scale] * Nx.dot(g, ortho) From a3838f10e5e26bfa52e0f332eedc06de72cdc3c4 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 7 Mar 2024 16:06:12 -0500 Subject: [PATCH 4/5] Add a way to filter params --- lib/polaris/updates.ex | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/lib/polaris/updates.ex b/lib/polaris/updates.ex index 5443257..5b2bc67 100644 --- a/lib/polaris/updates.ex +++ b/lib/polaris/updates.ex @@ -1007,21 +1007,24 @@ defmodule Polaris.Updates do Applies the GaLore algorithm to an optimizer for low-memory training. """ - def galore({parent_init_fn, parent_apply_fn}, opts \\ []) do + def galore({parent_init_fn, parent_apply_fn}, galore_params, opts \\ []) do opts = Keyword.validate!(opts, rank: 128, scale: 1.0) init_fn = fn params -> # on initialization, project down so we initialize parent # state with low-rank version - {projected, _ortho_matrix} = apply_galore_projection_down(params, opts) - parent_init_fn.(projected) + {galore, regular} = Map.split(params, galore_params) + {projected, _ortho_matrix} = apply_galore_projection_down(galore, opts) + parent_init_fn.(Map.merge(projected, regular)) end apply_fn = fn updates, state, params -> - {projected, ortho_matrix} = apply_galore_projection_down(updates, opts) - {scaled_updates, new_state} = parent_apply_fn.(projected, state, params) - updates = apply_galore_projection_up(scaled_updates, ortho_matrix, opts) - {updates, new_state} + {galore, regular} = Map.split(updates, galore_params) + {projected, ortho_matrix} = apply_galore_projection_down(galore, opts) + {scaled_updates, new_state} = parent_apply_fn.(Map.merge(projected, regular), state, params) + {galore, regular} = Map.split(scaled_updates, galore_params) + galore_updates = apply_galore_projection_up(galore, ortho_matrix, opts) + {Map.merge(galore_updates, regular), new_state} end {init_fn, apply_fn} From 125de8615c2472cee8f3c4db1965cf871e5dd564 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Sat, 9 Mar 2024 09:03:01 -0500 Subject: [PATCH 5/5] Do not use full matrices --- lib/polaris/updates.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/polaris/updates.ex b/lib/polaris/updates.ex index 5b2bc67..91829d4 100644 --- a/lib/polaris/updates.ex +++ b/lib/polaris/updates.ex @@ -1057,7 +1057,7 @@ defmodule Polaris.Updates do defnp get_orthogonal_matrix(g, opts \\ []) do opts = keyword!(opts, rank: 128) - {_u, _s, vh} = Nx.LinAlg.svd(g) + {_u, _s, vh} = Nx.LinAlg.svd(g, full_matrices?: false) vh[[0..opts[:rank], ..]] end