#!/usr/bin/python

# Audio Tools, a module and set of tools for manipulating audio data
# Copyright (C) 2007-2014  Brian Langenberger

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA


import sys
import os
import os.path
import operator
import audiotools
import audiotools.text as _


def cmp_files(progress, audiofile1, audiofile2):
    """Returns (path1, path2, mismatch) tuple

    where mismatch is the int of the first PCM mismatch,
    None if the files match exactly or
    a negative value if some error occurs."""

    try:
        if (os.path.samefile(audiofile1.filename, audiofile2.filename)):
            return (audiofile1.filename,
                    audiofile2.filename,
                    None)
        elif ((audiofile1.sample_rate() != audiofile2.sample_rate()) or
              (audiofile1.bits_per_sample() != audiofile2.bits_per_sample()) or
              (audiofile1.channels() != audiofile2.channels())):
            return (audiofile1.filename,
                    audiofile2.filename,
                    -1)
        else:
            return (audiofile1.filename,
                    audiofile2.filename,
                    audiotools.pcm_frame_cmp(
                        audiotools.to_pcm_progress(audiofile1, progress),
                        audiofile2.to_pcm()))
    except (IOError, ValueError, KeyboardInterrupt, audiotools.DecodingError):
        return (audiofile1.filename,
                audiofile2.filename,
                -2)


def cmp_result(result, is_tty=False):
    (path1, path2, mismatch) = result

    if (mismatch is None):
        return ((_.LAB_TRACKCMP_CMP %
                 {"file1": audiotools.Filename(path1),
                  "file2": audiotools.Filename(path2)}) +
                u" : " + audiotools.output_text(
                    _.LAB_TRACKCMP_OK,
                    fg_color="green").format(is_tty))
    elif (mismatch >= 0):
        return ((_.LAB_TRACKCMP_CMP %
                 {"file1": audiotools.Filename(path1),
                  "file2": audiotools.Filename(path2)}) +
                u" : " + audiotools.output_text(
                    _.LAB_TRACKCMP_MISMATCH %
                    {"frame_number": mismatch + 1},
                    fg_color="red").format(is_tty))
    elif (mismatch == -1):
        return ((_.LAB_TRACKCMP_CMP %
                 {"file1": audiotools.Filename(path1),
                  "file2": audiotools.Filename(path2)}) +
                u" : " + audiotools.output_text(
                    _.LAB_TRACKCMP_PARAM_MISMATCH,
                    fg_color="red").format(is_tty))
    else:
        return ((_.LAB_TRACKCMP_CMP %
                 {"file1": audiotools.Filename(path1),
                  "file2": audiotools.Filename(path2)}) +
                u" : " + audiotools.output_text(
                    _.LAB_TRACKCMP_ERROR,
                    fg_color="red").format(is_tty))


def cmp_result_tty(result):
    return cmp_result(result, is_tty=True)


def image_compare(progress, image_audiofile, track_audiofile,
                  image_filename, track_filename,
                  pcm_frames_offset, total_pcm_frames):
    image_pcmreader = image_audiofile.to_pcm()

    # if image_pcmreader has seek(),
    # use it to reduce the amount of frames to skip
    if (hasattr(image_pcmreader, "seek") and callable(image_pcmreader.seek)):
        pcm_frames_offset -= image_pcmreader.seek(pcm_frames_offset)

    try:
        return (
            audiotools.pcm_frame_cmp(
                audiotools.PCMReaderWindow(image_pcmreader,
                                           pcm_frames_offset,
                                           total_pcm_frames),
                audiotools.PCMReaderProgress(track_audiofile.to_pcm(),
                                             total_pcm_frames,
                                             progress)),
            unicode(image_filename),
            unicode(track_filename))
    except (IOError, ValueError, KeyboardInterrupt, audiotools.DecodingError):
        return (-2, unicode(image_filename), unicode(track_filename))


def image_compare_raw(progress, source_filename,
                      sample_rate, channels, channel_mask, bits_per_sample,
                      track_audiofile, image_filename, track_filename,
                      pcm_frames_offset, total_pcm_frames):
    f = open(source_filename, "rb")
    try:
        # skip initial offset
        f.seek(pcm_frames_offset * channels * (bits_per_sample // 8))

        return (
            audiotools.pcm_frame_cmp(
                audiotools.PCMReaderHead(
                    audiotools.PCMReader(file=f,
                                         sample_rate=sample_rate,
                                         channels=channels,
                                         channel_mask=channel_mask,
                                         bits_per_sample=bits_per_sample),
                    total_pcm_frames),
                audiotools.PCMReaderProgress(track_audiofile.to_pcm(),
                                             total_pcm_frames,
                                             progress)),
            unicode(image_filename),
            unicode(track_filename))
    finally:
        f.close()


def image_compare_results(result, is_tty=False):
    (mismatch, image_name, track_name) = result
    return u"%s : %s" % (_.LAB_TRACKCMP_CMP %
                         {"file1": image_name, "file2": track_name},
                         (audiotools.output_text(
                             _.LAB_TRACKCMP_OK,
                             fg_color="green").format(is_tty) if
                          mismatch is None else
                          audiotools.output_text(
                              _.LAB_TRACKCMP_MISMATCH %
                              {"frame_number": mismatch + 1},
                              fg_color="red").format(is_tty)))


def image_compare_results_tty(result):
    return image_compare_results(result, is_tty=True)


if (__name__ == '__main__'):
    import argparse

    parser = argparse.ArgumentParser(description=_.DESCRIPTION_TRACKCMP)

    parser.add_argument("--version",
                        action="version",
                        version="Python Audio Tools %s" % (audiotools.VERSION))

    parser.add_argument("-V", "--verbose",
                        action="store",
                        dest="verbosity",
                        choices=audiotools.VERBOSITY_LEVELS,
                        default=audiotools.DEFAULT_VERBOSITY,
                        help=_.OPT_VERBOSE)

    parser.add_argument("-j", "--joint",
                        type=int,
                        default=audiotools.MAX_JOBS,
                        dest="max_processes",
                        help=_.OPT_JOINT)

    parser.add_argument("-S", "--no-summary",
                        action="store_true",
                        dest="no_summary",
                        help=_.OPT_NO_SUMMARY)

    parser.add_argument("filename",
                        metavar="PATH",
                        help=_.OPT_INPUT_FILENAME_OR_IMAGE)

    parser.add_argument("filenames",
                        metavar="PATH",
                        nargs="+",
                        help=_.OPT_INPUT_FILENAME_OR_DIR)

    options = parser.parse_args()

    msg = audiotools.Messenger("trackcmp", options)

    args = [options.filename] + options.filenames

    if (options.max_processes < 1):
        msg.error(_.ERR_INVALID_JOINT)
        sys.exit(1)

    check_function = audiotools.pcm_frame_cmp

    if (len(args) == 2):
        if (os.path.isfile(args[0]) and os.path.isfile(args[1])):
            # comparing two files

            audiofiles = audiotools.open_files(args,
                                               messenger=msg,
                                               sorted=False)
            if (len(audiofiles) != 2):
                msg.error(_.ERR_TRACKCMP_TYPE_MISMATCH)
                sys.exit(1)
            else:
                (path1, path2, mismatch) = cmp_files(None,
                                                     audiofiles[0],
                                                     audiofiles[1])
                if (mismatch is not None):
                    msg.output(cmp_result((path1, path2, mismatch),
                                          msg.output_isatty()))
                    sys.exit(1)
        elif (os.path.isdir(args[0]) and os.path.isdir(args[1])):
            # comparing two directories

            to_compare = []
            results = []

            files1 = {f.filename: f for f in
                      audiotools.open_files(
                          [path for path in
                           [os.path.join(args[0], f) for f in
                            os.listdir(args[0])] if os.path.isfile(path)],
                          sorted=False,
                          messenger=msg)}

            files2 = {f.filename: f for f in
                      audiotools.open_files(
                          [path for path in
                           [os.path.join(args[1], f) for f in
                            os.listdir(args[1])] if os.path.isfile(path)],
                          sorted=False,
                          messenger=msg)}

            # first, attempt to match files by their stream characteristics
            streams1 = {}
            streams2 = {}

            for (files, streams) in [(files1, streams1),
                                     (files2, streams2)]:
                for f in files.values():
                    streams.setdefault((f.bits_per_sample(),
                                        f.channels(),
                                        f.sample_rate(),
                                        f.total_frames()), []).append(f)

            # anything with matching specs
            # and only a single possible match per directory
            # is queued for comparison
            for specs in set(streams1.keys()) & set(streams2.keys()):
                if (((len(streams1[specs]) == 1) and
                     (len(streams2[specs]) == 1))):
                    file1 = streams1[specs][0]
                    file2 = streams2[specs][0]

                    # remove matched files from lists
                    del(files1[file1.filename])
                    del(files2[file2.filename])

                    # queue up comparison job
                    to_compare.append((file1, file2))

            # then, attempt to match leftover files by metadata
            # such as album_number and track_number
            metadatas1 = {}
            metadatas2 = {}

            for (files, metadatas) in [(files1, metadatas1),
                                       (files2, metadatas2)]:
                for f in files.values():
                    m = f.get_metadata()
                    if (m is not None):
                        metadatas.setdefault((m.track_number,
                                              m.album_number), []).append(f)
                    else:
                        metadatas.setdefault((None,
                                              None), []).append(f)

            for metadata in set(metadatas1.keys()) & set(metadatas2.keys()):
                if (((len(metadatas1[metadata]) == 1) and
                     (len(metadatas2[metadata]) == 1))):
                    file1 = metadatas1[metadata][0]
                    file2 = metadatas2[metadata][0]

                    # remove matched files from lists
                    del(files1[file1.filename])
                    del(files2[file2.filename])

                    # queue up comparison job
                    to_compare.append((file1, file2))

            # anything left over is marked as a missing file
            for (files, directory) in [(files1, args[1]), (files2, args[0])]:
                for filename in files.keys():
                    msg.info(
                        audiotools.output_text(
                            _.LAB_TRACKCMP_MISSING %
                            {"filename": audiotools.Filename(
                             os.path.basename(filename)),
                             "directory": audiotools.Filename(directory)},
                            fg_color="red").format(msg.info_isatty()))
                    sys.stdout.flush()
                    results.append((filename, None, 0))

            queue = audiotools.ExecProgressQueue(
                audiotools.ProgressDisplay(msg))

            for (track1, track2) in sorted(to_compare,
                                           key=lambda f: f[0].filename):
                queue.execute(
                    function=cmp_files,
                    progress_text=_.LAB_TRACKCMP_CMP %
                    {"file1": audiotools.Filename(track1.filename),
                     "file2": audiotools.Filename(track2.filename)},
                    completion_output=(cmp_result_tty
                                       if msg.output_isatty() else
                                       cmp_result),
                    audiofile1=track1,
                    audiofile2=track2)

            try:
                results.extend(queue.run(options.max_processes))
            except KeyboardInterrupt:
                msg.error(_.ERR_CANCELLED)
                sys.exit(1)
            successes = len([r for r in results if r[2] is None])
            failures = len(results) - successes

            if (not options.no_summary):
                msg.output(_.LAB_TRACKCMP_RESULTS)
                msg.output(u"")

                table = audiotools.output_table()
                row = table.row()
                row.add_column(_.LAB_TRACKCMP_HEADER_SUCCESS, "right")
                row.add_column(u" ")
                row.add_column(_.LAB_TRACKCMP_HEADER_FAILURE, "right")
                row.add_column(u" ")
                row.add_column(_.LAB_TRACKCMP_HEADER_TOTAL, "right")

                table.divider_row([_.DIV, u" ", _.DIV, u" ", _.DIV])

                row = table.row()
                row.add_column(unicode(successes), "right")
                row.add_column(u" ")
                row.add_column(unicode(failures), "right")
                row.add_column(u" ")
                row.add_column(unicode(successes + failures), "right")

                for row in table.format(msg.output_isatty()):
                    msg.output(row)

            if (failures > 0):
                sys.exit(1)
        else:
            # comparison mismatch
            msg.error((_.LAB_TRACKCMP_CMP %
                       {"file1": audiotools.Filename(args[0]),
                        "file2": audiotools.Filename(args[1])}) +
                      u" : " +
                      audiotools.output_text(
                          _.LAB_TRACKCMP_TYPE_MISMATCH,
                          fg_color="red").format(msg.error_isatty()))
            sys.exit(1)
    elif (len(args) > 2):
        # possibly comparing disk image against tracks
        audiofiles = sorted(
            audiotools.open_files(args, messenger=msg, sorted=False),
            key=lambda t: t.total_frames())

        if ((sum([t.total_frames() for t in audiofiles[0:-1]]) !=
             audiofiles[-1].total_frames())):
            msg.usage(_.USAGE_TRACKCMP_CDIMAGE)
            sys.exit(1)

        cd_image = audiofiles[-1]
        tracks = audiofiles[0:-1]

        image_name = audiotools.Filename(cd_image.filename)

        # all tracks should have the same album number and track total
        tracks = audiotools.sorted_tracks(tracks)

        cd_data = audiotools.BufferedPCMReader(cd_image.to_pcm())

        queue = audiotools.ExecProgressQueue(audiotools.ProgressDisplay(msg))

        pcm_offset = 0

        if (cd_image.seekable()):
            for (i, track) in enumerate(tracks):
                track_name = audiotools.Filename(track.filename)
                queue.execute(
                    function=image_compare,
                    progress_text=_.LAB_TRACKCMP_CMP %
                    {"file1": image_name, "file2": track_name},
                    completion_output=(image_compare_results_tty
                                       if msg.output_isatty()
                                       else image_compare_results),
                    image_audiofile=cd_image,
                    track_audiofile=track,
                    image_filename=image_name,
                    track_filename=track_name,
                    pcm_frames_offset=pcm_offset,
                    total_pcm_frames=track.total_frames())

                pcm_offset += track.total_frames()

            try:
                if ({r[0] for r in queue.run(options.max_processes)} !=
                    {None}):
                    sys.exit(1)
            except KeyboardInterrupt:
                msg.error(_.ERR_CANCELLED)
                sys.exit(1)
        else:
            import tempfile

            # if file isn't seekable

            # decode it to a single PCM blob of binary data
            temp_blob = tempfile.NamedTemporaryFile()
            cache_progress = audiotools.SingleProgressDisplay(
                msg, _.LAB_CACHING_FILE)
            try:
                audiotools.transfer_framelist_data(
                    audiotools.PCMReaderProgress(
                        cd_image.to_pcm(),
                        cd_image.total_frames(),
                        cache_progress.update),
                    temp_blob.write)
            except audiotools.DecodingError as err:
                cache_progress.clear_rows()
                msg.error(unicode(err))
                temp_blob.close()
                sys.exit(1)
            except KeyboardInterrupt:
                cache_progress.clear_rows()
                msg.error(_.ERR_CANCELLED)
                temp_blob.close()
                sys.exit(1)

            cache_progress.clear_rows()
            temp_blob.flush()

            # compare the blob using multiple jobs
            for (i, track) in enumerate(tracks):
                track_name = audiotools.Filename(track.filename)
                queue.execute(
                    function=image_compare_raw,
                    progress_text=_.LAB_TRACKCMP_CMP %
                    {"file1": image_name, "file2": track_name},
                    completion_output=(image_compare_results_tty
                                       if msg.output_isatty()
                                       else image_compare_results),
                    source_filename=temp_blob.name,
                    sample_rate=cd_image.sample_rate(),
                    channels=cd_image.channels(),
                    channel_mask=int(cd_image.channel_mask()),
                    bits_per_sample=cd_image.bits_per_sample(),
                    track_audiofile=track,
                    image_filename=image_name,
                    track_filename=track_name,
                    pcm_frames_offset=pcm_offset,
                    total_pcm_frames=track.total_frames())

                pcm_offset += track.total_frames()

            try:
                if ({r[0] for r in queue.run(options.max_processes)} !=
                    {None}):
                    temp_blob.close()
                    sys.exit(1)
            except KeyboardInterrupt:
                msg.error(_.ERR_CANCELLED)
                temp_blob.close()
                sys.exit(1)

            # then delete the blob when finished
            temp_blob.close()
    else:
        msg.usage(_.USAGE_TRACKCMP_FILES)
        sys.exit(1)
