#!/usr/bin/env python3
"""
SecLog Independent Proof Verifier

Verifies a cryptographic proof using ONLY standard crypto libraries.
No SecLog code. No knowledge of internal data structures required.

Dependencies:
    pip install blake3 pynacl

Usage:
    python3 verify_proof.py <bundle-dir>
    python3 verify_proof.py <bundle-dir> --tamper   # flip one byte, watch it fail
"""

import json
import sys
from pathlib import Path

import blake3
from nacl.signing import VerifyKey
from nacl.exceptions import BadSignatureError


def hex_to_bytes(h: str) -> bytes:
    return bytes.fromhex(h)


def blake3_hash(data: bytes) -> bytes:
    return blake3.blake3(data).digest()


def blake3_pair(a: bytes, b: bytes) -> bytes:
    h = blake3.blake3()
    h.update(a)
    h.update(b)
    return h.digest()


def verify_merkle_proof(leaf_hash: bytes, proof: list) -> bytes:
    """Walk the Merkle proof to reconstruct the root."""
    current = leaf_hash
    for entry in proof:
        sibling = hex_to_bytes(entry["hash"])
        if entry["is_right"]:
            current = blake3_pair(current, sibling)
        else:
            current = blake3_pair(sibling, current)
    return current


def verify_ed25519(public_key: bytes, signature: bytes, message: bytes) -> bool:
    """Verify an Ed25519 signature."""
    try:
        VerifyKey(public_key).verify(message, signature)
        return True
    except BadSignatureError:
        return False


def main():
    if len(sys.argv) < 2:
        print("Usage: python3 verify_proof.py <bundle-dir> [--tamper]")
        sys.exit(1)

    bundle_dir = Path(sys.argv[1])
    tamper = "--tamper" in sys.argv

    with open(bundle_dir / "proof.json") as f:
        proof = json.load(f)

    pk_path = bundle_dir / "public_key.hex"
    public_key = hex_to_bytes(pk_path.read_text().strip()) if pk_path.exists() \
        else hex_to_bytes(proof["public_key"])

    print()
    print("SecLog Proof Verifier")
    print("=" * 50)
    print(f"Libraries: blake3, pynacl (libsodium)")
    print()

    # Get the record bytes
    record = bytearray(hex_to_bytes(proof["record"]))

    if tamper:
        print(">>> TAMPER: flipping one byte in the record")
        record[12] ^= 0x01  # flip a bit
        print()

    # ── Check 1: Hash the record, walk the inclusion proof ────────
    leaf_hash = blake3_hash(bytes(record))
    inclusion_proof = proof.get("proof", proof.get("merkle_proof", []))
    computed_root = verify_merkle_proof(leaf_hash, inclusion_proof)
    expected_root = hex_to_bytes(proof.get("root", proof.get("merkle_root", "")))
    root_ok = computed_root == expected_root

    print(f"1. Record hash + inclusion proof ({len(inclusion_proof)} levels):")
    print(f"   Computed root: {computed_root.hex()[:32]}...")
    print(f"   Expected root: {expected_root.hex()[:32]}...")
    print(f"   {'PASS' if root_ok else 'FAIL'}")

    # ── Check 2: Signature over the root ─────────────────────────
    sig = hex_to_bytes(proof.get("signature", proof.get("block_signature", "")))
    sig_ok = verify_ed25519(public_key, sig, expected_root)

    print(f"2. Signature (Ed25519):")
    print(f"   {'PASS' if sig_ok else 'FAIL'}")

    # ── Check 3: Anchor signature ────────────────────────────────
    anchor_payload = hex_to_bytes(proof["anchor_payload"])
    anchor_sig = hex_to_bytes(proof["anchor_signature"])
    anchor_sig_ok = verify_ed25519(public_key, anchor_sig, anchor_payload)

    print(f"3. Anchor signature (Ed25519):")
    print(f"   {'PASS' if anchor_sig_ok else 'FAIL'}")

    # ── Check 4: Chain integrity ─────────────────────────────────
    chain = proof["chain"]
    chain_ok = len(chain) > 0
    if chain_ok:
        chain_ok = chain_ok and hex_to_bytes(chain[0]) == expected_root
        anchor_root = hex_to_bytes(proof.get("anchor_root", proof.get("anchor_merkle_root", "")))
        chain_ok = chain_ok and hex_to_bytes(chain[-1]) == anchor_root

    print(f"4. Chain integrity ({len(chain)} links):")
    print(f"   {'PASS' if chain_ok else 'FAIL'}")

    # ── Verdict ──────────────────────────────────────────────────
    all_pass = root_ok and sig_ok and anchor_sig_ok and chain_ok
    print()
    print("=" * 50)
    if all_pass:
        print("VERDICT: VALID")
        print()
        print("This record is cryptographically proven to exist in the")
        print("log. Any modification to the record, the chain, or the")
        print("signatures would break the proof.")
    else:
        print("VERDICT: INVALID")
        print()
        if not root_ok:
            print("Record does not match the root hash.")
        if not sig_ok:
            print("Signature verification failed.")
        if not anchor_sig_ok:
            print("Anchor signature failed.")
        if not chain_ok:
            print("Chain integrity check failed.")
    print()

    # Show timing if available
    timing_path = bundle_dir / "timing.json"
    if timing_path.exists():
        with open(timing_path) as f:
            t = json.load(f)
        print(f"Throughput: {t['trades']:,} records in {t['elapsed_ms']}ms"
              f" = {t['tps']:,} records/sec")

        hw = t.get("hardware")
        if hw:
            cores = hw.get("logical_cores", "?")
            print(f"  measured on: {hw.get('cpu_model','?')} "
                  f"({hw.get('arch','?')}, {cores} cores, "
                  f"{hw.get('os','?')} {hw.get('kernel','?')})")
        else:
            print("  measured on: (hardware identity not recorded — "
                  "legacy bundle; do not quote the TPS figure without context)")

        cfg = t.get("config")
        if cfg:
            print(f"  config:      partitions={cfg.get('partition_count','?')}, "
                  f"trades_per_block={cfg.get('trades_per_block','?')}, "
                  f"durability={cfg.get('durability_mode','?')}")
        print()

    sys.exit(0 if all_pass else 1)


if __name__ == "__main__":
    main()
