#! /usr/bin/python3
# -*- coding: utf-8 -*-

# Copyright (C) 2012-2014 by László Nagy
# This file is part of Bear.
#
# Bear is a tool to generate compilation database for clang tooling.
#
# Bear is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Bear is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


""" This module is responsible to capture the compiler invocation of any
build process. The result of that should be a compilation database.

This implementation is using the LD_PRELOAD or DYLD_INSERT_LIBRARIES
mechanisms provided by the dynamic linker. The related library is implemented
in C language and can be found under 'libear' directory.

The 'libear' library is capturing all child process creation and logging the
relevant information about it into separate files in a specified directory.
The input of the library is therefore the output directory which is passed
as an environment variable.

This module implements the build command execution with the 'libear' library
and the post-processing of the output files, which will condensates into a
(might be empty) compilation database. """


import argparse
import logging
import functools
import subprocess
import json
import sys
import os
import os.path
import re
import shlex
import itertools


ENVIRONMENTS = [("ENV_OUTPUT", "BEAR_OUTPUT")]

if not '':
    ENVIRONMENTS.extend([("ENV_PRELOAD", "LD_PRELOAD")])
else:
    ENVIRONMENTS.extend([("ENV_PRELOAD", "DYLD_INSERT_LIBRARIES"),
                          ("ENV_FLAT", "DYLD_FORCE_FLAT_NAMESPACE")])


def main():
    try:
        parser = create_parser()
        args = parser.parse_args()

        logging.basicConfig(format='bear: %(message)s')
        logging.getLogger().setLevel(to_logging_level(args.verbose))
        logging.debug(args)

        if not args.build:
            parser.print_help()
            return 0

        return run(args)
    except KeyboardInterrupt:
        return 1
    except Exception as exception:
        logging.exception("Something unexpected had happened.")
        return 127


def run(args):

    def post_processing(commands):
        # run post processing only if that was requested
        if not args.raw_entries:
            # create entries from the current run
            current = itertools.chain.from_iterable(
                # creates a sequence of entry generators from an exec,
                # but filter out non compiler calls before.
                (format_entry(x) for x in commands if compiler_call(x)))
            # read entries from previous run
            if args.append and os.path.exists(args.cdb):
                with open(args.cdb) as handle:
                    previous = iter(json.load(handle))
            else:
                previous = iter([])
            # filter out duplicate entries from both
            duplicate = duplicate_check(entry_hash)
            return (entry
                    for entry in itertools.chain(previous, current)
                    if os.path.exists(entry['file']) and not duplicate(entry))
        return commands

    with TemporaryDirectory(prefix='bear-', dir=tempdir()) as tmpdir:
        # run the build command
        exit_code = run_build(args.build, tmpdir, args.libear)
        logging.debug('build finished with exit code: {0}'.format(exit_code))
        # read the intercepted exec calls
        commands = (parse_exec_trace(os.path.join(tmpdir, filename))
                    for filename
                    in sorted(os.listdir(tmpdir)))
        # do post processing
        entries = post_processing(itertools.chain.from_iterable(commands))
        # dump the compilation database
        with open(args.cdb, 'w+') as handle:
            json.dump(list(entries), handle, sort_keys=True, indent=4)
        return exit_code


def run_build(command, destination, libear):
    """ Runs the original build command.

    It sets the required environment variables and execute the given command.
    The exec calls will be logged by the 'libear' preloaded library. """

    environment = dict(os.environ)
    for alias, key in ENVIRONMENTS:
        value = '1'
        if alias == 'ENV_PRELOAD':
            value = libear
        elif alias == 'ENV_OUTPUT':
            value = destination
        environment.update({key: value})

    return subprocess.call(command, env=environment)


def parse_exec_trace(filename):
    """ Parse the file generated by the 'libear' preloaded library. """
    GS = chr(0x1d)
    RS = chr(0x1e)
    US = chr(0x1f)
    with open(filename, 'r') as handler:
        content = handler.read()
        for group in filter(bool, content.split(GS)):
            records = group.split(RS)
            yield {'pid': records[0],
                   'ppid': records[1],
                   'function': records[2],
                   'directory': records[3],
                   'command': records[4].split(US)[:-1]}


def format_entry(entry):
    """ Generate the desired fields for compilation database entries. """
    def join_command(args):
        return ' '.join([shell_escape(arg) for arg in args])

    def abspath(cwd, name):
        """ Create normalized absolute path from input filename. """
        fullname = name if os.path.isabs(name) else os.path.join(cwd, name)
        return os.path.normpath(fullname)

    atoms = classify_parameters(entry['command'])
    if atoms['action'] <= Action.Compile:
        for filename in atoms.get('files', []):
            if is_source_file(filename):
                yield {'directory': entry['directory'],
                       'command': join_command(entry['command']),
                       'file': abspath(entry['directory'], filename)}


def shell_escape(arg):
    """ Create a single string from list.

    The major challenge, to deal with white spaces. Which are used by
    the shell as separator. (Eg.: -D_KEY="Value with spaces") """
    def quote(arg):
        table = {'\\': '\\\\', '"': '\\"', "'": "\\'"}
        return '"' + ''.join([table.get(c, c) for c in arg]) + '"'

    return quote(arg) if len(shlex.split(arg)) > 1 else arg


def is_source_file(filename):
    """ A predicate to decide the filename is a source file or not. """
    accepted = {'.c', '.C', '.cc', '.CC', '.cxx', '.cp', '.cpp', '.c++',
                '.m', '.mm',
                '.i', '.ii', '.mii'}
    __, ext = os.path.splitext(filename)
    return ext in accepted


def compiler_call(entry):
    """ A predicate to decide the entry is a compiler call or not. """
    patterns = [
        re.compile(r'^([^/]*/)*c(c|\+\+)$'),
        re.compile(r'^([^/]*/)*([^-]*-)*g(cc|\+\+)(-\d+(\.\d+){0,2})?$'),
        re.compile(r'^([^/]*/)*([^-]*-)*clang(\+\+)?(-\d+(\.\d+){0,2})?$'),
        re.compile(r'^([^/]*/)*llvm-g(cc|\+\+)$'),
    ]
    executable = entry['command'][0]
    return any((pattern.match(executable) for pattern in patterns))


def entry_hash(entry):
    """ Implement unique hash method for compilation database entries. """

    # For faster lookup in set filename is reverted
    filename = entry['file'][::-1]
    # For faster lookup in set directory is reverted
    directory = entry['directory'][::-1]
    # On OS X the 'cc' and 'c++' compilers are wrappers for
    # 'clang' therefore both call would be logged. To avoid
    # this the hash does not contain the first word of the
    # command.
    command = ' '.join(shlex.split(entry['command'])[1:])

    return '<>'.join([filename, directory, command])


if sys.version_info.major >= 3 and sys.version_info.minor >= 2:
    from tempfile import TemporaryDirectory
else:
    class TemporaryDirectory(object):
        """ This function creates a temporary directory using mkdtemp() (the
        supplied arguments are passed directly to the underlying function).
        The resulting object can be used as a context manager. On completion
        of the context or destruction of the temporary directory object the
        newly created temporary directory and all its contents are removed
        from the filesystem. """
        def __init__(self, **kwargs):
            from tempfile import mkdtemp
            self.name = mkdtemp(**kwargs)

        def __enter__(self):
            return self.name

        def __exit__(self, _type, _value, _traceback):
            self.cleanup()

        def cleanup(self):
            from shutil import rmtree
            if self.name is not None:
                rmtree(self.name)


def duplicate_check(method):
    """ Predicate to detect duplicated entries.

    Unique hash method can be use to detect duplicates. Entries are
    represented as dictionaries, which has no default hash method.
    This implementation uses a set datatype to store the unique hash values.

    This method returns a method which can detect the duplicate values.
    """
    def predicate(entry):
        entry_hash = predicate.unique(entry)
        if entry_hash not in predicate.state:
            predicate.state.add(entry_hash)
            return False
        return True

    predicate.unique = method
    predicate.state = set()
    return predicate


def tempdir():
    """ Return the defatul temorary directory. """
    from os import getenv
    return getenv('TMPDIR', getenv('TEMP', getenv('TMP', '/tmp')))


def create_parser():
    """ Parser factory method. """

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        '--version',
        action='version',
        version='%(prog)s 2.0.4')
    parser.add_argument(
        '--verbose', '-v',
        action='count',
        default=0,
        help="""enable verbose output from '%(prog)s'. A second '-v' increases
                verbosity.""")
    parser.add_argument(
        '--cdb', '-o',
        metavar='<file>',
        default="compile_commands.json",
        help="""The JSON compilation database.""")
    parser.add_argument(
        '--append', '-a',
        action='store_true',
        help="""appends new entries to existing compilation database.""")

    testing = parser.add_argument_group('advanced options')
    testing.add_argument(
        '--disable-filter', '-n',
        dest='raw_entries',
        action='store_true',
        help="""disable filter, unformated output.""")
    testing.add_argument(
        '--libear', '-l',
        dest='libear',
        default="/usr/lib/arm-linux-gnueabihf/bear/libear.so",
        action='store',
        help="""specify libear file location.""")

    parser.add_argument(
        dest='build',
        nargs=argparse.REMAINDER,
        help="""command to run.""")

    return parser


def to_logging_level(num):
    """ Convert the count of verbose flags to logging level. """
    if 0 == num:
        return logging.WARNING
    elif 1 == num:
        return logging.INFO
    else:
        return logging.DEBUG


class Action(object):
    """ Enumeration class for compiler action. """
    Link, Compile, Preprocess, Info = range(4)


def classify_parameters(command):
    """ Parses the command line arguments of the given invocation.

    To run analysis from a compilation command, first it disassembles the
    compilation command. Classifies the parameters into groups and throws
    away those which are not relevant.
    """
    def match(state, iterator):
        """ This method contains a list of pattern and action tuples.
            The matching start from the top if the list, when the first
            match happens the action is executed.
        """
        def regex(pattern, action):
            regexp = re.compile(pattern)

            def evaluate(iterator):
                match = regexp.match(iterator.current())
                if match:
                    action(state, iterator, match)
                    return True
            return evaluate

        def anyof(opts, action):
            def evaluate(iterator):
                if iterator.current() in opts:
                    action(state, iterator, None)
                    return True
            return evaluate

        tasks = [
            #
            regex(r'^-(E|MM?)$', take_action(Action.Preprocess)),
            anyof({'-c'}, take_action(Action.Compile)),
            anyof({'-print-prog-name'}, take_action(Action.Info)),
            anyof({'-cc1'}, take_action(Action.Info)),
            #
            anyof({'-arch'}, take_two('archs_seen')),
            #
            anyof({'-filelist'}, take_from_file('files')),
            regex(r'^[^-].+', take_one('files')),
            #
            anyof({'-x'}, take_second('language')),
            #
            anyof({'-o'}, take_second('output')),
            #
            anyof({'-write-strings',
                   '-v'}, take_one('compile_options')),
            anyof({'-ftrapv-handler',
                   '--sysroot',
                   '-target'}, take_two('compile_options')),
            regex(r'^-isysroot', take_two('compile_options')),
            regex(r'^-m(32|64)$', take_one('compile_options')),
            regex(r'^-mios-simulator-version-min(.*)',
                  take_joined('compile_options')),
            regex(r'^-stdlib(.*)', take_joined('compile_options')),
            regex(r'^-mmacosx-version-min(.*)',
                  take_joined('compile_options')),
            regex(r'^-miphoneos-version-min(.*)',
                  take_joined('compile_options')),
            regex(r'^-O[1-3]$', take_one('compile_options')),
            anyof({'-O'}, take_as('-O1', 'compile_options')),
            anyof({'-Os'}, take_as('-O2', 'compile_options')),
            regex(r'^-[DIU](.*)$', take_joined('compile_options')),
            anyof({'-nostdinc'}, take_one('compile_options')),
            regex(r'^-std=', take_one('compile_options')),
            regex(r'^-include', take_two('compile_options')),
            anyof({'-idirafter',
                   '-imacros',
                   '-iprefix',
                   '-isystem',
                   '-iwithprefix',
                   '-iwithprefixbefore'}, take_two('compile_options')),
            regex(r'^-m.*', take_one('compile_options')),
            regex(r'^-iquote(.*)', take_joined('compile_options')),
            regex(r'^-Wno-', take_one('compile_options')),
            # ignore
            regex(r'^-framework$', take_two()),
            regex(r'^-fobjc-link-runtime(.*)', take_joined()),
            regex(r'^-[lL]', take_one()),
            regex(r'^-M[TF]$', take_two()),
            regex(r'^-[eu]$', take_two()),
            anyof({'-fsyntax-only',
                   '-save-temps'}, take_one()),
            anyof({'-install_name',
                   '-exported_symbols_list',
                   '-current_version',
                   '-compatibility_version',
                   '-init',
                   '-seg1addr',
                   '-bundle_loader',
                   '-multiply_defined',
                   '--param',
                   '--serialize-diagnostics'}, take_two()),
            anyof({'-sectorder'}, take_four()),
            #
            regex(r'^-[fF](.+)$', take_one('compile_options'))
        ]
        for task in tasks:
            if task(iterator):
                return

    def take_n(count=1, *keys):
        def take(values, iterator, _match):
            updates = []
            updates.append(iterator.current())
            for _ in range(count - 1):
                updates.append(iterator.next())
            for key in keys:
                current = values.get(key, [])
                values.update({key: current + updates})
        return take

    def take_one(*keys):
        return take_n(1, *keys)

    def take_two(*keys):
        return take_n(2, *keys)

    def take_four(*keys):
        return take_n(4, *keys)

    def take_joined(*keys):
        def take(values, iterator, match):
            updates = []
            updates.append(iterator.current())
            if not match.group(1):
                updates.append(iterator.next())
            for key in keys:
                current = values.get(key, [])
                values.update({key: current + updates})
        return take

    def take_from_file(*keys):
        def take(values, iterator, _match):
            with open(iterator.next()) as handle:
                current = [line.strip() for line in handle.readlines()]
                for key in keys:
                    values[key] = current
        return take

    def take_as(value, *keys):
        def take(values, _iterator, _match):
            updates = [value]
            for key in keys:
                current = values.get(key, [])
                values.update({key: current + updates})
        return take

    def take_second(*keys):
        def take(values, iterator, _match):
            current = iterator.next()
            for key in keys:
                values[key] = current
        return take

    def take_action(action):
        def take(values, _iterator, _match):
            key = 'action'
            current = values[key]
            values[key] = max(current, action)
        return take

    state = {'action': Action.Link,
             'cxx': _is_cplusplus_compiler(command[0])}

    arguments = Arguments(command)
    for _ in arguments:
        match(state, arguments)
    return state


class Arguments(object):
    """ An iterator wraper around compiler arguments.

    Python iterators are only implement the 'next' method, but this one
    implements the 'current' query method as well.
    """
    def __init__(self, args):
        """ Takes the full command line, but iterates on the parameters only.
        """
        self.__sequence = args[1:]
        self.__size = len(self.__sequence)
        self.__current = -1

    def __iter__(self):
        """ Needed for python iterator.
        """
        return self

    def __next__(self):
        """ Needed for python iterator. (version 3.x)
        """
        return self.next()

    def next(self):
        """ Needed for python iterator. (version 2.x)
        """
        self.__current += 1
        return self.current()

    def current(self):
        """ Extra method to query the current element.
        """
        if self.__current >= self.__size:
            raise StopIteration
        else:
            return self.__sequence[self.__current]


def _is_cplusplus_compiler(name):
    """ Returns true when the compiler name refer to a C++ compiler.
    """
    match = re.match(r'^([^/]*/)*(\w*-)*(\w+\+\+)(-(\d+(\.\d+){0,3}))?$', name)
    return False if match is None else True


if __name__ == "__main__":
    sys.exit(main())
