#!/usr/bin/env python3
# SPDX-License-Identifier: (Apache-2.0 OR MIT)

import datetime
import io
import json
import os
import random
from time import mktime
from timeit import timeit

import rapidjson
import simplejson
import ujson
from tabulate import tabulate

import orjson

os.sched_setaffinity(os.getpid(), {0, 1})

data_as_obj = []
for year in range(1920, 2020):
    start = datetime.date(year, 1, 1)
    array = [
        (int(mktime((start + datetime.timedelta(days=i)).timetuple())), i + 1)
        for i in range(0, 365)
    ]
    array.append(("other", 0))
    random.shuffle(array)
    data_as_obj.append(dict(array))

data_as_str = orjson.loads(orjson.dumps(data_as_obj, option=orjson.OPT_NON_STR_KEYS))

headers = ("Library", "str keys (ms)", "int keys (ms)", "int keys sorted (ms)")

LIBRARIES = ("orjson", "ujson", "rapidjson", "simplejson", "json")

ITERATIONS = 500


output_in_kib = len(orjson.dumps(data_as_str)) / 1024

print(f"{output_in_kib:,.0f}KiB output (orjson)")


def per_iter_latency(val):
    if val is None:
        return None
    return (val * 1000) / ITERATIONS


def test_correctness(serialized):
    return orjson.loads(serialized) == data_as_str


table = []
for lib_name in LIBRARIES:
    print(f"{lib_name}...")
    if lib_name == "json":
        time_as_str = timeit(
            lambda: json.dumps(data_as_str).encode("utf-8"),
            number=ITERATIONS,
        )
        time_as_obj = timeit(
            lambda: json.dumps(data_as_obj).encode("utf-8"),
            number=ITERATIONS,
        )
        time_as_obj_sorted = (
            None  # TypeError: '<' not supported between instances of 'str' and 'int'
        )
        correct = False
    elif lib_name == "simplejson":
        time_as_str = timeit(
            lambda: simplejson.dumps(data_as_str).encode("utf-8"),
            number=ITERATIONS,
        )
        time_as_obj = timeit(
            lambda: simplejson.dumps(data_as_obj).encode("utf-8"),
            number=ITERATIONS,
        )
        time_as_obj_sorted = timeit(
            lambda: simplejson.dumps(data_as_obj, sort_keys=True).encode("utf-8"),
            number=ITERATIONS,
        )
        correct = test_correctness(
            simplejson.dumps(data_as_obj, sort_keys=True).encode("utf-8")
        )
    elif lib_name == "ujson":
        time_as_str = timeit(
            lambda: ujson.dumps(data_as_str).encode("utf-8"),
            number=ITERATIONS,
        )
        time_as_obj = timeit(
            lambda: ujson.dumps(data_as_obj).encode("utf-8"),
            number=ITERATIONS,
        )
        time_as_obj_sorted = None  # segfault
        correct = False
    elif lib_name == "rapidjson":
        time_as_str = timeit(
            lambda: rapidjson.dumps(data_as_str).encode("utf-8"),
            number=ITERATIONS,
        )
        time_as_obj = None
        time_as_obj_sorted = None
        correct = False
    elif lib_name == "orjson":
        time_as_str = timeit(
            lambda: orjson.dumps(data_as_str, None, orjson.OPT_NON_STR_KEYS),
            number=ITERATIONS,
        )
        time_as_obj = timeit(
            lambda: orjson.dumps(data_as_obj, None, orjson.OPT_NON_STR_KEYS),
            number=ITERATIONS,
        )
        time_as_obj_sorted = timeit(
            lambda: orjson.dumps(
                data_as_obj, None, orjson.OPT_NON_STR_KEYS | orjson.OPT_SORT_KEYS
            ),
            number=ITERATIONS,
        )
        correct = test_correctness(
            orjson.dumps(
                data_as_obj, None, orjson.OPT_NON_STR_KEYS | orjson.OPT_SORT_KEYS
            )
        )
    else:
        raise NotImplementedError

    time_as_str = per_iter_latency(time_as_str)
    time_as_obj = per_iter_latency(time_as_obj)
    if not correct:
        time_as_obj_sorted = None
    else:
        time_as_obj_sorted = per_iter_latency(time_as_obj_sorted)

    table.append(
        (
            lib_name,
            f"{time_as_str:,.2f}" if time_as_str else "",
            f"{time_as_obj:,.2f}" if time_as_obj else "",
            f"{time_as_obj_sorted:,.2f}" if time_as_obj_sorted else "",
        )
    )

buf = io.StringIO()
buf.write(tabulate(table, headers, tablefmt="github"))
buf.write("\n")

print(buf.getvalue())
