#!/usr/bin/env python3

import argparse
import ast
from collections.abc import Iterable
from glob import iglob
from itertools import (
    chain,
    islice,
)
import multiprocessing
from os.path import abspath
import re
import sys
import tokenize

from twisted.python.modules import theSystemPath


python_standard_libs = [
    '__builtin__',
    '__future__',
    '__main__',
    '_dummy_thread',
    '_thread',
    '_winreg',
    'abc',
    'aepack',
    'aetools',
    'aetypes',
    'aifc',
    'AL',
    'al',
    'anydbm',
    'applesingle',
    'argparse',
    'array',
    'ast',
    'asynchat',
    'asyncio',
    'asyncore',
    'atexit',
    'audioop',
    'autoGIL',
    'base64',
    'BaseHTTPServer',
    'Bastion',
    'bdb',
    'binascii',
    'binhex',
    'bisect',
    'bsddb',
    'buildtools',
    'builtins',
    'bz2',
    'calendar',
    'Carbon',
    'cd',
    'cfmfile',
    'cgi',
    'CGIHTTPServer',
    'cgitb',
    'chunk',
    'cmath',
    'cmd',
    'code',
    'codecs',
    'codeop',
    'collections',
    'ColorPicker',
    'colorsys',
    'commands',
    'compileall',
    'compiler',
    'concurrent',
    'ConfigParser',
    'configparser',
    'contextlib',
    'Cookie',
    'cookielib',
    'copy',
    'copy_reg',
    'copyreg',
    'cPickle',
    'cProfile',
    'crypt',
    'cStringIO',
    'csv',
    'ctypes',
    'curses',
    'datetime',
    'dbhash',
    'dbm',
    'decimal',
    'DEVICE',
    'difflib',
    'dircache',
    'dis',
    'distutils',
    'dl',
    'doctest',
    'DocXMLRPCServer',
    'dumbdbm',
    'dummy_thread',
    'dummy_threading',
    'EasyDialogs',
    'email',
    'encodings',
    'ensurepip',
    'enum',
    'errno',
    'exceptions',
    'faulthandler',
    'fcntl',
    'filecmp',
    'fileinput',
    'findertools',
    'fl',
    'FL',
    'flp',
    'fm',
    'fnmatch',
    'formatter',
    'fpectl',
    'fpformat',
    'fractions',
    'FrameWork',
    'ftplib',
    'functools',
    'future_builtins',
    'gc',
    'gdbm',
    'gensuitemodule',
    'getopt',
    'getpass',
    'gettext',
    'GL',
    'gl',
    'glob',
    'grp',
    'gzip',
    'hashlib',
    'heapq',
    'hmac',
    'hotshot',
    'html',
    'htmlentitydefs',
    'htmllib',
    'HTMLParser',
    'http',
    'httplib',
    'ic',
    'icopen',
    'imageop',
    'imaplib',
    'imgfile',
    'imghdr',
    'imp',
    'importlib',
    'imputil',
    'inspect',
    'io',
    'ipaddress',
    'itertools',
    'jpeg',
    'json',
    'keyword',
    'lib2to3',
    'linecache',
    'locale',
    'logging',
    'lzma',
    'macerrors',
    'MacOS',
    'macostools',
    'macpath',
    'macresource',
    'mailbox',
    'mailcap',
    'marshal',
    'math',
    'md5',
    'mhlib',
    'mimetools',
    'mimetypes',
    'MimeWriter',
    'mimify',
    'MiniAEFrame',
    'mmap',
    'modulefinder',
    'msilib',
    'msvcrt',
    'multifile',
    'multiprocessing',
    'mutex',
    'Nav',
    'netrc',
    'new',
    'nis',
    'nntplib',
    'numbers',
    'operator',
    'optparse',
    'os',
    'ossaudiodev',
    'parser',
    'pathlib',
    'pdb',
    'pickle',
    'pickletools',
    'pipes',
    'PixMapWrapper',
    'pkgutil',
    'platform',
    'plistlib',
    'popen2',
    'poplib',
    'posix',
    'posixfile',
    'pprint',
    'profile',
    'pstats',
    'pty',
    'pwd',
    'py_compile',
    'pyclbr',
    'pydoc',
    'queue',
    'Queue',
    'quopri',
    'random',
    're',
    'readline',
    'repr',
    'reprlib',
    'resource',
    'rexec',
    'rfc822',
    'rlcompleter',
    'robotparser',
    'runpy',
    'sched',
    'ScrolledText',
    'select',
    'selectors',
    'sets',
    'sgmllib',
    'sha',
    'shelve',
    'shlex',
    'shutil',
    'signal',
    'SimpleHTTPServer',
    'SimpleXMLRPCServer',
    'site',
    'smtpd',
    'smtplib',
    'sndhdr',
    'socket',
    'socketserver',
    'SocketServer',
    'spwd',
    'sqlite3',
    'ssl',
    'stat',
    'statistics',
    'statvfs',
    'string',
    'StringIO',
    'stringprep',
    'struct',
    'subprocess',
    'sunau',
    'sunaudiodev',
    'SUNAUDIODEV',
    'symbol',
    'symtable',
    'sys',
    'sysconfig',
    'syslog',
    'tabnanny',
    'tarfile',
    'telnetlib',
    'tempfile',
    'termios',
    'test',
    'textwrap',
    'thread',
    'threading',
    'time',
    'timeit',
    'Tix',
    'Tkinter',
    'tkinter',
    'token',
    'tokenize',
    'trace',
    'traceback',
    'tracemalloc',
    'ttk',
    'tty',
    'turtle',
    'turtledemo',
    'types',
    'typing',
    'unicodedata',
    'unittest',
    'urllib',
    'urllib2',
    'urlparse',
    'user',
    'UserDict',
    'UserList',
    'UserString',
    'uu',
    'uuid',
    'venv',
    'videoreader',
    'W',
    'warnings',
    'wave',
    'weakref',
    'webbrowser',
    'whichdb',
    'winreg',
    'winsound',
    'wsgiref',
    'xdrlib',
    'xml',
    'xmlrpc',
    'xmlrpclib',
    'zipapp',
    'zipfile',
    'zipimport',
    'zlib',
]


def flatten(*things):
    """Recursively flatten iterable parts of `things`.

    For example::

      >>> sorted(flatten([1, 2, {3, 4, (5, 6)}]))
      [1, 2, 3, 4, 5, 6]

    :return: An iterator.
    """
    def _flatten(things):
        if isinstance(things, (bytes, str)):
            # String- and byte-like objects are treated as leaves; iterating
            # through either yields more of the same, each of which is also
            # iterable, and so on, until the heat-death of the universe.
            return iter((things,))
        elif isinstance(things, Iterable):
            # Recurse and merge in order to flatten nested structures.
            return chain.from_iterable(map(_flatten, things))
        else:
            # This is a leaf; return an single-item iterator so that it can be
            # chained with any others.
            return iter((things,))

    return _flatten(things)


class Pattern:

    def __init__(self, *patterns):
        super(Pattern, self).__init__()
        self.patterns = tuple(flatten(patterns))

    def _compile_pattern(self, pattern):
        for part in pattern.split("|"):
            expr = []
            for component in re.findall('([*]+|[^*]+)', part):
                if component == "*":
                    expr.append("[^.]+")
                elif component == "**":
                    expr.append(".+")
                elif component.count("*") >= 3:
                    raise ValueError(component)
                else:
                    expr.append(re.escape(component))
            yield "".join(expr)

    def compile(self):
        self._matcher = re.compile(
            r"(%s)\Z" % "|".join(chain.from_iterable(
                map(self._compile_pattern, self.patterns))),
            re.MULTILINE | re.DOTALL)

    def match(self, name):
        return self._matcher.match(name) is not None

    def __iter__(self):
        yield from self.patterns


ALLOWED, DENIED, INDIFFERENT = True, False, None


class Action:

    def __init__(self, *patterns):
        super(Action, self).__init__()
        self.pattern = Pattern(patterns)

    def compile(self):
        self.pattern.compile()

    def __iter__(self):
        yield from self.pattern.patterns


class Allow(Action):

    def check(self, name):
        if self.pattern.match(name):
            return ALLOWED
        else:
            return INDIFFERENT


class Deny(Action):

    def check(self, name):
        if self.pattern.match(name):
            return DENIED
        else:
            return INDIFFERENT


class Rule:

    def __init__(self, *actions):
        super(Rule, self).__init__()
        allows = {act for act in actions if isinstance(act, Allow)}
        denies = {act for act in actions if isinstance(act, Deny)}
        leftover = set(actions) - allows - denies
        if len(leftover) != 0:
            raise ValueError(
                "Expected Allow or Deny instance, got: %s"
                % ", ".join(map(repr, leftover)))
        self.allow = Allow(allows)
        self.deny = Deny(denies)

    def compile(self):
        self.allow.compile()
        self.deny.compile()

    def check(self, name):
        if self.deny.check(name) is DENIED:
            return DENIED
        elif self.allow.check(name) is ALLOWED:
            return ALLOWED
        else:
            return DENIED

    def __or__(self, other):
        if isinstance(other, self.__class__):
            return self.__class__(
                self.allow, self.deny, other.allow, other.deny)
        else:
            return self.__class__(
                self.allow, self.deny, other)


# Common module patterns.
StandardLibraries = Pattern(
    map("{0}|{0}.**".format, python_standard_libs),
)


def files(*patterns):
    return frozenset(chain.from_iterable(
        iglob(pattern, recursive=True)
        for pattern in flatten(patterns)))


#
# Tests
#

TestsUnit = files(
    "maas/**/tests/**/*.py",
    "maas/**/testing/**/*.py",
    "maas/**/testing.py")

TestsIntegration = files(
    "integrate/**/*.py")

TestsAll = TestsUnit | TestsIntegration


#
# Code
#


CodeFlesh = files("maas/client/flesh/**/*.py") - TestsAll

CodeOther = files("maas/**/*.py") - CodeFlesh - TestsAll

Setup = files("setup.py")


#
# Make some checks.
#

checks = [
    (
        TestsUnit,
        Rule(
            Allow("aiohttp|aiohttp.**"),
            Allow("django|django.**"),
            Allow("fixtures|fixtures.**"),
            Allow("macaroonbakery|macaroonbakery.**"),
            Allow("maas.client|maas.client.**"),
            Allow("multidict|multidict.**"),
            Allow("pkg_resources|pkg_resources.**"),
            Allow("testscenarios|testscenarios.**"),
            Allow("testtools|testtools.**"),
            Allow("twisted|twisted.**"),
            Allow("yaml|yaml.**"),
            Allow(StandardLibraries),
        ),
    ),
    (
        TestsIntegration,
        Rule(
            Allow("maas.client|maas.client.**"),
            Allow("testtools|testtools.**"),
            Allow(StandardLibraries),
        ),
    ),
    (
        CodeFlesh,
        Rule(
            Allow("argcomplete|argcomplete.**"),
            Allow("colorclass|colorclass.**"),
            Allow("ipdb|ipdb.**"),
            Allow("IPython|IPython.**"),
            Allow("maas.client|maas.client.**"),
            Allow("pytz"),
            Allow("terminaltables|terminaltables.**"),
            Allow("yaml|yaml.**"),
            Allow(StandardLibraries),
        ),
    ),
    (
        CodeOther,
        Rule(
            Allow("aiohttp|aiohttp.**"),
            Allow("maas.client|maas.client.**"),
            Allow("macaroonbakery|macaroonbakery.**"),
            Allow("oauthlib.oauth1"),
            Allow("pytz"),
            Allow("bson"),
            Allow(StandardLibraries),
            # Don't let non-flesh code import flesh.
            Deny("maas.client.flesh|maas.client.flesh.**"),
        ),
    ),
    (
        Setup,
        Rule(
            Allow("os.path|os.path.**"),
            Allow("setuptools|setuptools.**"),
        ),
    ),
]


def extract(node, modinfo):
    """Extract all imports from the given AST node."""
    for node in ast.walk(node):
        if isinstance(node, ast.Import):
            for alias in node.names:
                yield alias.name
        elif isinstance(node, ast.ImportFrom):
            # Note that `node.level == 0` if the import is absolute.
            if node.level == 0:
                # Absolute import.
                module = node.module
            else:
                # Relative import.
                trim = node.level - (1 if modinfo.isPackage() else 0)
                if trim == 0:
                    module = modinfo.name
                else:
                    parts = modinfo.name.split(".")[:-trim]
                    module = ".".join(parts)
                if node.module is not None:
                    # from ..some.module import some_name
                    module = "%s.%s" % (module, node.module)
            for alias in node.names:
                yield "%s.%s" % (module, alias.name)
        else:
            pass  # Not an import.


def _batch(objects, size):
    """Generate batches of `size` elements from `objects`.

    Each batch is a list of `size` elements exactly, except for the last which
    may contain fewer than `size` elements.
    """
    objects = iter(objects)
    batch = lambda: list(islice(objects, size))
    return iter(batch, [])


def _expand(checks):
    """Generate `(rule, batch-of-filenames)` for the given checks.

    It batches filenames to reduce the serialise/unserialise overhead when
    calling out to a pooled process.
    """
    for filenames, rule in checks:
        rule.compile()  # Compile or it's slow.
        for filenames in _batch(filenames, 100):
            yield rule, filenames


def _scan1(rule, filename, modinfo):
    """Scan one file and check against the given rule."""
    with tokenize.open(filename) as fd:
        module = ast.parse(fd.read())
    imports = set(extract(module, modinfo))
    allowed = set(filter(rule.check, imports))
    denied = imports.difference(allowed)
    return filename, allowed, denied


# A mapping of absolute file name to `t.p.modules.PythonModule` instances.
# Shared between all processes in the multiprocessing pool and so must be
# initialised in advance.
_packages = {
    module.filePath.path: module for module in
    theSystemPath.walkModules(importPackages=False)
}


def _scan(rule, filenames, packages=_packages):
    """Scan the files and check against the given rule."""
    return [
        _scan1(rule, filename, packages[abspath(filename)])
        for filename in filenames
    ]


def scan(checks):
    """Scan many files and check against the given rules."""
    with multiprocessing.Pool() as pool:
        for results in pool.starmap(_scan, _expand(checks)):
            yield from results


if sys.stdout.isatty():

    def print_filename(filename):
        print('\x1b[36m' + filename + '\x1b[39m')

    def print_allowed(name):
        print(" \x1b[32mallowed:\x1b[39m", name)

    def print_denied(name):
        print("  \x1b[31mdenied:\x1b[39m", name)

else:

    def print_filename(filename):
        print(filename)

    def print_allowed(name):
        print(" allowed:", name)

    def print_denied(name):
        print("  denied:", name)


def main(args):
    parser = argparse.ArgumentParser(
        description="Statically check imports against policy.",
        add_help=False)
    parser.add_argument(
        "--show-allowed", action="store_true", dest="show_allowed",
        help="Log allowed imports.")
    parser.add_argument(
        "--hide-denied", action="store_false", dest="show_denied",
        help="Log denied imports.")
    parser.add_argument(
        "--hide-summary", action="store_false", dest="show_summary",
        help="Show summary of allowed and denied imports.")
    parser.add_argument(
        "-h", "--help", action="help", help=argparse.SUPPRESS)
    options = parser.parse_args(args)

    allowedcount, deniedcount = 0, 0
    for filename, allowed, denied in scan(checks):
        allowedcount += len(allowed)
        deniedcount += len(denied)
        show_filename = (
            (options.show_allowed and len(allowed) != 0) or
            (options.show_denied and len(denied) != 0))
        if show_filename:
            print_filename(filename)
        if options.show_allowed and len(allowed) != 0:
            for imported_name in sorted(allowed):
                print_allowed(imported_name)
        if options.show_denied and len(denied) != 0:
            for imported_name in sorted(denied):
                print_denied(imported_name)

    if options.show_summary:
        if (options.show_allowed and allowedcount != 0):
            print()
        elif (options.show_denied and deniedcount != 0):
            print()

        print(allowedcount, "imported names were ALLOWED.")
        print(deniedcount, "imported names were DENIED.")

    return deniedcount == 0


if __name__ == '__main__':
    raise SystemExit(0 if main(sys.argv[1:]) else 1)
