Source code for sqsgenerator._optimize

import json
import warnings
from typing import Any, Callable, Optional, Union

from ._adapters import read as _read
from .core import (
    IterationMode,
    LogLevel,
    Prec,
    SqsCallback,
    SqsConfiguration,
    SqsConfigurationDouble,
    SqsConfigurationFloat,
    SqsResultPack,
    Structure,
    SublatticeMode,
)
from .core import (
    optimize as _optimize,
)
from .core import (
    parse_config as _parse_config,
)

__all__ = ["optimize", "parse_config"]


def _parse_prec(string: str) -> Prec:
    """
    Parse a string into a Prec enum value.

    Args:
        string (str): The string to parse.

    Returns:
        Prec: The corresponding Prec enum value.
    """
    if (prec := string.lower()) == "double":
        return Prec.double
    elif prec == "single":
        return Prec.single
    else:
        raise ValueError(f"Invalid precision: {string}. Use 'double' or 'single'.")


def _parse_iteration_mode(string: str) -> IterationMode:
    """
    Parse a string into an IterationMode enum value.

    Args:
        string (str): The string to parse.

    Returns:
        IterationMode: The corresponding IterationMode enum value.
    """
    if (mode := string.lower()) == "random":
        return IterationMode.random
    elif mode == "systematic":
        return IterationMode.systematic
    else:
        raise ValueError(
            f"Invalid iteration mode: {string}. Use 'random' or 'systematic'."
        )


def _parse_sublattice_mode(string: str) -> SublatticeMode:
    """
    Parse a string into a SublatticeMode enum value.

    Args:
        string (str): The string to parse.

    Returns:
        SublatticeMode: The corresponding SublatticeMode enum value.
    """
    if (mode := string.lower()) == "split":
        return SublatticeMode.split
    elif mode == "interact":
        return SublatticeMode.interact
    else:
        raise ValueError(
            f"Invalid sublattice mode: {string}. Use 'split' or 'interact'."
        )


def _preprocess_structure(structure_config: dict[str, Any]) -> dict[str, Any]:
    """
    Preprocess the structure configuration dictionary inplace.

    Args:
        structure_config (dict[str, Any]): The structure configuration dictionary containing
            information about the structure, such as lattice, coordinates, species, or a file path.

    Returns:
        dict[str, Any]: The updated structure configuration dictionary with processed structure data.
    """
    if structure_file_path := structure_config.get("file"):
        if (
            "lattice" in structure_config
            and "coords" in structure_config
            and "species" in structure_config
        ):
            warnings.warn(
                f"Structure data provided in both file and inline format. Using file data from {structure_file_path}.",
                UserWarning,
                stacklevel=2,
            )
        structure: Structure = _read(structure_file_path)
        structure_config["lattice"] = structure.lattice.tolist()
        structure_config["coords"] = structure.frac_coords.tolist()
        structure_config["species"] = structure.species

    return structure_config


[docs] def parse_config( config: Union[dict[str, Any], str], inplace: bool = False ) -> SqsConfiguration: """ Parse the configuration dictionary into a SqsConfiguration object. Args: config (dict[str, Any]): Configuration dictionary. inplace (bool, optional): If `True`, return a modified version of the input dictionary. Returns: SqsConfiguration: Parsed configuration object. """ if isinstance(config, str): config = json.loads(config) config = config.copy() if not inplace else config if structure_config := config.get("structure"): _preprocess_structure(structure_config) def apply(key: str, f: Callable[[str], Any]) -> None: if key in config: config[key] = f(config[key]) apply("prec", _parse_prec) apply("iteration_mode", _parse_iteration_mode) apply("sublattice_mode", _parse_sublattice_mode) return _parse_config(config) # type: ignore[return-value]
[docs] def optimize( config: Union[dict[str, Any], SqsConfiguration, str], level: LogLevel = LogLevel.warn, callback: Optional[SqsCallback] = None, ) -> SqsResultPack: """ Optimize the SQS configuration based on the provided parameters. Args: config (dict[str, Any] | SqsConfiguration | str): The configuration data to optimize. level (LogLevel): The logging level for the optimization process. Defaults to `LogLevel.warn`. callback (SqsCallback | None): A callback function to monitor the optimization progress. Defaults to `None`. Returns: SqsResultPack: The result of the optimization process. """ c = ( config if isinstance(config, (SqsConfigurationFloat, SqsConfigurationDouble)) else parse_config(config) ) return _optimize(c, log_level=level, callback=callback)