diff --git a/lib/polaris/updates.ex b/lib/polaris/updates.ex index b259a4a..91829d4 100644 --- a/lib/polaris/updates.ex +++ b/lib/polaris/updates.ex @@ -1003,6 +1003,64 @@ defmodule Polaris.Updates do Map.merge(params, state, fn _, s1, s2 -> merge_inner(s1, s2) end) end + @doc """ + Applies the GaLore algorithm to an optimizer for low-memory + training. + """ + def galore({parent_init_fn, parent_apply_fn}, galore_params, opts \\ []) do + opts = Keyword.validate!(opts, rank: 128, scale: 1.0) + + init_fn = fn params -> + # on initialization, project down so we initialize parent + # state with low-rank version + {galore, regular} = Map.split(params, galore_params) + {projected, _ortho_matrix} = apply_galore_projection_down(galore, opts) + parent_init_fn.(Map.merge(projected, regular)) + end + + apply_fn = fn updates, state, params -> + {galore, regular} = Map.split(updates, galore_params) + {projected, ortho_matrix} = apply_galore_projection_down(galore, opts) + {scaled_updates, new_state} = parent_apply_fn.(Map.merge(projected, regular), state, params) + {galore, regular} = Map.split(scaled_updates, galore_params) + galore_updates = apply_galore_projection_up(galore, ortho_matrix, opts) + {Map.merge(galore_updates, regular), new_state} + end + + {init_fn, apply_fn} + end + + defnp apply_galore_projection_down(params, opts \\ []) do + opts = keyword!(opts, rank: 128, scale: 1.0) + + ortho_matrix = + deep_new(params, fn g -> + get_orthogonal_matrix(g, rank: opts[:rank]) + end) + + projected = + deep_merge(params, ortho_matrix, fn g, ortho -> + Nx.dot(g, Nx.transpose(ortho)) + end) + + {projected, ortho_matrix} + end + + defnp apply_galore_projection_up(params, ortho_matrix, opts \\ []) do + opts = keyword!(opts, rank: 128, scale: 1.0) + + deep_merge(params, ortho_matrix, fn g, ortho -> + opts[:scale] * Nx.dot(g, ortho) + end) + end + + defnp get_orthogonal_matrix(g, opts \\ []) do + opts = keyword!(opts, rank: 128) + + {_u, _s, vh} = Nx.LinAlg.svd(g, full_matrices?: false) + vh[[0..opts[:rank], ..]] + end + ## Helpers defnp update_moment(x, moment, decay, order) do