diff --git a/lib/scholar/linear/logistic_regression.ex b/lib/scholar/linear/logistic_regression.ex index c49887e2..9230aee8 100644 --- a/lib/scholar/linear/logistic_regression.ex +++ b/lib/scholar/linear/logistic_regression.ex @@ -1,12 +1,11 @@ defmodule Scholar.Linear.LogisticRegression do @moduledoc """ - Logistic regression in both binary and multinomial variants. + Multiclass logistic regression. Time complexity is $O(N * K * I)$ where $N$ is the number of samples, $K$ is the number of features, and $I$ is the number of iterations. """ import Nx.Defn import Scholar.Shared - alias Scholar.Linear.LinearHelpers @derive {Nx.Container, containers: [:coefficients, :bias]} defstruct [:coefficients, :bias] @@ -15,35 +14,28 @@ defmodule Scholar.Linear.LogisticRegression do num_classes: [ required: true, type: :pos_integer, - doc: "number of classes contained in the input tensors." + doc: "Number of output classes." ], - iterations: [ + max_iterations: [ type: :pos_integer, default: 1000, - doc: """ - number of iterations of gradient descent performed inside logistic - regression. - """ + doc: "Maximum number of gradient descent iterations to perform." ], - learning_loop_unroll: [ - type: :boolean, - default: false, - doc: ~S""" - If `true`, the learning loop is unrolled. + alpha: [ + type: {:custom, Scholar.Options, :non_negative_number, []}, + default: 1.0, + doc: """ + Constant that multiplies the L2 regularization term, controlling regularization strength. + If 0, no regularization is applied. """ ], - optimizer: [ - type: {:custom, Scholar.Options, :optimizer, []}, - default: :sgd, + tol: [ + type: {:custom, Scholar.Options, :non_negative_number, []}, + default: 1.0e-4, doc: """ - The optimizer name or {init, update} pair of functions (see `Polaris.Optimizers` for more details). + Convergence tolerance. If the infinity norm of the gradient is less than `:tol`, + the algorithm is considered to have converged. """ - ], - eps: [ - type: :float, - default: 1.0e-8, - doc: - "The convergence tolerance. If the `abs(loss) < size(x) * :eps`, the algorithm is considered to have converged." ] ] @@ -53,9 +45,6 @@ defmodule Scholar.Linear.LogisticRegression do Fits a logistic regression model for sample inputs `x` and sample targets `y`. - Depending on number of classes the function chooses either binary - or multinomial logistic regression. - ## Options #{NimbleOptions.docs(@opts_schema)} @@ -68,10 +57,6 @@ defmodule Scholar.Linear.LogisticRegression do * `:bias` - Bias added to the decision function. - * `:mode` - Indicates whether the problem is binary classification (`:num_classes` set to 2) - or multinomial (`:num_classes` is bigger than 2). For binary classification set to `:binary`, otherwise - set to `:multinomial`. - ## Examples iex> x = Nx.tensor([[1.0, 2.0], [3.0, 2.0], [4.0, 7.0]]) @@ -80,134 +65,177 @@ defmodule Scholar.Linear.LogisticRegression do %Scholar.Linear.LogisticRegression{ coefficients: Nx.tensor( [ - [2.5531527996063232, -0.5531544089317322], - [-0.35652396082878113, 2.3565237522125244] + [0.0915902629494667, -0.09159023314714432], + [-0.1507941037416458, 0.1507941335439682] ] ), - bias: Nx.tensor( - [-0.28847914934158325, 0.28847917914390564] - ) + bias: Nx.tensor([-0.06566660106182098, 0.06566664576530457]) } """ deftransform fit(x, y, opts \\ []) do if Nx.rank(x) != 2 do raise ArgumentError, - "expected x to have shape {n_samples, n_features}, got tensor with shape: #{inspect(Nx.shape(x))}" + "expected x to have shape {num_samples, num_features}, got tensor with shape: #{inspect(Nx.shape(x))}" end - {n_samples, _} = Nx.shape(x) - y = LinearHelpers.validate_y_shape(y, n_samples, __MODULE__) - - opts = NimbleOptions.validate!(opts, @opts_schema) - - {optimizer, opts} = Keyword.pop!(opts, :optimizer) - - {optimizer_init_fn, optimizer_update_fn} = - case optimizer do - atom when is_atom(atom) -> apply(Polaris.Optimizers, atom, []) - {f1, f2} -> {f1, f2} - end + if Nx.rank(y) != 1 do + raise ArgumentError, + "expected y to have shape {num_samples}, got tensor with shape: #{inspect(Nx.shape(y))}" + end - n = Nx.axis_size(x, -1) - num_classes = opts[:num_classes] + num_samples = Nx.axis_size(x, 0) - coef = - Nx.broadcast( - Nx.tensor(1.0, type: to_float_type(x)), - {n, num_classes} - ) + if Nx.axis_size(y, 0) != num_samples do + raise ArgumentError, + "expected x and y to have the same number of samples, got #{num_samples} and #{Nx.axis_size(y, 0)}" + end - bias = Nx.broadcast(Nx.tensor(0, type: to_float_type(x)), {num_classes}) + opts = NimbleOptions.validate!(opts, @opts_schema) - coef_optimizer_state = optimizer_init_fn.(coef) |> as_type(to_float_type(x)) - bias_optimizer_state = optimizer_init_fn.(bias) |> as_type(to_float_type(x)) + type = to_float_type(x) - opts = Keyword.put(opts, :optimizer_update_fn, optimizer_update_fn) + {alpha, opts} = Keyword.pop!(opts, :alpha) + alpha = Nx.tensor(alpha, type: type) + {tol, opts} = Keyword.pop!(opts, :tol) + tol = Nx.tensor(tol, type: type) - fit_n(x, y, coef, bias, coef_optimizer_state, bias_optimizer_state, opts) + fit_n(x, y, alpha, tol, opts) end - deftransformp as_type(container, target_type) do - Nx.Defn.Composite.traverse(container, fn t -> - type = Nx.type(t) + defnp fit_n(x, y, alpha, tol, opts) do + num_classes = opts[:num_classes] + max_iterations = opts[:max_iterations] + {num_samples, num_features} = Nx.shape(x) - if Nx.Type.float?(type) and not Nx.Type.complex?(type) do - Nx.as_type(t, target_type) - else - t - end - end) - end + type = to_float_type(x) - # Logistic Regression training loop + # Initialize weights and bias with zeros + w = + Nx.broadcast( + Nx.tensor(0.0, type: type), + {num_features, num_classes} + ) - defnp fit_n(x, y, coef, bias, coef_optimizer_state, bias_optimizer_state, opts) do - num_samples = Nx.axis_size(x, 0) - iterations = opts[:iterations] - num_classes = opts[:num_classes] - optimizer_update_fn = opts[:optimizer_update_fn] + b = Nx.broadcast(Nx.tensor(0.0, type: type), {num_classes}) + # One-hot encoding of target labels y_one_hot = y |> Nx.new_axis(1) |> Nx.broadcast({num_samples, num_classes}) |> Nx.equal(Nx.iota({num_samples, num_classes}, axis: 1)) - {{final_coef, final_bias}, _} = - while {{coef, bias}, - {x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state, - has_converged = Nx.u8(0), iter = 0}}, - iter < iterations and not has_converged do - {loss, {coef_grad, bias_grad}} = loss_and_grad(coef, bias, x, y_one_hot) - - {coef_updates, coef_optimizer_state} = - optimizer_update_fn.(coef_grad, coef_optimizer_state, coef) - - coef = Polaris.Updates.apply_updates(coef, coef_updates) - - {bias_updates, bias_optimizer_state} = - optimizer_update_fn.(bias_grad, bias_optimizer_state, bias) + # Define Armijo parameters + c = Nx.tensor(1.0e-4, type: type) + rho = Nx.tensor(0.5, type: type) - bias = Polaris.Updates.apply_updates(bias, bias_updates) + eta_min = + case type do + {:f, 32} -> Nx.tensor(1.0e-6, type: type) + {:f, 64} -> Nx.tensor(1.0e-8, type: type) + _ -> Nx.tensor(1.0e-6, type: type) + end - has_converged = Nx.sum(Nx.abs(loss)) < Nx.size(x) * opts[:eps] + armijo_params = %{ + c: c, + rho: rho, + eta_min: eta_min + } - {{coef, bias}, - {x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state, has_converged, - iter + 1}} + {coef, bias, _} = + while {w, b, + {alpha, x, y_one_hot, tol, armijo_params, iter = Nx.u32(0), converged? = Nx.u8(0)}}, + iter < max_iterations and not converged? do + logits = Nx.dot(x, w) + b + probabilities = softmax(logits) + residuals = probabilities - y_one_hot + + # Compute loss + loss = + logits + |> log_softmax() + |> Nx.multiply(y_one_hot) + |> Nx.sum(axes: [1]) + |> Nx.mean() + |> Nx.negate() + |> Nx.add(alpha * Nx.sum(w * w)) + + # Compute gradients + grad_w = Nx.dot(x, [0], residuals, [0]) / num_samples + 2 * alpha * w + grad_b = Nx.sum(residuals, axes: [0]) / num_samples + + # Perform line search to find step size + eta = + armijo_line_search(w, b, alpha, x, y_one_hot, loss, grad_w, grad_b, armijo_params) + + w = w - eta * grad_w + b = b - eta * grad_b + + converged? = + Nx.reduce_max(Nx.abs(grad_w)) < tol and Nx.reduce_max(Nx.abs(grad_b)) < tol + + {w, b, {alpha, x, y_one_hot, tol, armijo_params, iter + 1, converged?}} end %__MODULE__{ - coefficients: final_coef, - bias: final_bias + coefficients: coef, + bias: bias } end - defnp loss_and_grad(coeff, bias, xs, ys) do - value_and_grad({coeff, bias}, fn {coeff, bias} -> - -Nx.sum(ys * log_softmax(Nx.dot(xs, coeff) + bias), axes: [-1]) - end) + defnp armijo_line_search(w, b, alpha, x, y, loss, grad_w, grad_b, armijo_params) do + c = armijo_params[:c] + rho = armijo_params[:rho] + eta_min = armijo_params[:eta_min] + + type = to_float_type(x) + dir_w = -grad_w + dir_b = -grad_b + # Directional derivative + slope = Nx.sum(dir_w * grad_w) + Nx.sum(dir_b * grad_b) + + {eta, _} = + while {eta = Nx.tensor(1.0, type: type), + {w, b, alpha, x, y, loss, dir_w, dir_b, slope, c, rho, eta_min}}, + compute_loss(w + eta * dir_w, b + eta * dir_b, alpha, x, y) > loss + c * eta * slope and + eta > eta_min do + eta = eta * rho + + {eta, {w, b, alpha, x, y, loss, dir_w, dir_b, slope, c, rho, eta_min}} + end + + eta + end + + defnp compute_loss(w, b, alpha, x, y) do + x + |> Nx.dot(w) + |> Nx.add(b) + |> log_softmax() + |> Nx.multiply(y) + |> Nx.sum(axes: [1]) + |> Nx.mean() + |> Nx.negate() + |> Nx.add(alpha * Nx.sum(w * w)) + end + + defnp softmax(logits) do + max = stop_grad(Nx.reduce_max(logits, axes: [1], keep_axes: true)) + normalized_exp = (logits - max) |> Nx.exp() + normalized_exp / Nx.sum(normalized_exp, axes: [1], keep_axes: true) end defnp log_softmax(x) do - shifted = x - stop_grad(Nx.reduce_max(x, axes: [-1], keep_axes: true)) + shifted = x - stop_grad(Nx.reduce_max(x, axes: [1], keep_axes: true)) shifted |> Nx.exp() - |> Nx.sum(axes: [-1], keep_axes: true) + |> Nx.sum(axes: [1], keep_axes: true) |> Nx.log() |> Nx.negate() |> Nx.add(shifted) end - # Normalized softmax - - defnp softmax(t) do - max = stop_grad(Nx.reduce_max(t, axes: [-1], keep_axes: true)) - normalized_exp = (t - max) |> Nx.exp() - normalized_exp / Nx.sum(normalized_exp, axes: [-1], keep_axes: true) - end - @doc """ Makes predictions with the given `model` on inputs `x`. @@ -219,14 +247,16 @@ defmodule Scholar.Linear.LogisticRegression do iex> y = Nx.tensor([1, 0, 1]) iex> model = Scholar.Linear.LogisticRegression.fit(x, y, num_classes: 2) iex> Scholar.Linear.LogisticRegression.predict(model, Nx.tensor([[-3.0, 5.0]])) - #Nx.Tensor< - s32[1] - [1] - > + Nx.tensor([1]) """ defn predict(%__MODULE__{coefficients: coeff, bias: bias} = _model, x) do - inter = Nx.dot(x, [1], coeff, [0]) + bias - Nx.argmax(inter, axis: 1) + if Nx.rank(x) != 2 do + raise ArgumentError, + "expected x to have shape {n_samples, n_features}, got tensor with shape: #{inspect(Nx.shape(x))}" + end + + logits = Nx.dot(x, coeff) + bias + Nx.argmax(logits, axis: 1) end @doc """ @@ -238,14 +268,14 @@ defmodule Scholar.Linear.LogisticRegression do iex> y = Nx.tensor([1, 0, 1]) iex> model = Scholar.Linear.LogisticRegression.fit(x, y, num_classes: 2) iex> Scholar.Linear.LogisticRegression.predict_probability(model, Nx.tensor([[-3.0, 5.0]])) - #Nx.Tensor< - f32[1][2] - [ - [6.470913388456623e-11, 1.0] - ] - > + Nx.tensor([[0.10075931251049042, 0.8992406725883484]]) """ defn predict_probability(%__MODULE__{coefficients: coeff, bias: bias} = _model, x) do - softmax(Nx.dot(x, [1], coeff, [0]) + bias) + if Nx.rank(x) != 2 do + raise ArgumentError, + "expected x to have shape {n_samples, n_features}, got tensor with shape: #{inspect(Nx.shape(x))}" + end + + softmax(Nx.dot(x, coeff) + bias) end end diff --git a/lib/scholar/model_selection.ex b/lib/scholar/model_selection.ex index 6303971e..20ce653f 100644 --- a/lib/scholar/model_selection.ex +++ b/lib/scholar/model_selection.ex @@ -178,8 +178,8 @@ defmodule Scholar.ModelSelection do iex> y = Nx.tensor([0, 1, 2, 0, 1, 1, 0]) iex> opts = [ ...> num_classes: [3], - ...> iterations: [10, 20, 50], - ...> optimizer: [Polaris.Optimizers.adam(learning_rate: 0.005), Polaris.Optimizers.adam(learning_rate: 0.01)], + ...> max_iterations: [10, 20, 50], + ...> alpha: [0.0, 0.1, 1.0], ...> ] iex> Scholar.ModelSelection.grid_search(x, y, folding_fun, scoring_fun, opts) """ diff --git a/test/scholar/linear/logistic_regression_test.exs b/test/scholar/linear/logistic_regression_test.exs index 8fc2d374..3248f0cd 100644 --- a/test/scholar/linear/logistic_regression_test.exs +++ b/test/scholar/linear/logistic_regression_test.exs @@ -6,11 +6,11 @@ defmodule Scholar.Linear.LogisticRegressionTest do test "Iris Data Set - multinomial logistic regression test" do {x_train, x_test, y_train, y_test} = iris_data() - model = LogisticRegression.fit(x_train, y_train, num_classes: 3) + model = LogisticRegression.fit(x_train, y_train, num_classes: 3, alpha: 0.0) res = LogisticRegression.predict(model, x_test) accuracy = Scholar.Metrics.Classification.accuracy(res, y_test) - assert Nx.greater_equal(accuracy, 0.96) == Nx.u8(1) + assert Nx.to_number(accuracy) >= 0.96 end describe "errors" do @@ -40,28 +40,14 @@ defmodule Scholar.Linear.LogisticRegressionTest do fn -> LogisticRegression.fit(x, y) end end - test "when :optimizer is invalid" do + test "when :max_iterations is not a positive integer" do x = Nx.tensor([[1, 2], [3, 4]]) y = Nx.tensor([1, 2]) assert_raise NimbleOptions.ValidationError, - "invalid value for :optimizer option: expected :optimizer to be either a valid 0-arity function in Polaris.Optimizers or a valid {init_fn, update_fn} tuple", + "invalid value for :max_iterations option: expected positive integer, got: 0", fn -> - LogisticRegression.fit(x, y, - num_classes: 2, - optimizer: :invalid_optimizer - ) - end - end - - test "when :iterations is not a positive integer" do - x = Nx.tensor([[1, 2], [3, 4]]) - y = Nx.tensor([1, 2]) - - assert_raise NimbleOptions.ValidationError, - "invalid value for :iterations option: expected positive integer, got: 0", - fn -> - LogisticRegression.fit(x, y, num_classes: 2, iterations: 0) + LogisticRegression.fit(x, y, num_classes: 2, max_iterations: 0) end end @@ -70,7 +56,7 @@ defmodule Scholar.Linear.LogisticRegressionTest do y = Nx.tensor([1, 2]) assert_raise ArgumentError, - "expected x to have shape {n_samples, n_features}, got tensor with shape: {2}", + "expected x to have shape {num_samples, num_features}, got tensor with shape: {2}", fn -> LogisticRegression.fit(x, y, num_classes: 2) end end @@ -79,22 +65,41 @@ defmodule Scholar.Linear.LogisticRegressionTest do y = Nx.tensor([[0, 1], [1, 0]]) assert_raise ArgumentError, - "Scholar.Linear.LogisticRegression expected y to have shape {n_samples}, got tensor with shape: {2, 2}", + """ + expected y to have shape {num_samples}, \ + got tensor with shape: {2, 2}\ + """, fn -> LogisticRegression.fit(x, y, num_classes: 2) end end end - describe "column target tests" do - @tag :wip - test "column target" do - {x_train, _, y_train, _} = iris_data() - - model = LogisticRegression.fit(x_train, y_train, num_classes: 3) - pred = LogisticRegression.predict(model, x_train) - col_model = LogisticRegression.fit(x_train, y_train |> Nx.new_axis(-1), num_classes: 3) - col_pred = LogisticRegression.predict(col_model, x_train) - assert model == col_model - assert pred == col_pred + describe "linearly separable data" do + test "1D" do + key = Nx.Random.key(12) + {x1, key} = Nx.Random.uniform(key, -2, -1, shape: {1000, 1}) + {x2, _key} = Nx.Random.uniform(key, 1, 2, shape: {1000, 1}) + x = Nx.concatenate([x1, x2]) + y1 = Nx.broadcast(0, {1000}) + y2 = Nx.broadcast(1, {1000}) + y = Nx.concatenate([y1, y2]) + model = LogisticRegression.fit(x, y, num_classes: 2) + y_pred = LogisticRegression.predict(model, x) + accuracy = Scholar.Metrics.Classification.accuracy(y, y_pred) + assert Nx.to_number(accuracy) == 1.0 + end + + test "2D" do + key = Nx.Random.key(12) + {x1, key} = Nx.Random.uniform(key, -2, -1, shape: {1000, 2}) + {x2, _key} = Nx.Random.uniform(key, 1, 2, shape: {1000, 2}) + x = Nx.concatenate([x1, x2]) + y1 = Nx.broadcast(0, {1000}) + y2 = Nx.broadcast(1, {1000}) + y = Nx.concatenate([y1, y2]) + model = LogisticRegression.fit(x, y, num_classes: 2) + y_pred = LogisticRegression.predict(model, x) + accuracy = Scholar.Metrics.Classification.accuracy(y, y_pred) + assert Nx.to_number(accuracy) == 1.0 end end end