import functools
import os.path
from typing import Literal, NamedTuple, Optional, Union, cast, get_args
from .core import Prec, Structure, StructureDouble, StructureFloat, StructureFormat
try:
from pymatgen.core import Structure as PymatgenStructure
except ImportError:
HAVE_PYMATGEN = False
else:
HAVE_PYMATGEN = True
try:
from ase import Atoms
except ImportError:
HAVE_ASE = False
else:
HAVE_ASE = True
@functools.lru_cache(maxsize=2)
def _structure_type(prec: Prec) -> Union[type[StructureFloat], type[StructureDouble]]:
"""
Return the Structure class based on the precision.
Args:
prec (Prec): The precision type (single or double).
Returns:
type[StructureFloat] | type[StructureDouble]: The corresponding Structure class.
Raises:
ValueError: If the precision type is invalid.
"""
if prec == Prec.double:
return StructureDouble
elif prec == Prec.single:
return StructureFloat
else:
raise ValueError("Invalid prec")
if HAVE_PYMATGEN:
from pymatgen.core.structure import FileFormats
@functools.lru_cache(maxsize=1)
def pymatgen_formats() -> tuple[str, ...]:
"""
Return the available formats for pymatgen.
Returns:
tuple[str, ...]: A tuple of available pymatgen formats.
"""
return tuple(
fmt for fmt in get_args(FileFormats) if fmt and fmt not in ("aims,")
)
[docs]
def to_pymatgen(structure: Structure) -> PymatgenStructure:
"""
Convert a Structure to a pymatgen Structure.
Args:
structure (Structure): The input Structure object.
Returns:
PymatgenStructure: The converted pymatgen Structure object.
"""
filtered = structure.without_vacancies()
return PymatgenStructure(
lattice=filtered.lattice,
species=[atom.symbol for atom in filtered.atoms],
coords=filtered.frac_coords,
coords_are_cartesian=False,
)
[docs]
def from_pymatgen(ps: PymatgenStructure, prec: Prec = Prec.double) -> Structure:
"""
Convert a pymatgen Structure to a Structure.
Args:
ps (PymatgenStructure): The pymatgen Structure object.
prec (Prec, optional): The precision type. Defaults to Prec.double.
Returns:
Structure: The converted Structure object.
"""
return _structure_type(prec)(
ps.lattice.matrix,
ps.frac_coords,
[s.specie.Z for s in ps],
)
def write_pymatgen(
structure: PymatgenStructure, filename: str, fmt: FileFormats
) -> None:
"""
Write a pymatgen Structure to a file.
Args:
structure (PymatgenStructure): The pymatgen Structure object.
filename (str): The file path to write to.
fmt (FileFormats): The file format to use.
"""
structure.to_file(filename, fmt)
def read_pymatgen(filename: str, fmt: FileFormats) -> PymatgenStructure:
"""
Read a pymatgen Structure from a file.
Args:
filename (str): The file path to read from.
fmt (FileFormats): The file format to use.
Returns:
PymatgenStructure: The read pymatgen Structure object.
"""
with open(filename) as f:
return PymatgenStructure.from_str(f.read(), fmt)
if HAVE_ASE:
class AseFileFormat(NamedTuple):
extension: str
code: str
@functools.lru_cache(maxsize=1)
def ase_formats() -> dict[str, AseFileFormat]:
"""
Return the available formats for ASE.
Returns:
dict[str, AseFileFormat]: A dictionary of ASE formats with their details.
"""
from ase.io.formats import all_formats
fmts = {
"abinit-in",
"aims",
"cfg",
"cif",
"crystal",
"cube",
"dlp4",
"dmol-car",
"dmol-incoor",
"gaussian-in",
"gen",
"gpumd",
"gromacs",
"gromos",
"json",
"magres",
"nwchem-in",
"onetep-in",
"res",
"rmc6f",
"struct",
"sys",
"traj",
"v-sim",
"vasp",
"xsd",
}
return {
fmt: AseFileFormat(
fmt,
all_formats[fmt].code,
)
for fmt in fmts
}
[docs]
def to_ase(structure: Structure) -> Atoms:
"""
Convert a Structure to an ASE Atoms object.
Args:
structure (Structure): The input Structure object.
Returns:
Atoms: The converted ASE Atoms object.
"""
filtered = structure.without_vacancies()
return Atoms(
numbers=filtered.species,
scaled_positions=filtered.frac_coords,
cell=filtered.lattice,
pbc=True,
)
[docs]
def from_ase(atom: Atoms, prec: Prec = Prec.double) -> Structure:
"""
Convert an ASE Atoms object to a Structure.
Args:
atom (Atoms): The ASE Atoms object.
prec (Prec, optional): The precision type. Defaults to Prec.double.
Returns:
Structure: The converted Structure object.
"""
return _structure_type(prec)(
atom.cell,
atom.get_scaled_positions(),
atom.numbers,
)
def read_ase(filename: str, fmt: str) -> Atoms:
"""
Read an ASE Atoms object from a file.
Args:
filename (str): The file path to read from.
fmt (str): The file format to use.
Returns:
Atoms: The read ASE Atoms object.
Raises:
ValueError: If the format is unsupported.
"""
from ase.io import read
if fmt not in ase_formats():
raise ValueError(f"Unsupported ASE format: {fmt}")
code = ase_formats()[fmt].code
if code.endswith("S"):
out = read(filename, format=fmt)
elif code.endswith("F"):
with open(filename) as f:
out = read(f, format=fmt)
elif code.endswith("B"):
with open(filename, "rb") as f:
out = read(f, format=fmt)
else:
raise ValueError(f"Unsupported ASE format code: {code}")
return out[0] if isinstance(out, list) else out
def write_ase(atoms: Atoms, filename: str, fmt: str) -> None:
"""
Write an ASE Atoms object to a file.
Args:
atoms (Atoms): The ASE Atoms object.
filename (str): The file path to write to.
fmt (str): The file format to use.
"""
from ase.io import write
if fmt not in ase_formats():
raise ValueError(f"Unsupported ASE format: {fmt}")
code = ase_formats()[fmt].code
# out = [atoms] if code.startswith("+") else atoms
out = atoms
if code.endswith("S"):
write(filename, out, format=fmt)
elif code.endswith("F"):
with open(filename, "w") as f:
write(f, out, format=fmt)
elif code.endswith("B"):
with open(filename, "wb") as f:
write(f, out, format=fmt)
else:
raise ValueError(f"Unsupported ASE format code: {code}")
def sqsgen_formats() -> tuple[str, ...]:
return "poscar", "cif", "json"
def deduce_format(filename: str) -> tuple[str, str]:
"""
Deduce the file format from the file extension.
Args:
filename (str): The file path to deduce the format from.
Returns:
str: The deduced file format.
"""
filename = os.path.basename(filename)
if not filename:
raise ValueError("filename cannot be empty.")
parts = [part for part in filename.split(".") if part]
if len(parts) < 2:
raise ValueError("filename must have a valid extension.")
elif len(parts) == 2:
# there is no reader defined we use "sqsgen"
_, fmt = parts
if fmt == "vasp":
fmt = "poscar"
if fmt in sqsgen_formats():
return "sqsgen", fmt
else:
raise ValueError(
"sqsgen only supports one of this formats: "
+ ", ".join(sqsgen_formats())
)
else:
*_, backend, fmt = parts
if backend == "sqsgen":
if fmt == "vasp":
fmt = "poscar"
if fmt in sqsgen_formats():
return "sqsgen", fmt
else:
raise ValueError(
"sqsgen only supports one of this formats: "
+ ", ".join(sqsgen_formats())
)
elif backend == "ase":
if not HAVE_ASE:
raise ImportError(
"ASE is not installed. Please install it to use this format."
)
if fmt in ase_formats():
return "ase", cast(str, fmt)
else:
raise ValueError(
"ase only supports one of this formats: "
+ ", ".join(ase_formats().keys())
)
elif backend == "pymatgen":
if not HAVE_PYMATGEN:
raise ImportError(
"Pymatgen is not installed. Please install it to use this format."
)
if fmt in pymatgen_formats():
return "pymatgen", cast(str, fmt)
else:
raise ValueError(
"pymatgen only supports one of this formats: "
+ ", ".join(pymatgen_formats())
)
else:
raise ValueError("Cannot deduce file format from extension")
[docs]
def write(
structure: Structure,
filename: str,
fmt: Optional[str] = None,
backend: Literal["sqsgen", "ase", "pymatgen"] = "sqsgen",
) -> None:
"""
Write a structure to a file in the specified format.
Args:
structure (Structure): The structure to write.
filename (str): The file path to write to.
fmt (str): The file format to use.
"""
structure = structure.without_vacancies()
if fmt is not None and (
not filename.endswith(f"{backend}.{fmt}") or not filename.endswith(fmt)
):
raise ValueError(
"The filename must end with the format extension if explicitly specified."
)
backend, fmt = (fmt, backend) if fmt is not None else deduce_format(filename)
def _write_str(s: Union[str, bytes]) -> None:
with open(filename, "w" if isinstance(s, str) else "wb") as f:
f.write(s)
if backend == "ase":
write_ase(to_ase(structure), filename, fmt)
elif backend == "pymatgen":
write_pymatgen(to_pymatgen(structure), filename, fmt)
elif backend == "sqsgen":
if fmt == "cif":
_write_str(structure.dump(StructureFormat.cif))
elif fmt in {"vasp", "poscar"}:
_write_str(structure.dump(StructureFormat.poscar))
elif fmt == "json":
_write_str(structure.dump(StructureFormat.json_sqsgen))
else:
raise ValueError(
f"Unsupported format '{fmt}' for sqsgen backend. "
"Supported formats are: cif, vasp, poscar and json."
)
else:
raise ValueError(f"Unsupported backend {backend}")
[docs]
def read(
filename: str,
fmt: Optional[str] = None,
backend: Literal["sqsgen", "ase", "pymatgen"] = "sqsgen",
prec: Prec = Prec.double,
) -> Structure:
"""
Read a structure from a file in the specified format.
Args:
filename (str): The file path to read from.
fmt (str): The file format to use.
backend (Literal["sqsgen", "ase", "pymatgen"]): The backend to use for reading.
prec (Prec): The precision type.
Returns:
Structure: The read structure.
"""
if fmt is not None and not filename.endswith(fmt):
raise ValueError(
"The filename must end with the format extension if explicitly specified."
)
backend, fmt = (backend, fmt) if fmt is not None else deduce_format(filename)
def _read_str() -> str:
with open(filename) as f:
return f.read()
if backend == "ase":
return from_ase(read_ase(filename, fmt), prec)
elif backend == "pymatgen":
return from_pymatgen(read_pymatgen(filename, fmt), prec)
elif backend == "sqsgen":
if fmt in ("vasp", "poscar"):
return _structure_type(prec).from_poscar(_read_str())
elif fmt == "json":
return _structure_type(prec).from_json(
_read_str(), StructureFormat.json_sqsgen
)
else:
raise ValueError(
f"Unsupported format '{fmt}' for sqsgen backend. "
"Supported formats are: vasp, poscar and json."
)
else:
raise ValueError(
f"Unsupported backend '{backend}'. Supported backends are: sqsgen, ase, pymatgen."
)
@functools.lru_cache(maxsize=1)
def available_formats() -> tuple[str, ...]:
"""
Return all available formats for reading and writing structures.
Returns:
tuple[str, ...]: A tuple of available formats.
"""
fmts = list(sqsgen_formats()) + [f"sqsgen.{fmt}" for fmt in sqsgen_formats()]
if HAVE_ASE:
fmts += list(f"ase.{fmt}" for fmt in ase_formats().keys())
if HAVE_PYMATGEN:
fmts += list(f"pymatgen.{fmt}" for fmt in pymatgen_formats())
return tuple(fmts)