diff --git a/pyproject.toml b/pyproject.toml
index 46659c0..d0c8fa4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,7 +20,7 @@ description = "ABACUS agent tools for connecting LLMs to first-principles calcul
requires-python = ">=3.11"
dependencies = [
- "abacustest>=0.4.51",
+ "abacustest>=0.4.55",
"numpy<2.0",
"pymatgen>=2025.5.28",
"bohr-agent-sdk==0.1.121",
diff --git a/src/abacusagent/modules/dos.py b/src/abacusagent/modules/dos.py
index 0555cf5..1b3257d 100644
--- a/src/abacusagent/modules/dos.py
+++ b/src/abacusagent/modules/dos.py
@@ -1,18 +1,19 @@
from pathlib import Path
-from typing import Dict, Any, List, Literal
+from typing import Dict, Any, List, Literal, Optional, Tuple
from abacusagent.init_mcp import mcp
from abacusagent.modules.submodules.dos import abacus_dos_run as _abacus_dos_run
+from abacusagent.modules.submodules.dos import plot_write_dos_pdos as _plot_write_dos_pdos
+
@mcp.tool()
def abacus_dos_run(
abacus_inputs_dir: Path,
- pdos_mode: Literal['species', 'species+shell', 'species+orbital'] = 'species+shell',
+ pdos_mode: Literal['atoms', 'species', 'species+shell', 'species+orbital'] = 'species+shell',
+ pdos_atom_indices: Optional[List[int]] = None,
dos_edelta_ev: float = 0.01,
dos_sigma: float = 0.07,
- dos_scale: float = 0.01,
- dos_emin_ev: float = None,
- dos_emax_ev: float = None,
- dos_nche: int = None,
+ dos_emin_ev: float = -10.0,
+ dos_emax_ev: float = 10.0,
) -> Dict[str, Any]:
"""Run the DOS and PDOS calculation.
@@ -24,27 +25,58 @@ def abacus_dos_run(
Args:
abacus_inputs_dir: Path to the ABACUS input files, which contains the INPUT, STRU, KPT, and pseudopotential or orbital files.
pdos_mode: Mode of plotted PDOS file.
+ - "atoms": PDOS of a list of atoms will be plotted.
- "species": Total PDOS of any species will be plotted in a picture.
- "species+shell": PDOS for any shell (s, p, d, f, g,...) of any species will be plotted. PDOS of a shell of a species willbe plotted in a subplot.
- "species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot.
+ pdos_atom_indices: A list of atom indices, only used if pdos_mode is "atoms".
dos_edelta_ev: Step size in writing Density of States (DOS) in eV.
dos_sigma: Width of the Gaussian factor when obtaining smeared Density of States (DOS) in eV.
- dos_scale: Defines the energy range of DOS output as (emax-emin)*(1+dos_scale), centered at (emax+emin)/2.
- This parameter will be used when dos_emin_ev and dos_emax_ev are not set.
- dos_emin_ev: Minimal range for Density of States (DOS) in eV.
- dos_emax_ev: Maximal range for Density of States (DOS) in eV.
- dos_nche: The order of Chebyshev expansions when using Stochastic Density Functional Theory (SDFT) to calculate DOS.
-
+ dos_emin_ev: Minimal range for Density of States (DOS) in eV. Default is -10.0.
+ dos_emax_ev: Maximal range for Density of States (DOS) in eV. Default is 10.0.
+
Returns:
Dict[str, Any]: A dictionary containing:
- dos_fig_path: Path to the plotted DOS.
- pdos_fig_path: Path to the plotted PDOS. Only for LCAO basis.
+ - dos_data_path: Path to the data used in plotting DOS.
+ - pdos_data_paths: Path to the data used in plotting PDOS. Only for LCAO basis.
- scf_work_path: Path to the work directory of SCF calculation.
- scf_normal_end: If the SCF calculation ended normally.
- scf_steps: Number of steps of SCF iteration.
- scf_converge: If the SCF calculation converged.
- scf_energy: The calculated energy of SCF calculation.
- - nscf_work_path: Path to the work directory of NSCF calculation.
- - nscf_normal_end: If the SCF calculation ended normally.
+ - nscf_work_path: Path to the work directory of NSCF calculation
+ """
+ return _abacus_dos_run(abacus_inputs_dir, pdos_mode, pdos_atom_indices, dos_edelta_ev, dos_sigma, dos_emin_ev, dos_emax_ev)
+
+def plot_write_dos_pdos(
+ scf_job_path: Path,
+ nscf_job_path: Path,
+ mode: Literal[
+ "species", "species+shell", "species+orbital", "atoms"
+ ] = "species+shell",
+ pdos_atom_indices: Optional[List[int]] = None,
+ dos_emin_ev: float = -10.0,
+ dos_emax_ev: float = 5.0,
+) -> Tuple[List[str], List[str]]:
+ """
+ Plot DOS, PDOS and write data used in plotting to files using SCF and NSCF job directories from abacus_dos_run.
+
+ Args:
+ scf_job_path (Path): Path to the SCF job directory of the DOS calculation
+ nscf_job_path (Path): Path to the NSCF job directory of the DOS calculation
+ mode: Mode for plotting PDOS and write PDOS data.
+ - "atoms": PDOS of a list of atoms will be plotted.
+ - "species": Total PDOS of any species will be plotted in a picture.
+ - "species+shell": PDOS for any shell (s, p, d, f, g,...) of any species will be plotted. PDOS of a shell of a species willbe plotted in a subplot.
+ - "species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot.
+ pdos_atom_indices: A list of atom indices, only used if pdos_mode is "atoms".
+ pdos_atom_indices (List[int], optional): List of atom indices for atom-specific PDOS. Only valid for 'atoms' mode.
+ dos_emin_ev (float): Minimum energy for DOS and PDOS plots.
+ dos_emax_ev (float): Maximum energy for DOS and PDOS plots.
+
+ Returns:
+ Tuple[List[str], List[str]]: Tuple containing list of plot file paths and data file paths.
"""
- return _abacus_dos_run(abacus_inputs_dir, pdos_mode, dos_edelta_ev, dos_sigma, dos_scale, dos_emin_ev, dos_emax_ev, dos_nche)
+ return _plot_write_dos_pdos(scf_job_path, nscf_job_path, mode, pdos_atom_indices, dos_emin_ev, dos_emax_ev)
diff --git a/src/abacusagent/modules/submodules/dos.py b/src/abacusagent/modules/submodules/dos.py
index 13c197e..d2f14db 100644
--- a/src/abacusagent/modules/submodules/dos.py
+++ b/src/abacusagent/modules/submodules/dos.py
@@ -1,81 +1,59 @@
import os
-import re
-import numpy as np
-import matplotlib.pyplot as plt
+import glob
from abacustest.lib_prepare.abacus import ReadInput, WriteInput
from abacustest.lib_collectdata.collectdata import RESULT
from abacustest.lib_model.comm import check_abacus_inputs
+from abacustest.lib_model.comm_dos import DOSData, PDOSData, l_map, orbital_names
from pathlib import Path
-from typing import Dict, Any, List, Literal
-
-from abacusagent.modules.util.comm import generate_work_path, link_abacusjob, run_abacus, has_chgfile, collect_metrics
-from abacusagent.modules.util.chemical_elements import MAX_ANGULAR_MOMENTUM_OF_ELEMENTS
-
-
-angular_momentum_map = ['s', 'p', 'd', 'f', 'g']
-color_map = {
- 's': '#FF5733',
- 'p': '#33FF57',
- 'd': '#3357FF',
- 'f': '#F033FF',
- 'g': '#33FFF0'
-}
-
-orbital_rep_map = {
- 's': 's',
- 'px': r'$p_x$',
- 'py': r'$p_y$',
- 'pz': r'$p_z$',
- 'dz^2': r'$d_{z^2}$',
- 'dxz': r'$d_{xz}$',
- 'dyz': r'$d_{yz}$',
- 'dxy': r'$d_{xy}$',
- 'dx^2-y^2': r'$d_{x^2-y^2}$',
- 'fz^3': r'$f_{z^3}$',
- 'fxz^2': r'$f_{xz^2}$',
- 'fyz^2': r'$f_{yz^2}$',
- 'fzx^2-zy^2': r'$f_{zx^2-zy^2}$',
- 'fxyz': r'$f_{xyz}$',
- 'fx^3-3*xy^2': r'$f_{x^3-3xy^2}$',
- 'f3yx^2-y^3': r'$f_{3yx^2-y^3}$'
-}
+from typing import Dict, Any, List, Literal, Optional, Tuple
+
+from abacusagent.modules.util.comm import (
+ generate_work_path,
+ link_abacusjob,
+ run_abacus,
+ has_chgfile,
+)
+
+
def abacus_dos_run(
abacus_inputs_dir: Path,
- pdos_mode: Literal['species', 'species+shell', 'species+orbital'] = 'species+shell',
+ pdos_mode: Literal[
+ "atoms", "species", "species+shell", "species+orbital"
+ ] = "species+shell",
+ pdos_atom_indices: Optional[List[int]] = None,
dos_edelta_ev: float = 0.01,
dos_sigma: float = 0.07,
- dos_scale: float = 0.01,
- dos_emin_ev: float = None,
- dos_emax_ev: float = None,
- dos_nche: int = None,
+ dos_emin_ev: float = -10.0,
+ dos_emax_ev: float = 10.0,
) -> Dict[str, Any]:
"""Run the DOS and PDOS calculation.
-
- This function will firstly run a SCF calculation with out_chg set to 1,
+
+ This function will firstly run a SCF calculation with out_chg set to 1,
then run a NSCF calculation with init_chg set to 'file' and out_dos set to 1 or 2.
If the INPUT parameter "basis_type" is "PW", then out_dos will be set to 1, and only DOS will be calculated and plotted.
If the INPUT parameter "basis_type" is "LCAO", then out_dos will be set to 2, and both DOS and PDOS will be calculated and plotted.
-
+
Args:
abacus_inputs_dir: Path to the ABACUS input files, which contains the INPUT, STRU, KPT, and pseudopotential or orbital files.
pdos_mode: Mode of plotted PDOS file.
+ - "atoms": PDOS of a list of atoms will be plotted.
- "species": Total PDOS of any species will be plotted in a picture.
- "species+shell": PDOS for any shell (s, p, d, f, g,...) of any species will be plotted. PDOS of a shell of a species willbe plotted in a subplot.
- - “species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot.
+ - "species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot.
+ pdos_atom_indices: A list of atom indices, only used if pdos_mode is "atoms".
dos_edelta_ev: Step size in writing Density of States (DOS) in eV.
- dos_sigma: Width of the Gaussian factor when obtaining smeared Density of States (DOS) in eV.
- dos_scale: Defines the energy range of DOS output as (emax-emin)*(1+dos_scale), centered at (emax+emin)/2.
- This parameter will be used when dos_emin_ev and dos_emax_ev are not set.
- dos_emin_ev: Minimal range for Density of States (DOS) in eV.
- dos_emax_ev: Maximal range for Density of States (DOS) in eV.
- dos_nche: The order of Chebyshev expansions when using Stochastic Density Functional Theory (SDFT) to calculate DOS.
-
+ dos_sigma: Width of the Gaussian factor when obtaining smeared Density of States (DOS) in eV.
+ dos_emin_ev: Minimal range for Density of States (DOS) in eV. Default is -10.0.
+ dos_emax_ev: Maximal range for Density of States (DOS) in eV. Default is 10.0.
+
Returns:
Dict[str, Any]: A dictionary containing:
- dos_fig_path: Path to the plotted DOS.
- pdos_fig_path: Path to the plotted PDOS. Only for LCAO basis.
+ - dos_data_path: Path to the data used in plotting DOS.
+ - pdos_data_paths: Path to the data used in plotting PDOS. Only for LCAO basis.
- scf_work_path: Path to the work directory of SCF calculation.
- scf_normal_end: If the SCF calculation ended normally.
- scf_steps: Number of steps of SCF iteration.
@@ -88,55 +66,68 @@ def abacus_dos_run(
is_valid, msg = check_abacus_inputs(abacus_inputs_dir)
if not is_valid:
raise RuntimeError(f"Invalid ABACUS input files: {msg}")
-
+
input_file = os.path.join(abacus_inputs_dir, "INPUT")
input_params = ReadInput(input_file)
nspin = input_params.get("nspin", 1)
if nspin in [4]:
- raise ValueError("Currently DOS calculation can only be plotted using for nspin=1 and nspin=2")
+ raise ValueError(
+ "Currently DOS calculation can only be plotted using for nspin=1 and nspin=2"
+ )
+ print("Performing SCF calculation...")
metrics_scf = abacus_dos_run_scf(abacus_inputs_dir)
- metrics_nscf = abacus_dos_run_nscf(metrics_scf["scf_work_path"],
- dos_edelta_ev=dos_edelta_ev,
- dos_sigma=dos_sigma,
- dos_scale=dos_scale,
- dos_emin_ev=dos_emin_ev,
- dos_emax_ev=dos_emax_ev,
- dos_nche=dos_nche)
-
- fig_paths = plot_dos_pdos(metrics_scf["scf_work_path"],
- metrics_nscf["nscf_work_path"],
- metrics_nscf["nscf_work_path"],
- nspin,
- pdos_mode)
+
+ print("Performing NSCF calculation...")
+ metrics_nscf = abacus_dos_run_nscf(
+ metrics_scf["scf_work_path"],
+ dos_edelta_ev=dos_edelta_ev,
+ dos_sigma=dos_sigma,
+ )
+
+ fig_paths, dos_pdos_data_paths = plot_write_dos_pdos(
+ metrics_scf["scf_work_path"],
+ metrics_nscf["nscf_work_path"],
+ pdos_mode,
+ pdos_atom_indices,
+ dos_emin_ev,
+ dos_emax_ev,
+ )
return_dict = {"dos_fig_path": fig_paths[0]}
+ return_dict["dos_data_path"] = dos_pdos_data_paths[0]
try:
- return_dict['pdos_fig_path'] = fig_paths[1]
+ return_dict["pdos_fig_path"] = fig_paths[1]
+ return_dict["pdos_data_path"] = dos_pdos_data_paths[1]
except:
- pass # Do nothing if PDOS file is not plotted
+ pass # Do nothing if PDOS file is not plotted
return_dict.update(metrics_scf)
return_dict.update(metrics_nscf)
return return_dict
except Exception as e:
+ import traceback
+
+ traceback.print_exc()
return {"message": f"Calculating DOS and PDOS failed: {e}"}
-def abacus_dos_run_scf(abacus_inputs_dir: Path,
- force_run: bool = False) -> Dict[str, Any]:
+
+def abacus_dos_run_scf(
+ abacus_inputs_dir: Path, force_run: bool = False
+) -> Dict[str, Any]:
"""
Run the SCF calculation to generate the charge density file.
If the charge file already exists, it will skip the SCF calculation.
-
+
Args:
abacus_inputs_dir: Path to the ABACUS input files, which contains the INPUT, STRU, KPT, and pseudopotential or orbital files.
force_run: If True, it will run the SCF calculation even if the charge file already exists.
-
+
Returns:
Dict[str, Any]: A dictionary containing the work path, normal end status, SCF steps, convergence status, and energies.
"""
-
+
input_param = ReadInput(os.path.join(abacus_inputs_dir, "INPUT"))
# check if charge file has been generated
if has_chgfile(abacus_inputs_dir) and not force_run:
@@ -144,9 +135,7 @@ def abacus_dos_run_scf(abacus_inputs_dir: Path,
work_path = abacus_inputs_dir
else:
work_path = generate_work_path()
- link_abacusjob(src=abacus_inputs_dir,
- dst=work_path,
- copy_files=["INPUT"])
+ link_abacusjob(src=abacus_inputs_dir, dst=work_path, copy_files=["INPUT"])
input_param = ReadInput(os.path.join(work_path, "INPUT"))
input_param["calculation"] = "scf"
@@ -156,432 +145,150 @@ def abacus_dos_run_scf(abacus_inputs_dir: Path,
run_abacus(work_path)
rs = RESULT(path=work_path, fmt="abacus")
-
+
return {
"scf_work_path": Path(work_path).absolute(),
"scf_normal_end": rs["normal_end"],
"scf_steps": rs["scf_steps"],
"scf_converge": rs["converge"],
- "scf_energy": rs["energy"]
+ "scf_energy": rs["energy"],
}
-def abacus_dos_run_nscf(abacus_inputs_dir: Path,
- dos_edelta_ev: float = None,
- dos_sigma: float = None,
- dos_scale: float = None,
- dos_emin_ev: float = None,
- dos_emax_ev: float = None,
- dos_nche: int = None,) -> Dict[str, Any]:
-
+
+def abacus_dos_run_nscf(
+ abacus_inputs_dir: Path, dos_edelta_ev: float = None, dos_sigma: float = None
+) -> Dict[str, Any]:
work_path = generate_work_path()
- link_abacusjob(src=abacus_inputs_dir,
- dst=work_path,
- copy_files=["INPUT"])
-
+ link_abacusjob(
+ src=abacus_inputs_dir,
+ dst=work_path,
+ copy_files=["INPUT", "KPT"]
+ + glob.glob(os.path.join(abacus_inputs_dir, "OUT.*")),
+ exclude=["*log", "*json"],
+ )
+
input_param = ReadInput(os.path.join(work_path, "INPUT"))
input_param["calculation"] = "nscf"
input_param["init_chg"] = "file"
+ input_param["out_chg"] = -1
if input_param.get("basis_type", "pw") == "lcao":
- input_param["out_dos"] = 2 # only for LCAO basis, and will output DOS and PDOS
+ input_param["out_dos"] = 2 # only for LCAO basis, and will output DOS and PDOS
else:
input_param["out_dos"] = 1
-
+
for dos_param, value in {
"dos_edelta_ev": dos_edelta_ev,
"dos_sigma": dos_sigma,
- "dos_scale": dos_scale,
- "dos_emin_ev": dos_emin_ev,
- "dos_emax_ev": dos_emax_ev,
- "dos_nche": dos_nche
}.items():
if value is not None:
input_param[dos_param] = value
-
-
+
WriteInput(input_param, os.path.join(work_path, "INPUT"))
-
+
run_abacus(work_path)
-
+
rs = RESULT(path=work_path, fmt="abacus")
-
+
return {
"nscf_work_path": Path(work_path).absolute(),
- "nscf_normal_end": rs["normal_end"]
+ "nscf_normal_end": rs["normal_end"],
}
-def parse_pdos_file(file_path):
- """Parse the PDOS file and extract energy values and orbital data."""
- with open(file_path, 'r') as f:
- content = f.read()
-
- energy_match = re.search(r'(.*?)', content, re.DOTALL)
- if not energy_match:
- raise ValueError("Energy values not found in the file.")
-
- energy_text = energy_match.group(1)
- energy_values = np.array([float(line.strip()) for line in energy_text.strip().split()])
-
- orbital_pattern = re.compile(r'(.*?)', re.DOTALL)
- orbitals = []
-
- for match in orbital_pattern.finditer(content):
- index, atom_index, species, l, m, z, orbital_content = match.groups()
-
- data_match = re.search(r'(.*?)', orbital_content, re.DOTALL)
- if data_match:
- data_text = data_match.group(1)
- data_values = np.array([float(line.strip()) for line in data_text.strip().split()])
-
- orbitals.append({
- 'index': int(index),
- 'atom_index': int(atom_index),
- 'species': species,
- 'l': int(l),
- 'm': int(m),
- 'z': int(z),
- 'data': data_values
- })
-
- return energy_values, orbitals
-
-def parse_log_file(file_path):
- """Parse Fermi energy from log file and convert to eV."""
- ry_to_ev = 13.605698066
- fermi_energy = None
-
- with open(file_path, 'r') as f:
- for line in f:
- if "Fermi energy is" in line:
- match = re.search(r'Fermi energy is\s*([\d.-]+)', line)
- if match:
- fermi_energy = float(match.group(1))
-
- if fermi_energy is None:
- raise ValueError("Fermi energy not found in log file")
-
- return fermi_energy * ry_to_ev
-
-def parse_basref_file(file_path):
- """Parse basref file to create mapping for custom labels."""
- label_map = {}
-
- with open(file_path, 'r') as f:
- for line in f:
- line = line.strip()
- if not line or line.startswith('#'):
- continue
-
- parts = line.split()
- if len(parts) >= 6:
- # Add 1 to atom_index as per requirement
- atom_index = int(parts[0]) + 1
- species = parts[1]
- l = int(parts[2])
- m = int(parts[3])
- z = int(parts[4])
- symbol = parts[5]
-
- key = (atom_index, species, l, m, z)
- label_map[key] = f'{species}{atom_index}({symbol})'
-
- return label_map
+def plot_write_dos_pdos(
+ scf_job_path: Path,
+ nscf_job_path: Path,
+ mode: Literal[
+ "species", "species+shell", "species+orbital", "atoms"
+ ] = "species+shell",
+ pdos_atom_indices: Optional[List[int]] = None,
+ dos_emin_ev: float = -10.0,
+ dos_emax_ev: float = 5.0,
+) -> Tuple[List[str], List[str]]:
+ """
+ Plot DOS, PDOS and write data used in plotting to files using SCF and NSCF job directories from abacus_dos_run.
-def plot_pdos(energy_values, orbitals, fermi_level, label_map, output_dir, nspin, mode, dpi=300):
- """Plot PDOS data separated by atom/species with custom labels."""
- # Create output directory if it doesn't exist
- os.makedirs(output_dir, exist_ok=True)
-
- # Shift energy values by Fermi level
- shifted_energy = energy_values - fermi_level
-
- # Group orbitals by atom_index and species
- atom_species_groups = {}
- for orbital in orbitals:
- key = (orbital['atom_index'], orbital['species'])
- if key not in atom_species_groups:
- atom_species_groups[key] = []
- atom_species_groups[key].append(orbital)
-
- if mode == "species":
- pdos_pic_file = plot_pdos_species(shifted_energy, orbitals, output_dir, nspin, dpi)
- elif mode == "species+shell":
- pdos_pic_file = plot_pdos_species_shell(shifted_energy, orbitals, output_dir, nspin, dpi)
- elif mode == "species+orbital":
- pdos_pic_file = plot_pdos_species_orbital(shifted_energy, orbitals, output_dir, nspin, label_map, dpi)
- else:
- raise ValueError(f"Not allowed mode {mode}")
-
- return pdos_pic_file
-
-def plot_pdos_species(shifted_energy, orbitals, output_dir, nspin, dpi):
- species = {}
- for orbital in orbitals:
- species_one = orbital['species']
- if species_one not in species.keys():
- species[species_one] = orbital['data']
- else:
- species[species_one] += orbital['data']
-
- num_species = len(species)
- plt.plot(figsize=(10, 6))
- for species_name, pdos_data in species.items():
- if nspin == 1:
- plt.plot(shifted_energy, pdos_data, label=species_name, linewidth=1.0)
- elif nspin == 2:
- plt.plot(shifted_energy, pdos_data[::2], label=f'{species_name} ' + r'$\uparrow$', linestyle='-', linewidth=1.0)
- plt.plot(shifted_energy, -pdos_data[1::2], label=f'{species_name} ' + r'$\downarrow$', linestyle='--', linewidth=1.0)
-
- plt.axvline(x=0, color='black', linestyle=':', linewidth=1.0)
- plt.xlabel('Energy (eV)', fontsize=10)
- plt.ylabel(r"States ($eV^{-1}$)", fontsize=10)
- plt.xlim(max(min(shifted_energy), -20), min(20, max(shifted_energy)))
- if nspin == 1:
- plt.ylim(bottom=0)
- plt.legend(fontsize=8, ncol=nspin)
- plt.grid(alpha=0.3)
- plt.title('Projected density of States of different species')
-
- pdos_pic_file = os.path.join(output_dir, 'PDOS.png')
- plt.savefig(pdos_pic_file, dpi=dpi)
- plt.close()
-
- return Path(pdos_pic_file).absolute()
-
-def plot_pdos_species_shell(shifted_energy, orbitals, output_dir, nspin, dpi):
- species_shells = {}
- for orbital in orbitals:
- species = orbital['species']
- if species not in species_shells.keys():
- species_shells[species] = {} # Initialize species kind
-
- angular_momentum = angular_momentum_map[orbital['l']]
- # The orbital with higher angular momentum than in realistic atoms will be ignored.
- if angular_momentum_map.index(angular_momentum) <= angular_momentum_map.index(MAX_ANGULAR_MOMENTUM_OF_ELEMENTS[orbital['species']]):
- if angular_momentum not in species_shells[species].keys():
- species_shells[species][angular_momentum] = orbital['data'] # Initialize DOS for angular momentum of a species
- else:
- species_shells[species][angular_momentum] += orbital['data'] # Add DOS of a angular momentum of a species
-
- # Plot PDOS for each species and each shell
- num_species = len(species_shells)
- fig, axes = plt.subplots(nrows=num_species, ncols=1, figsize=(8, 4*num_species))
- if num_species == 1:
- axes = [axes]
-
- for species_idx, (species, pdos_data_dict) in enumerate(species_shells.items()):
- ax = axes[species_idx]
-
- for l, pdos_data in pdos_data_dict.items():
- if nspin == 1:
- ax.plot(shifted_energy, pdos_data, color=color_map[l], label=f'{species}-{l}', linewidth=1.0)
- elif nspin == 2:
- ax.plot(shifted_energy, pdos_data[::2], color=color_map[l], label=f'{species}-{l}' + r' $\uparrow$', linestyle='-', linewidth=1.0)
- ax.plot(shifted_energy, -pdos_data[1::2], color=color_map[l], label=f'{species}-{l}' + r' $\downarrow$', linestyle='--', linewidth=1.0)
-
- ax.axvline(x=0, color='black', linestyle=':', linewidth=1.0)
- ax.set_title(f'PDOS for {species}', fontsize=12, pad=10)
- ax.set_ylabel(r"States ($eV^{-1}$)", fontsize=10)
- ax.set_xlim(max(min(shifted_energy), -20), min(20, max(shifted_energy)))
- #if nspin == 1:
- # ax.set_ylim(bottom=0)
- ax.legend(fontsize=8, ncol=nspin)
- ax.grid(alpha=0.3)
-
- #ax.set_ylim(bottom=0)
-
- axes[-1].set_xlabel('Energy (eV)', fontsize=10)
-
- plt.tight_layout()
- pdos_pic_file = os.path.join(output_dir, 'PDOS.png')
- plt.savefig(pdos_pic_file, dpi=dpi, bbox_inches='tight')
- plt.close()
-
- return Path(pdos_pic_file).absolute()
-
-def plot_pdos_species_orbital(shifted_energy, orbitals, output_dir, nspin, label_map, dpi):
-
- plt.rcParams["text.usetex"] = False
- plt.rcParams["axes.prop_cycle"] = plt.cycler("color", plt.cm.tab20.colors)
-
- orbital_label = {}
- for (atom_index, species, l, m, z), full_label in label_map.items():
- if species not in orbital_label.keys():
- orbital_label[species] = {}
- if str(l) not in orbital_label[species].keys():
- orbital_label[species][str(l)] = {}
- if str(m) not in orbital_label[species][str(l)].keys():
- orbital_name = full_label.split('(')[1].split(')')[0]
- if orbital_name in orbital_rep_map.keys():
- orbital_label[species][str(l)][str(m)] = orbital_rep_map[orbital_name]
- else:
- orbital_label[species][str(l)][str(m)] = orbital_name
- else:
- pass
-
- species_orbitals = {}
- for orbital in orbitals:
- species = orbital['species']
- if species not in species_orbitals.keys():
- species_orbitals[species] = {}
-
- angular_momentum = angular_momentum_map[orbital['l']]
- # The orbital with higher angular momentum than in realistic atoms will be ignored.
- if angular_momentum_map.index(angular_momentum) <= angular_momentum_map.index(MAX_ANGULAR_MOMENTUM_OF_ELEMENTS[orbital['species']]):
- if angular_momentum not in species_orbitals[species].keys():
- species_orbitals[species][angular_momentum] = {}
-
- mag_quantum_num = orbital['m']
- if mag_quantum_num not in species_orbitals[species][angular_momentum].keys():
- species_orbitals[species][angular_momentum][mag_quantum_num] = orbital['data']
- else:
- species_orbitals[species][angular_momentum][mag_quantum_num] += orbital['data']
-
- total_subplots = 0
- for species, species_pdos in species_orbitals.items():
- total_subplots += len(species_pdos)
- fig, axes = plt.subplots(nrows=total_subplots, ncols=1, figsize=(8, 4*total_subplots))
-
- subplot_count = 0
- for species, species_pdos in species_orbitals.items():
- for angular_momentum, species_shell_pdos in species_pdos.items():
- for m, species_orbital_pdos in species_shell_pdos.items():
- ax = axes[subplot_count]
- orbital_name = orbital_label[species][str(angular_momentum_map.index(angular_momentum))][str(m)]
- if nspin == 1:
- ax.plot(shifted_energy, species_orbital_pdos, label=f'{orbital_name}', linewidth=1.0)
- elif nspin == 2:
- ax.plot(shifted_energy, species_orbital_pdos[::2], label=f'{orbital_name} '+r'$\uparrow$', linestyle='-', linewidth=1.0)
- ax.plot(shifted_energy, -species_orbital_pdos[1::2], label=f'{orbital_name} '+r'$\downarrow$', linestyle='--', linewidth=1.0)
-
- ax.axvline(x=0, color='black', linestyle=':', linewidth=1.0)
- ax.set_title(f'PDOS for {species}-{angular_momentum}', fontsize=12, pad=10)
- ax.set_xlim(max(min(shifted_energy), -20), min(20, max(shifted_energy)))
- if nspin == 1:
- ax.set_ylim(bottom=0)
- ax.set_ylabel(r"States ($eV^{-1}$)", fontsize=10)
- ax.legend(fontsize=8, ncol=nspin)
- ax.grid(alpha=0.3)
-
- subplot_count += 1
-
- axes[-1].set_xlabel('Energy (eV)', fontsize=10)
-
- plt.tight_layout()
- pdos_pic_file = os.path.join(output_dir, 'PDOS.png')
- plt.savefig(pdos_pic_file, dpi=dpi, bbox_inches='tight')
- plt.close()
-
- return Path(pdos_pic_file).absolute()
-
-def plot_dos(file_path: List[Path],
- fermi_level: float,
- output_file: str = 'DOS.png',
- nspin: Literal[1, 2] = 1,
- dpi: int=300):
- """Plot total DOS from DOS1_smearing.dat and DOS2_smearing (if nspin=2) file."""
- # Read first two columns from file
- data = np.loadtxt(file_path[0], usecols=(0, 1))
- energy = data[:, 0] - fermi_level # Shift by Fermi level
- dos = data[:, 1]
- if nspin == 2:
- data = np.loadtxt(file_path[1], usecols=(0, 1))
- dos_dn = data[:, 1]
-
- # Determine energy limits based on data within x range
- x_min, x_max = max(min(energy), -20), min(20, max(energy))
-
- # Create plot
- plt.figure(figsize=(8, 6))
- if nspin == 1:
- plt.plot(energy, dos, linestyle='-')
- elif nspin == 2:
- plt.plot(energy, dos, linestyle='-', label='spin up')
- plt.plot(energy, -dos_dn, linestyle='--', label='spin down')
- plt.axvline(x=0, color='k', linestyle='--', alpha=0.5)
- plt.xlabel('Energy (eV)')
- plt.ylabel(r'States ($eV^{-1}$)')
- plt.title('Density of States')
- plt.grid(True, alpha=0.3)
- plt.xlim(x_min, x_max)
- #plt.ylim(y_min, y_max)
- #plt.legend()
-
- # Save plot
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
- plt.savefig(output_file, dpi=dpi, bbox_inches='tight')
- plt.close()
-
- return Path(output_file).absolute()
-
-def plot_dos_pdos(scf_job_path: Path,
- nscf_job_path: Path,
- output_dir: Path,
- nspin: Literal[1, 2] = 1,
- mode: Literal['species', 'species+shell', 'species+orbital'] = 'species+shell',
- dpi=300) -> List[str]:
- """Plot DOS and PDOS from the NSCF job path.
-
Args:
- nscf_job_path (Path): Path to the NSCF job directory containing the OUT.* files.
- output_dir (Path): Directory where the output plots will be saved.
- dpi (int): Dots per inch for the saved plots.
-
- Returns:
- List[str]: List of paths to the generated plot files.
+ scf_job_path (Path): Path to the SCF job directory of the DOS calculation
+ nscf_job_path (Path): Path to the NSCF job directory of the DOS calculation
+ mode: Mode for plotting PDOS and write PDOS data.
+ - "atoms": PDOS of a list of atoms will be plotted.
+ - "species": Total PDOS of any species will be plotted in a picture.
+ - "species+shell": PDOS for any shell (s, p, d, f, g,...) of any species will be plotted. PDOS of a shell of a species willbe plotted in a subplot.
+ - "species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot.
+ pdos_atom_indices: A list of atom indices, only used if pdos_mode is "atoms".
+ pdos_atom_indices (List[int], optional): List of atom indices for atom-specific PDOS. Only valid for 'atoms' mode.
+ dos_emin_ev (float): Minimum energy for DOS and PDOS plots.
+ dos_emax_ev (float): Maximum energy for DOS and PDOS plots. """
+ work_path = generate_work_path()
- """
input_param = ReadInput(os.path.join(nscf_job_path, "INPUT"))
- input_dir = os.path.join(nscf_job_path, "OUT." + input_param.get("suffix","ABACUS"))
- basis_type = input_param.get('basis_type', 'pw')
-
- # Construct file paths based on input directory
- pdos_file = os.path.join(input_dir, "PDOS")
- log_file = os.path.join(input_dir, "running_nscf.log")
- basref_file = os.path.join(input_dir, "Orbital")
- dos_file = [os.path.join(input_dir, "DOS1_smearing.dat")]
- dos_output = os.path.join(output_dir, "DOS.png")
- if nspin == 2:
- dos_file += [os.path.join(input_dir, "DOS2_smearing.dat")]
-
- # Validate input files exist
- for file_path in [log_file, dos_file[0]]:
- if not os.path.exists(file_path):
- print(f"Error: File not found - {file_path}")
- raise FileNotFoundError(f"Required file not found: {file_path}")
- if nspin == 2:
- if not os.path.exists(dos_file[1]):
- print(f"Error: File not found - {dos_file[1]}")
- raise FileNotFoundError(f"Required file not found: {dos_file[1]}")
+ basis_type = input_param.get("basis_type", "pw")
+ results = RESULT(fmt="abacus", path=scf_job_path)
+ efermi = results['efermi']
+
+ # Construct file paths
+ dos_plot_file = os.path.join(work_path, "DOS.png")
+ dos_data_file = os.path.join(work_path, "DOS.dat")
+
+ dosdata = DOSData.ReadFromAbacusJob(str(nscf_job_path), efermi)
+ dosdata.plot_dos(
+ dos_emin_ev,
+ dos_emax_ev,
+ "Density of States",
+ dos_plot_file,
+ )
+ dosdata.write_dos(dos_data_file)
+
+ all_plot_files = [Path(dos_plot_file).absolute()]
+ dos_pdos_data_files = [Path(dos_data_file).absolute()]
- fermi_level = collect_metrics(scf_job_path, ['efermi'])['efermi']
-
- # Plot DOS and get file path
- dos_plot_file = plot_dos(dos_file, fermi_level, dos_output, nspin, dpi)
- all_plot_files = [dos_plot_file]
-
print("DOS file plotted")
- # Plot PDOS (only for LCAO basis - PW basis doesn't support PDOS in ABACUS LTSv3.10)
- if os.path.exists(pdos_file) and os.path.exists(basref_file):
- if basis_type != 'pw':
- label_map = parse_basref_file(basref_file)
- energy_values, orbitals = parse_pdos_file(pdos_file)
- pdos_plot_file = plot_pdos(energy_values, orbitals, fermi_level, label_map, output_dir, nspin, mode, dpi)
-
+ # Plot PDOS using PDOSData class (only for LCAO basis)
+ if basis_type != "pw":
+ try:
+ # Load PDOS data using PDOSData class
+ pdos_data = PDOSData.ReadFromAbacusJob(str(nscf_job_path), efermi)
+ pdos_plot_file = Path(os.path.join(work_path, "PDOS.png")).absolute()
+ pdos_data_file = Path(os.path.join(work_path, "PDOS.dat")).absolute()
+
+ # Plot PDOS based on mode
+ if mode == "species":
+ pdos_data.plot_species_pdos(dos_emin_ev, dos_emax_ev, pdos_plot_file)
+ pdos_data.write_species_pdos(pdos_data_file)
+ elif mode == "species+shell":
+ pdos_data.plot_species_shell_pdos(dos_emin_ev, dos_emax_ev, pdos_plot_file)
+ pdos_data.write_species_shell_pdos(pdos_data_file)
+ elif mode == "species+orbital":
+ pdos_data.plot_species_orbital_pdos(dos_emin_ev, dos_emax_ev, pdos_plot_file)
+ pdos_data.write_species_orbital_pdos(pdos_data_file)
+ elif mode == "atoms":
+ if pdos_atom_indices is None or len(pdos_atom_indices) == 0:
+ raise ValueError(
+ "For 'atoms' mode, pdos_atom_indices must be provided"
+ )
+ pdos_data.plot_atoms_pdos(pdos_atom_indices, dos_emin_ev, dos_emax_ev, pdos_plot_file)
+ pdos_data.write_atoms_pdos(pdos_atom_indices, pdos_data_file)
+ else:
+ raise ValueError(f"Unsupported mode: {mode}")
+
# Combine file paths into a single list
- all_plot_files.append(pdos_plot_file)
- else:
- print(f"Warning: PDOS calculation not supported for PW basis type, skipping PDOS plotting")
- elif os.path.exists(pdos_file) and not os.path.exists(basref_file):
- print(f"Warning: PDOS file exists but Orbital file not found - {basref_file}, skipping PDOS plotting")
- elif not os.path.exists(pdos_file) and os.path.exists(basref_file):
- print(f"Warning: Orbital file exists but PDOS file not found - {pdos_file}, skipping PDOS plotting")
+ all_plot_files.append(Path(pdos_plot_file).absolute())
+ dos_pdos_data_files.append(Path(pdos_data_file).absolute())
+
+ except Exception as e:
+ import traceback
+ traceback.print_exc()
+ print(f"Warning: Failed to plot PDOS: {e}")
+ print("Skipping PDOS plotting")
else:
- print("Warning: Both PDOS and Orbital files not found, skipping PDOS plotting")
+ print(
+ f"Warning: PDOS calculation not supported for PW basis type, skipping PDOS plotting"
+ )
print("Plots generated:")
for file in all_plot_files:
print(f"- {file}")
-
- return all_plot_files
+
+ return all_plot_files, dos_pdos_data_files
diff --git a/src/abacusagent/modules/tool_wrapper.py b/src/abacusagent/modules/tool_wrapper.py
index 3195121..b9adbf4 100644
--- a/src/abacusagent/modules/tool_wrapper.py
+++ b/src/abacusagent/modules/tool_wrapper.py
@@ -382,13 +382,12 @@ def abacus_dos_run(
relax_precision: Literal['low', 'medium', 'high'] = 'medium',
relax_method: Literal["cg", "bfgs", "bfgs_trad", "cg_bfgs", "sd", "fire"] = "cg",
fixed_axes: Literal["None", "volume", "shape", "a", "b", "c", "ab", "ac", "bc"] = None,
- pdos_mode: Literal['species', 'species+shell', 'species+orbital'] = 'species+shell',
+ pdos_mode: Literal['atoms', 'species', 'species+shell', 'species+orbital'] = 'species+shell',
+ pdos_atom_indices: Optional[List[int]] = None,
dos_edelta_ev: float = 0.01,
dos_sigma: float = 0.07,
- dos_scale: float = 0.01,
- dos_emin_ev: float = None,
- dos_emax_ev: float = None,
- dos_nche: int = None,
+ dos_emin_ev: float = -10.0,
+ dos_emax_ev: float = 10.0,
) -> Dict[str, Any]:
"""
Run the DOS and PDOS calculation.
@@ -432,21 +431,22 @@ def abacus_dos_run(
- ac: fix both a and c axes
- bc: fix both b and c axes
pdos_mode: Mode of plotted PDOS file.
+ - "atoms": PDOS of a list of atoms will be plotted.
- "species": Total PDOS of any species will be plotted in a picture.
- "species+shell": PDOS for any shell (s, p, d, f, g,...) of any species will be plotted. PDOS of a shell of a species willbe plotted in a subplot.
- - “species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot.
+ - "species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot.
+ pdos_atom_indices: A list of atom indices, only used if pdos_mode is "atoms".
dos_edelta_ev: Step size in writing Density of States (DOS) in eV.
dos_sigma: Width of the Gaussian factor when obtaining smeared Density of States (DOS) in eV.
- dos_scale: Defines the energy range of DOS output as (emax-emin)*(1+dos_scale), centered at (emax+emin)/2.
- This parameter will be used when dos_emin_ev and dos_emax_ev are not set.
- dos_emin_ev: Minimal range for Density of States (DOS) in eV.
- dos_emax_ev: Maximal range for Density of States (DOS) in eV.
- dos_nche: The order of Chebyshev expansions when using Stochastic Density Functional Theory (SDFT) to calculate DOS.
-
+ dos_emin_ev: Minimal range for Density of States (DOS) in eV. Default is -10.0.
+ dos_emax_ev: Maximal range for Density of States (DOS) in eV. Default is 10.0.
+
Returns:
Dict[str, Any]: A dictionary containing:
- dos_fig_path: Path to the plotted DOS.
- pdos_fig_path: Path to the plotted PDOS. Only for LCAO basis.
+ - dos_data_path: Path to the data used in plotting DOS.
+ - pdos_data_paths: Path to the data used in plotting PDOS. Only for LCAO basis.
- scf_normal_end: If the SCF calculation ended normally.
- scf_converge: If the SCF calculation converged.
- scf_energy: The calculated energy of SCF calculation.
@@ -474,15 +474,16 @@ def abacus_dos_run(
dos_results = _abacus_dos_run(abacus_inputs_dir,
pdos_mode,
+ pdos_atom_indices,
dos_edelta_ev,
dos_sigma,
- dos_scale,
dos_emin_ev,
- dos_emax_ev,
- dos_nche)
+ dos_emax_ev)
return {'dos_fig_path': dos_results.get('dos_fig_path', None),
'pdos_fig_path': dos_results.get('pdos_fig_path', None),
+ 'dos_data_path': dos_results.get('dos_data_path', None),
+ 'pdos_data_paths': dos_results.get('pdos_data_paths', None),
'scf_normal_end': dos_results.get('scf_normal_end', None),
'scf_converge': dos_results.get('scf_converge', None),
'scf_energy': dos_results.get('scf_energy', None),
diff --git a/src/abacusagent/modules/util/comm.py b/src/abacusagent/modules/util/comm.py
index 5645f7b..9935b94 100644
--- a/src/abacusagent/modules/util/comm.py
+++ b/src/abacusagent/modules/util/comm.py
@@ -7,6 +7,7 @@
import json
import traceback
import uuid
+import shutil
import glob
from abacustest.lib_prepare.abacus import ReadInput
@@ -233,7 +234,7 @@ def link_abacusjob(src: str,
exclude (Optional[List[str]]): List of files to exclude. If None, no files are excluded.
copy_files (List[str]): List of files to copy from src to dst. Default is ["INPUT", "STRU", "KPT"].
overwrite (bool): If True, existing files in the destination will be overwritten. Default is True.
- exclude_directories (bool): If True, directories will be excluded from linking. Default is False.
+ exclude_directories (bool): If True, directories will be excluded from linking or copying. Default is False.
Notes:
- If somes files are included in both include and exclude, the file will be excluded.
@@ -250,42 +251,45 @@ def link_abacusjob(src: str,
if include is None:
include = ["*"]
- include_files = []
+ include_paths = []
for pattern in include:
- include_files.extend(src.glob(pattern))
+ include_paths.extend(src.glob(pattern))
if exclude is None:
exclude = []
- exclude_files = []
+ exclude_paths = []
for pattern in exclude:
- exclude_files.extend(src.glob(pattern))
+ exclude_paths.extend(src.glob(pattern))
os.makedirs(dst, exist_ok=True)
# Remove excluded files from included files
- include_files = [f for f in include_files if f not in exclude_files]
- if not include_files:
+ include_paths = [f for f in include_paths if f not in exclude_paths]
+ if not include_paths:
traceback.print_stack()
print("No files to link after applying include and exclude patterns.\n",
f"Include patterns: {include}, Exclude patterns: {exclude}, Source: {src}, Destination: {dst}\n",
f"Files in source: {list(src.glob('*'))}"
)
else:
- for file in include_files:
- if file == dst:
+ for path in include_paths:
+ if path == dst:
continue
- if exclude_directories and os.path.isdir(file):
+ if exclude_directories and path.is_dir():
continue
-
- dst_file = dst / file.name
- if dst_file.exists():
+
+ dst_path = dst / path.name
+ if dst_path.exists():
if overwrite:
- dst_file.unlink()
+ dst_path.unlink()
else:
continue
- if str(file.name) in copy_files:
- os.system(f"cp {file} {dst_file}")
+ if str(path.name) in copy_files or str(path) in copy_files:
+ if os.path.isfile(path):
+ shutil.copy(path, dst_path)
+ else:
+ shutil.copytree(path, dst_path)
else:
- os.symlink(file, dst_file)
+ os.symlink(path, dst_path)
def generate_work_path(create: bool = True) -> str:
"""
diff --git a/tests/integrate_test/data/ref_results.json b/tests/integrate_test/data/ref_results.json
index bd9879f..a7a5ec0 100644
--- a/tests/integrate_test/data/ref_results.json
+++ b/tests/integrate_test/data/ref_results.json
@@ -122,6 +122,12 @@
"scf_energy": -1688.5908919
}
},
+ "test_abacus_dos_run_atoms":
+ {
+ "result": {
+ "scf_energy": -1688.5908919
+ }
+ },
"test_abacus_dos_run_species_nspin2":
{
"result": {
@@ -140,6 +146,12 @@
"scf_energy": -3220.427369911483
}
},
+ "test_abacus_dos_run_atoms_nspin2":
+ {
+ "result": {
+ "scf_energy": -3220.427369911483
+ }
+ },
"test_abacus_dos_run_pw_nspin1":
{
"result": {
diff --git a/tests/integrate_test/test_dos.py b/tests/integrate_test/test_dos.py
index 3731f08..e7b4902 100644
--- a/tests/integrate_test/test_dos.py
+++ b/tests/integrate_test/test_dos.py
@@ -46,15 +46,19 @@ def test_abacus_dos_run_species(self):
pdos_mode='species',
dos_edelta_ev = 0.01,
dos_sigma = 0.07,
- dos_scale = 0.01,
- dos_emin_ev = -20,
- dos_emax_ev = 20)
+ dos_emin_ev=-20,
+ dos_emax_ev=20)
+ print(outputs)
dos_fig_path = outputs['dos_fig_path']
+ dos_data_path = outputs['dos_data_path']
pdos_fig_path = outputs['pdos_fig_path']
+ pdos_data_path = outputs['pdos_data_path']
self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
self.assertIsInstance(pdos_fig_path, get_path_type())
+ self.assertIsInstance(pdos_data_path, get_path_type())
self.assertTrue(outputs['scf_normal_end'])
self.assertTrue(outputs['scf_converge'])
self.assertTrue(outputs['nscf_normal_end'])
@@ -76,10 +80,14 @@ def test_abacus_dos_run_species_shell(self):
pdos_mode='species+shell')
dos_fig_path = outputs['dos_fig_path']
+ dos_data_path = outputs['dos_data_path']
pdos_fig_path = outputs['pdos_fig_path']
+ pdos_data_path = outputs['pdos_data_path']
self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
self.assertIsInstance(pdos_fig_path, get_path_type())
+ self.assertIsInstance(pdos_data_path, get_path_type())
self.assertTrue(outputs['scf_normal_end'])
self.assertTrue(outputs['scf_converge'])
self.assertTrue(outputs['nscf_normal_end'])
@@ -101,15 +109,50 @@ def test_abacus_dos_run_species_orbital(self):
pdos_mode='species+orbital')
dos_fig_path = outputs['dos_fig_path']
+ dos_data_path = outputs['dos_data_path']
pdos_fig_path = outputs['pdos_fig_path']
+ pdos_data_path = outputs['pdos_data_path']
self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
self.assertIsInstance(pdos_fig_path, get_path_type())
+ self.assertIsInstance(pdos_data_path, get_path_type())
self.assertTrue(outputs['scf_normal_end'])
self.assertTrue(outputs['scf_converge'])
self.assertTrue(outputs['nscf_normal_end'])
self.assertAlmostEqual(outputs['scf_energy'], ref_results['scf_energy'])
+ def test_abacus_dos_run_atoms(self):
+ """
+ Test the abacus_dos_run function with PDOS plotting mode set to different species and shell.
+ """
+ test_func_name = inspect.currentframe().f_code.co_name
+ ref_results = load_test_ref_result(test_func_name)
+
+ test_work_dir = self.test_path / test_func_name
+ shutil.copytree(self.abacus_inputs_dir_nacl_prim, test_work_dir)
+ shutil.copy2(self.stru_dos_nacl_prim, test_work_dir / "STRU")
+ shutil.copy2(self.input_dos_nacl_prim, test_work_dir / "INPUT")
+
+ outputs = abacus_dos_run(test_work_dir,
+ pdos_mode='atoms',
+ pdos_atom_indices = [1, 2])
+
+ dos_fig_path = outputs['dos_fig_path']
+ dos_data_path = outputs['dos_data_path']
+ pdos_fig_path = outputs['pdos_fig_path']
+ pdos_data_path = outputs['pdos_data_path']
+
+ self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
+ self.assertIsInstance(pdos_fig_path, get_path_type())
+ self.assertIsInstance(pdos_data_path, get_path_type())
+ self.assertTrue(outputs['scf_normal_end'])
+ self.assertTrue(outputs['scf_converge'])
+ self.assertTrue(outputs['nscf_normal_end'])
+ self.assertAlmostEqual(outputs['scf_energy'], ref_results['scf_energy'])
+
+
def test_abacus_dos_run_species_nspin2(self):
"""
Test the abacus_dos_run function with nspin=2 case and PDOS plotting mode set to different species.
@@ -125,15 +168,18 @@ def test_abacus_dos_run_species_nspin2(self):
pdos_mode='species',
dos_edelta_ev = 0.01,
dos_sigma = 0.07,
- dos_scale = 0.01,
dos_emin_ev = -20,
dos_emax_ev = 20)
dos_fig_path = outputs['dos_fig_path']
+ dos_data_path = outputs['dos_data_path']
pdos_fig_path = outputs['pdos_fig_path']
+ pdos_data_path = outputs['pdos_data_path']
self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
self.assertIsInstance(pdos_fig_path, get_path_type())
+ self.assertIsInstance(pdos_data_path, get_path_type())
self.assertTrue(outputs['scf_normal_end'])
self.assertTrue(outputs['scf_converge'])
self.assertTrue(outputs['nscf_normal_end'])
@@ -154,10 +200,14 @@ def test_abacus_dos_run_species_shell_nspin2(self):
pdos_mode='species+shell')
dos_fig_path = outputs['dos_fig_path']
+ dos_data_path = outputs['dos_data_path']
pdos_fig_path = outputs['pdos_fig_path']
+ pdos_data_path = outputs['pdos_data_path']
self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
self.assertIsInstance(pdos_fig_path, get_path_type())
+ self.assertIsInstance(pdos_data_path, get_path_type())
self.assertTrue(outputs['scf_normal_end'])
self.assertTrue(outputs['scf_converge'])
self.assertTrue(outputs['nscf_normal_end'])
@@ -178,10 +228,47 @@ def test_abacus_dos_run_species_orbital_nspin2(self):
pdos_mode='species+orbital')
dos_fig_path = outputs['dos_fig_path']
+ dos_data_path = outputs['dos_data_path']
pdos_fig_path = outputs['pdos_fig_path']
+ pdos_data_path = outputs['pdos_data_path']
self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
self.assertIsInstance(pdos_fig_path, get_path_type())
+ self.assertIsInstance(pdos_data_path, get_path_type())
+ self.assertTrue(outputs['scf_normal_end'])
+ self.assertTrue(outputs['scf_converge'])
+ self.assertTrue(outputs['nscf_normal_end'])
+ self.assertAlmostEqual(outputs['scf_energy'], ref_results['scf_energy'])
+
+ def test_abacus_dos_run_atoms_nspin2(self):
+ """
+ Test the abacus_dos_run function with nspin=2 case and PDOS plotting mode set to different species.
+ """
+ test_func_name = inspect.currentframe().f_code.co_name
+ ref_results = load_test_ref_result(test_func_name)
+
+ test_work_dir = self.test_path / test_func_name
+ shutil.copytree(self.abacus_inputs_dir_fe_bcc_prim, test_work_dir)
+ shutil.copy2(self.stru_dos_fe_bcc_prim, test_work_dir / "STRU")
+
+ outputs = abacus_dos_run(test_work_dir,
+ pdos_mode='atoms',
+ pdos_atom_indices = [1],
+ dos_edelta_ev = 0.01,
+ dos_sigma = 0.07,
+ dos_emin_ev = -20,
+ dos_emax_ev = 20)
+
+ dos_fig_path = outputs['dos_fig_path']
+ dos_data_path = outputs['dos_data_path']
+ pdos_fig_path = outputs['pdos_fig_path']
+ pdos_data_path = outputs['pdos_data_path']
+
+ self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
+ self.assertIsInstance(pdos_fig_path, get_path_type())
+ self.assertIsInstance(pdos_data_path, get_path_type())
self.assertTrue(outputs['scf_normal_end'])
self.assertTrue(outputs['scf_converge'])
self.assertTrue(outputs['nscf_normal_end'])
@@ -201,10 +288,13 @@ def test_abacus_dos_run_pw_nspin1(self):
shutil.copy2(self.input_dos_pw_nacl_prim, test_work_dir / "INPUT")
outputs = abacus_dos_run(test_work_dir)
+ print(outputs)
dos_fig_path = outputs['dos_fig_path']
+ dos_data_path = outputs['dos_data_path']
self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
self.assertTrue(outputs['scf_normal_end'])
self.assertTrue(outputs['scf_converge'])
self.assertTrue(outputs['nscf_normal_end'])
@@ -226,8 +316,10 @@ def test_abacus_dos_run_pw_nspin2(self):
print(outputs)
dos_fig_path = outputs['dos_fig_path']
+ dos_data_path = outputs['dos_data_path']
self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
self.assertTrue(outputs['scf_normal_end'])
self.assertTrue(outputs['scf_converge'])
self.assertTrue(outputs['nscf_normal_end'])
diff --git a/tests/integrate_test/test_tool_wrapper.py b/tests/integrate_test/test_tool_wrapper.py
index c8f33cb..0039271 100644
--- a/tests/integrate_test/test_tool_wrapper.py
+++ b/tests/integrate_test/test_tool_wrapper.py
@@ -175,15 +175,20 @@ def test_run_abacus_calculation_dos(self):
fixed_axes=None,
pdos_mode='species+shell',
dos_edelta_ev=0.01,
- dos_sigma=0.07,
- dos_scale=0.01)
+ dos_sigma=0.07)
print(outputs)
dos_fig_path = outputs['dos_fig_path']
pdos_fig_path = outputs['pdos_fig_path']
+ dos_data_path = outputs['dos_data_path']
+ pdos_data_paths = outputs['pdos_data_paths']
self.assertIsInstance(dos_fig_path, get_path_type())
+ self.assertIsInstance(dos_data_path, get_path_type())
self.assertIsInstance(pdos_fig_path, get_path_type())
+ for pdos_data_path in pdos_data_paths:
+ self.assertIsInstance(pdos_data_path, get_path_type())
+
self.assertTrue(outputs['scf_normal_end'])
self.assertTrue(outputs['scf_converge'])
self.assertTrue(outputs['nscf_normal_end'])
diff --git a/tests/test_dos.py b/tests/test_dos.py
index 0d498f7..041ff71 100644
--- a/tests/test_dos.py
+++ b/tests/test_dos.py
@@ -1,42 +1,46 @@
import unittest
-import os, sys, glob
+import os, sys, glob, shutil
from pathlib import Path
+
os.environ["ABACUSAGENT_MODEL"] = "test"
-from abacusagent.modules.submodules.dos import plot_dos_pdos as mkplots
+from abacusagent.modules.submodules.dos import plot_write_dos_pdos as mkplots
+
class TestPlotDos(unittest.TestCase):
def setUp(self):
self.data_dir = Path(__file__).parent / "plot_dos"
-
def tearDown(self):
for pngfile in glob.glob(os.path.join(self.data_dir, "*.png")):
os.remove(pngfile)
-
+
def test_run_dos(self):
"""
Test the run_dos function with a valid input.
"""
# ignore the screen output
- sys.stdout = open(os.devnull, 'w')
+ sys.stdout = open(os.devnull, "w")
# Call the run_dos function
- results_test = mkplots(self.data_dir, self.data_dir, self.data_dir, dpi=20)
- results_ref = [Path(self.data_dir) / "DOS.png",
- Path(self.data_dir) / "PDOS.png"]
-
- self.assertListEqual([Path(p) for p in results_test], results_ref)
+ results_figs, results_datas = mkplots(self.data_dir, self.data_dir, "species", dos_emin_ev=-1, dos_emax_ev=1)
+
+ output_dir = Path(glob.glob("*plot_write_dos_pdos*")[0]).absolute()
+ results_figs_ref = [
+ output_dir / "DOS.png",
+ output_dir / "PDOS.png",
+ ]
+ results_datas_ref = [
+ output_dir / "DOS.dat",
+ output_dir / "PDOS.dat",
+ ]
+
+ self.assertListEqual([Path(p) for p in results_figs], results_figs_ref)
+ self.assertListEqual([Path(p) for p in results_datas], results_datas_ref)
if os.path.exists(self.data_dir / "metrics.json"):
os.remove(self.data_dir / "metrics.json")
-
-
-
-
-
-
-
-
+ for dir in glob.glob("*plot_write_dos_pdos*"):
+ shutil.rmtree(dir)