# Soprano - a library to crack crystals! by Simone Sturniolo
# Copyright (C) 2016 - Science and Technology Facility Council
# Soprano is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# Soprano is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""
CLI to plot the extract and summarise dipolar couplings.
TODO
* implement symmetry/label-based averaging.
* implement averaging over functional groups.
* implement rotational averaging. 
"""
__author__ = "J. Kane Shenton"
__maintainer__ = "J. Kane Shenton"
__email__ = "kane.shenton@stfc.ac.uk"
__date__ = "August 10, 2022"
import logging
import click
import click_log
import numpy as np
import pandas as pd
from ase import io
from soprano.data.nmr import _get_isotope_list
from soprano.properties.labeling import MagresViewLabels
from soprano.properties.nmr import *
from soprano.scripts.cli_utils import (
    DIPOLAR_OPTIONS,
    NO_CIF_LABEL_WARNING,
    add_options,
    apply_df_filtering,
    expand_aliases,
    print_results,
    sortdf,
    units_rename,
    viewimages,
)
from soprano.selection import AtomSelection
from soprano.utils import has_cif_labels
HEADER = """
##################################################
#  Extracting Dipolar couplingsfrom magres file  #
"""
FOOTER = """
#  End of dipolar coupling extraction            #
##################################################
"""
dipolar_aliases = {
    "minimal": [
            "pair",
            "label_i",
            "label_j",
            "isotope_i",
            "isotope_j",
            "D",
            "alpha",
            "beta",
            ],
    "essential": [
        "pair",
        "D"
    ]}
rss_aliases = {
    "minimal": [
            "index",
            "label",
            "isotope",
            "D_RSS",
            ],
     "essential": [
            "index",
            "D_RSS",
            ],}
# logging
logging.captureWarnings(True)
logger = logging.getLogger("cli")
click_log.basic_config(logger)
@click.command()
# one of more magres files
@click.argument("files", nargs=-1, type=click.Path(exists=True), required=True)
@add_options(DIPOLAR_OPTIONS)
def dipolar(
    files,
    selection_i=None,
    selection_j=None,
    rss_flag=False,
    rss_cutoff=5.0,
    self_coupling=False,
    isonuclear=False,
    output=None,
    output_format=None,
    merge=False,
    isotopes={},
    precision=3,
    sortby=None,
    sort_order="ascending",
    include=None,
    exclude=None,
    query=None,
    view=False,
    verbosity=0,
    **kwargs,
):
    """
    Extract and summarise dipolar couplings from structure files.
    Usage:
    soprano dipolar seedname.{magres|cif|POSCAR|etc}
    """
    if verbosity == 0:
        logging.basicConfig(level=logging.WARNING)
    elif verbosity == 1:
        logging.basicConfig(level=logging.INFO)
    else:
        logging.basicConfig(level=logging.DEBUG)
    # set pandas print precision
    pd.set_option("display.precision", precision)
    # make sure we output all rows, even if there are lots!
    pd.set_option("display.max_rows", None)
    dfs = []
    images = []
    # loop over files
    for fname in files:
        logger.info(HEADER)
        logger.info(fname)
        # try to read in the file:
        try:
            atoms = io.read(fname)
        except OSError:
            logger.error(f"Could not read file {fname}, skipping.")
            continue
        # Inform user of best practice RE CIF labels
        if not has_cif_labels(atoms):
            logger.info(NO_CIF_LABEL_WARNING)
        # Selections -- if None, does all combinations
        sel_i = AtomSelection.all(atoms)
        sel_j = AtomSelection.all(atoms)
        # select subset of atoms based on selection string
        if selection_i:
            logger.info(f"\nSelecting atoms based on selection string: {selection_i}")
            sel_i = AtomSelection.from_selection_string(atoms, selection_i)
        if selection_j:
            if rss_flag:
                raise ValueError("Cannot use --rss flag with selection_j")
            logger.info(f"\nSelecting atoms based on selection string: {selection_j}")
            sel_j = AtomSelection.from_selection_string(atoms, selection_j)
        # --- rss or pairs --- #
        if rss_flag:
            df = extract_dipolar_RSS_couplings(
                atoms,
                isotopes=isotopes,
                cutoff=rss_cutoff,
                isonuclear=isonuclear,
                sel_i=sel_i,
            )
            essential_columns = [
                "index",
                "label",
                "isotope",
                "D_RSS",
            ]
            df = apply_df_filtering(
                df,
                expand_aliases(include, rss_aliases),
                exclude,
                query,
                essential_columns=rss_aliases["essential"],
                logger=logger,
            )
        else:
            df = extract_dipolar_couplings(
                atoms,
                sel_i=sel_i,
                sel_j=sel_j,
                isotopes=isotopes,
                rss_cutoff=rss_cutoff,
                isonuclear=isonuclear,
            )
            # reformat the 'v' column
            df["v"] = df["v"].apply(lambda x: np.round(x, precision))
            df = apply_df_filtering(
                df,
                expand_aliases(include, dipolar_aliases),
                exclude,
                query,
                essential_columns=dipolar_aliases["essential"],
                logger=logger,
            )
        ###############################################
        # add file info
        df["file"] = fname
        if len(df) > 0:
            # done -- save to lists
            dfs.append(df)
            images.append(atoms)
            logger.info(FOOTER)
    if view:
        # TODO: make sure indices match the original/df indices (currently don't if organic!)
        viewimages(images)
    if merge:
        # merge all dataframes into one
        dfs = [pd.concat(dfs, axis=0)]
    for i, df in enumerate(dfs):
        dfs[i] = sortdf(df, sortby, sort_order)
    # rename columns to include units for those that have units
    for df in dfs:
        df.rename(columns=units_rename, inplace=True)
    # write to file(s)
    print_results(dfs, output, output_format, verbosity > 0)