Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 154 additions & 124 deletions lib/scholar/linear/logistic_regression.ex
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
defmodule Scholar.Linear.LogisticRegression do
@moduledoc """
Logistic regression in both binary and multinomial variants.
Multiclass logistic regression.

Time complexity is $O(N * K * I)$ where $N$ is the number of samples, $K$ is the number of features, and $I$ is the number of iterations.
"""
import Nx.Defn
import Scholar.Shared
alias Scholar.Linear.LinearHelpers

@derive {Nx.Container, containers: [:coefficients, :bias]}
defstruct [:coefficients, :bias]
Expand All @@ -15,35 +14,28 @@ defmodule Scholar.Linear.LogisticRegression do
num_classes: [
required: true,
type: :pos_integer,
doc: "number of classes contained in the input tensors."
doc: "Number of output classes."
],
iterations: [
max_iterations: [
type: :pos_integer,
default: 1000,
doc: """
number of iterations of gradient descent performed inside logistic
regression.
"""
doc: "Maximum number of gradient descent iterations to perform."
],
learning_loop_unroll: [
type: :boolean,
default: false,
doc: ~S"""
If `true`, the learning loop is unrolled.
alpha: [
type: {:custom, Scholar.Options, :non_negative_number, []},
default: 1.0,
doc: """
Constant that multiplies the L2 regularization term, controlling regularization strength.
If 0, no regularization is applied.
"""
],
optimizer: [
type: {:custom, Scholar.Options, :optimizer, []},
default: :sgd,
tol: [
type: {:custom, Scholar.Options, :non_negative_number, []},
default: 1.0e-4,
doc: """
The optimizer name or {init, update} pair of functions (see `Polaris.Optimizers` for more details).
Convergence tolerance. If the infinity norm of the gradient is less than `:tol`,
the algorithm is considered to have converged.
"""
],
eps: [
type: :float,
default: 1.0e-8,
doc:
"The convergence tolerance. If the `abs(loss) < size(x) * :eps`, the algorithm is considered to have converged."
]
]

Expand All @@ -53,9 +45,6 @@ defmodule Scholar.Linear.LogisticRegression do
Fits a logistic regression model for sample inputs `x` and sample
targets `y`.

Depending on number of classes the function chooses either binary
or multinomial logistic regression.

## Options

#{NimbleOptions.docs(@opts_schema)}
Expand All @@ -68,10 +57,6 @@ defmodule Scholar.Linear.LogisticRegression do

* `:bias` - Bias added to the decision function.

* `:mode` - Indicates whether the problem is binary classification (`:num_classes` set to 2)
or multinomial (`:num_classes` is bigger than 2). For binary classification set to `:binary`, otherwise
set to `:multinomial`.

## Examples

iex> x = Nx.tensor([[1.0, 2.0], [3.0, 2.0], [4.0, 7.0]])
Expand All @@ -80,134 +65,177 @@ defmodule Scholar.Linear.LogisticRegression do
%Scholar.Linear.LogisticRegression{
coefficients: Nx.tensor(
[
[2.5531527996063232, -0.5531544089317322],
[-0.35652396082878113, 2.3565237522125244]
[0.0915902629494667, -0.09159023314714432],
[-0.1507941037416458, 0.1507941335439682]
]
),
bias: Nx.tensor(
[-0.28847914934158325, 0.28847917914390564]
)
bias: Nx.tensor([-0.06566660106182098, 0.06566664576530457])
}
"""
deftransform fit(x, y, opts \\ []) do
if Nx.rank(x) != 2 do
raise ArgumentError,
"expected x to have shape {n_samples, n_features}, got tensor with shape: #{inspect(Nx.shape(x))}"
"expected x to have shape {num_samples, num_features}, got tensor with shape: #{inspect(Nx.shape(x))}"
end

{n_samples, _} = Nx.shape(x)
y = LinearHelpers.validate_y_shape(y, n_samples, __MODULE__)

opts = NimbleOptions.validate!(opts, @opts_schema)

{optimizer, opts} = Keyword.pop!(opts, :optimizer)

{optimizer_init_fn, optimizer_update_fn} =
case optimizer do
atom when is_atom(atom) -> apply(Polaris.Optimizers, atom, [])
{f1, f2} -> {f1, f2}
end
if Nx.rank(y) != 1 do
raise ArgumentError,
"expected y to have shape {num_samples}, got tensor with shape: #{inspect(Nx.shape(y))}"
end

n = Nx.axis_size(x, -1)
num_classes = opts[:num_classes]
num_samples = Nx.axis_size(x, 0)

coef =
Nx.broadcast(
Nx.tensor(1.0, type: to_float_type(x)),
{n, num_classes}
)
if Nx.axis_size(y, 0) != num_samples do
raise ArgumentError,
"expected x and y to have the same number of samples, got #{num_samples} and #{Nx.axis_size(y, 0)}"
end

bias = Nx.broadcast(Nx.tensor(0, type: to_float_type(x)), {num_classes})
opts = NimbleOptions.validate!(opts, @opts_schema)

coef_optimizer_state = optimizer_init_fn.(coef) |> as_type(to_float_type(x))
bias_optimizer_state = optimizer_init_fn.(bias) |> as_type(to_float_type(x))
type = to_float_type(x)

opts = Keyword.put(opts, :optimizer_update_fn, optimizer_update_fn)
{alpha, opts} = Keyword.pop!(opts, :alpha)
alpha = Nx.tensor(alpha, type: type)
{tol, opts} = Keyword.pop!(opts, :tol)
tol = Nx.tensor(tol, type: type)

fit_n(x, y, coef, bias, coef_optimizer_state, bias_optimizer_state, opts)
fit_n(x, y, alpha, tol, opts)
end

deftransformp as_type(container, target_type) do
Nx.Defn.Composite.traverse(container, fn t ->
type = Nx.type(t)
defnp fit_n(x, y, alpha, tol, opts) do
num_classes = opts[:num_classes]
max_iterations = opts[:max_iterations]
{num_samples, num_features} = Nx.shape(x)

if Nx.Type.float?(type) and not Nx.Type.complex?(type) do
Nx.as_type(t, target_type)
else
t
end
end)
end
type = to_float_type(x)

# Logistic Regression training loop
# Initialize weights and bias with zeros
w =
Nx.broadcast(
Nx.tensor(0.0, type: type),
{num_features, num_classes}
)

defnp fit_n(x, y, coef, bias, coef_optimizer_state, bias_optimizer_state, opts) do
num_samples = Nx.axis_size(x, 0)
iterations = opts[:iterations]
num_classes = opts[:num_classes]
optimizer_update_fn = opts[:optimizer_update_fn]
b = Nx.broadcast(Nx.tensor(0.0, type: type), {num_classes})

# One-hot encoding of target labels
y_one_hot =
y
|> Nx.new_axis(1)
|> Nx.broadcast({num_samples, num_classes})
|> Nx.equal(Nx.iota({num_samples, num_classes}, axis: 1))

{{final_coef, final_bias}, _} =
while {{coef, bias},
{x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state,
has_converged = Nx.u8(0), iter = 0}},
iter < iterations and not has_converged do
{loss, {coef_grad, bias_grad}} = loss_and_grad(coef, bias, x, y_one_hot)

{coef_updates, coef_optimizer_state} =
optimizer_update_fn.(coef_grad, coef_optimizer_state, coef)

coef = Polaris.Updates.apply_updates(coef, coef_updates)

{bias_updates, bias_optimizer_state} =
optimizer_update_fn.(bias_grad, bias_optimizer_state, bias)
# Define Armijo parameters
c = Nx.tensor(1.0e-4, type: type)
rho = Nx.tensor(0.5, type: type)

bias = Polaris.Updates.apply_updates(bias, bias_updates)
eta_min =
case type do
{:f, 32} -> Nx.tensor(1.0e-6, type: type)
{:f, 64} -> Nx.tensor(1.0e-8, type: type)
_ -> Nx.tensor(1.0e-6, type: type)
end

has_converged = Nx.sum(Nx.abs(loss)) < Nx.size(x) * opts[:eps]
armijo_params = %{
c: c,
rho: rho,
eta_min: eta_min
}

{{coef, bias},
{x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state, has_converged,
iter + 1}}
{coef, bias, _} =
while {w, b,
{alpha, x, y_one_hot, tol, armijo_params, iter = Nx.u32(0), converged? = Nx.u8(0)}},
iter < max_iterations and not converged? do
logits = Nx.dot(x, w) + b
probabilities = softmax(logits)
residuals = probabilities - y_one_hot

# Compute loss
loss =
logits
|> log_softmax()
|> Nx.multiply(y_one_hot)
|> Nx.sum(axes: [1])
|> Nx.mean()
|> Nx.negate()
|> Nx.add(alpha * Nx.sum(w * w))

# Compute gradients
grad_w = Nx.dot(x, [0], residuals, [0]) / num_samples + 2 * alpha * w
grad_b = Nx.sum(residuals, axes: [0]) / num_samples

# Perform line search to find step size
eta =
armijo_line_search(w, b, alpha, x, y_one_hot, loss, grad_w, grad_b, armijo_params)

w = w - eta * grad_w
b = b - eta * grad_b

converged? =
Nx.reduce_max(Nx.abs(grad_w)) < tol and Nx.reduce_max(Nx.abs(grad_b)) < tol

{w, b, {alpha, x, y_one_hot, tol, armijo_params, iter + 1, converged?}}
end

%__MODULE__{
coefficients: final_coef,
bias: final_bias
coefficients: coef,
bias: bias
}
end

defnp loss_and_grad(coeff, bias, xs, ys) do
value_and_grad({coeff, bias}, fn {coeff, bias} ->
-Nx.sum(ys * log_softmax(Nx.dot(xs, coeff) + bias), axes: [-1])
end)
defnp armijo_line_search(w, b, alpha, x, y, loss, grad_w, grad_b, armijo_params) do
c = armijo_params[:c]
rho = armijo_params[:rho]
eta_min = armijo_params[:eta_min]

type = to_float_type(x)
dir_w = -grad_w
dir_b = -grad_b
# Directional derivative
slope = Nx.sum(dir_w * grad_w) + Nx.sum(dir_b * grad_b)

{eta, _} =
while {eta = Nx.tensor(1.0, type: type),
{w, b, alpha, x, y, loss, dir_w, dir_b, slope, c, rho, eta_min}},
compute_loss(w + eta * dir_w, b + eta * dir_b, alpha, x, y) > loss + c * eta * slope and
eta > eta_min do
eta = eta * rho

{eta, {w, b, alpha, x, y, loss, dir_w, dir_b, slope, c, rho, eta_min}}
end

eta
end

defnp compute_loss(w, b, alpha, x, y) do
x
|> Nx.dot(w)
|> Nx.add(b)
|> log_softmax()
|> Nx.multiply(y)
|> Nx.sum(axes: [1])
|> Nx.mean()
|> Nx.negate()
|> Nx.add(alpha * Nx.sum(w * w))
end

defnp softmax(logits) do
max = stop_grad(Nx.reduce_max(logits, axes: [1], keep_axes: true))
normalized_exp = (logits - max) |> Nx.exp()
normalized_exp / Nx.sum(normalized_exp, axes: [1], keep_axes: true)
end

defnp log_softmax(x) do
shifted = x - stop_grad(Nx.reduce_max(x, axes: [-1], keep_axes: true))
shifted = x - stop_grad(Nx.reduce_max(x, axes: [1], keep_axes: true))

shifted
|> Nx.exp()
|> Nx.sum(axes: [-1], keep_axes: true)
|> Nx.sum(axes: [1], keep_axes: true)
|> Nx.log()
|> Nx.negate()
|> Nx.add(shifted)
end

# Normalized softmax

defnp softmax(t) do
max = stop_grad(Nx.reduce_max(t, axes: [-1], keep_axes: true))
normalized_exp = (t - max) |> Nx.exp()
normalized_exp / Nx.sum(normalized_exp, axes: [-1], keep_axes: true)
end

@doc """
Makes predictions with the given `model` on inputs `x`.

Expand All @@ -219,14 +247,16 @@ defmodule Scholar.Linear.LogisticRegression do
iex> y = Nx.tensor([1, 0, 1])
iex> model = Scholar.Linear.LogisticRegression.fit(x, y, num_classes: 2)
iex> Scholar.Linear.LogisticRegression.predict(model, Nx.tensor([[-3.0, 5.0]]))
#Nx.Tensor<
s32[1]
[1]
>
Nx.tensor([1])
"""
defn predict(%__MODULE__{coefficients: coeff, bias: bias} = _model, x) do
inter = Nx.dot(x, [1], coeff, [0]) + bias
Nx.argmax(inter, axis: 1)
if Nx.rank(x) != 2 do
raise ArgumentError,
"expected x to have shape {n_samples, n_features}, got tensor with shape: #{inspect(Nx.shape(x))}"
end

logits = Nx.dot(x, coeff) + bias
Nx.argmax(logits, axis: 1)
end

@doc """
Expand All @@ -238,14 +268,14 @@ defmodule Scholar.Linear.LogisticRegression do
iex> y = Nx.tensor([1, 0, 1])
iex> model = Scholar.Linear.LogisticRegression.fit(x, y, num_classes: 2)
iex> Scholar.Linear.LogisticRegression.predict_probability(model, Nx.tensor([[-3.0, 5.0]]))
#Nx.Tensor<
f32[1][2]
[
[6.470913388456623e-11, 1.0]
]
>
Nx.tensor([[0.10075931251049042, 0.8992406725883484]])
"""
defn predict_probability(%__MODULE__{coefficients: coeff, bias: bias} = _model, x) do
softmax(Nx.dot(x, [1], coeff, [0]) + bias)
if Nx.rank(x) != 2 do
raise ArgumentError,
"expected x to have shape {n_samples, n_features}, got tensor with shape: #{inspect(Nx.shape(x))}"
end

softmax(Nx.dot(x, coeff) + bias)
end
end
4 changes: 2 additions & 2 deletions lib/scholar/model_selection.ex
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ defmodule Scholar.ModelSelection do
iex> y = Nx.tensor([0, 1, 2, 0, 1, 1, 0])
iex> opts = [
...> num_classes: [3],
...> iterations: [10, 20, 50],
...> optimizer: [Polaris.Optimizers.adam(learning_rate: 0.005), Polaris.Optimizers.adam(learning_rate: 0.01)],
...> max_iterations: [10, 20, 50],
...> alpha: [0.0, 0.1, 1.0],
...> ]
iex> Scholar.ModelSelection.grid_search(x, y, folding_fun, scoring_fun, opts)
"""
Expand Down
Loading