-
Notifications
You must be signed in to change notification settings - Fork 17
Feat/rkhs #354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ceriottm
wants to merge
15
commits into
master
Choose a base branch
from
feat/rkhs
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Feat/rkhs #354
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
1e8c0b3
Added (untested) RKHS solver
ceriottm 06acbb5
Tried and tested RKHS solver (at least on zundel example)
ceriottm ac41bd8
Mini-cleaning
ceriottm 3805fea
Better solver for the RKHS mode
ceriottm f3de495
Try to commit without stripping the zundel notebook
ceriottm ee5991a
Better comments
ceriottm faeda4a
Black
ceriottm feaa3ce
Added a "QR" mode for the example
ceriottm bdaa221
Merge branch 'master' into feat/rkhs
ceriottm ed2a4fc
Merge branch 'master' into feat/rkhs
ceriottm 076e1fb
black
ceriottm 41ce292
Added example file for water monomer
ceriottm 87a484e
A class for sparse GPR solvers that can be also called outside the tr…
ceriottm 80ccebe
Merge branch 'master' into feat/rkhs
ceriottm 4ae721a
Update .gitattributes
felixmusil File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| *.ipynb filter=nbstripout | ||
| *.ipynb diff=ipynb | ||
|
|
||
| examples/zundel_i-PI.ipynb filter= diff= | ||
| examples/optimized_radial_basis_functions.ipynb filter= diff= | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,10 +10,141 @@ | |
| from ..utils import BaseIO | ||
| from ..lib import compute_sparse_kernel_gradients, compute_sparse_kernel_neg_stress | ||
|
|
||
| import scipy | ||
| import numpy as np | ||
| import ase | ||
|
|
||
|
|
||
| class SparseGPRSolver: | ||
| """ | ||
| A few quick implementation notes, docs to be done. | ||
|
|
||
| This is meant to solve the sparse GPR problem | ||
| b = (KNM.T@KNM + reg*KMM)^-1 @ KNM.T@y | ||
|
|
||
| The inverse needs to be stabilized with application of a numerical jitter, | ||
| that is expressed as a fraction of the largest eigenvalue of KMM | ||
|
|
||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, KMM, regularizer=1, jitter=0, solver="RKHS", relative_jitter=True | ||
| ): | ||
|
|
||
| self._solver = solver | ||
|
|
||
| self._nM = len(KMM) | ||
| if self._solver == "RKHS" or self._solver == "RKHS-QR": | ||
| self._vk, self._Uk = scipy.linalg.eigh(KMM) | ||
| self._vk = self._vk[::-1] | ||
| self._Uk = self._Uk[:, ::-1] | ||
| elif self._solver == "QR" or self._solver == "Normal": | ||
| self._KMM = KMM | ||
| # gets maximum eigenvalue of KMM to scale the numerical jitter | ||
| self._KMM_maxeva = scipy.sparse.linalg.eigsh( | ||
| KMM, k=1, return_eigenvectors=False | ||
| )[0] | ||
| else: | ||
| raise ValueError( | ||
| "Solver ", | ||
| solver, | ||
| " not supported. Possible values are [RKHS, RKHS-QR, QR, Normal].", | ||
| ) | ||
| if relative_jitter: | ||
| if self._solver == "RKHS" or self._solver == "RKHS-QR": | ||
| self._jitter_scale = self._vk[0] | ||
| elif self._solver == "QR" or self._solver == "Normal": | ||
| self._jitter_scale = self._KMM_maxeva | ||
| else: | ||
| self._jitter_scale = 1.0 | ||
| self.set_regularizers(regularizer, jitter) | ||
|
|
||
| def set_regularizers(self, regularizer=1.0, jitter=0.0): | ||
| self._regularizer = regularizer | ||
| self._jitter = jitter | ||
| if self._solver == "RKHS" or self._solver == "RKHS-QR": | ||
| self._nM = len(np.where(self._vk > self._jitter * self._jitter_scale)[0]) | ||
| self._PKPhi = self._Uk[:, : self._nM] * 1 / np.sqrt(self._vk[: self._nM]) | ||
| elif self._solver == "QR": | ||
| self._VMM = scipy.linalg.cholesky( | ||
| self._regularizer * self._KMM | ||
| + np.eye(self._nM) * self._jitter_scale * self._jitter | ||
| ) | ||
| self._Cov = np.zeros((self._nM, self._nM)) | ||
| self._KY = None | ||
|
|
||
| def partial_fit(self, KNM, Y, accumulate_only=False): | ||
| if len(Y.shape) == 1: | ||
| Y = Y[:, np.newaxis] | ||
| if self._solver == "RKHS": | ||
| Phi = KNM @ self._PKPhi | ||
| elif self._solver == "Normal": | ||
| Phi = KNM | ||
| else: | ||
| raise ValueError( | ||
| "Partial fit can only be realized with solver = [RKHS, Normal]" | ||
| ) | ||
| if self._KY is None: | ||
| self._KY = np.zeros((self._nM, Y.shape[1])) | ||
|
|
||
| self._Cov += Phi.T @ Phi | ||
| self._KY += Phi.T @ Y | ||
|
|
||
| if not accumulate_only: | ||
| if self._solver == "RKHS": | ||
| self._weights = self._PKPhi @ scipy.linalg.solve( | ||
| self._Cov + np.eye(self._nM) * self._regularizer, | ||
| self._KY, | ||
| assume_a="pos", | ||
| ) | ||
| elif self._solver == "Normal": | ||
| self._weights = scipy.linalg.solve( | ||
| self._Cov | ||
| + self._regularizer * self._KMM | ||
| + np.eye(self._KMM.shape[0]) * self._jitter * self._jitter_scale, | ||
| self._KY, | ||
| assume_a="pos", | ||
| ) | ||
|
|
||
| def fit(self, KNM, Y): | ||
|
|
||
| if len(Y.shape) == 1: | ||
| Y = Y[:, np.newaxis] | ||
| if self._solver == "RKHS": | ||
| Phi = KNM @ self._PKPhi | ||
| self._weights = self._PKPhi @ scipy.linalg.solve( | ||
| Phi.T @ Phi + np.eye(self._nM) * self._regularizer, | ||
| Phi.T @ Y, | ||
| assume_a="pos", | ||
| ) | ||
| elif self._solver == "RKHS-QR": | ||
| A = np.vstack( | ||
| [KNM @ self._PKPhi, np.sqrt(self._regularizer) * np.eye(self._nM)] | ||
| ) | ||
| Q, R = np.linalg.qr(A) | ||
| self._weights = self._PKPhi @ scipy.linalg.solve_triangular( | ||
| R, Q.T @ np.hstack([Y, np.zeros((self._nM, Y.shape[1]))]) | ||
| ) | ||
| elif self._solver == "QR": | ||
| A = np.vstack([KNM, self._VMM]) | ||
| Q, R = np.linalg.qr(A) | ||
| self._weights = scipy.linalg.solve_triangular( | ||
| R, Q.T @ np.vstack([Y, np.zeros((KNM.shape[1], Y.shape[1]))]) | ||
| ) | ||
| elif self._solver == "Normal": | ||
| self._weights = scipy.linalg.solve( | ||
| KNM.T @ KNM | ||
| + self._regularizer * self._KMM | ||
| + np.eye(self._nM) * self._jitter * self._jitter_scale, | ||
| KNM.T @ Y, | ||
| assume_a="pos", | ||
| ) | ||
|
|
||
| def transform(self, KTM): | ||
| return KTM @ self._weights | ||
|
|
||
|
|
||
| class KRR(BaseIO): | ||
| """Kernel Ridge Regression model. Only supports sparse GPR | ||
| training for the moment. | ||
|
|
@@ -354,6 +485,7 @@ def train_gap_model( | |
| X_sparse, | ||
| y_train, | ||
| self_contributions, | ||
| solver="Normal", | ||
| grad_train=None, | ||
| lambdas=None, | ||
| jitter=1e-8, | ||
|
|
@@ -428,6 +560,15 @@ def train_gap_model( | |
| jitter : double, optional | ||
| small jitter for the numerical stability of solving the linear system, | ||
| by default 1e-8 | ||
| solver: string, optional | ||
| method used to solve the sparse KRR equations. | ||
| "Normal" uses a least-squares solver for the normal equations: | ||
| (K_NM.T@K_NM + K_MM)^(-1) K_NM.T@Y | ||
| "RKHS" computes first the reproducing kernel features by diagonalizing K_MM | ||
| and computing P_NM = K_NM @ U_MM @ Lam_MM^(-1.2) and then solves the linear | ||
| problem for those (which is usually better conditioned) | ||
| (P_NM.T@P_NM + 1)^(-1) P_NM.T@Y | ||
| by default, "Normal" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document the "QR" mode please - how is it different from "RKHS"? |
||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -463,15 +604,12 @@ def train_gap_model( | |
| F /= lambdas[1] / delta | ||
| Y = np.vstack([Y, F]) | ||
|
|
||
| KMM[np.diag_indices_from(KMM)] += jitter | ||
|
|
||
| K = KMM + np.dot(KNM.T, KNM) | ||
| Y = np.dot(KNM.T, Y) | ||
| weights = np.linalg.lstsq(K, Y, rcond=None)[0] | ||
| model = KRR(weights, kernel, X_sparse, self_contributions) | ||
| ssolver = SparseGPRSolver( | ||
| KMM, regularizer=1, jitter=jitter, solver=solver, relative_jitter=False | ||
| ) # in current implementation KMM incorporates regularization so it's better to use an absolute jitter value | ||
| ssolver.fit(KNM, Y) | ||
| model = KRR(ssolver._weights, kernel, X_sparse, self_contributions) | ||
|
|
||
| # avoid memory clogging | ||
| del K, KMM | ||
| K, KMM = [], [] | ||
| del KNM, KMM, solver | ||
|
|
||
| return model | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Normal" feels strange as a name. How about
StandardorDirectorDirect Least Square?