diff --git a/pyproject.toml b/pyproject.toml index 46659c0..d0c8fa4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ description = "ABACUS agent tools for connecting LLMs to first-principles calcul requires-python = ">=3.11" dependencies = [ - "abacustest>=0.4.51", + "abacustest>=0.4.55", "numpy<2.0", "pymatgen>=2025.5.28", "bohr-agent-sdk==0.1.121", diff --git a/src/abacusagent/modules/dos.py b/src/abacusagent/modules/dos.py index 0555cf5..1b3257d 100644 --- a/src/abacusagent/modules/dos.py +++ b/src/abacusagent/modules/dos.py @@ -1,18 +1,19 @@ from pathlib import Path -from typing import Dict, Any, List, Literal +from typing import Dict, Any, List, Literal, Optional, Tuple from abacusagent.init_mcp import mcp from abacusagent.modules.submodules.dos import abacus_dos_run as _abacus_dos_run +from abacusagent.modules.submodules.dos import plot_write_dos_pdos as _plot_write_dos_pdos + @mcp.tool() def abacus_dos_run( abacus_inputs_dir: Path, - pdos_mode: Literal['species', 'species+shell', 'species+orbital'] = 'species+shell', + pdos_mode: Literal['atoms', 'species', 'species+shell', 'species+orbital'] = 'species+shell', + pdos_atom_indices: Optional[List[int]] = None, dos_edelta_ev: float = 0.01, dos_sigma: float = 0.07, - dos_scale: float = 0.01, - dos_emin_ev: float = None, - dos_emax_ev: float = None, - dos_nche: int = None, + dos_emin_ev: float = -10.0, + dos_emax_ev: float = 10.0, ) -> Dict[str, Any]: """Run the DOS and PDOS calculation. @@ -24,27 +25,58 @@ def abacus_dos_run( Args: abacus_inputs_dir: Path to the ABACUS input files, which contains the INPUT, STRU, KPT, and pseudopotential or orbital files. pdos_mode: Mode of plotted PDOS file. + - "atoms": PDOS of a list of atoms will be plotted. - "species": Total PDOS of any species will be plotted in a picture. - "species+shell": PDOS for any shell (s, p, d, f, g,...) of any species will be plotted. PDOS of a shell of a species willbe plotted in a subplot. - "species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot. + pdos_atom_indices: A list of atom indices, only used if pdos_mode is "atoms". dos_edelta_ev: Step size in writing Density of States (DOS) in eV. dos_sigma: Width of the Gaussian factor when obtaining smeared Density of States (DOS) in eV. - dos_scale: Defines the energy range of DOS output as (emax-emin)*(1+dos_scale), centered at (emax+emin)/2. - This parameter will be used when dos_emin_ev and dos_emax_ev are not set. - dos_emin_ev: Minimal range for Density of States (DOS) in eV. - dos_emax_ev: Maximal range for Density of States (DOS) in eV. - dos_nche: The order of Chebyshev expansions when using Stochastic Density Functional Theory (SDFT) to calculate DOS. - + dos_emin_ev: Minimal range for Density of States (DOS) in eV. Default is -10.0. + dos_emax_ev: Maximal range for Density of States (DOS) in eV. Default is 10.0. + Returns: Dict[str, Any]: A dictionary containing: - dos_fig_path: Path to the plotted DOS. - pdos_fig_path: Path to the plotted PDOS. Only for LCAO basis. + - dos_data_path: Path to the data used in plotting DOS. + - pdos_data_paths: Path to the data used in plotting PDOS. Only for LCAO basis. - scf_work_path: Path to the work directory of SCF calculation. - scf_normal_end: If the SCF calculation ended normally. - scf_steps: Number of steps of SCF iteration. - scf_converge: If the SCF calculation converged. - scf_energy: The calculated energy of SCF calculation. - - nscf_work_path: Path to the work directory of NSCF calculation. - - nscf_normal_end: If the SCF calculation ended normally. + - nscf_work_path: Path to the work directory of NSCF calculation + """ + return _abacus_dos_run(abacus_inputs_dir, pdos_mode, pdos_atom_indices, dos_edelta_ev, dos_sigma, dos_emin_ev, dos_emax_ev) + +def plot_write_dos_pdos( + scf_job_path: Path, + nscf_job_path: Path, + mode: Literal[ + "species", "species+shell", "species+orbital", "atoms" + ] = "species+shell", + pdos_atom_indices: Optional[List[int]] = None, + dos_emin_ev: float = -10.0, + dos_emax_ev: float = 5.0, +) -> Tuple[List[str], List[str]]: + """ + Plot DOS, PDOS and write data used in plotting to files using SCF and NSCF job directories from abacus_dos_run. + + Args: + scf_job_path (Path): Path to the SCF job directory of the DOS calculation + nscf_job_path (Path): Path to the NSCF job directory of the DOS calculation + mode: Mode for plotting PDOS and write PDOS data. + - "atoms": PDOS of a list of atoms will be plotted. + - "species": Total PDOS of any species will be plotted in a picture. + - "species+shell": PDOS for any shell (s, p, d, f, g,...) of any species will be plotted. PDOS of a shell of a species willbe plotted in a subplot. + - "species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot. + pdos_atom_indices: A list of atom indices, only used if pdos_mode is "atoms". + pdos_atom_indices (List[int], optional): List of atom indices for atom-specific PDOS. Only valid for 'atoms' mode. + dos_emin_ev (float): Minimum energy for DOS and PDOS plots. + dos_emax_ev (float): Maximum energy for DOS and PDOS plots. + + Returns: + Tuple[List[str], List[str]]: Tuple containing list of plot file paths and data file paths. """ - return _abacus_dos_run(abacus_inputs_dir, pdos_mode, dos_edelta_ev, dos_sigma, dos_scale, dos_emin_ev, dos_emax_ev, dos_nche) + return _plot_write_dos_pdos(scf_job_path, nscf_job_path, mode, pdos_atom_indices, dos_emin_ev, dos_emax_ev) diff --git a/src/abacusagent/modules/submodules/dos.py b/src/abacusagent/modules/submodules/dos.py index 13c197e..d2f14db 100644 --- a/src/abacusagent/modules/submodules/dos.py +++ b/src/abacusagent/modules/submodules/dos.py @@ -1,81 +1,59 @@ import os -import re -import numpy as np -import matplotlib.pyplot as plt +import glob from abacustest.lib_prepare.abacus import ReadInput, WriteInput from abacustest.lib_collectdata.collectdata import RESULT from abacustest.lib_model.comm import check_abacus_inputs +from abacustest.lib_model.comm_dos import DOSData, PDOSData, l_map, orbital_names from pathlib import Path -from typing import Dict, Any, List, Literal - -from abacusagent.modules.util.comm import generate_work_path, link_abacusjob, run_abacus, has_chgfile, collect_metrics -from abacusagent.modules.util.chemical_elements import MAX_ANGULAR_MOMENTUM_OF_ELEMENTS - - -angular_momentum_map = ['s', 'p', 'd', 'f', 'g'] -color_map = { - 's': '#FF5733', - 'p': '#33FF57', - 'd': '#3357FF', - 'f': '#F033FF', - 'g': '#33FFF0' -} - -orbital_rep_map = { - 's': 's', - 'px': r'$p_x$', - 'py': r'$p_y$', - 'pz': r'$p_z$', - 'dz^2': r'$d_{z^2}$', - 'dxz': r'$d_{xz}$', - 'dyz': r'$d_{yz}$', - 'dxy': r'$d_{xy}$', - 'dx^2-y^2': r'$d_{x^2-y^2}$', - 'fz^3': r'$f_{z^3}$', - 'fxz^2': r'$f_{xz^2}$', - 'fyz^2': r'$f_{yz^2}$', - 'fzx^2-zy^2': r'$f_{zx^2-zy^2}$', - 'fxyz': r'$f_{xyz}$', - 'fx^3-3*xy^2': r'$f_{x^3-3xy^2}$', - 'f3yx^2-y^3': r'$f_{3yx^2-y^3}$' -} +from typing import Dict, Any, List, Literal, Optional, Tuple + +from abacusagent.modules.util.comm import ( + generate_work_path, + link_abacusjob, + run_abacus, + has_chgfile, +) + + def abacus_dos_run( abacus_inputs_dir: Path, - pdos_mode: Literal['species', 'species+shell', 'species+orbital'] = 'species+shell', + pdos_mode: Literal[ + "atoms", "species", "species+shell", "species+orbital" + ] = "species+shell", + pdos_atom_indices: Optional[List[int]] = None, dos_edelta_ev: float = 0.01, dos_sigma: float = 0.07, - dos_scale: float = 0.01, - dos_emin_ev: float = None, - dos_emax_ev: float = None, - dos_nche: int = None, + dos_emin_ev: float = -10.0, + dos_emax_ev: float = 10.0, ) -> Dict[str, Any]: """Run the DOS and PDOS calculation. - - This function will firstly run a SCF calculation with out_chg set to 1, + + This function will firstly run a SCF calculation with out_chg set to 1, then run a NSCF calculation with init_chg set to 'file' and out_dos set to 1 or 2. If the INPUT parameter "basis_type" is "PW", then out_dos will be set to 1, and only DOS will be calculated and plotted. If the INPUT parameter "basis_type" is "LCAO", then out_dos will be set to 2, and both DOS and PDOS will be calculated and plotted. - + Args: abacus_inputs_dir: Path to the ABACUS input files, which contains the INPUT, STRU, KPT, and pseudopotential or orbital files. pdos_mode: Mode of plotted PDOS file. + - "atoms": PDOS of a list of atoms will be plotted. - "species": Total PDOS of any species will be plotted in a picture. - "species+shell": PDOS for any shell (s, p, d, f, g,...) of any species will be plotted. PDOS of a shell of a species willbe plotted in a subplot. - - “species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot. + - "species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot. + pdos_atom_indices: A list of atom indices, only used if pdos_mode is "atoms". dos_edelta_ev: Step size in writing Density of States (DOS) in eV. - dos_sigma: Width of the Gaussian factor when obtaining smeared Density of States (DOS) in eV. - dos_scale: Defines the energy range of DOS output as (emax-emin)*(1+dos_scale), centered at (emax+emin)/2. - This parameter will be used when dos_emin_ev and dos_emax_ev are not set. - dos_emin_ev: Minimal range for Density of States (DOS) in eV. - dos_emax_ev: Maximal range for Density of States (DOS) in eV. - dos_nche: The order of Chebyshev expansions when using Stochastic Density Functional Theory (SDFT) to calculate DOS. - + dos_sigma: Width of the Gaussian factor when obtaining smeared Density of States (DOS) in eV. + dos_emin_ev: Minimal range for Density of States (DOS) in eV. Default is -10.0. + dos_emax_ev: Maximal range for Density of States (DOS) in eV. Default is 10.0. + Returns: Dict[str, Any]: A dictionary containing: - dos_fig_path: Path to the plotted DOS. - pdos_fig_path: Path to the plotted PDOS. Only for LCAO basis. + - dos_data_path: Path to the data used in plotting DOS. + - pdos_data_paths: Path to the data used in plotting PDOS. Only for LCAO basis. - scf_work_path: Path to the work directory of SCF calculation. - scf_normal_end: If the SCF calculation ended normally. - scf_steps: Number of steps of SCF iteration. @@ -88,55 +66,68 @@ def abacus_dos_run( is_valid, msg = check_abacus_inputs(abacus_inputs_dir) if not is_valid: raise RuntimeError(f"Invalid ABACUS input files: {msg}") - + input_file = os.path.join(abacus_inputs_dir, "INPUT") input_params = ReadInput(input_file) nspin = input_params.get("nspin", 1) if nspin in [4]: - raise ValueError("Currently DOS calculation can only be plotted using for nspin=1 and nspin=2") + raise ValueError( + "Currently DOS calculation can only be plotted using for nspin=1 and nspin=2" + ) + print("Performing SCF calculation...") metrics_scf = abacus_dos_run_scf(abacus_inputs_dir) - metrics_nscf = abacus_dos_run_nscf(metrics_scf["scf_work_path"], - dos_edelta_ev=dos_edelta_ev, - dos_sigma=dos_sigma, - dos_scale=dos_scale, - dos_emin_ev=dos_emin_ev, - dos_emax_ev=dos_emax_ev, - dos_nche=dos_nche) - - fig_paths = plot_dos_pdos(metrics_scf["scf_work_path"], - metrics_nscf["nscf_work_path"], - metrics_nscf["nscf_work_path"], - nspin, - pdos_mode) + + print("Performing NSCF calculation...") + metrics_nscf = abacus_dos_run_nscf( + metrics_scf["scf_work_path"], + dos_edelta_ev=dos_edelta_ev, + dos_sigma=dos_sigma, + ) + + fig_paths, dos_pdos_data_paths = plot_write_dos_pdos( + metrics_scf["scf_work_path"], + metrics_nscf["nscf_work_path"], + pdos_mode, + pdos_atom_indices, + dos_emin_ev, + dos_emax_ev, + ) return_dict = {"dos_fig_path": fig_paths[0]} + return_dict["dos_data_path"] = dos_pdos_data_paths[0] try: - return_dict['pdos_fig_path'] = fig_paths[1] + return_dict["pdos_fig_path"] = fig_paths[1] + return_dict["pdos_data_path"] = dos_pdos_data_paths[1] except: - pass # Do nothing if PDOS file is not plotted + pass # Do nothing if PDOS file is not plotted return_dict.update(metrics_scf) return_dict.update(metrics_nscf) return return_dict except Exception as e: + import traceback + + traceback.print_exc() return {"message": f"Calculating DOS and PDOS failed: {e}"} -def abacus_dos_run_scf(abacus_inputs_dir: Path, - force_run: bool = False) -> Dict[str, Any]: + +def abacus_dos_run_scf( + abacus_inputs_dir: Path, force_run: bool = False +) -> Dict[str, Any]: """ Run the SCF calculation to generate the charge density file. If the charge file already exists, it will skip the SCF calculation. - + Args: abacus_inputs_dir: Path to the ABACUS input files, which contains the INPUT, STRU, KPT, and pseudopotential or orbital files. force_run: If True, it will run the SCF calculation even if the charge file already exists. - + Returns: Dict[str, Any]: A dictionary containing the work path, normal end status, SCF steps, convergence status, and energies. """ - + input_param = ReadInput(os.path.join(abacus_inputs_dir, "INPUT")) # check if charge file has been generated if has_chgfile(abacus_inputs_dir) and not force_run: @@ -144,9 +135,7 @@ def abacus_dos_run_scf(abacus_inputs_dir: Path, work_path = abacus_inputs_dir else: work_path = generate_work_path() - link_abacusjob(src=abacus_inputs_dir, - dst=work_path, - copy_files=["INPUT"]) + link_abacusjob(src=abacus_inputs_dir, dst=work_path, copy_files=["INPUT"]) input_param = ReadInput(os.path.join(work_path, "INPUT")) input_param["calculation"] = "scf" @@ -156,432 +145,150 @@ def abacus_dos_run_scf(abacus_inputs_dir: Path, run_abacus(work_path) rs = RESULT(path=work_path, fmt="abacus") - + return { "scf_work_path": Path(work_path).absolute(), "scf_normal_end": rs["normal_end"], "scf_steps": rs["scf_steps"], "scf_converge": rs["converge"], - "scf_energy": rs["energy"] + "scf_energy": rs["energy"], } -def abacus_dos_run_nscf(abacus_inputs_dir: Path, - dos_edelta_ev: float = None, - dos_sigma: float = None, - dos_scale: float = None, - dos_emin_ev: float = None, - dos_emax_ev: float = None, - dos_nche: int = None,) -> Dict[str, Any]: - + +def abacus_dos_run_nscf( + abacus_inputs_dir: Path, dos_edelta_ev: float = None, dos_sigma: float = None +) -> Dict[str, Any]: work_path = generate_work_path() - link_abacusjob(src=abacus_inputs_dir, - dst=work_path, - copy_files=["INPUT"]) - + link_abacusjob( + src=abacus_inputs_dir, + dst=work_path, + copy_files=["INPUT", "KPT"] + + glob.glob(os.path.join(abacus_inputs_dir, "OUT.*")), + exclude=["*log", "*json"], + ) + input_param = ReadInput(os.path.join(work_path, "INPUT")) input_param["calculation"] = "nscf" input_param["init_chg"] = "file" + input_param["out_chg"] = -1 if input_param.get("basis_type", "pw") == "lcao": - input_param["out_dos"] = 2 # only for LCAO basis, and will output DOS and PDOS + input_param["out_dos"] = 2 # only for LCAO basis, and will output DOS and PDOS else: input_param["out_dos"] = 1 - + for dos_param, value in { "dos_edelta_ev": dos_edelta_ev, "dos_sigma": dos_sigma, - "dos_scale": dos_scale, - "dos_emin_ev": dos_emin_ev, - "dos_emax_ev": dos_emax_ev, - "dos_nche": dos_nche }.items(): if value is not None: input_param[dos_param] = value - - + WriteInput(input_param, os.path.join(work_path, "INPUT")) - + run_abacus(work_path) - + rs = RESULT(path=work_path, fmt="abacus") - + return { "nscf_work_path": Path(work_path).absolute(), - "nscf_normal_end": rs["normal_end"] + "nscf_normal_end": rs["normal_end"], } -def parse_pdos_file(file_path): - """Parse the PDOS file and extract energy values and orbital data.""" - with open(file_path, 'r') as f: - content = f.read() - - energy_match = re.search(r'(.*?)', content, re.DOTALL) - if not energy_match: - raise ValueError("Energy values not found in the file.") - - energy_text = energy_match.group(1) - energy_values = np.array([float(line.strip()) for line in energy_text.strip().split()]) - - orbital_pattern = re.compile(r'(.*?)', re.DOTALL) - orbitals = [] - - for match in orbital_pattern.finditer(content): - index, atom_index, species, l, m, z, orbital_content = match.groups() - - data_match = re.search(r'(.*?)', orbital_content, re.DOTALL) - if data_match: - data_text = data_match.group(1) - data_values = np.array([float(line.strip()) for line in data_text.strip().split()]) - - orbitals.append({ - 'index': int(index), - 'atom_index': int(atom_index), - 'species': species, - 'l': int(l), - 'm': int(m), - 'z': int(z), - 'data': data_values - }) - - return energy_values, orbitals - -def parse_log_file(file_path): - """Parse Fermi energy from log file and convert to eV.""" - ry_to_ev = 13.605698066 - fermi_energy = None - - with open(file_path, 'r') as f: - for line in f: - if "Fermi energy is" in line: - match = re.search(r'Fermi energy is\s*([\d.-]+)', line) - if match: - fermi_energy = float(match.group(1)) - - if fermi_energy is None: - raise ValueError("Fermi energy not found in log file") - - return fermi_energy * ry_to_ev - -def parse_basref_file(file_path): - """Parse basref file to create mapping for custom labels.""" - label_map = {} - - with open(file_path, 'r') as f: - for line in f: - line = line.strip() - if not line or line.startswith('#'): - continue - - parts = line.split() - if len(parts) >= 6: - # Add 1 to atom_index as per requirement - atom_index = int(parts[0]) + 1 - species = parts[1] - l = int(parts[2]) - m = int(parts[3]) - z = int(parts[4]) - symbol = parts[5] - - key = (atom_index, species, l, m, z) - label_map[key] = f'{species}{atom_index}({symbol})' - - return label_map +def plot_write_dos_pdos( + scf_job_path: Path, + nscf_job_path: Path, + mode: Literal[ + "species", "species+shell", "species+orbital", "atoms" + ] = "species+shell", + pdos_atom_indices: Optional[List[int]] = None, + dos_emin_ev: float = -10.0, + dos_emax_ev: float = 5.0, +) -> Tuple[List[str], List[str]]: + """ + Plot DOS, PDOS and write data used in plotting to files using SCF and NSCF job directories from abacus_dos_run. -def plot_pdos(energy_values, orbitals, fermi_level, label_map, output_dir, nspin, mode, dpi=300): - """Plot PDOS data separated by atom/species with custom labels.""" - # Create output directory if it doesn't exist - os.makedirs(output_dir, exist_ok=True) - - # Shift energy values by Fermi level - shifted_energy = energy_values - fermi_level - - # Group orbitals by atom_index and species - atom_species_groups = {} - for orbital in orbitals: - key = (orbital['atom_index'], orbital['species']) - if key not in atom_species_groups: - atom_species_groups[key] = [] - atom_species_groups[key].append(orbital) - - if mode == "species": - pdos_pic_file = plot_pdos_species(shifted_energy, orbitals, output_dir, nspin, dpi) - elif mode == "species+shell": - pdos_pic_file = plot_pdos_species_shell(shifted_energy, orbitals, output_dir, nspin, dpi) - elif mode == "species+orbital": - pdos_pic_file = plot_pdos_species_orbital(shifted_energy, orbitals, output_dir, nspin, label_map, dpi) - else: - raise ValueError(f"Not allowed mode {mode}") - - return pdos_pic_file - -def plot_pdos_species(shifted_energy, orbitals, output_dir, nspin, dpi): - species = {} - for orbital in orbitals: - species_one = orbital['species'] - if species_one not in species.keys(): - species[species_one] = orbital['data'] - else: - species[species_one] += orbital['data'] - - num_species = len(species) - plt.plot(figsize=(10, 6)) - for species_name, pdos_data in species.items(): - if nspin == 1: - plt.plot(shifted_energy, pdos_data, label=species_name, linewidth=1.0) - elif nspin == 2: - plt.plot(shifted_energy, pdos_data[::2], label=f'{species_name} ' + r'$\uparrow$', linestyle='-', linewidth=1.0) - plt.plot(shifted_energy, -pdos_data[1::2], label=f'{species_name} ' + r'$\downarrow$', linestyle='--', linewidth=1.0) - - plt.axvline(x=0, color='black', linestyle=':', linewidth=1.0) - plt.xlabel('Energy (eV)', fontsize=10) - plt.ylabel(r"States ($eV^{-1}$)", fontsize=10) - plt.xlim(max(min(shifted_energy), -20), min(20, max(shifted_energy))) - if nspin == 1: - plt.ylim(bottom=0) - plt.legend(fontsize=8, ncol=nspin) - plt.grid(alpha=0.3) - plt.title('Projected density of States of different species') - - pdos_pic_file = os.path.join(output_dir, 'PDOS.png') - plt.savefig(pdos_pic_file, dpi=dpi) - plt.close() - - return Path(pdos_pic_file).absolute() - -def plot_pdos_species_shell(shifted_energy, orbitals, output_dir, nspin, dpi): - species_shells = {} - for orbital in orbitals: - species = orbital['species'] - if species not in species_shells.keys(): - species_shells[species] = {} # Initialize species kind - - angular_momentum = angular_momentum_map[orbital['l']] - # The orbital with higher angular momentum than in realistic atoms will be ignored. - if angular_momentum_map.index(angular_momentum) <= angular_momentum_map.index(MAX_ANGULAR_MOMENTUM_OF_ELEMENTS[orbital['species']]): - if angular_momentum not in species_shells[species].keys(): - species_shells[species][angular_momentum] = orbital['data'] # Initialize DOS for angular momentum of a species - else: - species_shells[species][angular_momentum] += orbital['data'] # Add DOS of a angular momentum of a species - - # Plot PDOS for each species and each shell - num_species = len(species_shells) - fig, axes = plt.subplots(nrows=num_species, ncols=1, figsize=(8, 4*num_species)) - if num_species == 1: - axes = [axes] - - for species_idx, (species, pdos_data_dict) in enumerate(species_shells.items()): - ax = axes[species_idx] - - for l, pdos_data in pdos_data_dict.items(): - if nspin == 1: - ax.plot(shifted_energy, pdos_data, color=color_map[l], label=f'{species}-{l}', linewidth=1.0) - elif nspin == 2: - ax.plot(shifted_energy, pdos_data[::2], color=color_map[l], label=f'{species}-{l}' + r' $\uparrow$', linestyle='-', linewidth=1.0) - ax.plot(shifted_energy, -pdos_data[1::2], color=color_map[l], label=f'{species}-{l}' + r' $\downarrow$', linestyle='--', linewidth=1.0) - - ax.axvline(x=0, color='black', linestyle=':', linewidth=1.0) - ax.set_title(f'PDOS for {species}', fontsize=12, pad=10) - ax.set_ylabel(r"States ($eV^{-1}$)", fontsize=10) - ax.set_xlim(max(min(shifted_energy), -20), min(20, max(shifted_energy))) - #if nspin == 1: - # ax.set_ylim(bottom=0) - ax.legend(fontsize=8, ncol=nspin) - ax.grid(alpha=0.3) - - #ax.set_ylim(bottom=0) - - axes[-1].set_xlabel('Energy (eV)', fontsize=10) - - plt.tight_layout() - pdos_pic_file = os.path.join(output_dir, 'PDOS.png') - plt.savefig(pdos_pic_file, dpi=dpi, bbox_inches='tight') - plt.close() - - return Path(pdos_pic_file).absolute() - -def plot_pdos_species_orbital(shifted_energy, orbitals, output_dir, nspin, label_map, dpi): - - plt.rcParams["text.usetex"] = False - plt.rcParams["axes.prop_cycle"] = plt.cycler("color", plt.cm.tab20.colors) - - orbital_label = {} - for (atom_index, species, l, m, z), full_label in label_map.items(): - if species not in orbital_label.keys(): - orbital_label[species] = {} - if str(l) not in orbital_label[species].keys(): - orbital_label[species][str(l)] = {} - if str(m) not in orbital_label[species][str(l)].keys(): - orbital_name = full_label.split('(')[1].split(')')[0] - if orbital_name in orbital_rep_map.keys(): - orbital_label[species][str(l)][str(m)] = orbital_rep_map[orbital_name] - else: - orbital_label[species][str(l)][str(m)] = orbital_name - else: - pass - - species_orbitals = {} - for orbital in orbitals: - species = orbital['species'] - if species not in species_orbitals.keys(): - species_orbitals[species] = {} - - angular_momentum = angular_momentum_map[orbital['l']] - # The orbital with higher angular momentum than in realistic atoms will be ignored. - if angular_momentum_map.index(angular_momentum) <= angular_momentum_map.index(MAX_ANGULAR_MOMENTUM_OF_ELEMENTS[orbital['species']]): - if angular_momentum not in species_orbitals[species].keys(): - species_orbitals[species][angular_momentum] = {} - - mag_quantum_num = orbital['m'] - if mag_quantum_num not in species_orbitals[species][angular_momentum].keys(): - species_orbitals[species][angular_momentum][mag_quantum_num] = orbital['data'] - else: - species_orbitals[species][angular_momentum][mag_quantum_num] += orbital['data'] - - total_subplots = 0 - for species, species_pdos in species_orbitals.items(): - total_subplots += len(species_pdos) - fig, axes = plt.subplots(nrows=total_subplots, ncols=1, figsize=(8, 4*total_subplots)) - - subplot_count = 0 - for species, species_pdos in species_orbitals.items(): - for angular_momentum, species_shell_pdos in species_pdos.items(): - for m, species_orbital_pdos in species_shell_pdos.items(): - ax = axes[subplot_count] - orbital_name = orbital_label[species][str(angular_momentum_map.index(angular_momentum))][str(m)] - if nspin == 1: - ax.plot(shifted_energy, species_orbital_pdos, label=f'{orbital_name}', linewidth=1.0) - elif nspin == 2: - ax.plot(shifted_energy, species_orbital_pdos[::2], label=f'{orbital_name} '+r'$\uparrow$', linestyle='-', linewidth=1.0) - ax.plot(shifted_energy, -species_orbital_pdos[1::2], label=f'{orbital_name} '+r'$\downarrow$', linestyle='--', linewidth=1.0) - - ax.axvline(x=0, color='black', linestyle=':', linewidth=1.0) - ax.set_title(f'PDOS for {species}-{angular_momentum}', fontsize=12, pad=10) - ax.set_xlim(max(min(shifted_energy), -20), min(20, max(shifted_energy))) - if nspin == 1: - ax.set_ylim(bottom=0) - ax.set_ylabel(r"States ($eV^{-1}$)", fontsize=10) - ax.legend(fontsize=8, ncol=nspin) - ax.grid(alpha=0.3) - - subplot_count += 1 - - axes[-1].set_xlabel('Energy (eV)', fontsize=10) - - plt.tight_layout() - pdos_pic_file = os.path.join(output_dir, 'PDOS.png') - plt.savefig(pdos_pic_file, dpi=dpi, bbox_inches='tight') - plt.close() - - return Path(pdos_pic_file).absolute() - -def plot_dos(file_path: List[Path], - fermi_level: float, - output_file: str = 'DOS.png', - nspin: Literal[1, 2] = 1, - dpi: int=300): - """Plot total DOS from DOS1_smearing.dat and DOS2_smearing (if nspin=2) file.""" - # Read first two columns from file - data = np.loadtxt(file_path[0], usecols=(0, 1)) - energy = data[:, 0] - fermi_level # Shift by Fermi level - dos = data[:, 1] - if nspin == 2: - data = np.loadtxt(file_path[1], usecols=(0, 1)) - dos_dn = data[:, 1] - - # Determine energy limits based on data within x range - x_min, x_max = max(min(energy), -20), min(20, max(energy)) - - # Create plot - plt.figure(figsize=(8, 6)) - if nspin == 1: - plt.plot(energy, dos, linestyle='-') - elif nspin == 2: - plt.plot(energy, dos, linestyle='-', label='spin up') - plt.plot(energy, -dos_dn, linestyle='--', label='spin down') - plt.axvline(x=0, color='k', linestyle='--', alpha=0.5) - plt.xlabel('Energy (eV)') - plt.ylabel(r'States ($eV^{-1}$)') - plt.title('Density of States') - plt.grid(True, alpha=0.3) - plt.xlim(x_min, x_max) - #plt.ylim(y_min, y_max) - #plt.legend() - - # Save plot - os.makedirs(os.path.dirname(output_file), exist_ok=True) - plt.savefig(output_file, dpi=dpi, bbox_inches='tight') - plt.close() - - return Path(output_file).absolute() - -def plot_dos_pdos(scf_job_path: Path, - nscf_job_path: Path, - output_dir: Path, - nspin: Literal[1, 2] = 1, - mode: Literal['species', 'species+shell', 'species+orbital'] = 'species+shell', - dpi=300) -> List[str]: - """Plot DOS and PDOS from the NSCF job path. - Args: - nscf_job_path (Path): Path to the NSCF job directory containing the OUT.* files. - output_dir (Path): Directory where the output plots will be saved. - dpi (int): Dots per inch for the saved plots. - - Returns: - List[str]: List of paths to the generated plot files. + scf_job_path (Path): Path to the SCF job directory of the DOS calculation + nscf_job_path (Path): Path to the NSCF job directory of the DOS calculation + mode: Mode for plotting PDOS and write PDOS data. + - "atoms": PDOS of a list of atoms will be plotted. + - "species": Total PDOS of any species will be plotted in a picture. + - "species+shell": PDOS for any shell (s, p, d, f, g,...) of any species will be plotted. PDOS of a shell of a species willbe plotted in a subplot. + - "species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot. + pdos_atom_indices: A list of atom indices, only used if pdos_mode is "atoms". + pdos_atom_indices (List[int], optional): List of atom indices for atom-specific PDOS. Only valid for 'atoms' mode. + dos_emin_ev (float): Minimum energy for DOS and PDOS plots. + dos_emax_ev (float): Maximum energy for DOS and PDOS plots. """ + work_path = generate_work_path() - """ input_param = ReadInput(os.path.join(nscf_job_path, "INPUT")) - input_dir = os.path.join(nscf_job_path, "OUT." + input_param.get("suffix","ABACUS")) - basis_type = input_param.get('basis_type', 'pw') - - # Construct file paths based on input directory - pdos_file = os.path.join(input_dir, "PDOS") - log_file = os.path.join(input_dir, "running_nscf.log") - basref_file = os.path.join(input_dir, "Orbital") - dos_file = [os.path.join(input_dir, "DOS1_smearing.dat")] - dos_output = os.path.join(output_dir, "DOS.png") - if nspin == 2: - dos_file += [os.path.join(input_dir, "DOS2_smearing.dat")] - - # Validate input files exist - for file_path in [log_file, dos_file[0]]: - if not os.path.exists(file_path): - print(f"Error: File not found - {file_path}") - raise FileNotFoundError(f"Required file not found: {file_path}") - if nspin == 2: - if not os.path.exists(dos_file[1]): - print(f"Error: File not found - {dos_file[1]}") - raise FileNotFoundError(f"Required file not found: {dos_file[1]}") + basis_type = input_param.get("basis_type", "pw") + results = RESULT(fmt="abacus", path=scf_job_path) + efermi = results['efermi'] + + # Construct file paths + dos_plot_file = os.path.join(work_path, "DOS.png") + dos_data_file = os.path.join(work_path, "DOS.dat") + + dosdata = DOSData.ReadFromAbacusJob(str(nscf_job_path), efermi) + dosdata.plot_dos( + dos_emin_ev, + dos_emax_ev, + "Density of States", + dos_plot_file, + ) + dosdata.write_dos(dos_data_file) + + all_plot_files = [Path(dos_plot_file).absolute()] + dos_pdos_data_files = [Path(dos_data_file).absolute()] - fermi_level = collect_metrics(scf_job_path, ['efermi'])['efermi'] - - # Plot DOS and get file path - dos_plot_file = plot_dos(dos_file, fermi_level, dos_output, nspin, dpi) - all_plot_files = [dos_plot_file] - print("DOS file plotted") - # Plot PDOS (only for LCAO basis - PW basis doesn't support PDOS in ABACUS LTSv3.10) - if os.path.exists(pdos_file) and os.path.exists(basref_file): - if basis_type != 'pw': - label_map = parse_basref_file(basref_file) - energy_values, orbitals = parse_pdos_file(pdos_file) - pdos_plot_file = plot_pdos(energy_values, orbitals, fermi_level, label_map, output_dir, nspin, mode, dpi) - + # Plot PDOS using PDOSData class (only for LCAO basis) + if basis_type != "pw": + try: + # Load PDOS data using PDOSData class + pdos_data = PDOSData.ReadFromAbacusJob(str(nscf_job_path), efermi) + pdos_plot_file = Path(os.path.join(work_path, "PDOS.png")).absolute() + pdos_data_file = Path(os.path.join(work_path, "PDOS.dat")).absolute() + + # Plot PDOS based on mode + if mode == "species": + pdos_data.plot_species_pdos(dos_emin_ev, dos_emax_ev, pdos_plot_file) + pdos_data.write_species_pdos(pdos_data_file) + elif mode == "species+shell": + pdos_data.plot_species_shell_pdos(dos_emin_ev, dos_emax_ev, pdos_plot_file) + pdos_data.write_species_shell_pdos(pdos_data_file) + elif mode == "species+orbital": + pdos_data.plot_species_orbital_pdos(dos_emin_ev, dos_emax_ev, pdos_plot_file) + pdos_data.write_species_orbital_pdos(pdos_data_file) + elif mode == "atoms": + if pdos_atom_indices is None or len(pdos_atom_indices) == 0: + raise ValueError( + "For 'atoms' mode, pdos_atom_indices must be provided" + ) + pdos_data.plot_atoms_pdos(pdos_atom_indices, dos_emin_ev, dos_emax_ev, pdos_plot_file) + pdos_data.write_atoms_pdos(pdos_atom_indices, pdos_data_file) + else: + raise ValueError(f"Unsupported mode: {mode}") + # Combine file paths into a single list - all_plot_files.append(pdos_plot_file) - else: - print(f"Warning: PDOS calculation not supported for PW basis type, skipping PDOS plotting") - elif os.path.exists(pdos_file) and not os.path.exists(basref_file): - print(f"Warning: PDOS file exists but Orbital file not found - {basref_file}, skipping PDOS plotting") - elif not os.path.exists(pdos_file) and os.path.exists(basref_file): - print(f"Warning: Orbital file exists but PDOS file not found - {pdos_file}, skipping PDOS plotting") + all_plot_files.append(Path(pdos_plot_file).absolute()) + dos_pdos_data_files.append(Path(pdos_data_file).absolute()) + + except Exception as e: + import traceback + traceback.print_exc() + print(f"Warning: Failed to plot PDOS: {e}") + print("Skipping PDOS plotting") else: - print("Warning: Both PDOS and Orbital files not found, skipping PDOS plotting") + print( + f"Warning: PDOS calculation not supported for PW basis type, skipping PDOS plotting" + ) print("Plots generated:") for file in all_plot_files: print(f"- {file}") - - return all_plot_files + + return all_plot_files, dos_pdos_data_files diff --git a/src/abacusagent/modules/tool_wrapper.py b/src/abacusagent/modules/tool_wrapper.py index 3195121..b9adbf4 100644 --- a/src/abacusagent/modules/tool_wrapper.py +++ b/src/abacusagent/modules/tool_wrapper.py @@ -382,13 +382,12 @@ def abacus_dos_run( relax_precision: Literal['low', 'medium', 'high'] = 'medium', relax_method: Literal["cg", "bfgs", "bfgs_trad", "cg_bfgs", "sd", "fire"] = "cg", fixed_axes: Literal["None", "volume", "shape", "a", "b", "c", "ab", "ac", "bc"] = None, - pdos_mode: Literal['species', 'species+shell', 'species+orbital'] = 'species+shell', + pdos_mode: Literal['atoms', 'species', 'species+shell', 'species+orbital'] = 'species+shell', + pdos_atom_indices: Optional[List[int]] = None, dos_edelta_ev: float = 0.01, dos_sigma: float = 0.07, - dos_scale: float = 0.01, - dos_emin_ev: float = None, - dos_emax_ev: float = None, - dos_nche: int = None, + dos_emin_ev: float = -10.0, + dos_emax_ev: float = 10.0, ) -> Dict[str, Any]: """ Run the DOS and PDOS calculation. @@ -432,21 +431,22 @@ def abacus_dos_run( - ac: fix both a and c axes - bc: fix both b and c axes pdos_mode: Mode of plotted PDOS file. + - "atoms": PDOS of a list of atoms will be plotted. - "species": Total PDOS of any species will be plotted in a picture. - "species+shell": PDOS for any shell (s, p, d, f, g,...) of any species will be plotted. PDOS of a shell of a species willbe plotted in a subplot. - - “species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot. + - "species+orbital": Orbital-resolved PDOS will be plotted. PDOS of orbitals in the same shell of a species will be plotted in a subplot. + pdos_atom_indices: A list of atom indices, only used if pdos_mode is "atoms". dos_edelta_ev: Step size in writing Density of States (DOS) in eV. dos_sigma: Width of the Gaussian factor when obtaining smeared Density of States (DOS) in eV. - dos_scale: Defines the energy range of DOS output as (emax-emin)*(1+dos_scale), centered at (emax+emin)/2. - This parameter will be used when dos_emin_ev and dos_emax_ev are not set. - dos_emin_ev: Minimal range for Density of States (DOS) in eV. - dos_emax_ev: Maximal range for Density of States (DOS) in eV. - dos_nche: The order of Chebyshev expansions when using Stochastic Density Functional Theory (SDFT) to calculate DOS. - + dos_emin_ev: Minimal range for Density of States (DOS) in eV. Default is -10.0. + dos_emax_ev: Maximal range for Density of States (DOS) in eV. Default is 10.0. + Returns: Dict[str, Any]: A dictionary containing: - dos_fig_path: Path to the plotted DOS. - pdos_fig_path: Path to the plotted PDOS. Only for LCAO basis. + - dos_data_path: Path to the data used in plotting DOS. + - pdos_data_paths: Path to the data used in plotting PDOS. Only for LCAO basis. - scf_normal_end: If the SCF calculation ended normally. - scf_converge: If the SCF calculation converged. - scf_energy: The calculated energy of SCF calculation. @@ -474,15 +474,16 @@ def abacus_dos_run( dos_results = _abacus_dos_run(abacus_inputs_dir, pdos_mode, + pdos_atom_indices, dos_edelta_ev, dos_sigma, - dos_scale, dos_emin_ev, - dos_emax_ev, - dos_nche) + dos_emax_ev) return {'dos_fig_path': dos_results.get('dos_fig_path', None), 'pdos_fig_path': dos_results.get('pdos_fig_path', None), + 'dos_data_path': dos_results.get('dos_data_path', None), + 'pdos_data_paths': dos_results.get('pdos_data_paths', None), 'scf_normal_end': dos_results.get('scf_normal_end', None), 'scf_converge': dos_results.get('scf_converge', None), 'scf_energy': dos_results.get('scf_energy', None), diff --git a/src/abacusagent/modules/util/comm.py b/src/abacusagent/modules/util/comm.py index 5645f7b..9935b94 100644 --- a/src/abacusagent/modules/util/comm.py +++ b/src/abacusagent/modules/util/comm.py @@ -7,6 +7,7 @@ import json import traceback import uuid +import shutil import glob from abacustest.lib_prepare.abacus import ReadInput @@ -233,7 +234,7 @@ def link_abacusjob(src: str, exclude (Optional[List[str]]): List of files to exclude. If None, no files are excluded. copy_files (List[str]): List of files to copy from src to dst. Default is ["INPUT", "STRU", "KPT"]. overwrite (bool): If True, existing files in the destination will be overwritten. Default is True. - exclude_directories (bool): If True, directories will be excluded from linking. Default is False. + exclude_directories (bool): If True, directories will be excluded from linking or copying. Default is False. Notes: - If somes files are included in both include and exclude, the file will be excluded. @@ -250,42 +251,45 @@ def link_abacusjob(src: str, if include is None: include = ["*"] - include_files = [] + include_paths = [] for pattern in include: - include_files.extend(src.glob(pattern)) + include_paths.extend(src.glob(pattern)) if exclude is None: exclude = [] - exclude_files = [] + exclude_paths = [] for pattern in exclude: - exclude_files.extend(src.glob(pattern)) + exclude_paths.extend(src.glob(pattern)) os.makedirs(dst, exist_ok=True) # Remove excluded files from included files - include_files = [f for f in include_files if f not in exclude_files] - if not include_files: + include_paths = [f for f in include_paths if f not in exclude_paths] + if not include_paths: traceback.print_stack() print("No files to link after applying include and exclude patterns.\n", f"Include patterns: {include}, Exclude patterns: {exclude}, Source: {src}, Destination: {dst}\n", f"Files in source: {list(src.glob('*'))}" ) else: - for file in include_files: - if file == dst: + for path in include_paths: + if path == dst: continue - if exclude_directories and os.path.isdir(file): + if exclude_directories and path.is_dir(): continue - - dst_file = dst / file.name - if dst_file.exists(): + + dst_path = dst / path.name + if dst_path.exists(): if overwrite: - dst_file.unlink() + dst_path.unlink() else: continue - if str(file.name) in copy_files: - os.system(f"cp {file} {dst_file}") + if str(path.name) in copy_files or str(path) in copy_files: + if os.path.isfile(path): + shutil.copy(path, dst_path) + else: + shutil.copytree(path, dst_path) else: - os.symlink(file, dst_file) + os.symlink(path, dst_path) def generate_work_path(create: bool = True) -> str: """ diff --git a/tests/integrate_test/data/ref_results.json b/tests/integrate_test/data/ref_results.json index bd9879f..a7a5ec0 100644 --- a/tests/integrate_test/data/ref_results.json +++ b/tests/integrate_test/data/ref_results.json @@ -122,6 +122,12 @@ "scf_energy": -1688.5908919 } }, + "test_abacus_dos_run_atoms": + { + "result": { + "scf_energy": -1688.5908919 + } + }, "test_abacus_dos_run_species_nspin2": { "result": { @@ -140,6 +146,12 @@ "scf_energy": -3220.427369911483 } }, + "test_abacus_dos_run_atoms_nspin2": + { + "result": { + "scf_energy": -3220.427369911483 + } + }, "test_abacus_dos_run_pw_nspin1": { "result": { diff --git a/tests/integrate_test/test_dos.py b/tests/integrate_test/test_dos.py index 3731f08..e7b4902 100644 --- a/tests/integrate_test/test_dos.py +++ b/tests/integrate_test/test_dos.py @@ -46,15 +46,19 @@ def test_abacus_dos_run_species(self): pdos_mode='species', dos_edelta_ev = 0.01, dos_sigma = 0.07, - dos_scale = 0.01, - dos_emin_ev = -20, - dos_emax_ev = 20) + dos_emin_ev=-20, + dos_emax_ev=20) + print(outputs) dos_fig_path = outputs['dos_fig_path'] + dos_data_path = outputs['dos_data_path'] pdos_fig_path = outputs['pdos_fig_path'] + pdos_data_path = outputs['pdos_data_path'] self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) self.assertIsInstance(pdos_fig_path, get_path_type()) + self.assertIsInstance(pdos_data_path, get_path_type()) self.assertTrue(outputs['scf_normal_end']) self.assertTrue(outputs['scf_converge']) self.assertTrue(outputs['nscf_normal_end']) @@ -76,10 +80,14 @@ def test_abacus_dos_run_species_shell(self): pdos_mode='species+shell') dos_fig_path = outputs['dos_fig_path'] + dos_data_path = outputs['dos_data_path'] pdos_fig_path = outputs['pdos_fig_path'] + pdos_data_path = outputs['pdos_data_path'] self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) self.assertIsInstance(pdos_fig_path, get_path_type()) + self.assertIsInstance(pdos_data_path, get_path_type()) self.assertTrue(outputs['scf_normal_end']) self.assertTrue(outputs['scf_converge']) self.assertTrue(outputs['nscf_normal_end']) @@ -101,15 +109,50 @@ def test_abacus_dos_run_species_orbital(self): pdos_mode='species+orbital') dos_fig_path = outputs['dos_fig_path'] + dos_data_path = outputs['dos_data_path'] pdos_fig_path = outputs['pdos_fig_path'] + pdos_data_path = outputs['pdos_data_path'] self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) self.assertIsInstance(pdos_fig_path, get_path_type()) + self.assertIsInstance(pdos_data_path, get_path_type()) self.assertTrue(outputs['scf_normal_end']) self.assertTrue(outputs['scf_converge']) self.assertTrue(outputs['nscf_normal_end']) self.assertAlmostEqual(outputs['scf_energy'], ref_results['scf_energy']) + def test_abacus_dos_run_atoms(self): + """ + Test the abacus_dos_run function with PDOS plotting mode set to different species and shell. + """ + test_func_name = inspect.currentframe().f_code.co_name + ref_results = load_test_ref_result(test_func_name) + + test_work_dir = self.test_path / test_func_name + shutil.copytree(self.abacus_inputs_dir_nacl_prim, test_work_dir) + shutil.copy2(self.stru_dos_nacl_prim, test_work_dir / "STRU") + shutil.copy2(self.input_dos_nacl_prim, test_work_dir / "INPUT") + + outputs = abacus_dos_run(test_work_dir, + pdos_mode='atoms', + pdos_atom_indices = [1, 2]) + + dos_fig_path = outputs['dos_fig_path'] + dos_data_path = outputs['dos_data_path'] + pdos_fig_path = outputs['pdos_fig_path'] + pdos_data_path = outputs['pdos_data_path'] + + self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) + self.assertIsInstance(pdos_fig_path, get_path_type()) + self.assertIsInstance(pdos_data_path, get_path_type()) + self.assertTrue(outputs['scf_normal_end']) + self.assertTrue(outputs['scf_converge']) + self.assertTrue(outputs['nscf_normal_end']) + self.assertAlmostEqual(outputs['scf_energy'], ref_results['scf_energy']) + + def test_abacus_dos_run_species_nspin2(self): """ Test the abacus_dos_run function with nspin=2 case and PDOS plotting mode set to different species. @@ -125,15 +168,18 @@ def test_abacus_dos_run_species_nspin2(self): pdos_mode='species', dos_edelta_ev = 0.01, dos_sigma = 0.07, - dos_scale = 0.01, dos_emin_ev = -20, dos_emax_ev = 20) dos_fig_path = outputs['dos_fig_path'] + dos_data_path = outputs['dos_data_path'] pdos_fig_path = outputs['pdos_fig_path'] + pdos_data_path = outputs['pdos_data_path'] self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) self.assertIsInstance(pdos_fig_path, get_path_type()) + self.assertIsInstance(pdos_data_path, get_path_type()) self.assertTrue(outputs['scf_normal_end']) self.assertTrue(outputs['scf_converge']) self.assertTrue(outputs['nscf_normal_end']) @@ -154,10 +200,14 @@ def test_abacus_dos_run_species_shell_nspin2(self): pdos_mode='species+shell') dos_fig_path = outputs['dos_fig_path'] + dos_data_path = outputs['dos_data_path'] pdos_fig_path = outputs['pdos_fig_path'] + pdos_data_path = outputs['pdos_data_path'] self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) self.assertIsInstance(pdos_fig_path, get_path_type()) + self.assertIsInstance(pdos_data_path, get_path_type()) self.assertTrue(outputs['scf_normal_end']) self.assertTrue(outputs['scf_converge']) self.assertTrue(outputs['nscf_normal_end']) @@ -178,10 +228,47 @@ def test_abacus_dos_run_species_orbital_nspin2(self): pdos_mode='species+orbital') dos_fig_path = outputs['dos_fig_path'] + dos_data_path = outputs['dos_data_path'] pdos_fig_path = outputs['pdos_fig_path'] + pdos_data_path = outputs['pdos_data_path'] self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) self.assertIsInstance(pdos_fig_path, get_path_type()) + self.assertIsInstance(pdos_data_path, get_path_type()) + self.assertTrue(outputs['scf_normal_end']) + self.assertTrue(outputs['scf_converge']) + self.assertTrue(outputs['nscf_normal_end']) + self.assertAlmostEqual(outputs['scf_energy'], ref_results['scf_energy']) + + def test_abacus_dos_run_atoms_nspin2(self): + """ + Test the abacus_dos_run function with nspin=2 case and PDOS plotting mode set to different species. + """ + test_func_name = inspect.currentframe().f_code.co_name + ref_results = load_test_ref_result(test_func_name) + + test_work_dir = self.test_path / test_func_name + shutil.copytree(self.abacus_inputs_dir_fe_bcc_prim, test_work_dir) + shutil.copy2(self.stru_dos_fe_bcc_prim, test_work_dir / "STRU") + + outputs = abacus_dos_run(test_work_dir, + pdos_mode='atoms', + pdos_atom_indices = [1], + dos_edelta_ev = 0.01, + dos_sigma = 0.07, + dos_emin_ev = -20, + dos_emax_ev = 20) + + dos_fig_path = outputs['dos_fig_path'] + dos_data_path = outputs['dos_data_path'] + pdos_fig_path = outputs['pdos_fig_path'] + pdos_data_path = outputs['pdos_data_path'] + + self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) + self.assertIsInstance(pdos_fig_path, get_path_type()) + self.assertIsInstance(pdos_data_path, get_path_type()) self.assertTrue(outputs['scf_normal_end']) self.assertTrue(outputs['scf_converge']) self.assertTrue(outputs['nscf_normal_end']) @@ -201,10 +288,13 @@ def test_abacus_dos_run_pw_nspin1(self): shutil.copy2(self.input_dos_pw_nacl_prim, test_work_dir / "INPUT") outputs = abacus_dos_run(test_work_dir) + print(outputs) dos_fig_path = outputs['dos_fig_path'] + dos_data_path = outputs['dos_data_path'] self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) self.assertTrue(outputs['scf_normal_end']) self.assertTrue(outputs['scf_converge']) self.assertTrue(outputs['nscf_normal_end']) @@ -226,8 +316,10 @@ def test_abacus_dos_run_pw_nspin2(self): print(outputs) dos_fig_path = outputs['dos_fig_path'] + dos_data_path = outputs['dos_data_path'] self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) self.assertTrue(outputs['scf_normal_end']) self.assertTrue(outputs['scf_converge']) self.assertTrue(outputs['nscf_normal_end']) diff --git a/tests/integrate_test/test_tool_wrapper.py b/tests/integrate_test/test_tool_wrapper.py index c8f33cb..0039271 100644 --- a/tests/integrate_test/test_tool_wrapper.py +++ b/tests/integrate_test/test_tool_wrapper.py @@ -175,15 +175,20 @@ def test_run_abacus_calculation_dos(self): fixed_axes=None, pdos_mode='species+shell', dos_edelta_ev=0.01, - dos_sigma=0.07, - dos_scale=0.01) + dos_sigma=0.07) print(outputs) dos_fig_path = outputs['dos_fig_path'] pdos_fig_path = outputs['pdos_fig_path'] + dos_data_path = outputs['dos_data_path'] + pdos_data_paths = outputs['pdos_data_paths'] self.assertIsInstance(dos_fig_path, get_path_type()) + self.assertIsInstance(dos_data_path, get_path_type()) self.assertIsInstance(pdos_fig_path, get_path_type()) + for pdos_data_path in pdos_data_paths: + self.assertIsInstance(pdos_data_path, get_path_type()) + self.assertTrue(outputs['scf_normal_end']) self.assertTrue(outputs['scf_converge']) self.assertTrue(outputs['nscf_normal_end']) diff --git a/tests/test_dos.py b/tests/test_dos.py index 0d498f7..041ff71 100644 --- a/tests/test_dos.py +++ b/tests/test_dos.py @@ -1,42 +1,46 @@ import unittest -import os, sys, glob +import os, sys, glob, shutil from pathlib import Path + os.environ["ABACUSAGENT_MODEL"] = "test" -from abacusagent.modules.submodules.dos import plot_dos_pdos as mkplots +from abacusagent.modules.submodules.dos import plot_write_dos_pdos as mkplots + class TestPlotDos(unittest.TestCase): def setUp(self): self.data_dir = Path(__file__).parent / "plot_dos" - def tearDown(self): for pngfile in glob.glob(os.path.join(self.data_dir, "*.png")): os.remove(pngfile) - + def test_run_dos(self): """ Test the run_dos function with a valid input. """ # ignore the screen output - sys.stdout = open(os.devnull, 'w') + sys.stdout = open(os.devnull, "w") # Call the run_dos function - results_test = mkplots(self.data_dir, self.data_dir, self.data_dir, dpi=20) - results_ref = [Path(self.data_dir) / "DOS.png", - Path(self.data_dir) / "PDOS.png"] - - self.assertListEqual([Path(p) for p in results_test], results_ref) + results_figs, results_datas = mkplots(self.data_dir, self.data_dir, "species", dos_emin_ev=-1, dos_emax_ev=1) + + output_dir = Path(glob.glob("*plot_write_dos_pdos*")[0]).absolute() + results_figs_ref = [ + output_dir / "DOS.png", + output_dir / "PDOS.png", + ] + results_datas_ref = [ + output_dir / "DOS.dat", + output_dir / "PDOS.dat", + ] + + self.assertListEqual([Path(p) for p in results_figs], results_figs_ref) + self.assertListEqual([Path(p) for p in results_datas], results_datas_ref) if os.path.exists(self.data_dir / "metrics.json"): os.remove(self.data_dir / "metrics.json") - - - - - - - - + for dir in glob.glob("*plot_write_dos_pdos*"): + shutil.rmtree(dir)