Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[pytest]
addopts = -s -v --durations=10
filterwarnings =
error
ignore::DeprecationWarning
Comment on lines +3 to +5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't familiar with these options, looks useful!

39 changes: 29 additions & 10 deletions pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,13 @@ def get_JK(density_matrix, ERI, dense_ERI, backend):

return diff_JK

def _nanoDFT(state, ERI, grid_AO, grid_weights, opts, mol):
def _nanoDFT(state, ERI, grid_AO, grid_weights, profile_performance, opts, mol):

if profile_performance is not None and opts.backend == "ipu":
print("[INFO] Running nanoDFT with performance profiling.")
grid_weights, start = utils.get_ipu_cycles(grid_weights)


# Utilize the IPUs MIMD parallism to compute the electron repulsion integrals (ERIs) in parallel.
#if opts.backend == "ipu": state.ERI = electron_repulsion_integrals(state.input_floats, state.input_ints, mol, opts.threads_int, opts.intv)
#else: pass # Compute on CPU.
Expand All @@ -156,6 +162,10 @@ def _nanoDFT(state, ERI, grid_AO, grid_weights, opts, mol):
# Perform DFT iterations.
log = jax.lax.fori_loop(0, opts.its, partial(nanoDFT_iteration, opts=opts, mol=mol), [state.density_matrix, V_xc, diff_JK, state.O, H_core, state.L_inv, # all (N, N) matrices
state.E_nuc, state.mask, ERI, grid_weights, grid_AO, state.diis_history, log])[-1]
if profile_performance is not None and opts.backend == "ipu":
log["energy"], end = utils.get_ipu_cycles(log["energy"])

return log["matrices"], H_core, log["energy"], (start.array, end.array)

return log["matrices"], H_core, log["energy"]

Expand Down Expand Up @@ -256,7 +266,7 @@ def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):

return state, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords, grid_AO

def nanoDFT(mol, opts):
def nanoDFT(mol, opts, profile_performance=None):
# Init DFT tensors on CPU using PySCF.
state, n_electrons_half, E_nuc, N, L_inv, _grid_weights, grid_coords, grid_AO = init_dft_tensors_cpu(mol, opts)

Expand Down Expand Up @@ -314,17 +324,22 @@ def nanoDFT(mol, opts):

ERI = [nonzero_distinct_ERI, nonzero_indices]
eri_in_axes = [0,0]
#jitted_nanoDFT = jax.jit(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend)
# jitted_nanoDFT = jax.jit(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend)
jitted_nanoDFT = jax.pmap(partial(_nanoDFT, opts=opts, mol=mol), backend=opts.backend,
in_axes=(None, eri_in_axes, 0, 0),
in_axes=(None, eri_in_axes, 0, 0, None),
axis_name="p")
print(grid_AO.shape, grid_weights.shape)
vals = jitted_nanoDFT(state, ERI, grid_AO, grid_weights)
logged_matrices, H_core, logged_energies = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU

vals = jitted_nanoDFT(state, ERI, grid_AO, grid_weights, profile_performance)

if profile_performance is not None and opts.backend == "ipu":
logged_matrices, H_core, logged_energies, _ = [np.asarray(a[0]).astype(np.float64) for a in vals]
ipu_cycles_stamps = vals[3]
else:
logged_matrices, H_core, logged_energies = [np.asarray(a[0]).astype(np.float64) for a in vals] # Ensure CPU

# It's cheap to compute energy/hlgap on CPU in float64 from the logged values/matrices.
logged_E_xc = logged_energies[:, 3].copy()
print(logged_energies[:, 0] * HARTREE_TO_EV)
# print(logged_energies[:, 0] * HARTREE_TO_EV)
density_matrices, diff_JKs, H = [logged_matrices[:, i] for i in range(3)]
energies, hlgaps = np.zeros((opts.its, 5)), np.zeros(opts.its)
for i in range(opts.its):
Expand All @@ -333,6 +348,10 @@ def nanoDFT(mol, opts):
energies, logged_energies, hlgaps = [a * HARTREE_TO_EV for a in [energies, logged_energies, hlgaps]]
mo_energy, mo_coeff = np.linalg.eigh(L_inv @ H[-1] @ L_inv.T)
mo_coeff = L_inv.T @ mo_coeff

if profile_performance is not None and opts.backend == "ipu":
return energies, (logged_energies, hlgaps, mo_energy, mo_coeff, grid_coords, _grid_weights), ipu_cycles_stamps

return energies, (logged_energies, hlgaps, mo_energy, mo_coeff, grid_coords, _grid_weights)

def DIIS(i, H, density_matrix, O, diis_history, opts):
Expand Down Expand Up @@ -610,8 +629,8 @@ def nanoDFT_options(

from pyscf_ipu.experimental.device import has_ipu
import os
if has_ipu() and "JAX_IPU_USE_MODEL" in os.environ:
args.dense_ERI = True
# if has_ipu() and "JAX_IPU_USE_MODEL" in os.environ:
# args.dense_ERI = True
args = namedtuple('DFTOptionsImmutable',vars(args).keys())(**vars(args)) # make immutable
if not args.float32:
jax.config.update('jax_enable_x64', not float32)
Expand Down
15 changes: 15 additions & 0 deletions pyscf_ipu/nanoDFT/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,18 @@ def prepare(val):
writer.append_data(imageio.v2.imread(f'{images_subdir}num_error{i}.jpg'))
writer.close()
print("Numerical error visualisation saved in", gif_path)


from tessellate_ipu import tile_map, ipu_cycle_count, tile_put_sharded
from typing import List

def get_ipu_cycles(data_to_be_sharded: List[float], num_items_to_be_sharded: int = 1) -> List[float]:
tmp = data_to_be_sharded[0:num_items_to_be_sharded]
tiles = tuple(range(len(tmp)))
tmp = tile_put_sharded(tmp, tiles)
tmp, cycles_count = ipu_cycle_count(tmp)
tmp = tmp.array
for idx in tiles:
data_to_be_sharded = data_to_be_sharded.at[idx].set(tmp[idx])

return data_to_be_sharded, cycles_count
77 changes: 77 additions & 0 deletions test/test_benchmark_performance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from subprocess import call, Popen
import os
import pytest
import numpy as np
import jax

from pyscf_ipu.nanoDFT.nanoDFT import nanoDFT_options, nanoDFT, build_mol

from tessellate_ipu import tile_map, tile_put_replicated, tile_put_sharded
from tessellate_ipu import ipu_cycle_count


# def test_basic_demonstration():
# dummy = np.random.rand(2,3).astype(np.float32)
# dummier = np.random.rand(2,3).astype(np.float32)

# @jax.jit
# def jitted_inner_test(dummy, dummier):
# tiles = tuple(range(len(dummy)))
# dummy = tile_put_sharded(dummy, tiles)
# tiles = tuple(range(len(dummier)))
# dummier = tile_put_sharded(dummier, tiles)

# dummy, dummier, start = ipu_cycle_count(dummy, dummier)
# out = tile_map(jax.lax.add_p, dummy, dummier)
# out, end = ipu_cycle_count(out)

# return out, start, end

# _, start, end = jitted_inner_test(dummy, dummier)
# print("Start cycle count:", start, start.shape)
# print("End cycle count:", end, end.shape)
# print("Diff cycle count:", end.array - start.array)

# assert True

@pytest.mark.parametrize("molecule", ["methane", "benzene"])
def test_dense_eri(molecule):

opts, mol_str = nanoDFT_options(float32 = True, mol_str=molecule, backend="ipu")
mol = build_mol(mol_str, opts.basis)

_, _, ipu_cycles_stamps = nanoDFT(mol, opts, profile_performance=True)

start, end = ipu_cycles_stamps
start = np.asarray(start)
end = np.asarray(end)

diff = (end - start)[0][0][0]
print("----------------------------------------------------------------------------")
print(" Diff cycle count:", diff)
print(" Diff cycle count [M]:", diff/1e6)
print("Estimated time of execution on Bow-IPU [seconds]:", diff/(1.85*1e9))
print("----------------------------------------------------------------------------")

assert True

@pytest.mark.parametrize("molecule", ["methane", "benzene", "c20"])
def test_sparse_eri(molecule):

opts, mol_str = nanoDFT_options(float32 = True, mol_str=molecule, backend="ipu", dense_ERI=False, eri_threshold=1e-9)
mol = build_mol(mol_str, opts.basis)

_, _, ipu_cycles_stamps = nanoDFT(mol, opts, profile_performance=True)

start, end = ipu_cycles_stamps
start = np.asarray(start)
end = np.asarray(end)

diff = (end - start)[0][0][0]
print("----------------------------------------------------------------------------")
print(" Diff cycle count:", diff)
print(" Diff cycle count [M]:", diff/1e6)
print("Estimated time of execution on Bow-IPU [seconds]:", diff/(1.85*1e9))
print("----------------------------------------------------------------------------")

assert True