#! /usr/bin/env python3

"""\
%(prog)s <yodafile1> [<yodafile2> ...] -o <yodaoutfile> -r <yodareffile>

Combine the central values of multiple YODA files into an envelope
and write out the resulting envelope into a separate output YODA file.
If a reference file is specified, the resulting central value are taken
from that file, otherwise the midpoint between the edges of the envelope
will be used.

Examples:

%(prog)s -o out.yoda file1.yoda file2.yoda
  Write out a file with the envelope constructed from file1 and file2 and with the central value set to the midpoint
  between the envelope edges.

%(prog)s -o out.yoda -r reffile.yoda file1.yoda file2.yoda
  Write out a file with the envelope constructed from reffile, file1 and file2, with the central value set to the
  central values from reffile.

"""


## Parse command line args
import argparse, sys
parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument("INFILES", nargs="+",
                    help="file or folder with reference histos")
parser.add_argument("-r", "--ref", default=None, dest="REF_FILE",
                    help="file to be used for central values, midpoint is used if none is specified")
parser.add_argument("-o", "--output", default="-", dest="OUTPUT_FILE", metavar="PATH",
                    help="write output to specified path")
parser.add_argument("--add-to-ref", action="store_true", default=False, dest="KEEP_ERRS",
                    help="keep the errors of the reference histo (default is to overwrite)")
parser.add_argument("--err-label", dest="ERR_LABEL", default="env",
                    help="error source label to be used for Estimates (default 'env')")
parser.add_argument("-m", "--match", dest="MATCH", metavar="PATH", default=None,
                    help="only write out histograms whose path matches this regex")
parser.add_argument("-M", "--unmatch", dest="UNMATCH", metavar="PATH", default=None,
                    help="exclude histograms whose path matches this regex")
parser.add_argument("-q", "--quiet", dest="VERBOSITY", action="store_const", const=0, default=1,
                    help="reduce printouts to errors-only")
parser.add_argument("-v", "--debug", dest="VERBOSITY", action="store_const", const=2, default=1,
                   help="increase printouts to include debug info")
args = parser.parse_args()


nfiles = len(args.INFILES) + (not args.REF_FILE is None)
if (nfiles < 2 or args.INFILES[0] == args.REF_FILE):
    sys.stderr.write("ERROR: Need at least two input files to construct an envelope!")
    sys.exit(1)

import yoda, math


def add2envelope(data, aos, keepAOs = None):
    for aoname, ao in aos.items():
        ao = ao.mkInert(ao.path())
        if keepAOs is not None:
            keepAOs[aoname] = ao
        newlo, newhi, cen = None, None, None
        # get central value
        cen = ao.vals(ao.dim()-1) if 'Scatter' in ao.type() else \
              [ ao.val() ] if 'Estimate0D' in ao.type() else ao.vals()
        # construct envelope for this central value
        if aoname in data:
            lo, hi = data[aoname]
            newlo = list(map(min, zip(lo, cen)))
            newhi = list(map(max, zip(hi, cen)))
        else:
            newlo = newhi = cen
        data[aoname] = [ newlo, newhi ]


def updateErrors(data, aos, useMidPoint):
    for aoname, ao in aos.items():
        all_lo, all_hi = data[aoname]
        isScat = 'Scatter' in ao.type()
        isE0D = 'Estimate0D' in ao.type()
        for i in range(ao.numPoints() if isScat else 1 if isE0D else ao.numBins()):
            lo = all_lo[i]
            hi = all_hi[i]
            ao_i = ao.point(i) if isScat else ao if isE0D else ao.bin(i+1)
            cen = ao_i.val(ao.dim()-1) if isScat else ao_i.val()
            if useMidPoint:
                cen = 0.5 * ( lo + hi )
                ao_i.setVal(ao.dim()-1, cen) if isScat else ao_i.setVal(cen)
            errlo = cen - lo
            errhi = hi - cen
            if args.KEEP_ERRS and isScat:
                oldlo, oldhi = ao_i.errs()
                errlo = math.sqrt(errlo**2 + oldlo**2)
                errhi = math.sqrt(errhi**2 + oldhi**2)
            elif not isScat:
                if args.KEEP_ERRS and ao_i.hasSource(""):
                    ao_i.renameSource("", "old_err")
                elif not args.KEEP_ERRS:
                    ao_i.rmErrs()
            ao_i.setErrs(ao.dim()-1, (errlo, errhi)) if isScat else ao_i.setErr((-errlo, errhi), f"{args.ERR_LABEL}")

aos_out = { }
envdata = { }
aos_in = None
for i, filename in enumerate(args.INFILES):
    if args.VERBOSITY > 0:
        msg = "Adding data file {:s} [{:d}/{:d}]".format(filename, i+1, nfiles)
        sys.stdout.write(msg + "\n")

    del aos_in
    aos_in = yoda.read(filename, True, args.MATCH, args.UNMATCH)
    add2envelope(envdata, aos_in, None if i else aos_out)

if args.REF_FILE is not None:
    # Take the central values from the specified reference file
    if args.VERBOSITY > 0:
        msg = "Adding data file {:s} [{:d}/{:d}]".format(args.REF_FILE, nfiles, nfiles)
        sys.stdout.write(msg + "\n")

    del aos_in
    aos_in = yoda.read(args.REF_FILE, True, args.MATCH, args.UNMATCH)
    add2envelope(envdata, aos_in, aos_out)

useMidPoint = args.REF_FILE is None
updateErrors(envdata, aos_out, useMidPoint)

yoda.write(aos_out, args.OUTPUT_FILE)
