#! /usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0-only
#
# Extract FIPS 140 cryptographic module (and hmac) from a kernel image.
#
# Copyright © 2025, Oracle and/or its affiliates.
#

import argparse
import os
import re
import shutil
import subprocess
import sys
import tempfile

extract_vmlinux = os.path.join(os.path.dirname(__file__), 'extract-vmlinux')

parser = argparse.ArgumentParser()
parser.add_argument('kernel')

args = parser.parse_args()

warnings = False


def is_elf(path):
    return subprocess.call(
        ['readelf', '-h', path],
        stdout=subprocess.DEVNULL,
        stderr=subprocess.DEVNULL,
    ) == 0


def get_section_range(path, section):
    output = subprocess.check_output(
        ['readelf', '-W', '-S', path],
        text=True,
        stderr=subprocess.DEVNULL,
    )

    pattern = re.compile(
        r'^\s*\[\s*\d+\]\s+(\S+)\s+\S+\s+[0-9A-Fa-f]+\s+'
        r'([0-9A-Fa-f]+)\s+([0-9A-Fa-f]+)\b'
    )

    for line in output.splitlines():
        match = pattern.match(line)
        if not match or match.group(1) != section:
            continue
        return int(match.group(2), 16), int(match.group(3), 16)

    return None


def recover_vmlinux(kernel_path, vmlinux_path):
    if not os.path.isfile(kernel_path):
        raise SystemExit(f"error: input file not found: {kernel_path}")

    with open(vmlinux_path, 'wb') as f:
        ret = subprocess.call([extract_vmlinux, kernel_path], stdout=f)

    if ret == 0 and is_elf(vmlinux_path):
        return

    if 'aarch64' in kernel_path.lower() or 'arm64' in kernel_path.lower():
        hint = (
            "hint: on arm64, /boot/vmlinuz is typically Image.gz or vmlinuz.efi, "
            "which does not embed the vmlinux ELF section table."
        )
    else:
        hint = (
            "hint: this input is not a directly usable ELF vmlinux, and "
            "extract-vmlinux could not recover one from it."
        )

    raise SystemExit(
        f"error: failed to recover an ELF vmlinux from {kernel_path}\n"
        f"{hint}\n"
        "hint: pass the uncompressed vmlinux file instead."
    )

with tempfile.TemporaryDirectory() as tmp_dir:
    vmlinux_path = os.path.join(tmp_dir, 'vmlinux')
    recover_vmlinux(args.kernel, vmlinux_path)

    def extract_section(input_path, output_path, section):
        global warnings

        tmp_output_path = os.path.join(tmp_dir, output_path)
        section_range = get_section_range(input_path, section)

        if not section_range:
            print(f"warning: failed to extract {output_path}; missing section {section}?", file=sys.stderr)
            warnings = True
            return

        offset, size = section_range
        with open(input_path, 'rb') as input_file, open(tmp_output_path, 'wb') as output_file:
            input_file.seek(offset)
            output_file.write(input_file.read(size))

        actual_size = os.path.getsize(tmp_output_path)
        if size == 0:
            print(f"warning: failed to extract {output_path}; empty section {section}?", file=sys.stderr)
            warnings = True
        elif actual_size != size:
            print(f"warning: failed to extract {output_path}; short section read for {section}", file=sys.stderr)
            warnings = True
        else:
            shutil.move(tmp_output_path, output_path)
            print(f"extracted {output_path}")

    extract_section(vmlinux_path, 'fips140.ko', '__fips140_module')
    extract_section(vmlinux_path, 'fips140.hmac', '__fips140_digest')

if warnings:
    sys.exit(1)
