#!/usr/bin/env python3

'''
Wrapper script for running the Rumur model checker.

This script is intended to be installed alongside the `rumur` binary from the
Rumur model checker. It can then be used to quickly generate and run a model, as
an alternative to having to run the model generation, compilation and execution
steps manually.
'''

import atexit
import os
import platform
import shutil
import subprocess as sp
import sys
import tempfile
from typing import Optional

def which(cmd: str) -> Optional[str]:
  '''
  Equivalent of shell `which`
  '''
  try:
    return sp.check_output(['which', cmd], stderr=sp.DEVNULL,
      universal_newlines=True).strip()
  except sp.CalledProcessError:
    return None

# C compiler
CC = which(os.environ.get('CC', 'cc'))

def categorise(cc: str) -> str:
  '''
  Determine the vendor of a given C compiler
  '''

  # Create a temporary area to compile a test file
  tmp = tempfile.mkdtemp()

  # Setup the test file
  src = os.path.join(tmp, 'test.c')
  with open(src, 'wt') as f:
    f.write('#include <stdio.h>\n'
            '#include <stdlib.h>\n'
            'int main(void) {\n'
            '#ifdef __clang__\n'
            '  printf("clang\\n");\n'
            '#elif defined(__GNUC__)\n'
            '  printf("gcc\\n");\n'
            '#else\n'
            '  printf("unknown\\n");\n'
            '#endif\n'
            '  return EXIT_SUCCESS;\n'
            '}\n')

  categorisation = 'unknown'

  # Compile it
  aout = os.path.join(tmp, 'a.out')
  cc_proc = sp.run([cc, '-o', aout, src], universal_newlines=True,
    stdout=sp.DEVNULL, stderr=sp.DEVNULL)

  # Run it
  if cc_proc.returncode == 0:
    try:
      categorisation = sp.check_output([aout], universal_newlines=True).strip()
    except sp.CalledProcessError:
      pass

  # Clean up
  shutil.rmtree(tmp)

  return categorisation

def supports(flag: str) -> bool:
  '''check whether the compiler supports a given command line flag'''

  # a trivial program to ask it to compile
  program = 'int main(void) { return 0; }'

  # compile it
  p = sp.run([CC, '-o', os.devnull, '-x', 'c', '-', flag], stderr=sp.DEVNULL,
    input=program.encode('utf-8', 'replace'))

  # check whether compilation succeeded
  return p.returncode == 0

def optimisation_flags() -> [str]:
  '''C compiler optimisation command line options for this platform'''

  flags = ['-O3']

  # optimise code for the current host architecture
  if supports('-march=native'): flags.append('-march=native')

  # optimise code for the current host CPU
  if supports('-mtune=native'): flags.append('-mtune=native')

  # enable link-time optimisation
  if supports('-flto'): flags.append('-flto')

  cc_vendor = categorise(CC)

  # allow GCC to perform more advanced interprocedural optimisations
  if cc_vendor == 'gcc': flags.append('-fwhole-program')

  if platform.machine() in ('amd64', 'x86_64') and cc_vendor == 'gcc':
    # GCC needs a personal invitation to use cmpxchg16b
    flags.append('-mcx16')

  return flags

def main(args: [str]) -> int:

  # Find the Rumur binary
  rumur_bin = which('rumur')
  if rumur_bin is None:
    rumur_bin = which(os.path.join(os.path.dirname(__file__), 'rumur'))
  if rumur_bin is None:
    sys.stderr.write('rumur binary not found\n')
    return -1

  # if the user asked for help or version information, run Rumur directly
  for arg in args[1:]:
    if arg.startswith('-h') or arg.startswith('--h') or arg.startswith('--vers'):
      os.execv(rumur_bin, [rumur_bin] + args[1:])

  if CC is None:
    sys.stderr.write('no C compiler found\n')
    return -1

  # Generate the checker
  print('Generating the checker...')
  rumur_proc = sp.run([rumur_bin] + args[1:] + ['--output', '/dev/stdout'],
    stdin=sp.PIPE, stdout=sp.PIPE)
  if rumur_proc.returncode != 0:
    return rumur_proc.returncode
  checker_c = rumur_proc.stdout

  ok = True

  # Setup a temporary directory in which to generate the checker
  tmp = tempfile.mkdtemp()
  atexit.register(shutil.rmtree, tmp)

  # Compile the checker
  if ok:
    print('Compiling the checker...')
    aout = os.path.join(tmp, 'a.out')
    argv = [CC, '-std=c11'] + optimisation_flags() + ['-o', aout, '-x', 'c',
      '-', '-lpthread']
    # XXX: these architectures do not have a double-word CAS, so need libatomic
    # support
    if platform.machine() in ('mips', 'mips64', 'ppc', 'ppc64', 's390x',
        'riscv', 'riscv32', 'riscv64'):
      argv.append('-latomic')
    cc_proc = sp.run(argv, input=checker_c)
    ok &= cc_proc.returncode == 0

  # Run the checker
  if ok:
    print('Running the checker...')
    checker_proc = sp.run([aout])
    ok &= checker_proc.returncode == 0

  return 0 if ok else -1

if __name__ == '__main__':
  try:
    sys.exit(main(sys.argv))
  except KeyboardInterrupt:
    sys.exit(130)
