#!/usr/libexec/platform-python
#
# Copyright (c) 2026, Oracle and/or its affiliates.
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
#
# This code is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 2 only, as
# published by the Free Software Foundation.
#
# This code is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
# version 2 for more details (a copy is included in the LICENSE file that
# accompanied this code).
#
# You should have received a copy of the GNU General Public License version
# 2 along with this work; if not, see https://www.gnu.org/licenses/.
#
# Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
# or visit www.oracle.com if you need additional information or have any
# questions.
"""cgstat: collect and summarize cgroup CPU and memory statistics."""

import argparse
import json
import os
import subprocess
import sys
from copy import deepcopy
from pathlib import Path
from typing import List

from cglib import cgstat_alarm as cgalarm
from cglib import cgstat_util as cgutil
from cglib.cgstat_alarm import alarm, alarm_exclusive, print_alarm, setup_alarm
from cglib.cgstat_cpu import get_cpu_stat_v1, get_cpu_stat_v2
from cglib.cgstat_mem import (
    get_cur_swap_v2,
    get_mem_current_v2,
    get_mem_events_v2,
    get_mem_failcnt_v1,
    get_mem_stat_v1,
    get_mem_stat_v2,
    get_mem_usage_v1,
    get_numa_stat_v1,
    get_numa_stat_v2,
)
from cglib.cgstat_util import debug

VERSION = "1.0.0"
CONFIG_FILE = "/etc/oled/cgstat.json"
DEFAULT_PATH = "/sys/fs/cgroup"
TYPE_KEYS = ["stat", "current", "events", "numa", "swap", "usage", "failcnt"]
VALID_TYPE_CHOICES = {
    1: ["stat", "usage", "failcnt", "numa"],
    2: ["stat", "current", "events", "numa", "swap"],
}

DEFAULT_TOP = 10
DEFAULT_NUMA_DIFF = 20  # Numa difference in %
DEFAULT_MEM_FAULTS = 1
DEFAULT_CPU_FAULTS = 1

DEFAULT_CONFIG = {
    "path": DEFAULT_PATH,
    "top": DEFAULT_TOP,
    "numadiff": DEFAULT_NUMA_DIFF,
    "memfault": DEFAULT_MEM_FAULTS,
    "cpufault": DEFAULT_CPU_FAULTS,
}

V2_CONTROLLER = {
    "cgroup.controllers",
    "cgroup.subtree_control",
    "cgroup.procs",
    "cgroup.events",
    "cgroup.type",
    "cgroup.stat",
}

V1_CONTROLLER = {
    "cpu",
    "cpuacct",
    "cpuset",
    "memory",
    "blkio",
    "systemd",
}

FUNCTIONS: List[str] = []


def default_config_for_version():
    """Return a normalized default configuration."""
    config = deepcopy(DEFAULT_CONFIG)
    return config


def load_cgstat_conf(conf_path):
    """Load cgstat configuration from disk."""
    try:
        with open(conf_path, "r", encoding="utf-8") as conf:
            return json.load(conf)
    except OSError:
        return None
    except json.JSONDecodeError as exc:
        print(f"Unable to parse json from {conf_path}: {exc}")
        return None


def write_cgstat_conf(conf_path, conf_dict):
    """Write cgstat configuration to disk."""
    conf_dir = os.path.dirname(conf_path)
    if conf_dir:
        os.makedirs(conf_dir, exist_ok=True)
    try:
        with open(conf_path, "w", encoding="utf-8") as conf:
            json.dump(conf_dict, conf, indent=2, sort_keys=True)
            conf.write("\n")
    except OSError as exc:
        raise ValueError(
            f"Unable to write config to {conf_path}: {exc}"
        ) from exc


def normalize_config(conf_dict):
    """Normalize loaded config values and fill defaults."""
    config = default_config_for_version()
    if conf_dict:
        config.update(conf_dict)
    config["path"] = str(config.get("path") or DEFAULT_PATH)

    try:
        config["top"] = int(config.get("top", DEFAULT_TOP))
    except (TypeError, ValueError):
        config["top"] = DEFAULT_TOP
    if config["top"] <= 0:
        config["top"] = DEFAULT_TOP

    try:
        config["numadiff"] = int(config.get("numadiff", DEFAULT_NUMA_DIFF))
    except (TypeError, ValueError):
        config["numadiff"] = DEFAULT_NUMA_DIFF
    if config["numadiff"] <= 0 or config["numadiff"] > 100:
        config["numadiff"] = DEFAULT_NUMA_DIFF

    try:
        config["cpufault"] = int(config.get("cpufault", DEFAULT_CPU_FAULTS))
    except (TypeError, ValueError):
        config["cpufault"] = DEFAULT_CPU_FAULTS
    if config["cpufault"] < 0 or config["cpufault"] > 1:
        config["cpufault"] = DEFAULT_CPU_FAULTS

    try:
        config["memfault"] = int(config.get("memfault", DEFAULT_MEM_FAULTS))
    except (TypeError, ValueError):
        config["memfault"] = DEFAULT_MEM_FAULTS
    if config["memfault"] < 0 or config["memfault"] > 1:
        config["memfault"] = DEFAULT_MEM_FAULTS

    return config


def print_conf(conf_dict):
    """Print cgstat configuration."""
    print(f"path    : {conf_dict['path']}")
    print(f"top     : {conf_dict['top']}")
    print(f"numadiff    : {conf_dict['numadiff']}")
    print(f"memfault    : {conf_dict['memfault']}")
    print(f"cpufault    : {conf_dict['cpufault']}")


def get_cgroup_status(print_version=True):
    """Print cgroup version and exit if not configured."""
    base = DEFAULT_PATH
    if not os.path.exists(base):
        print("Error: cgroup is not configured (/sys/fs/cgroup missing).")
        sys.exit(1)
    try:
        fs_type = (
            subprocess.check_output(["stat", "-fc", "%T", base])
            .decode()
            .strip()
        )
    except (
            subprocess.CalledProcessError, OSError, UnicodeDecodeError
           ) as error:
        print(f"Error: cgroup is not configured. Details: {error}")
        sys.exit(1)

    if fs_type == "cgroup2fs":
        if print_version:
            print("Cgroup Version: v2")
            print("")
        return 2
    if fs_type in ["tmpfs", "cgroup"]:
        if print_version:
            print("Cgroup Version: v1")
            print("")
        return 1

    print(f"Error: cgroup is not configured (Type: {fs_type})")
    sys.exit(1)


def cpu_stat_v1(path, count, run_config):
    """
    cgroup v1, Get cpu stats.
    """
    if debug():
        print(f"Component: cpu | Type: stat | N: {count}")
        print(path, count)
    get_cpu_stat_v1(path, count, run_config)


def memory_stat_v1(path, count, run_config):
    """
    cgroup v1, Get memory stats.
    """
    if debug():
        print(f"Component: memory | Type: stat | N: {count}")
        print(path, count)
    get_mem_stat_v1(path, count, run_config)


def memory_usage_v1(path, count, run_config):
    """
    cgroup v1, Get memory usage.
    """
    if debug():
        print(f"Component: memory | Type: usage | N: {count}")
        print(path, count)
    get_mem_usage_v1(path, count, run_config)


def memory_failcnt_v1(path, count, run_config):
    """
    cgroup v1, Get memory failcount.
    """
    if debug():
        print(f"Component: memory | Type: failcount | N: {count}")
        print(path, count)
    get_mem_failcnt_v1(path, count, run_config)


def memory_numa_v1(path, count, run_config):
    """
    cgroup v1, Get numa memory usage.
    """
    if debug():
        print(f"Component: memory | Type: numa | N: {count}")
        print(path, count)
    get_numa_stat_v1(path, count, run_config)


def memory_swap_v1(path, count, run_config):
    """
    cgroup v2, Get swap memory usage.
    """
    del path, count, run_config
    print("Warning: CGroup v1 does not support separate swap stats.")
    print("Tip: Check memory stats in CGroup v1 for swap usage info.")


def cpu_stat_v2(path, count, run_config):
    """
    cgroup v2, Get cpu usage.
    """
    if debug():
        print(f"Component: cpu | Type: stat | N: {count}")
        print(path, count)
    get_cpu_stat_v2(path, count, run_config)


def memory_stat_v2(path, count, run_config):
    """
    cgroup v2, Get memory stats.
    """
    if debug():
        print(f"Component: memory | Type: stat | N: {count}")
        print(path, count)
    get_mem_stat_v2(path, count, run_config)


def memory_current_v2(path, count, run_config):
    """
    cgroup v2, Get memory current usage.
    """
    if debug():
        print(f"Component: memory | Type: current | N: {count}")
        print(path, count)
    get_mem_current_v2(path, count, run_config)


def memory_events_v2(path, count, run_config):
    """
    cgroup v2, Get memory events.
    """
    if debug():
        print(f"Component: memory | Type: events | N: {count}")
        print(path, count)
    get_mem_events_v2(path, count, run_config)


def memory_numa_v2(path, count, run_config):
    """
    cgroup v2, Get numa memory usage.
    """
    if debug():
        print(f"Component: memory | Type: numa | N: {count}")
        print(path, count)
    get_numa_stat_v2(path, count, run_config)


def memory_swap_v2(path, count, run_config):
    """
    cgroup v2, Get swap memory usage.
    """
    if debug():
        print(f"Component: memory | Type: swap | N: {count}")
        print(path, count)
    get_cur_swap_v2(path, count, run_config)


def detect_cgroup_version(cgroup_root, print_version=True):
    """
    Detect cgroup version from path.
    """
    root = Path(cgroup_root)
    if not root.exists() or not root.is_dir():
        raise ValueError(f"Not a directory: {root}")

    entries = {path.name for path in root.iterdir()}
    has_v2_markers = len(V2_CONTROLLER & entries) >= 2
    v1_dirs_present = sorted(
        [
            name
            for name in entries
            if name in V1_CONTROLLER and (root / name).is_dir()
        ]
    )

    if has_v2_markers and not v1_dirs_present:
        if print_version:
            print(f"{cgroup_root} : CGroup Version 2")
        return 2
    if (not has_v2_markers) and v1_dirs_present:
        if print_version:
            print(f"{cgroup_root} : CGroup Version 1")
        return 1
    if has_v2_markers and v1_dirs_present:
        print("Both v2 root markers and v1 controller dirs present.")
        sys.exit(1)

    print("Insufficient evidence (snapshot may be incomplete)")
    sys.exit(1)


def raise_all_alarm(run_config, version, path):
    """
    Raise Alarm.
    """
    current_module = sys.modules[__name__]
    FUNCTIONS.clear()
    if version == 2:
        FUNCTIONS.extend(
            [
                "cpu_stat_v2",
                "memory_stat_v2",
                "memory_current_v2",
                "memory_events_v2",
                "memory_numa_v2",
                "memory_swap_v2",
            ]
        )
    else:
        FUNCTIONS.extend(
            [
                "cpu_stat_v1",
                "memory_stat_v1",
                "memory_usage_v1",
                "memory_failcnt_v1",
                "memory_numa_v1",
            ]
        )

    for func_name in FUNCTIONS:
        if hasattr(current_module, func_name):
            func = getattr(current_module, func_name)
            func(path, run_config["top"], run_config)


def memory_main(args, run_config, version, path):
    """
    Get memory usage.
    """
    func_name = f"memory_{args.memory}_v{version}"
    if args.verbose:
        print(f"[DEBUG] Calling function: {func_name}")
    current_module = sys.modules[__name__]
    if hasattr(current_module, func_name):
        getattr(current_module, func_name)(path, run_config["top"], run_config)
    else:
        print(
            f"Error: Function {func_name}"
            f" is not implemented for this version."
        )


def cpu_main(args, run_config, version, path):
    """
    Get cpu usage.
    """
    func_name = f"cpu_stat_v{version}"
    if args.verbose:
        print(f"[DEBUG] Calling function: {func_name}")
    current_module = sys.modules[__name__]
    if hasattr(current_module, func_name):
        getattr(current_module, func_name)(path, run_config["top"], run_config)
    else:
        print(
            f"Error: Function {func_name}"
            f" is not implemented for this version."
        )


def detect_version_for_path(path, print_version=True):
    """Detect cgroup version for configured path."""
    if path == DEFAULT_PATH:
        return get_cgroup_status(print_version=print_version)
    if debug() and print_version:
        print("Configuring CGroup root path to:", path)
    return detect_cgroup_version(path, print_version=print_version)


def preparse_path(argv):
    """Extract --path early for version-aware help choices."""
    pre_parser = argparse.ArgumentParser(add_help=False)
    pre_parser.add_argument("-p", "--path", default=None)
    known_args, _ = pre_parser.parse_known_args(argv)
    return known_args.path


def resolve_runtime_config(config, args):
    """Merge stored config with CLI overrides."""
    run_config = normalize_config(config)

    if getattr(args, "path", None) is not None:
        run_config["path"] = args.path
    if getattr(args, "top", None) is not None:
        run_config["top"] = args.top
    if getattr(args, "numadiff", None) is not None:
        run_config["numadiff"] = args.numadiff
    if getattr(args, "memfault", None) is not None:
        run_config["memfault"] = args.memfault
    if getattr(args, "cpufault", None) is not None:
        run_config["cpufault"] = args.cpufault

    return run_config


def configure_cgstat(current_conf, options):
    """Apply configure subcommand changes to cgstat config."""
    conf = normalize_config(current_conf)
    if options.reset_defaults:
        conf = default_config_for_version()

    changed = False
    if options.path is not None:
        conf["path"] = options.path
        changed = True
        conf = normalize_config(conf)

    if options.top is not None:
        if options.top <= 0:
            raise ValueError("--top must be a positive integer")
        conf["top"] = options.top
        changed = True

    if options.numadiff is not None:
        if options.numadiff <= 0 or options.numadiff > 100:
            raise ValueError(
                "--numadiff percentage must be a positive integer"
                " in range [1-100]"
            )
        conf["numadiff"] = options.numadiff
        changed = True

    if options.memfault is not None:
        if options.memfault < 0 or options.memfault > 1:
            raise ValueError("--memfault must be a boolean")
        conf["memfault"] = options.memfault
        changed = True

    if options.cpufault is not None:
        if options.cpufault < 0 or options.cpufault > 1:
            raise ValueError("--cpufault must be a boolean")
        conf["cpufault"] = options.cpufault
        changed = True

    if options.show or not changed:
        print_conf(conf)

    if changed:
        write_cgstat_conf(CONFIG_FILE, conf)

    return conf, changed


def build_parser(version=None):
    """Build cgstat CLI parser."""
    memory_choices = TYPE_KEYS
    if version in VALID_TYPE_CHOICES:
        memory_choices = VALID_TYPE_CHOICES.get(version, [])

    parser = argparse.ArgumentParser(
        prog="oled cgstat",
        description="Display cgroup CPU and memory statistics and alarms.",
    )
    parser.add_argument(
        "-p", "--path", default=None, help="Path to cgroup root"
    )
    parser.add_argument(
        "-A", "--alarm", action="store_true", help="Raise an alarm on anomaly"
    )
    parser.add_argument(
        "-c",
        "--cpu",
        action="store_true",
        help="Collect CPU cgroup statistics",
    )
    parser.add_argument(
        "-m",
        "--memory",
        nargs="?",
        const="stat",
        choices=memory_choices,
        help="Collect Memory cgroup statistics (default: stat)",
    )
    parser.add_argument(
        "-n",
        "--top",
        type=int,
        default=None,
        help="Number of top entries to show",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="Enable verbose/debug output",
    )

    subparsers = parser.add_subparsers(dest="command")
    config_parser = subparsers.add_parser(
        "configure", help="Configure cgstat defaults."
    )
    config_parser.add_argument(
        "--show",
        action="store_true",
        help="Show current cgstat configuration.",
    )
    config_parser.add_argument(
        "--reset-defaults",
        action="store_true",
        help="Reset cgstat configuration to defaults.",
    )
    config_parser.add_argument(
        "--path", type=str, help="Set default cgroup root path."
    )
    config_parser.add_argument(
        "--top", type=int, help="Set default top-N result count"
    )
    config_parser.add_argument(
        "--numadiff",
        type=int,
        help="Set default %% for NUMA difference to trigger an alarm.",
    )
    config_parser.add_argument(
        "--memfault",
        type=int,
        help="Set default configuration to enable memory fault alarms.",
    )
    config_parser.add_argument(
        "--cpufault",
        type=int,
        help="Set default configuration to enable memory fault alarms.",
    )
    return parser


def main(argv=None):
    """Main entry point."""
    argv = sys.argv[1:] if argv is None else argv
    path_hint = preparse_path(argv) or DEFAULT_PATH
    try:
        version_for_help = detect_version_for_path(
            path_hint, print_version=False
        )
    except (ValueError, OSError, SystemExit):
        version_for_help = None
    parser = build_parser(version_for_help)

    options = parser.parse_args(argv)
    loaded_conf = load_cgstat_conf(CONFIG_FILE)
    conf = normalize_config(loaded_conf)

    if options.command == "configure":
        try:
            configure_cgstat(conf, options)
        except ValueError as exc:
            print(exc)
            sys.exit(1)
        return

    path_for_version = getattr(options, "path", None) or conf["path"]
    version = detect_version_for_path(path_for_version)
    try:
        run_config = resolve_runtime_config(conf, options)
    except ValueError as exc:
        parser.error(str(exc))

    path = run_config["path"]
    cgutil.DEBUG = options.verbose
    cgalarm.ALARM = setup_alarm(options)

    if alarm_exclusive(alarm()):
        raise_all_alarm(run_config, version, path)

    if options.memory:
        memory_main(options, run_config, version, path)

    if options.cpu:
        cpu_main(options, run_config, version, path)

    if options.alarm:
        print_alarm(alarm())


if __name__ == "__main__":
    main()
