#!/usr/pkg/bin/python3.10

# Audio Tools, a module and set of tools for manipulating audio data
# Copyright (C) 2007-2015  Brian Langenberger

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA

import sys
import os.path
import audiotools
import audiotools.text as _
from operator import or_

MAX_CPUS = audiotools.MAX_JOBS
PY3 = audiotools.PY3


class FailedAudioFile(object):
    def __init__(self, class_name, filename, err):
        self.NAME = class_name
        self.filename = filename
        self.err = err

    def verify(self, progress=None):
        raise self.err


def open_file(filename):
    # this is much like audiotools.open
    # except that file init errors fail differently

    with open(filename, "rb") as f:
        audio_class = audiotools.file_type(f)
        if ((audio_class is not None) and audio_class.supports_to_pcm()):
            class_name = audio_class.NAME
            try:
                return audio_class(filename)
            except audiotools.InvalidFile as err:
                return FailedAudioFile(class_name, filename, err)
        else:
            raise audiotools.UnsupportedFile(filename)


def get_tracks(args, queued_files, accept_list=None):
    if accept_list is not None:
        accept_list = set(accept_list)

    for path in args:
        if os.path.isfile(path):
            try:
                filename = audiotools.Filename(path)
                if filename not in queued_files:
                    queued_files.add(filename)
                    track = open_file(str(filename))
                    if (accept_list is None) or (track.NAME in accept_list):
                        yield track
            except (audiotools.UnsupportedFile, IOError, OSError):
                continue
        elif os.path.isdir(path):
            for (d, ds, fs) in os.walk(path):
                for track in get_tracks([os.path.join(d, f) for f in fs],
                                        queued_files,
                                        accept_list=accept_list):
                    yield track


def track_number(track, default):
    metadata = track.get_metadata()
    if metadata is not None:
        if metadata.track_number is not None:
            return metadata.track_number
        else:
            return default
    else:
        return default


def verify(progress, track):
    try:
        track.verify(progress)
        return (audiotools.Filename(track.filename).__unicode__(),
                track.NAME,
                None)
    except audiotools.InvalidFile as err:
        return (audiotools.Filename(track.filename).__unicode__(),
                track.NAME,
                str(err) if PY3 else unicode(err))
    except KeyboardInterrupt:
        return (audiotools.Filename(track.filename).__unicode__(),
                track.NAME,
                _.ERR_CANCELLED)


def display_results(result, is_tty=False):
    (filename, track_type, error) = result
    if error is None:
        return _.LAB_TRACKVERIFY_RESULT % {
            "path": filename,
            "result": audiotools.output_text(_.LAB_TRACKVERIFY_OK,
                                             fg_color="green").format(is_tty)}
    else:
        return _.LAB_TRACKVERIFY_RESULT % {
            "path": filename,
            "result": audiotools.output_text(error,
                                             fg_color="red").format(is_tty)}


def display_results_tty(result):
    return display_results(result, is_tty=True)


# returned if the track isn't found in the AccurateRip database
AR_NOT_FOUND = -1

# returned if the track has one or more matches in the AccurateRip database
# but our track doesn't match any of the checksums
AR_MISMATCH = -2

PREVIOUS_TRACK_FRAMES = (5880 // 2)
NEXT_TRACK_FRAMES = (5880 // 2)


def accuraterip_checksum(progress,
                         track,
                         previous_track,
                         next_track,
                         ar_matches):
    from audiotools import (transfer_data,
                            PCMCat,
                            PCMReaderHead,
                            PCMReaderDeHead,
                            PCMReaderProgress)
    from audiotools.decoders import SameSample
    from audiotools.accuraterip import Checksum, match_offset

    # unify previous track, current track and next track into a single stream

    pcmreaders = []
    if previous_track is not None:
        frames_to_skip = previous_track.total_frames() - PREVIOUS_TRACK_FRAMES
        prev_pcmreader = previous_track.to_pcm()
        if hasattr(prev_pcmreader, "seek") and callable(prev_pcmreader.seek):
            frames_to_skip -= prev_pcmreader.seek(frames_to_skip)
        pcmreaders.append(PCMReaderDeHead(prev_pcmreader, frames_to_skip))
    else:
        pcmreaders.append(
            SameSample(sample=0,
                       total_pcm_frames=PREVIOUS_TRACK_FRAMES,
                       sample_rate=track.sample_rate(),
                       channels=track.channels(),
                       channel_mask=int(track.channel_mask()),
                       bits_per_sample=track.bits_per_sample()))

    pcmreaders.append(track.to_pcm())

    if next_track is not None:
        pcmreaders.append(PCMReaderHead(next_track.to_pcm(),
                                        NEXT_TRACK_FRAMES))
    else:
        pcmreaders.append(
            SameSample(sample=0,
                       total_pcm_frames=NEXT_TRACK_FRAMES,
                       sample_rate=track.sample_rate(),
                       channels=track.channels(),
                       channel_mask=int(track.channel_mask()),
                       bits_per_sample=track.bits_per_sample()))

    # feed stream to checksummer
    checksummer = Checksum(
        total_pcm_frames=track.total_frames(),
        sample_rate=track.sample_rate(),
        is_first=(previous_track is None),
        is_last=(next_track is None),
        pcm_frame_range=PREVIOUS_TRACK_FRAMES + 1 + NEXT_TRACK_FRAMES,
        accurateripv2_offset=PREVIOUS_TRACK_FRAMES)

    try:
        pcmreader = PCMReaderProgress(PCMCat(pcmreaders),
                                      PREVIOUS_TRACK_FRAMES +
                                      track.total_frames() +
                                      NEXT_TRACK_FRAMES,
                                      progress)
        audiotools.transfer_data(pcmreader.read, checksummer.update)
    except (IOError, ValueError) as err:
        return {"filename": audiotools.Filename(track.filename).__unicode__(),
                "error": str(err),
                "v1": {"checksum": None,
                       "offset": None,
                       "confidence": None},
                "v2": {"checksum": None,
                       "offset": None,
                       "confidence": None}}

    # determine checksum, confidence and offset from
    # the calculated checksums and possible AccurateRip matches
    (checksum_v2,
     confidence_v2,
     offset_v2) = match_offset(ar_matches,
                               checksums=[checksummer.checksum_v2()],
                               initial_offset=0)

    (checksum_v1,
     confidence_v1,
     offset_v1) = match_offset(ar_matches=ar_matches,
                               checksums=checksummer.checksums_v1(),
                               initial_offset=-PREVIOUS_TRACK_FRAMES)

    if len(ar_matches) == 0:
        return {"filename": audiotools.Filename(track.filename).__unicode__(),
                "error": None,
                "v1": {"checksum": checksum_v1,
                       "offset": offset_v1,
                       "confidence": AR_NOT_FOUND},
                "v2": {"checksum": checksum_v2,
                       "offset": offset_v2,
                       "confidence": AR_NOT_FOUND}}
    else:
        return {"filename": audiotools.Filename(track.filename).__unicode__(),
                "error": None,
                "v1": {"checksum": checksum_v1,
                       "offset": offset_v1,
                       "confidence": (confidence_v1 if
                                      (confidence_v1 is not None) else
                                      AR_MISMATCH)},
                "v2": {"checksum": checksum_v2,
                       "offset": offset_v2,
                       "confidence": (confidence_v2 if
                                      (confidence_v2 is not None) else
                                      AR_MISMATCH)}}


def accuraterip_image_checksum(progress,
                               track,
                               is_first,
                               is_last,
                               ar_matches,
                               displayed_filename,
                               pcm_frames_offset,
                               total_pcm_frames):
    from audiotools import (transfer_data,
                            PCMReaderProgress,
                            PCMReaderWindow)
    from audiotools.accuraterip import Checksum, match_offset

    reader = track.to_pcm()

    pcm_frames_offset -= PREVIOUS_TRACK_FRAMES

    # if PCMReader has seek(), use it to reduce the amount of frames to skip
    if (hasattr(reader, "seek") and
        callable(reader.seek) and
        (pcm_frames_offset > 0)):
        pcm_frames_offset -= reader.seek(pcm_frames_offset)

    # feed stream to checksummers
    checksummer = Checksum(
        total_pcm_frames=total_pcm_frames,
        sample_rate=track.sample_rate(),
        is_first=is_first,
        is_last=is_last,
        pcm_frame_range=PREVIOUS_TRACK_FRAMES + 1 + NEXT_TRACK_FRAMES,
        accurateripv2_offset=PREVIOUS_TRACK_FRAMES)

    try:
        pcmreader = audiotools.PCMReaderProgress(
            audiotools.PCMReaderWindow(
                reader,
                pcm_frames_offset,
                PREVIOUS_TRACK_FRAMES + total_pcm_frames + NEXT_TRACK_FRAMES),
            PREVIOUS_TRACK_FRAMES + total_pcm_frames + NEXT_TRACK_FRAMES,
            progress)

        audiotools.transfer_data(pcmreader.read, checksummer.update)
    except (IOError, ValueError) as err:
        return {"filename": displayed_filename,
                "error": str(err),
                "v1": {"checksum": None,
                       "offset": None,
                       "confidence": None},
                "v2": {"checksum": None,
                       "offset": None,
                       "confidence": None}}

    # determine checksum, confidence and offset from
    # the calculated checksums and possible AccurateRip matches
    (checksum_v2,
     confidence_v2,
     offset_v2) = match_offset(ar_matches=ar_matches,
                               checksums=[checksummer.checksum_v2()],
                               initial_offset=0)

    (checksum_v1,
     confidence_v1,
     offset_v1) = match_offset(ar_matches=ar_matches,
                               checksums=checksummer.checksums_v1(),
                               initial_offset=-PREVIOUS_TRACK_FRAMES)

    if len(ar_matches) == 0:
        return {"filename": displayed_filename,
                "error": None,
                "v1": {"checksum": checksum_v1,
                       "offset": offset_v1,
                       "confidence": AR_NOT_FOUND},
                "v2": {"checksum": checksum_v2,
                       "offset": offset_v2,
                       "confidence": AR_NOT_FOUND}}
    else:
        return {"filename": displayed_filename,
                "error": None,
                "v1": {"checksum": checksum_v1,
                       "offset": offset_v1,
                       "confidence": (confidence_v1 if
                                      (confidence_v1 is not None) else
                                      AR_MISMATCH)},
                "v2": {"checksum": checksum_v2,
                       "offset": offset_v2,
                       "confidence": (confidence_v2 if
                                      (confidence_v2 is not None) else
                                      AR_MISMATCH)}}


def accuraterip_display_result(result):
    if result["error"] is None:
        confidence_v1 = result["v1"]["confidence"]
        confidence_v2 = result["v2"]["confidence"]
        path = result["filename"]

        if ((confidence_v1 == AR_NOT_FOUND) and
            (confidence_v2 == AR_NOT_FOUND)):
            # no matches found for disc
            return _.LAB_TRACKVERIFY_RESULT % \
                {"path": path, "result": _.LAB_ACCURATERIP_NOT_FOUND}
        elif ((confidence_v1 == AR_MISMATCH) and
              (confidence_v2 == AR_MISMATCH)):
            # neither checksum matches
            return _.LAB_TRACKVERIFY_RESULT % \
                {"path": path, "result": _.LAB_ACCURATERIP_MISMATCH}
        elif confidence_v2 == AR_MISMATCH:
            # only AccurateRip V1 matches
            return _.LAB_TRACKVERIFY_RESULT % \
                {"path": path,
                 "result": u"%s (%s)" % (
                     _.LAB_ACCURATERIP_FOUND,
                     _.LAB_ACCURATERIP_CONFIDENCE % (confidence_v1))}
        else:
            # AccurateRip V2 has a match, so display that by default
            return _.LAB_TRACKVERIFY_RESULT % \
                {"path": path,
                 "result": u"%s (%s)" % (
                     _.LAB_ACCURATERIP_FOUND,
                     _.LAB_ACCURATERIP_CONFIDENCE % (confidence_v2))}
    else:
        # FIXME - display actual error
        return u"read error"


if (__name__ == '__main__'):
    import argparse

    parser = argparse.ArgumentParser(description=_.DESCRIPTION_TRACKVERIFY)

    parser.add_argument("--version",
                        action="version",
                        version="Python Audio Tools %s" % (audiotools.VERSION))

    parser.add_argument("-V", "--verbose",
                        dest="verbosity",
                        choices=audiotools.VERBOSITY_LEVELS,
                        default=audiotools.DEFAULT_VERBOSITY,
                        help=_.OPT_VERBOSE)

    parser.add_argument("-t", "--type",
                        action="append",
                        dest="accept_list",
                        metavar="type",
                        choices=audiotools.TYPE_MAP.keys(),
                        help=_.OPT_TYPE_TRACKVERIFY)

    parser.add_argument("-S", "--no-summary",
                        action="store_true",
                        dest="no_summary",
                        help=_.OPT_NO_SUMMARY)

    parser.add_argument("-R", "--accuraterip",
                        action="store_true",
                        dest="accuraterip",
                        default=False,
                        help=_.OPT_ACCURATERIP)

    parser.add_argument("--cue",
                        dest="cuesheet",
                        metavar="FILENAME",
                        help=_.OPT_CUESHEET_TRACKVERIFY)

    parser.add_argument("-j", "--joint",
                        type=int,
                        default=MAX_CPUS,
                        dest="max_processes",
                        help=_.OPT_JOINT)

    parser.add_argument("filenames",
                        metavar="PATH",
                        nargs="+",
                        help=_.OPT_INPUT_FILENAME_OR_DIR)

    options = parser.parse_args()
    msg = audiotools.Messenger(options.verbosity == "quiet")

    if not options.accuraterip:
        queued_files = set()  # a set of Filename objects already encountered
        queue = audiotools.ExecProgressQueue(msg)
        for track in get_tracks(options.filenames,
                                queued_files,
                                options.accept_list):
            queue.execute(
                function=verify,
                progress_text=
                audiotools.Filename(track.filename).__unicode__(),
                completion_output=(display_results_tty
                                   if msg.output_isatty() else
                                   display_results),
                track=track)

        msg.ansi_clearline()
        try:
            results = queue.run(options.max_processes)
        except KeyboardInterrupt:
            msg.error(_.ERR_CANCELLED)
            sys.exit(1)

        formats = sorted(list({r[1] for r in results}))
        success_total = len([r for r in results if r[2] is None])
        failure_total = len(results) - success_total

        if (len(formats) > 0) and (not options.no_summary):
            msg.output(_.LAB_TRACKVERIFY_RESULTS)
            msg.output(u"")

            table = audiotools.output_table()

            row = table.row()
            row.add_column(_.LAB_TRACKVERIFY_RESULT_FORMAT, "right")
            row.add_column(u" ")
            row.add_column(_.LAB_TRACKVERIFY_RESULT_SUCCESS, "right")
            row.add_column(u" ")
            row.add_column(_.LAB_TRACKVERIFY_RESULT_FAILURE, "right")
            row.add_column(u" ")
            row.add_column(_.LAB_TRACKVERIFY_RESULT_TOTAL, "right")

            table.divider_row([_.DIV, u" ", _.DIV, u" ", _.DIV, u" ", _.DIV])

            for format in formats:
                success = len([r for r in results if ((r[1] == format) and
                                                      (r[2] is None))])
                failure = len([r for r in results if ((r[1] == format) and
                                                      (r[2] is not None))])
                row = table.row()
                row.add_column(u"%s" % (format), "right")
                row.add_column(u" ")
                row.add_column(u"%d" % (success), "right")
                row.add_column(u" ")
                row.add_column(u"%d" % (failure), "right")
                row.add_column(u" ")
                row.add_column(u"%d" % (success + failure), "right")

            table.divider_row([_.DIV, u" ", _.DIV, u" ", _.DIV, u" ", _.DIV])
            row = table.row()
            row.add_column(_.LAB_TRACKVERIFY_SUMMARY, "right")
            row.add_column(u" ")
            row.add_column(u"%d" % (success_total), "right")
            row.add_column(u" ")
            row.add_column(u"%d" % (failure_total), "right")
            row.add_column(u" ")
            row.add_column(u"%d" % (success_total + failure_total), "right")

            for row in table.format(msg.output_isatty()):
                msg.output(row)

        if failure_total > 0:
            sys.exit(1)
    else:
        queued_files = set()  # a set of Filename objects already encountered
        queue = audiotools.ExecProgressQueue(msg)

        for tracks in audiotools.group_tracks(
                get_tracks(options.filenames,
                           queued_files,
                           options.accept_list)):
            # perform AccurateRip lookup on album's worth of tracks
            # if tracks are CD formatted
            if ((({t.channels() for t in tracks} == {2}) and
                 ({t.sample_rate() for t in tracks} == {44100}) and
                 ({t.bits_per_sample() for t in tracks} == {16}))):
                if ((len(tracks) == 1) and
                    ((options.cuesheet is not None) or
                     (tracks[0].get_cuesheet() is not None))):
                    filename = audiotools.Filename(tracks[0].filename)
                    total_frames = tracks[0].total_frames()
                    sample_rate = tracks[0].sample_rate()
                    # cuesheet from command-line takes precedence
                    # over cuesheet embedded in file
                    if options.cuesheet is not None:
                        try:
                            sheet = audiotools.read_sheet(options.cuesheet)
                        except audiotools.SheetException as err:
                            msg.error(err)
                            sys.exit(1)
                    else:
                        sheet = tracks[0].get_cuesheet()

                    # process tracks in CD image individually
                    ar_results = audiotools.accuraterip_sheet_lookup(
                        sheet, total_frames, sample_rate)

                    for track_num in sheet.track_numbers():

                        filename = u"%2.2d - %s" % \
                            (track_num,
                             audiotools.Filename(
                                 tracks[0].filename).basename().__unicode__())

                        offset = int(sheet.track_offset(track_num) *
                                     sample_rate)

                        if sheet.track_length(track_num) is not None:
                            length = int(sheet.track_length(track_num) *
                                         sample_rate)
                        elif track_num > 1:
                            length = (tracks[0].total_frames() -
                                      int(sheet.track_offset(track_num) *
                                          sample_rate))
                        else:
                            length = (tracks[0].total_frames() -
                                      int(sheet.pre_gap() * sample_rate))

                        queue.execute(
                            function=accuraterip_image_checksum,
                            progress_text=filename,
                            completion_output=accuraterip_display_result,
                            track=tracks[0],
                            is_first=(track_num == 1),
                            is_last=(track_num == len(sheet)),
                            ar_matches=ar_results.get(track_num, []),
                            displayed_filename=filename,
                            pcm_frames_offset=offset,
                            total_pcm_frames=length)
                else:
                    # process each track as if it were part of a CD
                    tracks = audiotools.sorted_tracks(tracks)
                    if options.cuesheet is not None:
                        try:
                            sheet = audiotools.read_sheet(options.cuesheet)
                        except audiotools.SheetException as err:
                            msg.error(err)
                            sys.exit(1)

                        ar_results = audiotools.accuraterip_sheet_lookup(
                            sheet,
                            int(sheet.pre_gap() * tracks[0].sample_rate()) +
                            sum([t.total_frames() for t in tracks]),
                            sample_rate=tracks[0].sample_rate())
                    else:
                        ar_results = audiotools.accuraterip_lookup(tracks)

                    for (i,
                         previous_track,
                         track,
                         next_track) in zip(range(1, len(tracks) + 1),
                                            [None] + tracks,
                                            tracks,
                                            tracks[1:] + [None]):
                        filename = audiotools.Filename(track.filename)

                        queue.execute(
                            function=accuraterip_checksum,
                            progress_text=filename.__unicode__(),
                            completion_output=accuraterip_display_result,
                            track=track,
                            previous_track=previous_track,
                            next_track=next_track,
                            ar_matches=ar_results.get(
                                track_number(track, i), []))
            else:
                for track in tracks:
                    msg.error(u"\"%(path)s\" %(result)s" %
                              {"path": audiotools.Filename(track.filename),
                               "result": _.ERR_TRACKVERIFY})

        msg.ansi_clearline()

        results = queue.run(options.max_processes)

        table = audiotools.output_table()

        row = table.row()
        row.add_column(u"")
        row.add_column(u"")
        row.add_column(_.LAB_TRACKVERIFY_AR_VERSION1, "center", colspan=5)
        row.add_column(u"")
        row.add_column(_.LAB_TRACKVERIFY_AR_VERSION2, "center", colspan=5)

        row = table.row()
        row.add_column(_.LAB_TRACKVERIFY_AR_TRACK)

        row.add_column(u" ")
        row.add_column(_.LAB_TRACKVERIFY_AR_CONFIDENCE)
        row.add_column(u" ")
        row.add_column(_.LAB_TRACKVERIFY_AR_OFFSET)
        row.add_column(u" ")
        row.add_column(_.LAB_TRACKVERIFY_AR_CHECKSUM)

        row.add_column(u" ")
        row.add_column(_.LAB_TRACKVERIFY_AR_CONFIDENCE)
        row.add_column(u" ")
        row.add_column(_.LAB_TRACKVERIFY_AR_OFFSET)
        row.add_column(u" ")
        row.add_column(_.LAB_TRACKVERIFY_AR_CHECKSUM)

        if sys.stdout.isatty():
            filename_max = (msg.terminal_size(sys.stdout)[1] -
                            sum(c.minimum_width() for c in
                                table.__rows__[1].__columns__[1:]))
        else:
            # effectively no maximum
            filename_max = 2 ** 32

        table.divider_row([_.DIV] + [u" ", _.DIV] * 6)

        for result in results:
            filename = result["filename"]
            err = result["error"]

            row = table.row()

            # truncate long filenames if necessary
            filename = audiotools.output_text(filename)
            if len(filename) <= filename_max:
                row.add_column(filename)
            else:
                head = filename.head(filename_max - 1)
                row.add_column(
                    u"%s%s" % (head, u"\u2026" * (filename_max - len(head))))

            if err is None:
                for version in ["v1", "v2"]:
                    checksum = result[version]["checksum"]
                    offset = result[version]["offset"]
                    confidence = result[version]["confidence"]

                    row.add_column(u"")
                    if confidence >= 0:
                        row.add_column(u"%d" % (confidence), "right")
                    else:
                        row.add_column(u"", "right")
                    row.add_column(u"")
                    row.add_column(u"%d" % (offset), "right")
                    row.add_column(u"")
                    row.add_column(u"%8.8X" % (checksum), "right")
            else:
                row.add_column(_.ERR_READ_ERROR, "right", colspan=6)

        for row in table.format(msg.output_isatty()):
            msg.output(row)

        if len([r for r in results if r["error"] is not None]) > 0:
            sys.exit(1)
