Skip to content

Possible updates to PyTorch 1.10 #9

@jhualberta

Description

@jhualberta

Thank you for your lecture, it is very useful.
I know it has already been 5 years since these note originally released. Some functions in PyTorch has been deprecated.
The followings are possible updates to PyTorch 1.10 and Python3.11
schrodinger.py

schrodinger.py

Replaced torch.symeig(H, eigenvectors=True) → ✅ torch.linalg.eigh(H)
#########################################################

import numpy as np 
import torch
torch.set_default_dtype(torch.float64)
import torch.nn as nn
import matplotlib.pyplot as plt

class Schrodinger1D(nn.Module):
    def __init__(self, xmesh):
        super(Schrodinger1D, self).__init__()
        
        self.xmesh = xmesh
        self.potential = nn.Parameter(xmesh**2)

        nmesh = xmesh.shape[0]
        h2 = (xmesh[1] - xmesh[0]) ** 2
        self.K =   torch.diag(1/h2 * torch.ones(nmesh, dtype=xmesh.dtype), diagonal=0) \
                 - torch.diag(0.5/h2 * torch.ones(nmesh-1, dtype=xmesh.dtype), diagonal=1) \
                 - torch.diag(0.5/h2 * torch.ones(nmesh-1, dtype=xmesh.dtype), diagonal=-1)

    def _solve(self):
        H = torch.diag(self.potential) + self.K
        eigvals, eigvecs = torch.linalg.eigh(H)  # Replaced deprecated symeig
        return eigvecs[:, 0]  # Ground state (corresponding to smallest eigenvalue)

    def forward(self, target):
        psi = self._solve()
        return (psi**2 - target).abs().sum()

    def plot(self, target):
        psi = self._solve().detach()

        plt.cla()
        plt.plot(self.xmesh.numpy(), target.numpy(), label='Target Density')
        plt.plot(self.xmesh.numpy(), psi.square().numpy(), label='Current Density')
        plt.plot(self.xmesh.numpy(), self.potential.detach().numpy()/10000, label='Potential (V/10000)')
        plt.legend()
        plt.draw()

if __name__ == '__main__':
    # Prepare mesh and target density
    xmin, xmax, Nmesh = -1, 1, 500
    xmesh = torch.linspace(xmin, xmax, Nmesh)
    
    target = torch.zeros(Nmesh)
    idx = torch.where(torch.abs(xmesh) < 0.5)
    target[idx] = 1. - torch.abs(xmesh[idx])
    target = (target / torch.norm(target))**2
    
    model = Schrodinger1D(xmesh)
    optimizer = torch.optim.LBFGS(
        model.parameters(), 
        max_iter=10, 
        tolerance_change=1E-7, 
        tolerance_grad=1E-7, 
        line_search_fn='strong_wolfe'
    )

    def closure():
        optimizer.zero_grad()
        loss = model(target)  # Density difference 
        loss.backward()
        return loss 

    plt.ion()
    for epoch in range(50):
        loss = optimizer.step(closure)
        print(epoch, loss.item())
        model.plot(target)
        plt.pause(0.01)

    plt.ioff()
    model.plot(target)
    plt.show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions