#!/usr/bin/python3

# script to generate Unattended-Upgrade::Origins-Pattern
# part of MX Updater package


import argparse
import sys
from datetime import datetime
import subprocess
from typing import List, Dict, Optional


def run_apt_cache_policy() -> str:
    """
    Run apt-cache policy and return its output.

    :return: Output of apt-cache policy command
    :raises subprocess.CalledProcessError: If the command fails
    """
    try:
        result = subprocess.run(
            ['apt-cache', 'policy'],
            capture_output=True,
            text=True,
            check=True
        )
        return result.stdout
    except subprocess.CalledProcessError as e:
        print(f"Error running apt-cache policy: {e}", file=sys.stderr)
        sys.exit(1)


def get_architectures() -> set:
    """
    Get the current and foreign architectures.

    :return: Set of enabled architectures
    """
    try:
        # Get primary architecture
        primary_arch = subprocess.run(
            ['dpkg', '--print-architecture'],
            capture_output=True,
            text=True,
            check=True
        ).stdout.strip()

        # Get foreign architectures
        foreign_arch_result = subprocess.run(
            ['dpkg', '--print-foreign-architectures'],
            capture_output=True,
            text=True,
            check=True
        ).stdout.strip().splitlines()

        # Combine into a set of architectures
        architectures = set([primary_arch] + foreign_arch_result)

        return architectures

    except subprocess.CalledProcessError as e:
        print(f"Error getting architectures: {e}", file=sys.stderr)
        sys.exit(1)


def parse_apt_cache_policy(policy_output: str) -> List[Dict[str, str]]:
    """
    Parse apt-cache policy output into a structured list of repository entries.

    :param policy_output: Output string from apt-cache policy
    :return: List of dictionaries containing repository information
    """
    # Get allowed architectures
    allowed_architectures = get_architectures()

    repositories = []
    current_repo = None

    for line in policy_output.splitlines():
        line = line.strip()
        parts = line.split()

        # Check if the first part is a positive integer
        if parts and parts[0].isdigit() and int(parts[0]) >= 1:
            # Parse PIN and URI, removing architecture-specific "Packages" part
            pin = int(parts[0])

            try:
                packages_index = parts.index("Packages")
                uri_parts = parts[1:packages_index-1]  # Exclude architecture and "Packages"
                uri = ' '.join(uri_parts)
            except ValueError:
                # If no "Packages" found, use the entire line
                uri = ' '.join(parts[1:])

            # Create a new repository entry for this PIN/URI
            current_repo = {
                'pin': pin,
                'uri': uri,
                'uris': [uri],  # Track all URIs
                'attrs': {},
                'site': '',
            }

        # Release line
        elif line.startswith('release '):
            # Only process if we have a current repository
            if current_repo is not None:
                # Parse release attributes
                attrs = {}

                attrs.update({attr.split('=')[0]: attr.split('=')[1]
                              for attr in line[8:].split(',')
                              if '=' in attr})

                # Check if the architecture is in the allowed set
                if attrs.get('b') not in allowed_architectures:
                    # Reset current repo if architecture doesn't match
                    current_repo = None
                else:
                    current_repo['attrs'] = attrs

        elif line.startswith('origin '):
            if current_repo is not None:
                    site = parts[1] if len(parts) > 1 else ""
                    current_repo['site'] = site
                    # Find or create repository with current attrs and site
                    matching_repo = next(
                        (repo for repo in repositories
                         if repo.get('attrs', {}) == attrs and repo.get('site', '') == site and site),
                        None
                    )

                    if matching_repo is None:
                        # New unique origin
                        repositories.append(current_repo)
                    else:
                        # Add URI to existing repository if not already present
                        if current_repo['uri'] not in matching_repo.get('uris', []):
                            matching_repo['uris'].append(current_repo['uri'])
    return repositories


def filter_repositories(repositories: List[Dict[str, str]],
                        filter_origins: List[str] = ['Debian', 'Debian Backports', 'MX repository'],
                        min_pin: int = 500) -> List[Dict[str, str]]:
    """
    Filter repositories based on origin and PIN.

    :param repositories: List of repository dictionaries
    :param filter_origins: List of origins to filter out
    :param min_pin: Minimum PIN value to include
    :return: Filtered list of repositories
    """
    return [
        repo for repo in repositories
        if (repo.get('attrs', {}).get('o', '') not in filter_origins) and
           (repo.get('pin', 0) >= min_pin)
    ]


def generate_origins_pattern(repositories: List[Dict[str, str]],
                             no_uris: bool = False) -> str:
    """
    Generate the Unattended-Upgrade::Origins-Pattern stanza.

    :param repositories: List of filtered repository dictionaries
    :return: Formatted Origins-Pattern string
    """
    # Get current timestamp
    timestamp = datetime.now().astimezone().strftime("%a, %d %b %Y %H:%M:%S %z")

    # Start building the output
    output = [f"// generated by update-Origins-Pattern at {timestamp}"]
    output.append("//")
    output.append("Unattended-Upgrade::Origins-Pattern {")

    # Track processed origins to avoid duplicates
    processed_origins = set()

    for repo in repositories:
        attrs = repo.get('attrs', {})

        # Create origin pattern string
        origin_pattern = ','.join(f'{k}={attrs.get(k,"")}' for k in ('o', 'a', 'n', 'l'))

        # Add site to origin pattern string
        origin_pattern += f',site={site}' if (site := repo.get('site', '')) else ''

        # Avoid duplicates
        if origin_pattern in processed_origins:
            continue
        processed_origins.add(origin_pattern)

        # Add commented URI lines
        if not no_uris:
            for uri in repo.get('uris', []):
                output.append(f"//  {repo['pin']} {uri}")

        # Add origin pattern
        output.append(f'    "{origin_pattern}";')

    output.append("};\n")

    return "\n".join(output)


def positive_int(value):
    try:
        ivalue = int(value)
        if ivalue < 1:
            raise argparse.ArgumentTypeError(f"invalid non-positive int value: {value}")
        return ivalue
    except ValueError:
        raise argparse.ArgumentTypeError(f"invalid int value: {value}")


def create_parser(default_file):
    parser = argparse.ArgumentParser(
        description="Update the Unattended-Upgrade::Origins-Pattern for non-standard repositories"
    )

    # Origins option
    parser.add_argument(
        '-o', '--origins',
        nargs='+',
        default=['Debian', 'Debian Backports', 'MX repository'],
        help="Origins to filter out (default: 'Debian', 'Debian Backports', 'MX repository')"
    )

    parser.add_argument(
        '-p', '--pin',
        type=positive_int,
        default=500,
        help='Minimum positive PIN to include (default: 500)'
    )

    parser.add_argument(
        '-n', '--no-uris',
        help="Don't include pin/uri lines within the generated output",
        action="store_true"
    )

    # Mutually exclusive group for output
    output_group = parser.add_mutually_exclusive_group()

    output_group.add_argument(
        '-d', '--default',
        help=f'Use default file path (default: {default_file})',
        action="store_true"
    )
    output_group.add_argument(
        'file',
        nargs='?',
        help='Custom file to be updated'
    )

    return parser


def main():
    # Default output file path
    default_file = '/etc/apt/apt.conf.d/51unattended-upgrades-origins'
    parser = create_parser(default_file)
    args = parser.parse_args()

   # Determine output file
    if args.default:
        # -d was used
        output_file = default_file
    elif args.file:
        # Custom file specified
        output_file = args.file
    else:
        # No output file specified
        output_file = None  # stdout

    # Parse apt-cache policy input
    try:
        repositories = parse_apt_cache_policy(run_apt_cache_policy())
    except Exception as e:
        print(f"Error parsing apt-cache policy: {e}", file=sys.stderr)
        sys.exit(1)

    # Filter repositories
    filtered_repos = filter_repositories(
        repositories,
        filter_origins=args.origins,
        min_pin=args.pin
    )

    # Generate origins pattern
    origins_pattern = generate_origins_pattern(
        filtered_repos,
        args.no_uris
    )

    # Determine output destination
    if output_file:
        try:
            with open(output_file, 'w') as f:
                f.write(origins_pattern)
            print(f"Origins-Pattern written to '{output_file}'")
        except PermissionError:
            print(f"[Errno 13] Permission denied: '{output_file}'", file=sys.stderr)
            sys.exit(1)
        except IOError as e:
            print(f"Error writing to file: {e}", file=sys.stderr)
            sys.exit(1)
    else:
        # Output to stdout
        print(origins_pattern)

if __name__ == '__main__':
    main()
