#! /usr/bin/python3
from datetime import datetime, timezone
import sys
from getopt import gnu_getopt, GetoptError
import ldap3

port = 2135
base_dn = 'Mds-Vo-name=local,o=grid'

attributes = [
    'nordugrid-job-submissiontime',
    'nordugrid-job-completiontime',
    'nordugrid-job-status',
]

class StatsBin(object):
    def __init__(self):
        self._counts = {}

    def add(self, job_status):
        if not job_status in self._counts:
            self._counts[job_status] = 1
        else:
            self._counts[job_status] += 1

    def statuses(self):
        return self._counts.keys()

    def status_count(self, job_status):
        return self._counts.get(job_status, 0)

    def report(self, intro):
        sys.stdout.write(intro)
        sys.stdout.write(': ')
        statuses = list(self._counts.keys())
        if statuses == []:
            sys.stdout.write('no jobs.\n')
        else:
            statuses.sort()
            sys.stdout.write(
                ', '.join('%d jobs %s' % (self._counts[st], st)
                          for st in statuses))
            sys.stdout.write('.\n')

def check(host, bin_durations, relevant_status, warn_limits, crit_limits):
    lserver = ldap3.Server(host, port)
    lconn = ldap3.Connection(lserver, auto_bind=True)
    job_def = ldap3.ObjectDef('nordugrid-job', lconn)
    query = '(objectClass=nordugrid-job)'
    reader = ldap3.Reader(lconn, job_def, base_dn, query, attributes=attributes)
    reader.search()
    inq_stats_bin = StatsBin()
    stats = dict((bin_duration, StatsBin()) for bin_duration in bin_durations)
    now = datetime.now(timezone.utc)
    for ent in reader:
        job_completeiontime = ent['nordugrid-job-completiontime'].value
        job_status = ent['nordugrid-job-status'].value.lower()
        if job_completeiontime:
            age = (now - job_completeiontime).total_seconds()
            for bin_duration in bin_durations:
                if age < bin_duration:
                    stats[bin_duration].add(job_status)
        else:
            inq_stats_bin.add(job_status)
    status_code = 0
    status_msg = []
    for d, w, c in zip(bin_durations, warn_limits, crit_limits):
        count = sum(stats[d].status_count(s) for s in relevant_status)
        if count < c:
            status_code = 2
            status_msg.append('%d jobs in %g s is too low' % (count, d))
        elif count < w:
            status_code = max(status_code, 1)
            status_msg.append('%d jobs in %g s is low' % (count, d))
        elif len(status_msg) < 2 and (c or w):
            status_msg.append('%d jobs in %g is ok' % (count, d))
    sys.stdout.write((', '.join(status_msg) or 'ok') + '\n')
    inq_stats_bin.report('Not yet completed')
    for d in bin_durations:
        stats[d].report('The last %6d s' % d)
    sys.exit(status_code)

def main():
    try:
        bin_durations = [300, 3600, 21600, 86400, 259200]
        warn_limits = [0, 1]
        crit_limits = []
        relevant_status = ['finished']
        host = None
        opts, args = gnu_getopt(sys.argv[1:], 'H:T:w:c:s:')
        for opt, arg in opts:
            if opt == '-H':
                host = arg
            elif opt == '-T':
                bin_durations = [float(t) for t in arg.split(',')]
            elif opt == '-s':
                relevant_status = arg.split('+')
            elif opt == '-c':
                crit_limits = [int(c) for c in arg.split(',')]
            elif opt == '-w':
                warn_limits = [int(c) for c in arg.split(',')]
            else:
                assert False
        if host is None:
            raise GetoptError('-H is mandatory.')
        if args != []:
            raise GetoptError('No positional arguments expected.')
        n = len(bin_durations)
        if len(warn_limits) > n:
            raise GetoptError('More warning limits than durations.')
        if len(crit_limits) > n:
            raise GetoptError('More critical limits than durations.')
        warn_limits += [0 for i in range(len(warn_limits), n)]
        crit_limits += [0 for i in range(len(crit_limits), n)]
        check(host, bin_durations, relevant_status, warn_limits, crit_limits)
    except GetoptError as exn:
        sys.stderr.write('%s\n' % exn)
        sys.exit(64)

main()
