Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
*.ipynb filter=nbstripout
*.ipynb diff=ipynb

examples/zundel_i-PI.ipynb filter= diff=
examples/optimized_radial_basis_functions.ipynb filter= diff=

156 changes: 147 additions & 9 deletions bindings/rascal/models/krr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,141 @@
from ..utils import BaseIO
from ..lib import compute_sparse_kernel_gradients, compute_sparse_kernel_neg_stress

import scipy
import numpy as np
import ase


class SparseGPRSolver:
"""
A few quick implementation notes, docs to be done.

This is meant to solve the sparse GPR problem
b = (KNM.T@KNM + reg*KMM)^-1 @ KNM.T@y

The inverse needs to be stabilized with application of a numerical jitter,
that is expressed as a fraction of the largest eigenvalue of KMM


"""

def __init__(
self, KMM, regularizer=1, jitter=0, solver="RKHS", relative_jitter=True
):

self._solver = solver

self._nM = len(KMM)
if self._solver == "RKHS" or self._solver == "RKHS-QR":
self._vk, self._Uk = scipy.linalg.eigh(KMM)
self._vk = self._vk[::-1]
self._Uk = self._Uk[:, ::-1]
elif self._solver == "QR" or self._solver == "Normal":
self._KMM = KMM
# gets maximum eigenvalue of KMM to scale the numerical jitter
self._KMM_maxeva = scipy.sparse.linalg.eigsh(
KMM, k=1, return_eigenvectors=False
)[0]
else:
raise ValueError(
"Solver ",
solver,
" not supported. Possible values are [RKHS, RKHS-QR, QR, Normal].",
)
if relative_jitter:
if self._solver == "RKHS" or self._solver == "RKHS-QR":
self._jitter_scale = self._vk[0]
elif self._solver == "QR" or self._solver == "Normal":
self._jitter_scale = self._KMM_maxeva
else:
self._jitter_scale = 1.0
self.set_regularizers(regularizer, jitter)

def set_regularizers(self, regularizer=1.0, jitter=0.0):
self._regularizer = regularizer
self._jitter = jitter
if self._solver == "RKHS" or self._solver == "RKHS-QR":
self._nM = len(np.where(self._vk > self._jitter * self._jitter_scale)[0])
self._PKPhi = self._Uk[:, : self._nM] * 1 / np.sqrt(self._vk[: self._nM])
elif self._solver == "QR":
self._VMM = scipy.linalg.cholesky(
self._regularizer * self._KMM
+ np.eye(self._nM) * self._jitter_scale * self._jitter
)
self._Cov = np.zeros((self._nM, self._nM))
self._KY = None

def partial_fit(self, KNM, Y, accumulate_only=False):
if len(Y.shape) == 1:
Y = Y[:, np.newaxis]
if self._solver == "RKHS":
Phi = KNM @ self._PKPhi
elif self._solver == "Normal":
Phi = KNM
else:
raise ValueError(
"Partial fit can only be realized with solver = [RKHS, Normal]"
)
if self._KY is None:
self._KY = np.zeros((self._nM, Y.shape[1]))

self._Cov += Phi.T @ Phi
self._KY += Phi.T @ Y

if not accumulate_only:
if self._solver == "RKHS":
self._weights = self._PKPhi @ scipy.linalg.solve(
self._Cov + np.eye(self._nM) * self._regularizer,
self._KY,
assume_a="pos",
)
elif self._solver == "Normal":
self._weights = scipy.linalg.solve(
self._Cov
+ self._regularizer * self._KMM
+ np.eye(self._KMM.shape[0]) * self._jitter * self._jitter_scale,
self._KY,
assume_a="pos",
)

def fit(self, KNM, Y):

if len(Y.shape) == 1:
Y = Y[:, np.newaxis]
if self._solver == "RKHS":
Phi = KNM @ self._PKPhi
self._weights = self._PKPhi @ scipy.linalg.solve(
Phi.T @ Phi + np.eye(self._nM) * self._regularizer,
Phi.T @ Y,
assume_a="pos",
)
elif self._solver == "RKHS-QR":
A = np.vstack(
[KNM @ self._PKPhi, np.sqrt(self._regularizer) * np.eye(self._nM)]
)
Q, R = np.linalg.qr(A)
self._weights = self._PKPhi @ scipy.linalg.solve_triangular(
R, Q.T @ np.hstack([Y, np.zeros((self._nM, Y.shape[1]))])
)
elif self._solver == "QR":
A = np.vstack([KNM, self._VMM])
Q, R = np.linalg.qr(A)
self._weights = scipy.linalg.solve_triangular(
R, Q.T @ np.vstack([Y, np.zeros((KNM.shape[1], Y.shape[1]))])
)
elif self._solver == "Normal":
self._weights = scipy.linalg.solve(
KNM.T @ KNM
+ self._regularizer * self._KMM
+ np.eye(self._nM) * self._jitter * self._jitter_scale,
KNM.T @ Y,
assume_a="pos",
)

def transform(self, KTM):
return KTM @ self._weights


class KRR(BaseIO):
"""Kernel Ridge Regression model. Only supports sparse GPR
training for the moment.
Expand Down Expand Up @@ -354,6 +485,7 @@ def train_gap_model(
X_sparse,
y_train,
self_contributions,
solver="Normal",
grad_train=None,
lambdas=None,
jitter=1e-8,
Expand Down Expand Up @@ -428,6 +560,15 @@ def train_gap_model(
jitter : double, optional
small jitter for the numerical stability of solving the linear system,
by default 1e-8
solver: string, optional
method used to solve the sparse KRR equations.
"Normal" uses a least-squares solver for the normal equations:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Normal" feels strange as a name. How about Standard or Direct or Direct Least Square?

(K_NM.T@K_NM + K_MM)^(-1) K_NM.T@Y
"RKHS" computes first the reproducing kernel features by diagonalizing K_MM
and computing P_NM = K_NM @ U_MM @ Lam_MM^(-1.2) and then solves the linear
problem for those (which is usually better conditioned)
(P_NM.T@P_NM + 1)^(-1) P_NM.T@Y
by default, "Normal"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the "QR" mode please - how is it different from "RKHS"?


Returns
-------
Expand Down Expand Up @@ -463,15 +604,12 @@ def train_gap_model(
F /= lambdas[1] / delta
Y = np.vstack([Y, F])

KMM[np.diag_indices_from(KMM)] += jitter

K = KMM + np.dot(KNM.T, KNM)
Y = np.dot(KNM.T, Y)
weights = np.linalg.lstsq(K, Y, rcond=None)[0]
model = KRR(weights, kernel, X_sparse, self_contributions)
ssolver = SparseGPRSolver(
KMM, regularizer=1, jitter=jitter, solver=solver, relative_jitter=False
) # in current implementation KMM incorporates regularization so it's better to use an absolute jitter value
ssolver.fit(KNM, Y)
model = KRR(ssolver._weights, kernel, X_sparse, self_contributions)

# avoid memory clogging
del K, KMM
K, KMM = [], []
del KNM, KMM, solver

return model
Loading