Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ROMFPMD with release branch #246

Merged
merged 12 commits into from
Jun 5, 2024
16 changes: 12 additions & 4 deletions src/DFTsolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ template <class OrbitalsType>
DFTsolver<OrbitalsType>::DFTsolver(Hamiltonian<OrbitalsType>* hamiltonian,
ProjectedMatricesInterface* proj_matrices, Energy<OrbitalsType>* energy,
Electrostatic* electrostat, MGmol<OrbitalsType>* mgmol_strategy, Ions& ions,
Rho<OrbitalsType>* rho, DMStrategy* dm_strategy, std::ostream& os)
Rho<OrbitalsType>* rho, DMStrategy<OrbitalsType>* dm_strategy,
std::ostream& os)
: mgmol_strategy_(mgmol_strategy),
hamiltonian_(hamiltonian),
proj_matrices_(proj_matrices),
Expand Down Expand Up @@ -316,7 +317,7 @@ int DFTsolver<OrbitalsType>::solve(OrbitalsType& orbitals,
const bool ortho
= (ct.getOrthoType() == OrthoType::Eigenfunctions || orthof);

if (!ortho)
if (!ortho || !ct.fullyOccupied())
{
// strip dm from the overlap contribution
// dm <- Ls**T * dm * Ls
Expand All @@ -337,7 +338,14 @@ int DFTsolver<OrbitalsType>::solve(OrbitalsType& orbitals,
}
else
{
orbitals.orthonormalizeLoewdin();
bool updateDM = false;
if (!ct.fullyOccupied())
{
orbitals.computeGramAndInvS();
dm_strategy_->dressDM();
updateDM = true;
}
orbitals.orthonormalizeLoewdin(true, nullptr, updateDM);

orbitals_stepper_->restartMixing();
}
Expand Down Expand Up @@ -384,7 +392,7 @@ int DFTsolver<OrbitalsType>::solve(OrbitalsType& orbitals,
mgmol_strategy_->updateHmatrix(orbitals, ions);

// compute new density matrix
dm_strategy_->update();
dm_strategy_->update(orbitals);

incInnerIt();

Expand Down
8 changes: 4 additions & 4 deletions src/DFTsolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#ifndef MGMOL_DFTSOLVER_H
#define MGMOL_DFTSOLVER_H

#include "DMStrategy.h"
#include "DielectricControl.h"
#include "Energy.h"
#include "Hamiltonian.h"
Expand All @@ -23,7 +24,6 @@
class Ions;
class Electrostatic;
class ProjectedMatricesInterface;
class DMStrategy;

template <class OrbitalsType>
class DFTsolver
Expand All @@ -40,7 +40,7 @@ class DFTsolver
Electrostatic* electrostat_;
Ions& ions_;
Rho<OrbitalsType>* rho_;
DMStrategy* dm_strategy_;
DMStrategy<OrbitalsType>* dm_strategy_;

OrbitalsStepper<OrbitalsType>* orbitals_stepper_;

Expand Down Expand Up @@ -72,8 +72,8 @@ class DFTsolver
DFTsolver(Hamiltonian<OrbitalsType>* hamiltonian,
ProjectedMatricesInterface* proj_matrices, Energy<OrbitalsType>* energy,
Electrostatic* electrostat, MGmol<OrbitalsType>* mgmol_strategy,
Ions& ions, Rho<OrbitalsType>* rho, DMStrategy* dm_strategy,
std::ostream& os);
Ions& ions, Rho<OrbitalsType>* rho,
DMStrategy<OrbitalsType>* dm_strategy, std::ostream& os);

~DFTsolver();

Expand Down
5 changes: 3 additions & 2 deletions src/DMStrategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
#ifndef DMSTRATEGY_H
#define DMSTRATEGY_H

template <class OrbitalsType>
class DMStrategy
{
public:
virtual void initialize() = 0;
virtual int update() = 0;
virtual void initialize(OrbitalsType& orbitals) = 0;
virtual int update(OrbitalsType& orbitals) = 0;

virtual ~DMStrategy(){};

Expand Down
33 changes: 18 additions & 15 deletions src/DMStrategyFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include "ReplicatedMatrix.h"

template <>
DMStrategy* DMStrategyFactory<LocGridOrbitals,
DMStrategy<LocGridOrbitals>* DMStrategyFactory<LocGridOrbitals,
dist_matrix::DistMatrix<double>>::createHamiltonianMVP_DMStrategy(MPI_Comm
comm,
std::ostream& os, Ions& ions, Rho<LocGridOrbitals>* rho,
Expand All @@ -13,7 +13,7 @@ DMStrategy* DMStrategyFactory<LocGridOrbitals,
{
if (short_sighted)
{
DMStrategy* dm_strategy
DMStrategy<LocGridOrbitals>* dm_strategy
= new HamiltonianMVP_DMStrategy<VariableSizeMatrix<sparserow>,
ProjectedMatricesSparse, LocGridOrbitals>(comm, os, ions, rho,
energy, electrostat, mgmol_strategy, orbitals);
Expand All @@ -22,18 +22,19 @@ DMStrategy* DMStrategyFactory<LocGridOrbitals,
}
else
{
DMStrategy* dm_strategy = new HamiltonianMVP_DMStrategy<
dist_matrix::DistMatrix<DISTMATDTYPE>,
ProjectedMatrices<dist_matrix::DistMatrix<DISTMATDTYPE>>,
LocGridOrbitals>(
comm, os, ions, rho, energy, electrostat, mgmol_strategy, orbitals);
DMStrategy<LocGridOrbitals>* dm_strategy
= new HamiltonianMVP_DMStrategy<
dist_matrix::DistMatrix<DISTMATDTYPE>,
ProjectedMatrices<dist_matrix::DistMatrix<DISTMATDTYPE>>,
LocGridOrbitals>(comm, os, ions, rho, energy, electrostat,
mgmol_strategy, orbitals);

return dm_strategy;
}
}

template <>
DMStrategy* DMStrategyFactory<ExtendedGridOrbitals,
DMStrategy<ExtendedGridOrbitals>* DMStrategyFactory<ExtendedGridOrbitals,
dist_matrix::DistMatrix<double>>::createHamiltonianMVP_DMStrategy(MPI_Comm
comm,
std::ostream& os, Ions& ions, Rho<ExtendedGridOrbitals>* rho,
Expand All @@ -44,7 +45,7 @@ DMStrategy* DMStrategyFactory<ExtendedGridOrbitals,
{
(void)short_sighted;

DMStrategy* dm_strategy
DMStrategy<ExtendedGridOrbitals>* dm_strategy
= new HamiltonianMVP_DMStrategy<dist_matrix::DistMatrix<DISTMATDTYPE>,
ProjectedMatrices<dist_matrix::DistMatrix<DISTMATDTYPE>>,
ExtendedGridOrbitals>(
Expand All @@ -55,19 +56,21 @@ DMStrategy* DMStrategyFactory<ExtendedGridOrbitals,

#ifdef HAVE_MAGMA
template <>
DMStrategy* DMStrategyFactory<ExtendedGridOrbitals,
DMStrategy<ExtendedGridOrbitals>* DMStrategyFactory<ExtendedGridOrbitals,
ReplicatedMatrix>::createHamiltonianMVP_DMStrategy(MPI_Comm comm,
std::ostream& os, Ions& ions, Rho<ExtendedGridOrbitals>* rho,
Energy<ExtendedGridOrbitals>* energy, Electrostatic* electrostat,
MGmol<ExtendedGridOrbitals>* mgmol_strategy,
ProjectedMatricesInterface* /*proj_matrices*/,
ExtendedGridOrbitals* orbitals, const bool short_sighted)
ProjectedMatricesInterface* /*proj_matrices*/, LocGridOrbitals* orbitals,
const bool short_sighted)
{
(void)short_sighted;

DMStrategy* dm_strategy = new HamiltonianMVP_DMStrategy<ReplicatedMatrix,
ProjectedMatrices<ReplicatedMatrix>, ExtendedGridOrbitals>(
comm, os, ions, rho, energy, electrostat, mgmol_strategy, orbitals);
DMStrategy<ExtendedGridOrbitals>* dm_strategy
= new HamiltonianMVP_DMStrategy<ReplicatedMatrix,
ProjectedMatrices<ReplicatedMatrix>, ExtendedGridOrbitals>(comm, os,
ions, rho, energy, electrostat, mgmol_strategy,
orbitals->getOverlappingGids());

return dm_strategy;
}
Expand Down
27 changes: 14 additions & 13 deletions src/DMStrategyFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@ template <class OrbitalsType, class MatrixType>
class DMStrategyFactory
{
public:
static DMStrategy* create(MPI_Comm comm, std::ostream& os, Ions& ions,
Rho<OrbitalsType>* rho, Energy<OrbitalsType>* energy,
static DMStrategy<OrbitalsType>* create(MPI_Comm comm, std::ostream& os,
Ions& ions, Rho<OrbitalsType>* rho, Energy<OrbitalsType>* energy,
Electrostatic* electrostat, MGmol<OrbitalsType>* mgmol_strategy,
ProjectedMatricesInterface* proj_matrices, OrbitalsType* orbitals)
{
Control& ct = *(Control::instance());
MGmol_MPI& mmpi = *(MGmol_MPI::instance());

DMStrategy* dm_strategy = nullptr;
DMStrategy<OrbitalsType>* dm_strategy = nullptr;
if (ct.DM_solver() == DMNonLinearSolverType::MVP)
{
dm_strategy = new MVP_DMStrategy<OrbitalsType, MatrixType>(comm, os,
ions, rho, energy, electrostat, mgmol_strategy, orbitals,
proj_matrices, ct.use_old_dm());
ions, rho, energy, electrostat, mgmol_strategy,
orbitals->getOverlappingGids(), proj_matrices, ct.use_old_dm());
}
else if (ct.DM_solver() == DMNonLinearSolverType::HMVP)
{
Expand All @@ -51,17 +51,17 @@ class DMStrategyFactory
{
if (mmpi.instancePE0())
std::cout << "Fully occupied strategy" << std::endl;
dm_strategy
= new FullyOccupiedNonOrthoDMStrategy(proj_matrices);
dm_strategy = new FullyOccupiedNonOrthoDMStrategy<OrbitalsType>(
proj_matrices);
}
else
{
if (ct.getOrthoType() == OrthoType::Eigenfunctions)
{
if (mmpi.instancePE0())
std::cout << "EigenDMStrategy..." << std::endl;
dm_strategy = new EigenDMStrategy<OrbitalsType>(
orbitals, proj_matrices);
dm_strategy
= new EigenDMStrategy<OrbitalsType>(proj_matrices);
}
else
{
Expand All @@ -70,7 +70,7 @@ class DMStrategyFactory
if (mmpi.instancePE0())
std::cout << "NonOrthoDMStrategy..." << std::endl;
dm_strategy = new NonOrthoDMStrategy<OrbitalsType>(
orbitals, proj_matrices, ct.dm_mix);
proj_matrices, ct.dm_mix);
}
}
}
Expand All @@ -81,11 +81,12 @@ class DMStrategyFactory
}

private:
static DMStrategy* createHamiltonianMVP_DMStrategy(MPI_Comm comm,
std::ostream& os, Ions& ions, Rho<OrbitalsType>* rho,
static DMStrategy<OrbitalsType>* createHamiltonianMVP_DMStrategy(
MPI_Comm comm, std::ostream& os, Ions& ions, Rho<OrbitalsType>* rho,
Energy<OrbitalsType>* energy, Electrostatic* electrostat,
MGmol<OrbitalsType>* mgmol_strategy,
ProjectedMatricesInterface* proj_matrices, OrbitalsType*, const bool);
ProjectedMatricesInterface* proj_matrices, OrbitalsType* orbitals,
const bool);
};

#endif
27 changes: 13 additions & 14 deletions src/EigenDMStrategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@
#include "LocGridOrbitals.h"
#include "ProjectedMatrices.h"

template <class T>
EigenDMStrategy<T>::EigenDMStrategy(
T* current_orbitals, ProjectedMatricesInterface* proj_matrices)
: current_orbitals_(current_orbitals), proj_matrices_(proj_matrices)
template <class OrbitalsType>
EigenDMStrategy<OrbitalsType>::EigenDMStrategy(
ProjectedMatricesInterface* proj_matrices)
: proj_matrices_(proj_matrices)
{
}

template <class T>
void EigenDMStrategy<T>::initialize()
template <class OrbitalsType>
void EigenDMStrategy<OrbitalsType>::initialize(OrbitalsType& orbitals)
{
update();
update(orbitals);
}

template <class T>
int EigenDMStrategy<T>::update()
template <class OrbitalsType>
int EigenDMStrategy<OrbitalsType>::update(OrbitalsType& orbitals)
{
Control& ct = *(Control::instance());

Expand All @@ -37,14 +37,13 @@ int EigenDMStrategy<T>::update()
= dynamic_cast<
ProjectedMatrices<dist_matrix::DistMatrix<DISTMATDTYPE>>*>(
proj_matrices_);
pmat->updateDMwithEigenstatesAndRotate(
current_orbitals_->getIterativeIndex(), zz);
pmat->updateDMwithEigenstatesAndRotate(orbitals.getIterativeIndex(), zz);

// if( onpe0 && ct.verbose>2 )
// (*MPIdata::sout)<<"get_dm_diag: rotate orbitals "<<endl;
current_orbitals_->multiply_by_matrix(zz);
current_orbitals_->setDataWithGhosts();
current_orbitals_->trade_boundaries();
orbitals.multiply_by_matrix(zz);
orbitals.setDataWithGhosts();
orbitals.trade_boundaries();

return 0;
}
Expand Down
12 changes: 5 additions & 7 deletions src/EigenDMStrategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,17 @@
#include "DMStrategy.h"
class ProjectedMatricesInterface;

template <class T>
class EigenDMStrategy : public DMStrategy
template <class OrbitalsType>
class EigenDMStrategy : public DMStrategy<OrbitalsType>
{
private:
T* current_orbitals_;
ProjectedMatricesInterface* proj_matrices_;

public:
EigenDMStrategy(
T* current_orbitals, ProjectedMatricesInterface* proj_matrices);
EigenDMStrategy(ProjectedMatricesInterface* proj_matrices);

void initialize() override;
int update() override;
void initialize(OrbitalsType& orbitals) override;
int update(OrbitalsType& orbitals) override;

bool needH() const override { return true; }

Expand Down
2 changes: 2 additions & 0 deletions src/ExtendedGridOrbitals.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ void ExtendedGridOrbitals::reset(MasksSet* masks, MasksSet* corrmasks,

void ExtendedGridOrbitals::assign(const ExtendedGridOrbitals& orbitals)
{
assert(proj_matrices_ != nullptr);

assign_tm_.start();

setIterativeIndex(orbitals);
Expand Down
6 changes: 5 additions & 1 deletion src/ExtendedGridOrbitals.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,11 @@ class ExtendedGridOrbitals : public Orbitals
virtual void assign(const ExtendedGridOrbitals& orbitals);
void copyDataFrom(const ExtendedGridOrbitals& src);

ProjectedMatricesInterface* getProjMatrices() { return proj_matrices_; }
ProjectedMatricesInterface* getProjMatrices()
{
assert(proj_matrices_ != nullptr);
return proj_matrices_;
}

const ProjectedMatricesInterface* projMatrices() const
{
Expand Down
20 changes: 17 additions & 3 deletions src/FullyOccupiedNonOrthoDMStrategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,30 @@
// Please also read this link https://github.com/llnl/mgmol/LICENSE

#include "FullyOccupiedNonOrthoDMStrategy.h"
#include "ExtendedGridOrbitals.h"
#include "LocGridOrbitals.h"
#include "ProjectedMatricesInterface.h"

FullyOccupiedNonOrthoDMStrategy::FullyOccupiedNonOrthoDMStrategy(
template <class OrbitalsType>
FullyOccupiedNonOrthoDMStrategy<OrbitalsType>::FullyOccupiedNonOrthoDMStrategy(
ProjectedMatricesInterface* proj_matrices)
: proj_matrices_(proj_matrices)
{
}

void FullyOccupiedNonOrthoDMStrategy::initialize() { update(); }
template <class OrbitalsType>
void FullyOccupiedNonOrthoDMStrategy<OrbitalsType>::initialize(
OrbitalsType& orbitals)
{
update(orbitals);
}

int FullyOccupiedNonOrthoDMStrategy::update()
template <class OrbitalsType>
int FullyOccupiedNonOrthoDMStrategy<OrbitalsType>::update(
OrbitalsType& orbitals)
{
assert(proj_matrices_ != nullptr);
(void)orbitals;

proj_matrices_->setDMto2InvS();

Expand All @@ -34,3 +45,6 @@ int FullyOccupiedNonOrthoDMStrategy::update()

return 0; // success
}

template class FullyOccupiedNonOrthoDMStrategy<LocGridOrbitals>;
template class FullyOccupiedNonOrthoDMStrategy<ExtendedGridOrbitals>;
Loading
Loading