#! /usr/bin/python3
# -*- coding: utf-8 -*-

# Copyright (C) 2012-2014 by László Nagy
# This file is part of Bear.
#
# Bear is a tool to generate compilation database for clang tooling.
#
# Bear is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Bear is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
""" This module is responsible to capture the compiler invocation of any
build process. The result of that should be a compilation database.

This implementation is using the LD_PRELOAD or DYLD_INSERT_LIBRARIES
mechanisms provided by the dynamic linker. The related library is implemented
in C language and can be found under 'libear' directory.

The 'libear' library is capturing all child process creation and logging the
relevant information about it into separate files in a specified directory.
The input of the library is therefore the output directory which is passed
as an environment variable.

This module implements the build command execution with the 'libear' library
and the post-processing of the output files, which will condensates into a
(might be empty) compilation database. """

import argparse
import logging
import functools
import subprocess
import json
import sys
import os
import os.path
import re
import shlex
import itertools


def main():
    try:
        parser = create_parser()
        args = parser.parse_args()

        logging.basicConfig(format='bear: %(message)s')
        logging.getLogger().setLevel(to_logging_level(args.verbose))
        logging.debug(args)

        if not args.build:
            parser.print_help()
            return 0

        return capture(args)
    except KeyboardInterrupt:
        return 1
    except Exception as exception:
        logging.exception("Something unexpected had happened.")
        return 127


def capture(args):
    """ The entry point of build command interception. """

    def post_processing(commands):
        # run post processing only if that was requested
        if not args.raw_entries:
            # create entries from the current run
            current = itertools.chain.from_iterable(
                # creates a sequence of entry generators from an exec,
                # but filter out non compiler calls before.
                (format_entry(x) for x in commands if compiler_call(x)))
            # read entries from previous run
            if args.append and os.path.exists(args.cdb):
                with open(args.cdb) as handle:
                    previous = iter(json.load(handle))
            else:
                previous = iter([])
            # filter out duplicate entries from both
            duplicate = duplicate_check(entry_hash)
            return (entry for entry in itertools.chain(previous, current)
                    if os.path.exists(entry['file']) and not duplicate(entry))
        return commands

    with TemporaryDirectory(prefix='bear-', dir=tempdir()) as tmpdir:
        # run the build command
        environment = setup_environment(tmpdir, args.libear)
        logging.info('run build in environment: %s', environment)
        exit_code = subprocess.call(args.build, env=environment)
        logging.info('build finished with exit code: %s', exit_code)
        # read the intercepted exec calls
        commands = (parse_exec_trace(os.path.join(tmpdir, filename))
                    for filename in sorted(os.listdir(tmpdir)))
        # do post processing
        entries = post_processing(itertools.chain.from_iterable(commands))
        # dump the compilation database
        with open(args.cdb, 'w+') as handle:
            json.dump(list(entries), handle, sort_keys=True, indent=4)
        return exit_code


def setup_environment(destination, ear_library_path):
    """ Sets up the environment for the build command.

    It sets the required environment variables and execute the given command.
    The exec calls will be logged by the 'libear' preloaded library or by the
    'wrapper' programs. """

    environment = dict(os.environ)
    environment.update({'BEAR_OUTPUT': destination})

    if 'darwin' == sys.platform:
        environment.update({
            'DYLD_INSERT_LIBRARIES': ear_library_path,
            'DYLD_FORCE_FLAT_NAMESPACE': '1'
        })
    else:
        environment.update({'LD_PRELOAD': ear_library_path})

    return environment


def parse_exec_trace(filename):
    """ Parse the file generated by the 'libear' preloaded library.

    Given filename points to a file which contains the basic report
    generated by the interception library or wrapper command. A single
    report file _might_ contain multiple process creation info. """

    GS = chr(0x1d)
    RS = chr(0x1e)
    US = chr(0x1f)

    logging.debug('parse exec trace file: %s', filename)
    with open(filename, 'r') as handler:
        content = handler.read()
        for group in filter(bool, content.split(GS)):
            records = group.split(RS)
            yield {
                'pid': records[0],
                'ppid': records[1],
                'function': records[2],
                'directory': records[3],
                'command': records[4].split(US)[:-1]
            }


def format_entry(entry):
    """ Generate the desired fields for compilation database entries. """

    def join_command(args):
        return ' '.join([shell_escape(arg) for arg in args])

    def abspath(cwd, name):
        """ Create normalized absolute path from input filename. """
        fullname = name if os.path.isabs(name) else os.path.join(cwd, name)
        return os.path.normpath(fullname)

    logging.debug('format this command: %s', entry['command'])
    atoms = classify_parameters(entry['command'])
    if atoms['action'] <= Action.Compile:
        for source in atoms['files']:
            compiler = 'c++' if atoms['c++'] else 'cc'
            command = [compiler, '-c'] + atoms['compile_options'] + [source]
            logging.debug('formated as: %s', command)
            yield {
                'directory': entry['directory'],
                'command': join_command(command),
                'file': abspath(entry['directory'], source)
            }


def shell_escape(arg):
    """ Create a single string from list.

    The major challenge, to deal with white spaces. Which are used by
    the shell as separator. (Eg.: -D_KEY="Value with spaces") """

    def quote(arg):
        table = {'\\': '\\\\', '"': '\\"', "'": "\\'"}
        return '"' + ''.join([table.get(c, c) for c in arg]) + '"'

    return quote(arg) if len(shlex.split(arg)) > 1 else arg


def is_source(filename):
    """ A predicate to decide the filename is a source file or not. """

    accepted = {
        '.c', '.cc', '.cp', '.cpp', '.cxx', '.c++', '.m', '.mm', '.i', '.ii',
        '.mii'
    }
    __, ext = os.path.splitext(filename)
    return ext.lower() in accepted


def compiler_call(entry):
    """ A predicate to decide the entry is a compiler call or not. """

    patterns = [
        re.compile(r'^([^/]*/)*c(c|\+\+)$'),
        re.compile(r'^([^/]*/)*([^-]*-)*[mg](cc|\+\+)(-\d+(\.\d+){0,2})?$'),
        re.compile(r'^([^/]*/)*([^-]*-)*clang(\+\+)?(-\d+(\.\d+){0,2})?$'),
        re.compile(r'^([^/]*/)*llvm-g(cc|\+\+)$'),
    ]
    executable = entry['command'][0]
    return any((pattern.match(executable) for pattern in patterns))


def entry_hash(entry):
    """ Implement unique hash method for compilation database entries. """

    # For faster lookup in set filename is reverted
    filename = entry['file'][::-1]
    # For faster lookup in set directory is reverted
    directory = entry['directory'][::-1]
    # On OS X the 'cc' and 'c++' compilers are wrappers for
    # 'clang' therefore both call would be logged. To avoid
    # this the hash does not contain the first word of the
    # command.
    command = ' '.join(shlex.split(entry['command'])[1:])

    return '<>'.join([filename, directory, command])


if sys.version_info.major >= 3 and sys.version_info.minor >= 2:
    from tempfile import TemporaryDirectory
else:

    class TemporaryDirectory(object):
        """ This function creates a temporary directory using mkdtemp() (the
        supplied arguments are passed directly to the underlying function).
        The resulting object can be used as a context manager. On completion
        of the context or destruction of the temporary directory object the
        newly created temporary directory and all its contents are removed
        from the filesystem. """

        def __init__(self, **kwargs):
            from tempfile import mkdtemp
            self.name = mkdtemp(**kwargs)

        def __enter__(self):
            return self.name

        def __exit__(self, _type, _value, _traceback):
            self.cleanup()

        def cleanup(self):
            from shutil import rmtree
            if self.name is not None:
                rmtree(self.name)


def duplicate_check(method):
    """ Predicate to detect duplicated entries.

    Unique hash method can be use to detect duplicates. Entries are
    represented as dictionaries, which has no default hash method.
    This implementation uses a set datatype to store the unique hash values.

    This method returns a method which can detect the duplicate values. """

    def predicate(entry):
        entry_hash = predicate.unique(entry)
        if entry_hash not in predicate.state:
            predicate.state.add(entry_hash)
            return False
        return True

    predicate.unique = method
    predicate.state = set()
    return predicate


def tempdir():
    """ Return the default temorary directory. """

    from os import getenv
    return getenv('TMPDIR', getenv('TEMP', getenv('TMP', '/tmp')))


def create_parser():
    """ Parser factory method. """

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        '--version',
        action='version',
        version='%(prog)s 2.1.1')
    parser.add_argument(
        '--verbose', '-v',
        action='count',
        default=0,
        help="""enable verbose output from '%(prog)s'. A second '-v' increases
                verbosity.""")
    parser.add_argument(
        '--cdb', '-o',
        metavar='<file>',
        default="compile_commands.json",
        help="""The JSON compilation database.""")
    parser.add_argument(
        '--append', '-a',
        action='store_true',
        help="""appends new entries to existing compilation database.""")

    testing = parser.add_argument_group('advanced options')
    testing.add_argument(
        '--disable-filter', '-n',
        dest='raw_entries',
        action='store_true',
        help="""disable filter, unformated output.""")
    testing.add_argument(
        '--libear', '-l',
        dest='libear',
        default="/usr/lib/arm-linux-gnueabihf/bear/libear.so",
        action='store',
        help="""specify libear file location.""")

    parser.add_argument(
        dest='build',
        nargs=argparse.REMAINDER,
        help="""command to run.""")

    return parser


def to_logging_level(num):
    """ Convert the count of verbose flags to logging level. """
    if 0 == num:
        return logging.WARNING
    elif 1 == num:
        return logging.INFO
    else:
        return logging.DEBUG


class Action(object):
    """ Enumeration class for compiler action. """

    Link, Compile, Ignored = range(3)


def classify_parameters(command):
    """ Parses the command line arguments of the given invocation. """

    def take(values, key, iterator):
        current = values.get(key, [])
        values.update({key: current + [iterator]})

    def action(values, value):
        current = values.get('action', value)
        values.update({'action': max(current, value)})

    state = {
        'action': Action.Link,
        'files': [],
        'compile_options': [],
        'c++': cplusplus_compiler(command[0])
    }

    args = iter(command[1:])
    for arg in args:
        # compiler action parameters are the most important ones...
        if arg == '-c':
            action(state, Action.Compile)
        elif arg in {'-E', '-S', '-cc1', '-M', '-MM'}:
            action(state, Action.Ignored)
        # some preprocessor parameters are ignored...
        elif arg in {'-MD', '-MMD', '-MG', '-MP'}:
            pass
        elif arg in {'-MF', '-MT', '-MQ'}:
            next(args)
        # linker options are ignored...
        elif arg in {'-static', '-shared', '-s', '-rdynamic'}:
            pass
        elif re.match(r'^-[lL].+', arg):
            pass
        elif arg in {'-l', '-L', '-u', '-z', '-T', '-Xlinker'}:
            next(args)
        # optimalization and waring options are ignored...
        elif re.match(r'^-[mWO].+', arg):
            pass
        # parameters which looks source file are taken...
        elif re.match(r'^[^-].+', arg) and is_source(arg):
            take(state, 'files', arg)
        # and consider everything else as compile option.
        else:
            take(state, 'compile_options', arg)

    return state


def cplusplus_compiler(name):
    """ Returns true when the compiler name refer to a C++ compiler. """

    match = re.match(r'^([^/]*/)*(\w*-)*(\w+\+\+)(-(\d+(\.\d+){0,3}))?$', name)
    return False if match is None else True


if __name__ == "__main__":
    sys.exit(main())
