#!/usr/bin/env python
#
# Copyright (C) 2012-2013 Fanout, Inc.
#
# This file is part of Pushpin.
#
# Pushpin is free software: you can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Pushpin is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import sys
import os
import ConfigParser

config_file = "/etc/pushpin/pushpin.conf"
log_file = None
verbose = False
for arg in sys.argv:
	if arg.startswith("--config="):
		config_file = arg[9:]
	elif arg.startswith("--logfile="):
		log_file = arg[10:]
	elif arg.startswith("--verbose"):
		verbose = True

config = ConfigParser.ConfigParser()
config.read([config_file])

libdir = None
if config.has_option("global", "libdir"):
	libdir = config.get("global", "libdir")

if libdir:
	sys.path.insert(0, os.path.join(libdir, "handler"))

import time
import threading
import json
import logging
from logging.handlers import WatchedFileHandler
from base64 import b64decode
from setproctitle import setproctitle
import zmq
import tnetstring
import httpinterface
from validation import validate_publish, validate_http_publish, ValidationError
from conversion import ensure_utf8, convert_json_transport
from statusreasons import get_reason

setproctitle("pushpin-handler")

# reopen stdout file descriptor with write mode
# and 0 as the buffer size (unbuffered)
sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0)

logger = logging.getLogger('handler')
if log_file:
	logger_handler = WatchedFileHandler(log_file)
else:
	logger_handler = logging.StreamHandler(stream=sys.stdout)
if verbose:
	logger.setLevel(logging.DEBUG)
	logger_handler.setLevel(logging.DEBUG)
else:
	logger.setLevel(logging.INFO)
	logger_handler.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(levelname)s %(asctime)s.%(msecs)03d %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
logger_handler.setFormatter(formatter)
logger.addHandler(logger_handler)

instance_id = "pushpin-handler_%d" % os.getpid()

m2a_in_stream_specs = config.get("handler", "m2a_in_stream_specs").split(",")
m2a_out_specs = config.get("handler", "m2a_out_specs").split(",")

if config.has_option("handler", "share_all") and config.get("handler", "share_all") == "true":
	share_all = True
else:
	share_all = False

if config.has_option("handler", "stats_spec"):
	stats_spec = config.get("handler", "stats_spec")
else:
	stats_spec = None

ctx = zmq.Context()

class Hold(object):
	def __init__(self, rid, request, mode, response, auto_cross_origin, jsonp_callback):
		self.rid = rid
		self.request = request
		self.mode = mode
		self.response = response
		self.auto_cross_origin = auto_cross_origin
		self.jsonp_callback = jsonp_callback
		self.expire_time = None
		self.last_keepalive = None

class Subscription(object):
	def __init__(self, mode, channel):
		self.mode = mode
		self.channel = channel
		self.expire_time = None
		self.last_keepalive = None

lock = threading.Lock()
response_channels = dict()
response_lastids = dict()
stream_channels = dict()

# key=(type, channel)
subs = dict()

def header_get(headers, name):
	lname = name.lower()
	if isinstance(headers, list):
		for i in headers:
			if i[0].lower() == lname:
				return i[1]
	else:
		for k, v in headers.iteritems():
			if k.lower() == lname:
				return v
	return None

def header_remove(headers, name):
	lname = name.lower()
	if isinstance(headers, list):
		for n, i in enumerate(headers):
			if i[0].lower() == lname:
				del headers[n]
				break
	else:
		for k in headers.keys():
			if k.lower() == lname:
				del headers[k]
				break

def header_set(headers, name, value):
	header_remove(headers, name)
	headers[name] = value

def header_names_contains(header_names, name):
	lname = name.lower()
	for i in header_names:
		if i.lower() == lname:
			return True
	return False

HTTP_FORMAT = "HTTP/1.1 %(code)s %(status)s\r\n%(headers)s\r\n\r\n%(body)s"
HTTP_FORMAT_NOHEADERS = "HTTP/1.1 %(code)s %(status)s\r\n\r\n%(body)s"

def http_response(body, code, status, headers):
	payload = {"code": code, "status": status, "body": body}
	header_set(headers, "Content-Length", str(len(body)))
	payload["headers"] = "\r\n".join("%s: %s" % (k, v) for k, v in
		headers.items())

	return HTTP_FORMAT % payload

def http_response_nolen(body, code, status, headers):
	payload = {"code": code, "status": status, "body": body}
	header_remove(headers, "Content-Length")

	if len(headers) > 0:
		payload["headers"] = "\r\n".join("%s: %s" % (k, v) for k, v in
			headers.items())

		return HTTP_FORMAT % payload
	else:
		return HTTP_FORMAT_NOHEADERS % payload

def reply_http_old(sock, rid, code, status, headers, body, nolen=False):
	header = "%s %d:%s," % (rid[0], len(rid[1]), rid[1])

	if isinstance(status, unicode):
		status = status.encode("utf-8")

	# ensure headers are utf-8
	tmp = dict()
	for k, v in headers.iteritems():
		if isinstance(k, unicode):
				k = k.encode("utf-8")
		if isinstance(v, unicode):
				v = v.encode("utf-8")
		tmp[k] = v
	headers = tmp

	if isinstance(body, unicode):
		body = body.encode("utf-8")

	if nolen:
		msg = http_response_nolen(body, code, status, headers)
	else:
		msg = http_response(body, code, status, headers)
	m_raw = header + " " + msg
	logger.debug("OUT publish: %s" % m_raw)
	sock.send(m_raw)

def reply_http_chunk_old(sock, rid, content):
	header = "%s %d:%s," % (rid[0], len(rid[1]), rid[1])
	m_raw = header + " " + content
	sock.send(m_raw)

def reply_http(sock, rid, code, reason, headers, body, nolen=False):
	if isinstance(reason, unicode):
		reason = reason.encode("utf-8")

	# ensure headers are utf-8
	tmp = dict()
	for k, v in headers.iteritems():
		if isinstance(k, unicode):
				k = k.encode("utf-8")
		if isinstance(v, unicode):
				v = v.encode("utf-8")
		tmp[k] = v
	headers = tmp

	if isinstance(body, unicode):
		body = body.encode("utf-8")

	if nolen:
		header_remove(headers, "Content-Length")
	else:
		header_set(headers, "Content-Length", str(len(body)))

	out = dict()
	out["from"] = instance_id
	out["id"] = rid[1]
	out["code"] = code
	out["reason"] = reason
	headers_list = list()
	for k, v in headers.iteritems():
		headers_list.append([k, v])
	out["headers"] = headers_list
	if body:
		out["body"] = body
	if nolen:
		out["more"] = True

	m_raw = rid[0] + " T" + tnetstring.dumps(out)
	logger.debug("OUT publish: %s" % m_raw)
	sock.send(m_raw)

def reply_http_chunk(sock, rid, content):
	out = dict()
	out["from"] = instance_id
	out["id"] = rid[1]
	out["body"] = content
	out["more"] = True

	m_raw = rid[0] + " T" + tnetstring.dumps(out)
	logger.debug("OUT publish: %s" % m_raw)
	sock.send(m_raw)

simple_headers = set()
simple_headers.add("Cache-control")
simple_headers.add("Content-Language")
simple_headers.add("Content-Length")
simple_headers.add("Content-Type")
simple_headers.add("Expires")
simple_headers.add("Last-Modified")
simple_headers.add("Pragma")

# modifies response_headers as needed
def apply_cors_headers(request_headers, response_headers):
	if not header_get(response_headers, "Access-Control-Allow-Methods"):
		acr_method = header_get(request_headers, "Access-Control-Request-Method")
		if acr_method:
			header_set(response_headers, "Access-Control-Allow-Methods", acr_method)
		else:
			header_set(response_headers, "Access-Control-Allow-Methods", "OPTIONS, HEAD, GET, POST, PUT, DELETE")

	if not header_get(response_headers, "Access-Control-Allow-Headers"):
		acr_headers = header_get(request_headers, "Access-Control-Request-Headers")
		allow_headers = list();
		if acr_headers:
			for name in acr_headers.split(","):
				name = name.strip()
				if name:
					allow_headers.append(name)
		if len(allow_headers) > 0:
			header_set(response_headers, "Access-Control-Allow-Headers", ", ".join(allow_headers))

	if not header_get(response_headers, "Access-Control-Expose-Headers"):
		expose_headers = list()
		for name in response_headers.keys():
			lname = name.lower()
			if not header_names_contains(simple_headers, name) and not lname.startswith("access-control-") and not header_names_contains(expose_headers, name):
				expose_headers.append(name)
		if len(expose_headers) > 0:
			header_set(response_headers, "Access-Control-Expose-Headers", ", ".join(expose_headers))

	if not header_get(response_headers, "Access-Control-Allow-Credentials"):
		header_set(response_headers, "Access-Control-Allow-Credentials", "true")

	if not header_get(response_headers, "Access-Control-Allow-Origin"):
		origin = header_get(request_headers, "Origin")
		if not origin:
			origin = "*"
		header_set(response_headers, "Access-Control-Allow-Origin", origin)

def inspect_worker():
	sock = ctx.socket(zmq.REP)
	sock.connect(config.get("handler", "proxy_inspect_spec"))

	while True:
		m_raw = sock.recv()
		m = tnetstring.loads(m_raw)
		logger.debug("IN inspect: %s" % m)

		# reply saying to always proxy
		id = m["id"]
		method = m["method"]
		uri = m["uri"]
		m = dict()
		m["id"] = id
		m["no-proxy"] = False

		if share_all:
			m["sharing-key"] = method + '|' + uri

		logger.debug("OUT inspect: %s" % m)
		m_raw = tnetstring.dumps(m)
		sock.send(m_raw)

	sock.close()

def accept_worker():
	sock = ctx.socket(zmq.PULL)
	sock.connect(config.get("handler", "proxy_accept_in_spec"))

	out_sock = ctx.socket(zmq.PUB)
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	retry_sock = ctx.socket(zmq.PUSH)
	retry_sock.connect(config.get("handler", "proxy_retry_out_spec"))

	stats_sock = ctx.socket(zmq.PUSH)
	stats_sock.connect('inproc://stats_in')

	while True:
		m_raw = sock.recv()
		m = tnetstring.loads(m_raw)
		logger.debug("IN accept: %s" % m)

		try:
			instruct = json.loads(m["response"]["body"])
			hold = instruct["hold"]
			mode = hold.get("mode")
			if mode is None:
				mode = "response"
			if mode != "response" and mode != "stream":
				raise ValueError("bad mode")
			channel = hold["channels"][0]["name"]
			prev_id = hold["channels"][0].get("prev-id")
			if "channel-prefix" in m:
				channel = m["channel-prefix"] + channel
			if "timeout" in hold:
				timeout = int(hold["timeout"])
			else:
				timeout = 55
			response = instruct.get("response")
			if response is None:
				response = dict()
				response["body"] = ""
			if "headers" in response and isinstance(response["headers"], list):
				d = dict()
				for i in response["headers"]:
					d[i[0]] = i[1]
				response["headers"] = d
			if "body-bin" in response:
				response["body"] = b64decode(response["body-bin"])
				del response["body-bin"]
			elif "body" in response:
				response["body"] = response["body"].encode("utf-8")
			else:
				response["body"] = ""
		except:
			logger.debug("failed to parse accept instructions")
			continue

		reqs = m["requests"]
		logger.debug("accepting %d requests" % len(reqs))

		for req in reqs:
			rid = (req["rid"]["sender"], req["rid"]["id"])

			h = Hold(rid, m["request-data"], mode, response, req.get("auto-cross-origin"), req.get("jsonp-callback"))
			now = int(time.time())
			h.last_keepalive = now

			notify_sub = False
			if mode == "response":
				logger.debug("adding response hold on %s" % channel)
				h.expire_time = int(time.time()) + timeout
				lock.acquire()
				if prev_id is not None:
					last_id = response_lastids.get(channel)
					if last_id is not None and last_id != prev_id:
						del response_lastids[channel]
						lock.release()
						# note: we don't need to do a handoff here because we didn't ack to take over yet
						logger.debug("lastid inconsistency (got=%s, expected=%s), retrying" % (prev_id, last_id))
						r = dict()
						r["requests"] = m["requests"]
						r["request-data"] = m["request-data"]
						if "inspect" in m:
							r["inspect"] = m["inspect"]
						logger.debug("OUT retry: %s" % r)
						r_raw = tnetstring.dumps(r)
						retry_sock.send(r_raw)
						continue
				hchannel = response_channels.get(channel)
				if not hchannel:
					hchannel = dict()
					response_channels[channel] = hchannel
				hchannel[rid] = h
				sub_key = (mode, channel)
				sub = subs.get(sub_key)
				if not sub:
					sub = Subscription(mode, channel)
					sub.last_keepalive = now
					subs[sub_key] = sub
					notify_sub = True
				sub.expire_time = None
				lock.release()

				# ack
				out = dict()
				out['from'] = instance_id
				out['id'] = rid[1]
				out['type'] = 'keep-alive'
				m_raw = rid[0] + ' T' + tnetstring.dumps(out)
				logger.debug('OUT publish (ack response accept): %s' % m_raw)
				out_sock.send(m_raw)
			else: # stream
				# initial reply
				if "code" in response:
					rcode = response["code"]
				else:
					rcode = 200

				if "reason" in response:
					rreason = response["reason"]
				else:
					rreason = get_reason(rcode)

				if "headers" in response:
					rheaders = response["headers"]
				else:
					rheaders = dict()

				reply_http(out_sock, rid, rcode, rreason, rheaders, response.get("body"), True)

				logger.debug("adding stream hold on %s" % channel)

				# bind channel
				lock.acquire()
				hchannel = stream_channels.get(channel)
				if not hchannel:
					hchannel = dict()
					stream_channels[channel] = hchannel
				hchannel[rid] = h
				sub_key = (mode, channel)
				sub = subs.get(sub_key)
				if not sub:
					sub = Subscription(mode, channel)
					sub.last_keepalive = now
					subs[sub_key] = sub
					notify_sub = True
				sub.expire_time = None
				lock.release()

			if notify_sub:
				out = dict()
				out['from'] = instance_id
				out['mode'] = ensure_utf8(mode)
				out['channel'] = ensure_utf8(channel)
				out['ttl'] = 60
				stats_sock.send(tnetstring.dumps(out))

	sock.close()

def push_in_zmq_worker():
	in_sock = ctx.socket(zmq.PULL)
	in_sock.bind(config.get("handler", "push_in_spec"))

	out_sock = ctx.socket(zmq.PUSH)
	out_sock.connect("inproc://push_in")

	while True:
		m_raw = in_sock.recv()
		try:
			try:
				m = tnetstring.loads(m_raw)
			except:
				raise ValidationError("bad format (not a tnetstring)")

			m = validate_publish(m)

		except ValidationError as e:
			logger.debug("warning: %s, dropping" % e)

		out_sock.send(tnetstring.dumps(m))

	out_sock.linger = 0
	out_sock.close()

# return None for success or string on error
def push_in_http_handler(context, m):
	out_sock = context["out_sock"]

	try:
		m = validate_http_publish(m)
	except ValidationError as e:
		return e.message

	for n, i in enumerate(m["items"]):
		out = dict()

		channel = i.get("channel")
		if channel is not None:
			out["channel"] = ensure_utf8(channel)

		id = i.get("id")
		if id is not None:
			out["id"] = ensure_utf8(id)

		prev_id = i.get("prev-id")
		if prev_id is not None:
			out["prev-id"] = ensure_utf8(prev_id)

		for transport in ("http-response", "http-stream"):
			if transport in i:
				out[transport] = convert_json_transport(i[transport])

		out_sock.send(tnetstring.dumps(out))

def push_in_http_worker():
	out_sock = ctx.socket(zmq.PUSH)
	out_sock.connect("inproc://push_in")

	context = dict()
	context["out_sock"] = out_sock
	httpinterface.run(config.get("handler", "push_in_http_addr"), int(config.get("handler", "push_in_http_port")), push_in_http_handler, context)

	out_sock.linger = 0
	out_sock.close()

def push_in_worker(c):
	in_sock = ctx.socket(zmq.PULL)
	in_sock.bind("inproc://push_in")
	c.acquire()
	c.notify()
	c.release()

	out_sock = ctx.socket(zmq.PUB)
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	while True:
		m_raw = in_sock.recv()
		m = tnetstring.loads(m_raw)
		logger.debug("IN publish: %s" % m)
		channel = m["channel"]

		if "http-response" in m:
			lock.acquire()
			hchannel = response_channels.get(channel)
			if hchannel:
				holds = hchannel.values()
				del response_channels[channel]
			else:
				holds = list()
			sub_key = ('response', channel)
			sub = subs.get(sub_key)
			if sub and sub.expire_time is None:
				# flag for deletion soon
				sub.expire_time = int(time.time()) + 60
			item_id = m.get("id")
			if item_id is not None:
				response_lastids[channel] = item_id
			lock.release()
			logger.debug("relaying to %d subscribers" % len(holds))
			http_response = m["http-response"]

			if "code" in http_response:
				pcode = http_response["code"]
			else:
				pcode = 200

			if "reason" in http_response:
				preason = http_response["reason"]
			else:
				preason = get_reason(pcode)

			if "headers" in http_response:
				pheaders = http_response["headers"]
				if isinstance(pheaders, list):
					d = dict()
					for i in pheaders:
						d[i[0]] = i[1]
					pheaders = d
			else:
				pheaders = dict()

			if "body" in http_response:
				pbody = http_response["body"]
			else:
				pbody = ""

			for n, h in enumerate(holds):
				headers = dict()
				if h.jsonp_callback:
					result = dict()
					result["code"] = pcode
					result["reason"] = preason
					result["headers"] = dict()
					if pheaders:
						for k, v in pheaders.iteritems():
							result["headers"][k] = v
					header_set(result["headers"], "Content-Length", str(len(pbody)))
					result["body"] = pbody

					body = h.jsonp_callback + "(" + json.dumps(result) + ");\n"
					header_set(headers, "Content-Type", "application/javascript")
					header_set(headers, "Content-Length", str(len(body)))
					reply_http(out_sock, h.rid, 200, "OK", headers, body)
				else:
					if pheaders:
						for k, v in pheaders.iteritems():
							headers[k] = v

					if h.auto_cross_origin:
						apply_cors_headers(h.request["headers"], headers)

					reply_http(out_sock, h.rid, pcode, preason, headers, pbody)

				if n % 10 == 0:
					time.sleep(0.0005)

		if "http-stream" in m:
			# TODO: support close action and report immediate unsubscribe via stats
			lock.acquire()
			hchannel = stream_channels.get(channel)
			if hchannel:
				holds = hchannel.values()
			else:
				holds = list()
			lock.release()
			logger.debug("relaying to %d subscribers" % len(holds))
			for h in holds:
				content = m["http-stream"]["content"]
				if content:
					reply_http_chunk(out_sock, h.rid, content)

	in_sock.close()

def session_worker():
	in_sock = ctx.socket(zmq.DEALER)
	in_sock.identity = instance_id
	for spec in m2a_in_stream_specs:
		in_sock.connect(spec)

	out_sock = ctx.socket(zmq.PUB)
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	stats_sock = ctx.socket(zmq.PUSH)
	stats_sock.connect('inproc://stats_in')

	while True:
		m_list = in_sock.recv_multipart()
		m = tnetstring.loads(m_list[1][1:])
		logger.debug('IN session: %s' % m)
		mtype = m.get('type')
		notify_unsubs = set()
		if mtype is not None and (mtype == 'error' or mtype == 'cancel'):
			rid = (m['from'], m['id'])
			logger.debug('cleaning up subscriber %s' % repr(rid))
			now = int(time.time())
			lock.acquire()
			channels = set()
			for channel, hchannels in response_channels.iteritems():
				channels.add(channel)
				if rid in hchannels:
					del hchannels[rid]
			for channel in channels:
				if channel in response_channels and len(response_channels[channel]) == 0:
					del response_channels[channel]
					sub_key = ('response', channel)
					sub = subs.get(sub_key)
					if sub and sub.expire_time is None:
						# flag for deletion soon
						sub.expire_time = now + 60
			channels = set()
			for channel, hchannels in stream_channels.iteritems():
				channels.add(channel)
				if rid in hchannels:
					del hchannels[rid]
			for channel in channels:
				if channel in stream_channels and len(stream_channels[channel]) == 0:
					del stream_channels[channel]
					sub_key = ('stream', channel)
					sub = subs.get(sub_key)
					if sub:
						del subs[sub_key]
						notify_unsubs.add(sub_key)
			lock.release()
		elif mtype is not None:
			# is this a known session?
			rid = (m['from'], m['id'])

			lock.acquire()
			found = False
			for hchannels in response_channels.itervalues():
				if rid in hchannels:
					found = True
					break
			if not found:
				for hchannels in stream_channels.itervalues():
					if rid in hchannels:
						found = True
						break
			lock.release()

			if not found:
				# no such session, send cancel
				out = dict()
				out['from'] = instance_id
				out['id'] = m['id']
				out['type'] = 'cancel'
				m_raw = m['from'] + ' T' + tnetstring.dumps(out)
				logger.debug('OUT publish: %s' % m_raw)
				out_sock.send(m_raw)

		for sub_key in notify_unsubs:
			out = dict()
			out['from'] = instance_id
			out['mode'] = ensure_utf8(sub_key[0])
			out['channel'] = ensure_utf8(sub_key[1])
			out['unavailable'] = True
			stats_sock.send(tnetstring.dumps(out))

def timeout_worker():
	out_sock = ctx.socket(zmq.PUB)
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	stats_sock = ctx.socket(zmq.PUSH)
	stats_sock.connect('inproc://stats_in')

	while True:
		now = int(time.time())
		lock.acquire()
		holds = list()
		channels = set()
		for channel, hchannels in response_channels.iteritems():
			channels.add(channel)
			channel_holds = list()
			for h in hchannels.values():
				if h.expire_time and now >= h.expire_time:
					channel_holds.append(h)
			for h in channel_holds:
				del hchannels[h.rid]
			holds.extend(channel_holds)
		for channel in channels:
			if channel in response_channels and len(response_channels[channel]) == 0:
				del response_channels[channel]
				sub_key = ('response', channel)
				sub = subs.get(sub_key)
				if sub and sub.expire_time is None:
					# flag for deletion soon
					sub.expire_time = now + 60
		lock.release()

		if len(holds) > 0:
			logger.debug("timing out %d subscribers" % len(holds))

			for h in holds:
				if "code" in h.response:
					pcode = h.response["code"]
				else:
					pcode = 200

				if "reason" in h.response:
					preason = h.response["reason"]
				else:
					preason = get_reason(pcode)

				if "headers" in h.response:
					pheaders = h.response["headers"]
				else:
					pheaders = dict()

				if "body" in h.response:
					pbody = h.response["body"]
				else:
					pbody = ""

				headers = dict()
				if h.jsonp_callback:
					result = dict()
					result["code"] = pcode
					result["reason"] = preason
					result["headers"] = dict()
					if pheaders:
						for k, v in pheaders.iteritems():
							result["headers"][k] = v
					header_set(result["headers"], "Content-Length", str(len(pbody)))
					result["body"] = pbody

					body = h.jsonp_callback + "(" + json.dumps(result) + ");\n"
					header_set(headers, "Content-Type", "application/javascript")
					header_set(headers, "Content-Length", str(len(body)))
					reply_http(out_sock, h.rid, 200, "OK", headers, body)
				else:
					if pheaders:
						for k, v in pheaders.iteritems():
							headers[k] = v

					if h.auto_cross_origin:
						apply_cors_headers(h.request["headers"], headers)

					reply_http(out_sock, h.rid, pcode, preason, headers, pbody)

		now = int(time.time())
		rids = set()
		lock.acquire()
		for channel, hchannels in response_channels.iteritems():
			for h in hchannels.values():
				if h.last_keepalive is None or h.last_keepalive + 30 < now:
					if h.rid not in rids:
						h.last_keepalive = now
						rids.add(h.rid)
		for channel, hchannels in stream_channels.iteritems():
			for h in hchannels.values():
				if h.last_keepalive is None or h.last_keepalive + 30 < now:
					if h.rid not in rids:
						h.last_keepalive = now
						rids.add(h.rid)
		lock.release()

		if len(rids) > 0:
			logger.debug("keep-aliving %d subscribers" % len(rids))
			for rid in rids:
				out = dict()
				out['from'] = instance_id
				out['id'] = rid[1]
				out['type'] = 'keep-alive'
				m_raw = rid[0] + ' T' + tnetstring.dumps(out)
				logger.debug('OUT publish: %s' % m_raw)
				out_sock.send(m_raw)

		notify_subs = set()
		notify_unsubs = set()
		lock.acquire()
		for sub_key, sub in subs.iteritems():
			if sub.expire_time is not None and now >= sub.expire_time:
				notify_unsubs.add(sub_key)
			elif sub.last_keepalive is None or sub.last_keepalive + 30 < now:
				sub.last_keepalive = now
				notify_subs.add(sub_key)
		for sub_key in notify_unsubs:
			del subs[sub_key]
		lock.release()

		for sub_key in notify_unsubs:
			out = dict()
			out['from'] = instance_id
			out['mode'] = ensure_utf8(sub_key[0])
			out['channel'] = ensure_utf8(sub_key[1])
			out['unavailable'] = True
			stats_sock.send(tnetstring.dumps(out))

		if len(notify_subs) > 0:
			logger.debug('keep-aliving %d subscriptions' % len(notify_subs))
			for sub_key in notify_subs:
				out = dict()
				out['from'] = instance_id
				out['mode'] = ensure_utf8(sub_key[0])
				out['channel'] = ensure_utf8(sub_key[1])
				out['ttl'] = 60
				stats_sock.send(tnetstring.dumps(out))

		time.sleep(1)

def stats_worker(c):
	in_sock = ctx.socket(zmq.PULL)
	in_sock.bind('inproc://stats_in')
	c.acquire()
	c.notify()
	c.release()

	if stats_spec:
		out_sock = ctx.socket(zmq.PUB)
		out_sock.bind(stats_spec)
	else:
		out_sock = None

	while True:
		m = tnetstring.loads(in_sock.recv())
		if out_sock:
			m_raw = 'sub T' + tnetstring.dumps(m)
			logger.debug('OUT stats: %s' % m_raw)
			out_sock.send(m_raw)

inspect_thread = threading.Thread(target=inspect_worker)
inspect_thread.start()

# we use a condition here to ensure the inproc bind succeeds before progressing
c = threading.Condition()
c.acquire()
stats_thread = threading.Thread(target=stats_worker, args=(c,))
stats_thread.daemon = True
stats_thread.start()
c.wait()
c.release()

accept_thread = threading.Thread(target=accept_worker)
accept_thread.start()

# we use a condition here to ensure the inproc bind succeeds before progressing
c = threading.Condition()
c.acquire()
push_in_thread = threading.Thread(target=push_in_worker, args=(c,))
push_in_thread.start()
c.wait()
c.release()

push_in_zmq_thread = threading.Thread(target=push_in_zmq_worker)
push_in_zmq_thread.start()

push_in_http_thread = threading.Thread(target=push_in_http_worker)
push_in_http_thread.start()

session_thread = threading.Thread(target=session_worker)
session_thread.daemon = True
session_thread.start()

timeout_thread = threading.Thread(target=timeout_worker)
timeout_thread.daemon = True
timeout_thread.start()

try:
	while True:
		time.sleep(60)
except KeyboardInterrupt:
	pass

httpinterface.stop()
ctx.term()
