#!python
# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
#   See COPYING file distributed along with the NiBabel package for the
#   copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""
Output a summary table for neuroimaging files (resolution, dimensionality, etc.)
"""
from __future__ import division, print_function, absolute_import

import re
import sys

import numpy as np
import nibabel as nib

from math import ceil
from optparse import OptionParser, Option
from io import StringIO
from nibabel.py3k import asunicode

__author__ = 'Yaroslav Halchenko'
__copyright__ = 'Copyright (c) 2011-2016 Yaroslav Halchenko ' \
                'and NiBabel contributors'
__license__ = 'MIT'


# global verbosity switch
verbose_level = 0
MAX_UNIQUE = 1000  # maximal number of unique values to report for --counts

def _err(msg=None):
    """To return a string to signal "error" in output table"""
    if msg is None:
        msg = 'error'
    return '!' + msg

def verbose(l, msg):
    """Print `s` if `l` is less than the `verbose_level`
    """
    # TODO: consider using nibabel's logger
    if l <= int(verbose_level):
        print("%s%s" % (' ' * l, msg))


def error(msg, exit_code):
    print >> sys.stderr, msg
    sys.exit(exit_code)


def table2string(table, out=None):
    """Given list of lists figure out their common widths and print to out

    Parameters
    ----------
    table : list of lists of strings
      What is aimed to be printed
    out : None or stream
      Where to print. If None -- will print and return string

    Returns
    -------
    string if out was None
    """

    print2string = out is None
    if print2string:
        out = StringIO()

    # equalize number of elements in each row
    nelements_max = \
        len(table) and \
        max(len(x) for x in table)

    for i, table_ in enumerate(table):
        table[i] += [''] * (nelements_max - len(table_))

    # figure out lengths within each column
    atable = np.asarray(table)
    # eat whole entry while computing width for @w (for wide)
    markup_strip = re.compile('^@([lrc]|w.*)')
    col_width = [max([len(markup_strip.sub('', x))
                      for x in column]) for column in atable.T]
    string = ""
    for i, table_ in enumerate(table):
        string_ = ""
        for j, item in enumerate(table_):
            item = str(item)
            if item.startswith('@'):
                align = item[1]
                item = item[2:]
                if align not in ['l', 'r', 'c', 'w']:
                    raise ValueError('Unknown alignment %s. Known are l,r,c' %
                                     align)
            else:
                align = 'c'

            nspacesl = max(ceil((col_width[j] - len(item)) / 2.0), 0)
            nspacesr = max(col_width[j] - nspacesl - len(item), 0)

            if align in ['w', 'c']:
                pass
            elif align == 'l':
                nspacesl, nspacesr = 0, nspacesl + nspacesr
            elif align == 'r':
                nspacesl, nspacesr = nspacesl + nspacesr, 0
            else:
                raise RuntimeError('Should not get here with align=%s' % align)

            string_ += "%%%ds%%s%%%ds " \
                       % (nspacesl, nspacesr) % ('', item, '')
        string += string_.rstrip() + '\n'
    out.write(asunicode(string))

    if print2string:
        value = out.getvalue()
        out.close()
        return value


def ap(l, format_, sep=', '):
    """Little helper to enforce consistency"""
    if l == '-':
        return l
    ls = [format_ % x for x in l]
    return sep.join(ls)


def safe_get(obj, name):
    """A getattr which would return '-' if getattr fails
    """
    try:
        f = getattr(obj, 'get_' + name)
        return f()
    except Exception as e:
        verbose(2, "get_%s() failed -- %s" % (name, e))
        return '-'


def get_opt_parser():
    # use module docstring for help output
    p = OptionParser(
        usage="%s [OPTIONS] [FILE ...]\n\n" % sys.argv[0] + __doc__,
        version="%prog " + nib.__version__)

    p.add_options([
        Option("-v", "--verbose", action="count",
               dest="verbose", default=0,
               help="Make more noise.  Could be specified multiple times"),

        Option("-H", "--header-fields",
               dest="header_fields", default='',
               help="Header fields (comma separated) to be printed as well (if present)"),

        Option("-s", "--stats",
               action="store_true", dest='stats', default=False,
               help="Output basic data statistics"),

        Option("-c", "--counts",
               action="store_true", dest='counts', default=False,
               help="Output counts - number of entries for each numeric value "
                    "(useful for int ROI maps)"),

        Option("--all-counts",
               action="store_true", dest='all_counts', default=False,
               help="Output all counts, even if number of unique values > %d" % MAX_UNIQUE),

        Option("-z", "--zeros",
               action="store_true", dest='stats_zeros', default=False,
               help="Include zeros into output basic data statistics (--stats, --counts)"),
    ])

    return p


def proc_file(f, opts):
    verbose(1, "Loading %s" % f)

    row = ["@l%s" % f]
    try:
        vol = nib.load(f)
        h = vol.header
    except Exception as e:
        row += ['failed']
        verbose(2, "Failed to gather information -- %s" % str(e))
        return row

    row += [str(safe_get(h, 'data_dtype')),
            '@l[%s]' % ap(safe_get(h, 'data_shape'), '%3g'),
            '@l%s' % ap(safe_get(h, 'zooms'), '%.2f', 'x')]
    # Slope
    if hasattr(h, 'has_data_slope') and \
            (h.has_data_slope or h.has_data_intercept) and \
            not h.get_slope_inter() in [(1.0, 0.0), (None, None)]:
        row += ['@l*%.3g+%.3g' % h.get_slope_inter()]
    else:
        row += ['']

    if hasattr(h, 'extensions') and len(h.extensions):
        row += ['@l#exts: %d' % len(h.extensions)]
    else:
        row += ['']

    if opts.header_fields:
        # signals "all fields"
        if opts.header_fields == 'all':
            # TODO: might vary across file types, thus prior sensing
            # would be needed
            header_fields = h.keys()
        else:
            header_fields = opts.header_fields.split(',')

        for f in header_fields:
            if not f:  # skip empty
                continue
            try:
                row += [str(h[f])]
            except (KeyError, ValueError):
                row += [_err()]

    try:
        if (hasattr(h, 'get_qform') and hasattr(h, 'get_sform') and
                (h.get_qform() != h.get_sform()).any()):
            row += ['sform']
        else:
            row += ['']
    except Exception as e:
        verbose(2, "Failed to obtain qform or sform -- %s" % str(e))
        if isinstance(h, nib.AnalyzeHeader):
            row += ['']
        else:
            row += [_err()]

    if opts.stats or opts.counts:
        # We are doomed to load data
        try:
            d = vol.get_data()
            if not opts.stats_zeros:
                d = d[np.nonzero(d)]
            else:
                # at least flatten it -- functionality below doesn't
                # depend on the original shape, so let's use a flat view
                d = d.reshape(-1)
            if opts.stats:
                # just # of elements
                row += ["@l[%d]" % np.prod(d.shape)]
                # stats
                row += [len(d) and '@l[%.2g, %.2g]' % (np.min(d), np.max(d)) or '-']
            if opts.counts:
                items, inv = np.unique(d, return_inverse=True)
                if len(items) > 1000 and not opts.all_counts:
                    counts = _err("%d uniques. Use --all-counts" % len(items))
                else:
                    freq = np.bincount(inv)
                    counts = " ".join("%g:%d" % (i, f) for i, f in zip(items, freq))
                row += ["@l" + counts]
        except IOError as e:
            verbose(2, "Failed to obtain stats/counts -- %s" % str(e))
            row += [_err()]
    return row


def main():
    """Show must go on"""

    parser = get_opt_parser()
    (opts, files) = parser.parse_args()

    global verbose_level
    verbose_level = opts.verbose

    if verbose_level < 3:
        # suppress nibabel format-compliance warnings
        nib.imageglobals.logger.level = 50

    rows = [proc_file(f, opts) for f in files]

    print(table2string(rows))


if __name__ == '__main__':
    main()
