#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import with_statement

import os
import sys
import math
import optparse
import webbrowser
import urllib
from time import sleep
from itertools import cycle
from tempfile import NamedTemporaryFile

description = '''Given a list of numbers indicating the number of nodes
in each separate datacenter, outputs a recommended list of tokens to use
with Murmur3Partitioner (by default) or with RandomPartitioner.
The list contains one token for each node in each datacenter.
'''

usage = "%prog <nodes_in_dc1> [<nodes_in_dc2> [...]]"

parser = optparse.OptionParser(description=description, usage=usage)

def part_murmur3(option, opt, value, parser):
    parser.values.ringoffset=-(1<<63)
    parser.values.ringrange=(1<<64)
    return
def part_random(option, opt, value, parser):
    parser.values.ringoffset=0
    parser.values.ringrange=(1<<127)
    return
parser.add_option('--murmur3', action='callback', callback=part_murmur3,
                  help='Generate tokens for Murmur3Partitioner (default).')
parser.add_option('--random', action='callback', callback=part_random,
                  help='Generate tokens for RandomPartitioner.')
parser.add_option('--ringoffset', type='int',
                  help=optparse.SUPPRESS_HELP)
parser.add_option('--ringrange', type='int',
                  help=optparse.SUPPRESS_HELP)

parser.add_option('--graph', action='store_true',
                  help='Show a rendering of the generated tokens as line '
                       'segments in a circle, colored according to datacenter')
parser.add_option('-n', '--nts', action='store_const', dest='strat', const='nts',
                  help=optparse.SUPPRESS_HELP)
parser.add_option('-o', '--onts', action='store_const', dest='strat', const='onts',
                  help=optparse.SUPPRESS_HELP)

parser.add_option('--test', action='store_true',
                  help='Run in test mode, outputting an HTML file to display '
                       'various generated ring arrangements.')

parser.add_option('--html-output', help=optparse.SUPPRESS_HELP)
parser.add_option('--browser-wait-time', type='float', help=optparse.SUPPRESS_HELP)
parser.add_option('--test-colors', help=optparse.SUPPRESS_HELP)
parser.add_option('--test-graphsize', type='int', help=optparse.SUPPRESS_HELP)


parser.set_defaults(
    # default is Murmur3
    ringoffset=-(1<<63),
    ringrange=(1<<64),

    # whether to create (and try to display) graph output
    graph=False,

    # 'nts' or 'onts'; the replication strategy for which to optimize
    strat='nts',

    # durr
    test=False,

    # size of the generated graph
    graphsize=600,

    # where to write generated graph (HTML) output, or '*tmp*' to write a
    # temporary file, and remove it after telling a browser to open it. '-'
    # to write to stdout.
    html_output='*tmp*',

    # how long, in seconds, to wait before cleaning up a temporary html file
    # after telling the browser to open it
    browser_wait_time=5.0,

    # comma-separated list of HTML color codes, used in order to represent
    # respective datacenter nodes
    test_colors='#000,#00F,#0F0,#F00,#0FF,#FF0,#F0F',

    # size of the per-test graphs
    test_graphsize=200,
)

class Ring:
    MIN_DC_OFFSET_DIVIDER = 235
    offset_spacer = 2

    def __init__(self, dc_counts, ringoffset, ringrange, strategy='nts'):
        self.ringoffset = ringoffset
        self.ringrange = ringrange
        self.dc_counts = dc_counts
        self.calculate_offset_tokens = getattr(self, 'calc_offset_tokens_' + strategy)

    def best_per_dc_offset(self):
        """
        Calculate a per-dc offset for NTS DC spacing, such that there is a little
        bit of room between nodes which would otherwise have been at the same token;
        (hopefully) large enough that the difference can show when --graph is used,
        but small enough that it there's no chance of the relative ordering changing.
        """
        lowest_division = len(self.dc_counts) * max(self.dc_counts) * self.offset_spacer
        division = max(lowest_division, self.MIN_DC_OFFSET_DIVIDER)
        return -self.ringrange // division

    def bound_token(self, tok):
        if tok < self.ringoffset:
            tok += self.ringrange
        return tok

    def calc_offset_tokens_nts(self):
        dc_offset = self.best_per_dc_offset()
        dcs = []
        for (dcnum, dccount) in enumerate(self.dc_counts):
            offset = dcnum * dc_offset
            arcsize = self.ringrange // (dccount or 1)
            dcs.append(sorted([self.bound_token((n * arcsize + offset) - self.ringoffset % self.ringrange) for n in xrange(dccount)]))
        return dcs

    def calc_offset_tokens_onts(self):
        dcs_by_count = sorted(enumerate(self.dc_counts), key=lambda d:d[1], reverse=True)
        biggest = dcs_by_count[0][1]
        nodes = [dcnum for (dcnum, dccount) in dcs_by_count for x in range(dccount)]
        layout = [nodes[n] for i in range(biggest) for n in range(i, len(nodes), biggest)]

        final = [[] for x in dcs_by_count]
        for pos, dc in enumerate(layout):
            final[dc].append(self.ringoffset + pos * self.ringrange // len(layout))
        return final


def print_tokens(tokens, tokenwidth, indent=0):
    indentstr = ' ' * indent
    for dcnum, toklist in enumerate(tokens):
        print "%sDC #%d:" % (indentstr, dcnum + 1)
        nwidth = len(str(len(toklist)))
        for tnum, tok in enumerate(toklist):
            print "%s  Node #%0*d: % *d" % (indentstr, nwidth, tnum + 1, tokenwidth, tok)

def calculate_ideal_tokens(datacenters, ringoffset, ringrange, strategy):
    return Ring(datacenters, ringoffset, ringrange, strategy).calculate_offset_tokens()

def file_to_url(path):
    path = os.path.abspath(path)
    if os.name == 'nt':
        host, path = os.path.splitunc(path)
        drive, path = os.path.splitdrive(path)
        path = (host or (drive + '|')) + path.replace(os.sep, '/')
    return 'file://' + urllib.quote(path, safe='/')

html_template = """<!DOCTYPE html>
<html>
<body>

%(generated_body)s

</body>
</html>
"""

chart_template = """
<canvas id="%(id)s" width="%(size)s" height="%(size)s" style="border:1px solid #c3c3c3;">
    Your browser does not support the canvas element.
</canvas>
<script type="text/javascript">
    var c=document.getElementById("%(id)s");
    var ctx=c.getContext("2d");
%(generated_script)s
</script>
"""

chart_js_template = """
    ctx.beginPath();
    ctx.strokeStyle = "%(color)s";
    ctx.moveTo(%(center)s,%(center)s);
    ctx.lineTo(%(x)s,%(y)s);
    ctx.stroke();
    ctx.closePath();
"""

class RingRenderer:
    border_fraction = 0.08

    def __init__(self, ringrange, graphsize, colors):
        self.ringrange = ringrange
        self.graphsize = graphsize
        self.colors = colors
        self.anglefactor = 2 * math.pi / ringrange
        self.linelength = graphsize * (1 - self.border_fraction) / 2
        self.center = graphsize / 2

    def calc_coords(self, tokens):
        these_calcs = []

        for toklist in tokens:
            coordlist = []
            for tok in toklist:
                angle = tok * self.anglefactor
                x2 = self.center + self.linelength * math.sin(angle)
                y2 = self.center - self.linelength * math.cos(angle)
                coordlist.append((x2, y2))
            these_calcs.append(coordlist)

        return these_calcs

    def make_html(self, tokensets):
        coordinate_sets = map(self.calc_coords, tokensets)
        all_charts = []
        for chart_index, chart_set in enumerate(coordinate_sets):
            chart_code = []
            for coordlist, color in zip(chart_set, cycle(self.colors)):
                for x, y in coordlist:
                    chart_code.append(chart_js_template
                                          % dict(color=color, x=x, y=y,
                                                 center=(self.graphsize / 2)))
            this_chart = chart_template % dict(generated_script=''.join(chart_code),
                                               id=chart_index, size=self.graphsize)
            all_charts.append(this_chart)
        return html_template % dict(generated_body=''.join(all_charts))

# ===========================
# Tests

def run_tests(opts):
    tests = [
        [1],
        [1, 1],
        [2, 2],
        [1, 2, 2],
        [2, 2, 2],
        [2, 0, 0],
        [0, 2, 0],
        [0, 0, 2],
        [2, 2, 0],
        [2, 0, 2],
        [0, 2, 2],
        [0, 0, 1, 1, 0, 1, 1],
        [6],
        [3, 3, 3],
        [9],
        [1,1,1,1],
        [4],
        [3,3,6,4,2]
    ]

    tokensets = []
    for test in tests:
        print "Test %r" % (test,)
        tokens = calculate_ideal_tokens(test, opts.ringoffset, opts.ringrange, opts.strat)
        print_tokens(tokens, len(str(opts.ringrange)) + 1, indent=2)
        tokensets.append(tokens)
    return tokensets

# ===========================

def display_html(html, wait_time):
    with NamedTemporaryFile(suffix='.html') as f:
        f.write(html)
        f.flush()
        webbrowser.open(file_to_url(f.name), new=2)
        # this is stupid. webbrowser.open really can't wait until the
        # browser has said "yes I've got it"?
        sleep(wait_time)

def write_output(html, opts):
    if opts.html_output == '-':
        sys.stdout.write(html)
    elif opts.html_output == '*tmp*':
        display_html(html, opts.browser_wait_time)
    else:
        with open(opts.html_output, 'w') as f:
            f.write(html)

def readnum(prompt, min=None, max=None):
    while True:
        x = raw_input(prompt + ' ')
        try:
            val = int(x)
        except ValueError:
            print "Oops, %r is not an integer. Try again.\n" % (x,)
            continue
        if min is not None and val < min:
            print "Oops, the answer must be at least %d. Try again.\n" % (min,)
        elif max is not None and val > max:
            print "Oops, the answer must be at most %d. Try again.\n" % (max,)
        else:
            return val

def get_dc_sizes_interactive():
    print "Token Generator Interactive Mode"
    print "--------------------------------"
    print
    dcs = readnum(" How many datacenters will participate in this Cassandra cluster?", min=1)
    sizes = []
    for n in xrange(dcs):
        sizes.append(readnum(" How many nodes are in datacenter #%d?" % (n + 1), min=0))
    print
    return sizes

def main(opts, args):
    opts.colorlist = [s.strip() for s in opts.test_colors.split(',')]
    if opts.test:
        opts.graph = True
        tokensets = run_tests(opts)
        renderer = RingRenderer(ringrange=opts.ringrange, graphsize=opts.test_graphsize,
                                colors=opts.colorlist)
    else:
        if len(args) == 0:
            args = get_dc_sizes_interactive()
        try:
            datacenters = map(int, args)
        except ValueError, e:
            parser.error('Arguments should be integers.')
        renderer = RingRenderer(ringrange=opts.ringrange, graphsize=opts.graphsize,
                                colors=opts.colorlist)
        tokens = calculate_ideal_tokens(datacenters, opts.ringoffset, opts.ringrange, opts.strat)
        print_tokens(tokens, len(str(opts.ringrange)) + 1)
        tokensets = [tokens]

    if opts.graph:
        html = renderer.make_html(tokensets)
        write_output(html, opts)
    return 0

if __name__ == '__main__':
    opts, args = parser.parse_args()
    try:
        res = main(opts, args)
    except KeyboardInterrupt:
        res = -128
    sys.exit(res)
