#!/usr/bin/env python
import os, re, sys

USAGE = """
Usage:

    {program} add NAME

        Add a new test with the given NAME in the test directory and then
        update the list of tests in Main.hs (warning: discards manual edits).

    {program} update

        Update the list of tests in Main.hs (warning: discards manual edits).

"""[1:]

LIBRARY = "System.Directory"
TEST_DIR = "tests"
TEST_EXT = ".hs"
TEST_TEMPLATE = """
{{-# LANGUAGE CPP #-}}
module {name} where
#include "util.inl"
import {library}

main :: TestEnv -> IO ()
main _t = do

"""[1:]
MAIN_NAME = "Main"
MAIN_TEMPLATE = """
module Main (main) where
import qualified Util as T
{imports}
main :: IO ()
main = T.testMain $ \ _t -> do
{runs}
"""[1:-1]
MAIN_IMPORT_TEMPLATE = "import qualified {name}\n"
MAIN_RUN_TEMPLATE = '  T.isolatedRun _t "{name}" {name}.main\n'
BLACKLIST = "^(Main|.*Util.*)$"

program = os.path.basename(sys.argv[0])

def usage():
    sys.stderr.write(USAGE.format(program=program))
    sys.exit(2)

def add(name):
    if not name[0].isupper():
        sys.stderr.write("{0}: must start with a capital letter: {1}\n"
                         .format(program, name))
        sys.exit(1)
    filename = os.path.join(TEST_DIR, name + TEST_EXT)
    if os.path.exists(filename):
        sys.stderr.write("{0}: test already exists: {1}\n"
                         .format(program, filename))
        sys.exit(1)
    with open(filename, "wb") as file:
        file.write(TEST_TEMPLATE.format(name=name, library=LIBRARY)
                   .encode("utf8"))
    update()
    print("{0}: test added: {1}".format(program, filename))

def update():
    tests = []
    for basename in os.listdir(TEST_DIR):
        name, ext = os.path.splitext(basename)
        if (ext == TEST_EXT and
            len(name) > 0 and
            not re.search(BLACKLIST, name) and
            name[0].isupper()):
            tests.append(name)
    tests.sort()
    with open(os.path.join(TEST_DIR, MAIN_NAME + TEST_EXT), "wb") as file:
        file.write(MAIN_TEMPLATE.format(
            imports="".join(MAIN_IMPORT_TEMPLATE.format(name=name)
                            for name in tests),
            runs="".join(MAIN_RUN_TEMPLATE.format(name=name)
                         for name in tests),
        ).encode("utf8"))

if len(sys.argv) < 2:
    usage()

command = sys.argv[1]
if command == "add":
    if len(sys.argv) > 3:
        usage()
    add(sys.argv[2])
elif command == "update":
    if len(sys.argv) > 2:
        usage()
    update()
