#!/usr/bin/env python3
"""
Generate Debian copyright files from vendored dependencies.

This script scans vendor directories, extracts license and copyright information
using licensecheck and SPDX identifiers, then generates a machine-readable
debian/copyright file following DEP-5 format.

The script:
1. Runs licensecheck on vendor directories to extract licenses and copyrights
2. Looks for SPDX-License-Identifier tags in source files
3. Builds a directory tree and merges nodes with identical license info
4. Generates a DEP-5 format debian/copyright file

Requirements:
    - python3-anytree
    - licensecheck (from devscripts package)
"""

import os
import sys
import re
import subprocess
import argparse
import glob
from collections import defaultdict
from typing import Dict, List, Set, Optional, Tuple
import logging

logger = logging.getLogger(__name__)

try:
    from anytree import NodeMixin, PreOrderIter, PostOrderIter
except ImportError:
    logger.error("Error: python3-anytree is required")
    sys.exit(1)

# Constants
LICENSE_FILE_PATTERNS = ['license', 'copying', 'copyright', 'notice', 'unlicense']
IGNORE_FILE_PATTERNS = ['readme', 'citation.cff']
MAX_COPYRIGHT_LENGTH = 100
MIN_YEAR = 1900
MAX_YEAR = 2100
# Scan first 20 lines for SPDX identifier (typically in file header)
SPDX_SCAN_LINES = 20
LICENSE_COPYRIGHT_SCAN_LINES = 5

COPYRIGHT_HEADER = """Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/
Upstream-Name: <REPLACE WITH PACKAGE NAME>
Upstream-Contact: <REPLACE WITH UPSTREAM CONTACT>
Source: <REPLACE WITH UPSTREAM SOURCE URL>

Files: *
Copyright: <REPLACE WITH COPYRIGHT HOLDER(S)>
License: <REPLACE WITH LICENSE>
"""


class CopyrightHolder:
    """Represents a copyright holder with years, name, and optional email separated."""
    def __init__(self, years: List[int], holder_name: str, email: Optional[str] = None):
        self.years = sorted(set(years))
        self.holder_name = holder_name
        self.email = email

    def __str__(self) -> str:
        """Format as 'YYYY-YYYY Name <email>' or 'YYYY Name'."""
        name_part = self.holder_name
        if self.email:
            name_part = f"{self.holder_name.strip('. ')} <{self.email}>"

        if not self.years:
            return name_part
        min_year, max_year = min(self.years), max(self.years)
        if min_year == max_year:
            return f"{min_year} {name_part}"
        return f"{min_year}-{max_year} {name_part}"

    @staticmethod
    def better_name(name1: str, name2: str) -> str:
        """Return the better name (prefer natural mixed case over all caps/lowercase)."""
        if not name1:
            return name2
        if not name2:
            return name1

        # Count uppercase and lowercase letters in each word
        def score_name(name: str) -> Tuple[int, int]:
            words = name.split()
            mixed_case_words = 0

            for word in words:
                has_upper = any(c.isupper() for c in word)
                has_lower = any(c.islower() for c in word)
                # A word has "natural" mixed case if it starts with upper and has lower
                # (e.g., "Matsumoto" not "MATSUMOTO")
                if has_upper and has_lower:
                    mixed_case_words += 1

            # Score: (number of mixed-case words, total length)
            # Prefer more mixed-case words (natural capitalization)
            # Then prefer longer names (might have middle names)
            return (mixed_case_words, len(name))

        score1 = score_name(name1)
        score2 = score_name(name2)

        return name1 if score1 >= score2 else name2

    @staticmethod
    def extract_years(text: str) -> List[int]:
        """Extract all years from copyright text."""
        years = []
        # Year ranges
        for match in re.finditer(r'\b(\d{4})\s*-\s*(\d{4})\b', text):
            start, end = int(match.group(1)), int(match.group(2))
            if MIN_YEAR <= start <= MAX_YEAR and MIN_YEAR <= end <= MAX_YEAR:
                years.extend(range(start, end + 1))
        # Individual years (not in ranges)
        text_without_ranges = re.sub(r'\b\d{4}\s*-\s*\d{4}\b', '', text)
        for match in re.finditer(r'\b((?:19|20)\d{2})\b', text_without_ranges):
            year = int(match.group(1))
            if MIN_YEAR <= year <= MAX_YEAR:
                years.append(year)
        return sorted(set(years))

    @staticmethod
    def extract_name_and_email(text: str) -> Tuple[str, Optional[str]]:
        """Extract holder name and email (without years). Returns (name, email)."""
        # Extract email first (look for <email@domain.com> or just email@domain.com)
        email = None
        email_match = re.search(r'<([^>]+@[^>]+)>', text)
        if email_match:
            email = email_match.group(1).strip()
            # Remove the email from text
            text = text[:email_match.start()] + text[email_match.end():]
        else:
            # Try to find email without angle brackets
            email_match = re.search(r'\b([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})\b', text)
            if email_match:
                email = email_match.group(1).strip()
                # Remove the email from text
                text = text[:email_match.start()] + text[email_match.end():]

        # Now extract the name (without years)
        name = re.sub(r'\b\d{4}\s*-\s*\d{4}\b', '', text)
        name = re.sub(r'\b(?:19|20)\d{2}\b', '', name)
        name = re.sub(r'\s*[,\-]\s*$', '', name)
        name = re.sub(r'^\s*[,\-]\s*', '', name)
        return name.strip(), email

    @staticmethod
    def normalize_name(holder_name: str) -> str:
        """Normalize holder name for comparison (case-insensitive, no email)."""
        # Note: email should already be removed by extract_holder_name_and_email
        normalized = re.sub(r'^\s*the\s+', '', holder_name, flags=re.IGNORECASE)
        normalized = normalized.rstrip('.')
        normalized = re.sub(r'[,\-]+', ' ', normalized)
        normalized = re.sub(r'\s+', ' ', normalized)
        return normalized.strip().lower()


class LicenseInfo:
    """License and copyright information for a directory."""
    def __init__(self, path: str, licenses: Set[str], copyrights: Optional[Set[str]] = None):
        self.file_path = path
        self.licenses = licenses
        # normalized_name -> CopyrightHolder
        self.copyright_holders: Dict[str, CopyrightHolder] = {}

        if copyrights:
            for c in copyrights:
                self.add_copyright(c)

    def __repr__(self) -> str:
        if not self.licenses and not self.copyright_holders:
            return "No licenses or copyrights"
        if not self.licenses:
            return f"Copyrights: {[str(k) for k in self.copyright_holders.keys()]}"
        if not self.copyright_holders:
            return f"Licenses: {self.licenses}"
        return f"Licenses: {self.licenses}, Copyrights: {[str(k) for k in self.copyright_holders.keys()]}"

    def __bool__(self) -> bool:
        return bool(self.licenses or self.copyright_holders)

    @property
    def license(self) -> str:
        return  ' or '.join(sorted(self.licenses)) if self.licenses else "UNKNOWN"

    def add_copyright(self, copyright_text: str):
        """Add copyright from text, parsing years and holder name."""
        if not copyright_text:
            return

        # Discard copyright if it contains non-printable characters (often from binary files)
        if not copyright_text.isprintable():
            logger.warning(
                f"Discarding copyright of {self.file_path} with non-printable characters: {repr(copyright_text)}")
            return

        # Discard copyright if it's longer than 100 characters (probably a false positive)
        if len(copyright_text) > MAX_COPYRIGHT_LENGTH:
            logger.warning(
                f"Discarding copyright of {self.file_path} that is longer than {MAX_COPYRIGHT_LENGTH} chars: {repr(copyright_text)}")
            return

        cleaned = clean_copyright(copyright_text)
        if not cleaned or (cleaned.strip().isdigit() and len(cleaned.strip()) == 4):
            return

        years = CopyrightHolder.extract_years(cleaned)
        holder_name, email = CopyrightHolder.extract_name_and_email(cleaned)
        if not holder_name:
            return

        normalized = CopyrightHolder.normalize_name(holder_name)

        if normalized in self.copyright_holders:
            # Merge years and update name/email
            existing = self.copyright_holders[normalized]
            existing.years = sorted(set(existing.years + years))
            # Choose better formatted name
            existing.holder_name = CopyrightHolder.better_name(existing.holder_name, holder_name)
            # Prefer email if we have one and didn't before
            if email and not existing.email:
                existing.email = email
        else:
            self.copyright_holders[normalized] = CopyrightHolder(years, holder_name, email)

    def get_copyright_strings(self) -> List[str]:
        """Get formatted copyright strings for output."""
        return sorted([str(h) for h in self.copyright_holders.values()])


class VendorNode(NodeMixin):
    """Node in the directory tree using anytree."""
    def __init__(self, file_path: str, parent: Optional['VendorNode'] = None):
        super().__init__()
        self.file_path = file_path
        self.parent: Optional[VendorNode] = parent
        self.license_info: LicenseInfo = LicenseInfo(file_path, set())
        self.is_dir = False
        self.merged = False

    def __repr__(self) -> str:
        return f"{self.file_path} - {self.license_info}"


class VendorLicenseScanner:
    """Scans vendor directories and generates Debian copyright file."""

    SPDX_PATTERN = re.compile(
        r'SPDX-License-Identifier:\s*([A-Za-z0-9.\-+]+(?:\s+(?:AND|OR|WITH)\s+[A-Za-z0-9.\-+]+)*)',
        re.IGNORECASE
    )

    def __init__(self):
        self.roots: List[VendorNode] = []  # Root nodes for each vendor directory
        self.license_texts: Dict[str, str] = {}
        self.licensecheck_cache: Dict[str, Tuple[Optional[str], Optional[str]]] = {}

    def __enter__(self):
        """Context manager entry."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit - cleanup resources."""
        self.licensecheck_cache.clear()
        self.license_texts.clear()
        return False

    @staticmethod
    def is_license_file(filename: str) -> bool:
        """Check if file is a license file."""
        basename = filename.lower()
        return any(basename.startswith(pattern) for pattern in LICENSE_FILE_PATTERNS)

    def find_node(self, file_path: str) -> Optional[VendorNode]:
        """Find a node by file_path in the tree."""
        for root in self.roots:
            for node in PreOrderIter(root):
                if node.file_path == file_path:
                    return node
        return None

    def get_all_nodes(self) -> List[VendorNode]:
        """Get all nodes from all trees."""
        all_nodes = []
        for root in self.roots:
            all_nodes.extend(PreOrderIter(root))
        return all_nodes

    def get_or_create_node(self, path: str) -> VendorNode:
        """Get or create a file or directory node, building parent chain."""
        existing = self.find_node(path)
        if existing:
            return existing

        return self.create_node(path)

    def create_node(self, path: str) -> VendorNode:
        """Create a file or directory node, building parent chain."""
        logger.debug(f"Creating node for {path}...")

        parent_path = os.path.dirname(path)
        parent_node = None
        if parent_path and parent_path != path:
            parent_node = self.get_or_create_node(parent_path)

        node = VendorNode(path, parent=parent_node)
        if parent_node:
            parent_node.is_dir = True

        # Track root nodes
        if parent_node is None:
            self.roots.append(node)

        return node

    def delete_nodes(self, nodes: List[VendorNode]):
        """Delete nodes and recursively delete their children."""
        for node in nodes:
            logger.debug(f"    Removing node {node.file_path}")
            # Reparent children to the node's parent
            for child in node.children:
                child.parent = node.parent

            # Remove from roots if it's a root node
            if node in self.roots:
                logger.debug(f"      Removing root node {node.file_path}")
                self.roots.remove(node)

            # Detach from parent (anytree handles the rest)
            node.parent = None


    def run_licensecheck(self, dir_path: str):
        """Run licensecheck and cache results using machine-readable format."""
        logger.info(f"  Running licensecheck on {dir_path}...")

        try:
            result = subprocess.run(
                [
                    'licensecheck',
                    '--machine',
                    '--copyright',
                    '--shortname-scheme=debian,spdx',
                    '--recursive',
                    dir_path
                ],
                capture_output=True,
                text=True,
                check=True  # Raises CalledProcessError on non-zero exit
            )
        except subprocess.CalledProcessError as e:
            logger.error(f"Error: licensecheck failed for {dir_path}")
            logger.error(f"  Exit code: {e.returncode}")
            if e.stderr:
                logger.error(f"  Error output: {e.stderr}")
            raise
        except FileNotFoundError:
            logger.error("Error: licensecheck not found - please install the devscripts package")
            sys.exit(1)

        for line in result.stdout.split('\n'):
            line = line.strip()
            if not line:
                continue

            # Parse tab-separated format: filename\tlicense\tcopyright
            parts = line.split('\t')
            if len(parts) < 2:
                continue

            file_path = parts[0].strip()
            license_id = parts[1].strip()
            copyright_text = parts[2].strip() if len(parts) > 2 else None

            # Skip UNKNOWN licenses
            if license_id == 'UNKNOWN':
                license_id = None

            # Skip *No copyright* entries
            if copyright_text == '*No copyright*':
                copyright_text = None

            # Normalize license
            if license_id:
                license_id = normalize_license(license_id)

            # Clean copyright
            if copyright_text:
                copyright_text = clean_copyright(copyright_text)

            self.licensecheck_cache[file_path] = (license_id, copyright_text)

        logger.info(f"  Cached {len(self.licensecheck_cache)} files")

    def get_spdx_licenses(self, file_path: str) -> Set[str]:
        """Get SPDX licenses from file."""
        spdx_id = self.find_spdx_identifier(file_path)
        return split_license_str(spdx_id)

    def find_spdx_identifier(self, file_path: str) -> Optional[str]:
        """Find SPDX identifier in file."""
        try:
            with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                for i, line in enumerate(f):
                    if i >= SPDX_SCAN_LINES:
                        break
                    match = self.SPDX_PATTERN.search(line)
                    if match:
                        return normalize_license(match.group(1).strip())
        except Exception as e:
            logger.warning(f"Warning: Could not read SPDX identifier from {file_path}: {e}")
        return None

    def get_licensecheck_licenses_and_copyrights(self, path: str) -> Tuple[Set[str], Set[str]]:
        """Get licenses and copyrights from licensecheck cache for a given path."""
        licenses = set()
        copyrights = set()

        license_str, copyright_str = self.licensecheck_cache.get(path, (None, None))
        for l in split_license_str(license_str):
            licenses.add(normalize_license(l))
        for c in split_copyright_str(copyright_str):
            copyrights.add(c)

        logger.debug(f"licensecheck: {path} - licenses: {licenses}, copyrights: {copyrights}")
        return licenses, copyrights

    def process_license_files(self, license_files: List[str]) -> Tuple[Set[str], Set[str]]:
        """Process LICENSE files to determine directory's license and copyright."""
        all_licenses = set()
        all_copyrights = set()

        for lic_file in license_files:
            licenses, copyrights = self.get_licensecheck_licenses_and_copyrights(lic_file)
            all_copyrights.update(copyrights)
            for l in licenses:
                self.read_license_file(lic_file, l)
                all_licenses.add(l)

        return all_licenses, all_copyrights

    def read_license_file(self, license_file: str, license_id: str):
        """Read full license text."""

        # Skip if we already have a license text for this license
        if license_id in self.license_texts:
            return

        try:
            with open(license_file, 'r', encoding='utf-8', errors='ignore') as f:
                text = f.read()
                # Remove copyright lines from first few lines
                lines = text.splitlines()
                cleaned = []
                for i, line in enumerate(lines):
                    if i < LICENSE_COPYRIGHT_SCAN_LINES and re.match(r'^\s*Copyright\s+', line, re.IGNORECASE):
                        continue
                    cleaned.append(line)
                self.license_texts[license_id] = '\n'.join(cleaned)
        except Exception as e:
            logger.warning(f"Warning: Could not read license file {license_file}: {e}")


    def scan_directory(self, dir_path: str):
        """Scan directory and build tree."""
        logger.info(f"Scanning {dir_path}...")
        self.run_licensecheck(dir_path)

        # Group files by directory to process them together
        files_by_dir = defaultdict(list)
        for file_path in self.licensecheck_cache.keys():
            dir_path_of_file = os.path.dirname(file_path)
            files_by_dir[dir_path_of_file].append(file_path)

        dir_count = 0
        for dir_path, file_paths in sorted(files_by_dir.items()):
            dir_count += 1
            if dir_count % 100 == 0:
                logger.info(f"  Processed {dir_count} directories...")

            dir_node = self.create_node(dir_path)

            # Separate LICENSE files from other files
            license_files = [f for f in file_paths if self.is_license_file(os.path.basename(f))]
            other_files = [f for f in file_paths if not self.is_license_file(os.path.basename(f))]

            # If there are LICENSE files, use them for the directory's license and copyright
            licenses, copyrights = self.process_license_files(license_files)
            dir_node.license_info = LicenseInfo(dir_path, licenses, copyrights)
            logger.debug(str(dir_node))

            # Process other files - add nodes with licenses and copyrights
            for file_path in other_files:
                file_node = self.create_node(file_path)
                licenses, copyrights = self.get_licensecheck_licenses_and_copyrights(file_path)

                # Check for SPDX identifier which overrides licensecheck license
                # (licensecheck doesn't parse SPDX identifiers unfortunately)
                spdx_licenses = self.get_spdx_licenses(file_path)
                if spdx_licenses:
                    licenses = spdx_licenses

                file_node.license_info = LicenseInfo(file_path, licenses, copyrights)

    def merge_nodes_in_all_dirs(self):
        """Walk tree bottom-up and merge nodes where children match."""
        max_iterations = 1000  # Safety limit to prevent infinite loops
        iteration = 0

        while iteration < max_iterations:
            iteration += 1
            merged_any = False

            # Iterate over all roots
            for root in self.roots:
                merged = self.merge_nodes_in_tree(root)
                if merged:
                    merged_any = True

            if not merged_any:
                break

        if iteration == max_iterations:
            logger.warning(f"Merge loop hit maximum iterations ({max_iterations})")

    def merge_nodes_in_tree(self, root: VendorNode) -> bool:
        """Walk tree bottom-up and merge nodes where children match.

        Algorithm:
        1. Walk tree bottom-up (PostOrderIter)
        2. Group children by (licenses, copyright_holders)
        3. Merge groups that match parent or represent majority

        Returns:
            True if any merges were performed
        """
        logger.debug(f"Merging nodes in {root}")
        merged_any = False

        for node in PostOrderIter(root):
            if not node.children:
                continue

            # Group children by (license, normalized_copyright)
            groups = defaultdict(list)  # type: Dict[Tuple[Set[str], Set[str]], List[VendorNode]]
            for child in node.children:
                if child.license_info.copyright_holders:
                    normalized_holders = frozenset(child.license_info.copyright_holders.keys())
                else:
                    normalized_holders = frozenset()
                licenses = frozenset(child.license_info.licenses)
                key = (licenses, normalized_holders)
                groups[key].append(child)

            if not groups:
                continue

            # Check if any groups can be merged
            for group in sorted(groups.values(), key=len,
                                reverse=True):
                logger.debug(
                    f"Checking if we can merge\n  {node}\n with:\n  {"\n  ".join(str(g) for g in group)}")
                if self.should_merge(group, node):
                    self.merge_with_parent(group, node)
                    merged_any = True
                    continue

                if not group[0].license_info.licenses:
                    # Inherit license from parent
                    for i in range(0, len(group)):
                        group[i].license_info.licenses = node.license_info.licenses
                if not group[0].license_info.copyright_holders:
                    # Inherit copyright from parent
                    for i in range(0, len(group)):
                        group[i].license_info.copyright_holders = node.license_info.copyright_holders

        return merged_any

    def merge_with_parent(self, group: List[VendorNode], parent: VendorNode):
        """Merge children into parent node."""
        # Start with parent's existing license info
        merged_info = LicenseInfo(parent.file_path, parent.license_info.licenses.copy())

        # Copy parent's copyright holders
        for normalized, holder in parent.license_info.copyright_holders.items():
            merged_info.copyright_holders[normalized] = CopyrightHolder(
                holder.years.copy(), holder.holder_name, holder.email
            )

        # Merge children's licenses
        for child in group:
            merged_info.licenses.update(child.license_info.licenses)
            # Merge copyright holders from children
            for normalized, holder in child.license_info.copyright_holders.items():
                if normalized in merged_info.copyright_holders:
                    existing = merged_info.copyright_holders[normalized]
                    existing.years = sorted(set(existing.years + holder.years))
                    # Choose better formatted name
                    existing.holder_name = CopyrightHolder.better_name(existing.holder_name, holder.holder_name)
                    # Prefer email if we have one and didn't before
                    if holder.email and not existing.email:
                        existing.email = holder.email
                else:
                    merged_info.copyright_holders[
                        normalized] = CopyrightHolder(
                        holder.years.copy(), holder.holder_name, holder.email
                    )

        # Replace parent's license info
        parent.license_info = merged_info
        parent.merged = True

        # Delete the group nodes
        self.delete_nodes(group)

    @staticmethod
    def should_merge(group: List[VendorNode], parent: VendorNode) -> bool:
        parent_info = parent.license_info
        group_info = group[0].license_info

        if not group_info:
            # Group has no license or copyright - should merge
            logger.debug(green("  Merging because children have no license or copyright"))
            return True

        if not parent_info:
            # Group has license or copyright but parent does not - only merge
            # if the group contains more than half of the parent's children
            if len(group) > len(parent.children) // 2:
                logger.debug(green(f"  Merging because parent has no license or copyright and group contains more than half of its children ({len(group)}/{len(parent.children)})"))
                return True
            logger.debug(yellow("  Skipping merge - parent has no license or copyright"))
            return False

        if not parent_info.copyright_holders:
            # Parent has license but no copyright - should merge if license
            # matches or the children have no license
            if not group_info.licenses:
                logger.debug(green("  Merging because children have no license and parent has no copyright"))
                return True
            if parent_info.licenses == group_info.licenses:
                logger.debug(green("  Merging because parent and children have same license and parent has no copyright"))
                return True
            logger.debug(yellow("  Skipping merge - parent has license but does not match children"))
            return False

        # Parent has license and copyright
        parent_holders = frozenset(parent_info.copyright_holders.keys())
        child_holders = frozenset(group_info.copyright_holders.keys())
        parent_licenses = frozenset(parent_info.licenses)
        child_licenses = frozenset(group_info.licenses)

        if not child_licenses and child_holders == parent_holders:
            logger.debug(green("  Merging because children have no license and parent copyright matches children"))
            return True

        if not child_holders and child_licenses == parent_licenses:
            logger.debug(green("  Merging because children have no copyright and parent license matches children"))
            return True

        if parent_licenses == child_licenses and parent_holders == child_holders:
            logger.debug(green("  Merging because parent has license and copyright and both match children"))
            return True

        logger.debug(yellow("  Skipping merge - parent has license and copyright but do not match children"))
        return False


class CopyrightFileGenerator:
    """Handles reading and writing Debian copyright files."""

    def __init__(self, output_path: str):
        self.output_path = output_path
        self.existing_header: List[str] = []
        self.existing_license_texts: Dict[str, str] = {}

    def read_existing_copyright(self):
        """Read existing copyright file to preserve header and license texts."""
        if not os.path.exists(self.output_path):
            return
        logger.info(f"Reading existing copyright from {self.output_path}...")
        try:
            with open(self.output_path, 'r', encoding='utf-8') as f:
                lines = f.read().splitlines()

            # Find markers
            vendored_idx = next((i for i, l in enumerate(lines) if l.strip() == '# Vendored files'), None)
            license_idx = next((i for i, l in enumerate(lines) if l.strip() == '# License texts'), None)

            if vendored_idx is not None:
                self.existing_header = lines[:vendored_idx]
                while self.existing_header and not self.existing_header[-1].strip():
                    self.existing_header.pop()

            if license_idx is not None:
                # Parse license texts
                current_lic = None
                current_lines = []
                for i in range(license_idx + 1, len(lines)):
                    line = lines[i]
                    if line.startswith('License: '):
                        if current_lic and current_lines:
                            self.existing_license_texts[current_lic] = '\n'.join(current_lines)
                        current_lic = line[9:].strip()
                        current_lines = []
                    elif current_lic:
                        current_lines.append(line)

                if current_lic and current_lines:
                    self.existing_license_texts[current_lic] = '\n'.join(current_lines)

            logger.info(f"  Preserved {len(self.existing_header)} header lines, {len(self.existing_license_texts)} license texts")
        except Exception as e:
            logger.warning(f"Warning: Could not read existing copyright: {e}")

    def generate_copyright_file(self, scanners: List[VendorLicenseScanner]):
        """Generate the Debian copyright file from multiple scanners."""
        logger.info(f"Generating {self.output_path}...")

        lines = []

        # Header
        if self.existing_header:
            lines.extend(self.existing_header)
            lines.append('')
        else:
            lines.extend(COPYRIGHT_HEADER.splitlines())

        lines.append('# Vendored files')
        lines.append('')

        # Collect all nodes with license info from all scanners
        all_nodes_with_licenses = []
        all_license_texts = {}

        for scanner in scanners:
            nodes_with_licenses = [n for n in scanner.get_all_nodes() if n.license_info]
            all_nodes_with_licenses.extend(nodes_with_licenses)
            all_license_texts.update(scanner.license_texts)

        # Separate directories (nodes with children) from files (leaf nodes)
        dir_nodes = [n for n in all_nodes_with_licenses if n.is_dir]
        file_nodes = [n for n in all_nodes_with_licenses if not n.is_dir]

        # Group file nodes with identical license info
        file_groups = self.group_nodes_by_license_info(file_nodes)

        # Create entries for both directories and file groups
        entries = []

        # Add directory entries
        for node in dir_nodes:
            info = node.license_info
            path = node.file_path[2:] if node.file_path.startswith('./') else node.file_path
            entries.append((path, [path + '/*'], info))

        # Add file group entries
        for group in file_groups:
            info = group[0].license_info
            paths = [n.file_path[2:] if n.file_path.startswith('./') else n.file_path for n in group]
            # Sort key is the first path in the group
            entries.append((paths[0], paths, info))

        # Sort all entries by their sort key (first path)
        entries.sort(key=lambda e: e[0])

        # Write all entries in sorted order
        for sort_key, paths, info in entries:
            logger.debug(f"Writing {paths[0]}")

            # Write Files: line(s)
            if len(paths) == 1:
                lines.append(f'Files: {paths[0]}')
            else:
                lines.append(f'Files: {paths[0]}')
                for path in paths[1:]:
                    lines.append(f'       {path}')

            copyrights = info.get_copyright_strings()
            if copyrights:
                lines.append(f'Copyright: {copyrights[0]}')
                for c in copyrights[1:]:
                    lines.append(f'           {c}')

            lines.append(f'License: {info.license}')
            lines.append('')

        # License texts
        lines.append('# License texts')
        lines.append('')

        all_license_ids = set(all_license_texts.keys()) | set(self.existing_license_texts.keys())
        for license_id in sorted(all_license_ids):
            lines.append(f'License: {license_id}')

            if license_id in self.existing_license_texts:
                text = self.existing_license_texts[license_id]
                for line in text.splitlines():
                    lines.append(line)
            elif license_id in all_license_texts:
                text = all_license_texts[license_id]
                for line in text.splitlines():
                    if line.strip():
                        lines.append(' ' + line)
                    else:
                        lines.append(' .')

            lines.append('')

        # Write file
        with open(self.output_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(lines))

        logger.info(f"Generated {self.output_path} with {len(all_nodes_with_licenses)} rules")

    @staticmethod
    def group_nodes_by_license_info(nodes: List[VendorNode]) -> List[List[VendorNode]]:
        """Group nodes with identical license info that share the same parent."""
        # First, group by parent and license info
        parent_groups = defaultdict(lambda: defaultdict(list))

        for node in nodes:
            parent_path = os.path.dirname(node.file_path) if node.parent else None

            # Create a key from license info
            normalized_holders = frozenset(
                node.license_info.copyright_holders.keys()) if node.license_info.copyright_holders else frozenset()
            licenses = frozenset(node.license_info.licenses)
            key = (licenses, normalized_holders)

            parent_groups[parent_path][key].append(node)

        # Flatten to list of groups
        result = []
        for parent_path, license_groups in parent_groups.items():
            for group in license_groups.values():
                result.append(sorted(group, key=lambda n: n.file_path))

        return result


def split_license_str(license_str: str) -> Set[str]:
    """Split a license identifier into individual licenses if it contains 'or' or 'and/or'."""
    if not license_str:
        return set()
    return set(re.split(r'\s+or\s+|\s+and/or\s+', license_str, flags=re.IGNORECASE))


def split_copyright_str(copyright_str: str) -> Set[str]:
    """licensecheck concatenates multiple holders with '/'"""
    if not copyright_str:
        return set()
    return set(re.split(r'\s*/\s*', copyright_str))


def normalize_license(license_id: Optional[str]) -> Optional[str]:
    """Normalize license identifiers (MIT -> Expat, etc)."""
    if not license_id:
        return None
    normalized = re.sub(r'\bMIT~unspecified\b', 'Expat', license_id)
    normalized = re.sub(r'\bMIT\b', 'Expat', normalized)
    normalized = re.sub(r'\s+OR\s+', ' or ', normalized)
    normalized = re.sub(r'\s+and/or\s+', ' or ', normalized, flags=re.IGNORECASE)
    return normalized

def clean_copyright(text: Optional[str]) -> Optional[str]:
    """Clean copyright text."""
    if not text:
        return None
    cleaned = re.sub(r'\.\s*See\s+(the\s+)?COPYRIGHT\s*$', '', text, flags=re.IGNORECASE)
    return cleaned.strip() or None

def yellow(text: str) -> str:
    """Return text wrapped in ANSI yellow color codes."""
    return f"\033[93m{text}\033[0m"

def green(text: str) -> str:
    """Return text wrapped in ANSI green color codes."""
    return f"\033[92m{text}\033[0m"

def main():
    parser = argparse.ArgumentParser(
        description='Generate Debian copyright file from vendored dependencies.',
        epilog="""
This script scans vendor directories for license and copyright information,
then generates a machine-readable debian/copyright file in DEP-5 format.

The script will:
  1. Run licensecheck on all files in the vendor directories
  2. Look for SPDX-License-Identifier tags in source files (overrides licensecheck)
  3. Build a directory tree and merge nodes with identical license information
  4. Generate or update the debian/copyright file with Files: stanzas
  5. Preserve existing header and License: stanzas from the copyright file

Examples:
  # Auto-detect vendor directories (./vendor*)
  %(prog)s
  
  # Scan specific vendor directories
  %(prog)s vendor vendor_rust
  
  # Output to a different location
  %(prog)s --debian-copyright /path/to/copyright
        """,
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument('vendor_dirs', nargs='*',
                        metavar='<vendor-dir>',
                       help='Vendor directories to scan (default: auto-detect ./vendor*)')
    parser.add_argument('--debian-copyright', default='debian/copyright',
                        metavar="<copyright-file>",
                       help='Path to debian/copyright file (default: debian/copyright)')
    parser.add_argument('--debug', action='store_true',
                       help='Enable debug logging')

    args = parser.parse_args()

    # Set up logging
    logging.basicConfig(
        level=logging.DEBUG if args.debug else logging.INFO,
        format='%(levelname)s: %(message)s',
        stream=sys.stderr
    )

    vendor_dirs = args.vendor_dirs
    if not vendor_dirs:
        vendor_dirs = sorted(glob.glob('./vendor*'))
        if not vendor_dirs:
            logger.error("Error: No vendor directories found")
            sys.exit(1)
        logger.info(f"Auto-detected: {', '.join(vendor_dirs)}")

    for vendor_dir in vendor_dirs:
        if not os.path.isdir(vendor_dir):
            logger.error(
                f"Vendor directory '{vendor_dir}' does not exist or is not a directory")
            sys.exit(1)

    if not os.path.isdir(os.path.dirname(args.debian_copyright)):
        logger.error(
            f"Output directory '{os.path.dirname(args.debian_copyright)}' does not exist or is not a directory")
        sys.exit(1)

    # Create copyright file generator and read existing copyright
    generator = CopyrightFileGenerator(args.debian_copyright)
    generator.read_existing_copyright()

    # Create separate scanner for each vendor directory
    scanners = []
    for vendor_dir in vendor_dirs:
        scanner = VendorLicenseScanner()
        scanner.scan_directory(vendor_dir)
        scanner.merge_nodes_in_all_dirs()
        scanners.append(scanner)

    # Generate copyright file from all scanners
    generator.generate_copyright_file(scanners)

    logger.info("Done! Run 'lrc' to validate the copyright file.")


if __name__ == '__main__':
    main()

