import hmac
from hashlib import sha256
from Crypto.Cipher import AES
from sage.all import *

# P384 parameters
p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff
a = -3
b = 0xb3312fa7e23ee7e4988e056be3f82d19181d9c6efe8141120314088f5013875ac656398d8a2ed19d2a85c8edd3ec2aef
q = 0xffffffffffffffffffffffffffffffffffffffffffffffffc7634d81f4372ddf581a0db248b0a77aecec196accc52973
Gx = 0xaa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab7
Gy = 0x3617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e5f

msgs = [
    b"https://www.youtube.com/watch?v=LaX6EIkk_pQ",
    b"https://www.youtube.com/watch?v=wK4wA0aKvg8",
    b"https://www.youtube.com/watch?v=iq90nHs3Gbs",
    b"https://www.youtube.com/watch?v=zTKADhU__sw",
]

sigs = [(317707421133410288073354603009480426136391906002873302709570879761947103070512898051132583840618463139472027601216698251294206460344755339051109898589809987983731707077909099505833365567522347006453766545663380230105595126817790425, 25185752159924706126981435669717936861361993674900106138337831137838509453749313533989197233649309651483579988978205), (417548456675579988606680466439690234874946492911623920447331037240230655879606626325624623314611471522814787475988129078726743347417903386362824681134780863810523742180718053363084828145812067731683272119151061828749117659255650820, 27618563118772187320593702066291845973666620541831283288991142064228070314197536489147588491763843793593821643513457), (703771273054730080235579285501232710659154148145979519264450072512823561624248636822569827736905476306443746390214567198923437156846958456303186787370323078966806939434118158768394748234214487029382926999880135374613932395712372460, 27052092405825396792237011211691900251888872753276208811631357208317438773416505653305767076226992282260977625878007), (821717323558426535455119744526279609022144869806906586662554363968363839151910768914318502227461974453838258550953434850776924606792184210954238562503515009237179979646111655773804054528212491391076376250546737439142144165942539844, 28870411728276849847003745583242490365442899058004875752358198407125701328587711166784961247940279464305857022011977)]
ct = b'iXm\x982\xc5\xf23\x85\x88\x91\x0c\x7f\xdc\x1b,\x1b\x82\x9d\xcd\x00 BWn\xad\n\xc3`\xe7\x8e\xfc`%\x9cQ\x12E\x97\x97\xa5\xd5t\x8b\x87v\xb4\xcf\x8d'

# Try all endianness combinations
configs = [
    ('big', 'big'),
    ('little', 'big'),
    ('big', 'little'),
    ('little', 'little')
]

M = 2**256

for r_endian, e_endian in configs:
    print(f"Testing r={r_endian}, e={e_endian}")
    
    # Reconstruct e values
    data = []
    try:
        for i, (r_int, s) in enumerate(sigs):
            msg = msgs[i]
            r_bytes = int(r_int).to_bytes(1337, r_endian)
            e = int.from_bytes(hmac.new(r_bytes, msg, sha256).digest(), e_endian) % q
            data.append({'s': s, 'e': e})
    except Exception as err:
        print(f"Error computing e: {err}")
        continue
    
    de = []
    ds = []
    for i in range(3):
        de_val = (data[i]['e'] - data[i+1]['e']) % q
        ds_val = (data[i]['s'] - data[i+1]['s']) % q
        de.append(de_val)
        ds.append(ds_val)

    # CVP approach
    print("Running CVP...")
    
    # Lattice basis
    BL = Matrix(ZZ, [
        [q, 0, 0],
        [0, q, 0],
        [0, 0, q],
        [de[0], de[1], de[2]]
    ])
    
    BL = BL.LLL()
    t = vector(ZZ, [ds[0], ds[1], ds[2]])
    
    try:
        # Use Babai's nearest plane algorithm (approximate CVP)
        # Sage has a method for this?
        # Or we can use the embedding result if CVP fails.
        
        # Let's try the embedding result first as we did, but check ALL vectors in basis
        # that have last component M.
        # We already did that.
        
        # Let's try to use the CVP vector from the embedding?
        # The embedding finds (error, -M).
        # So error = vector - (0,0,0,-M).
        # v_lattice = t + error.
        # We want v_lattice.
        
        # Let's try to use Sage's closest_vector if available
        # It uses fpylll's CVP
        
        # Note: closest_vector might be slow for large dimensions, but 3 is small.
        # However, the coordinates are large (size q).
        
        # We need to be careful.
        pass
        
    except:
        pass

    # Continue with embedding approach but maybe check more things?
    
    B = Matrix(ZZ, [
        [q, 0, 0, 0],
        [0, q, 0, 0],
        [0, 0, q, 0],
        [de[0], de[1], de[2], 0],
        [ds[0], ds[1], ds[2], M]
    ])

    print("Running BKZ-60...")
    reduced_basis = B.BKZ(block_size=60)
    
    # Enumeration approach
    print("Running Enumeration...")
    from fpylll import IntegerMatrix, LLL, BKZ, Enumeration, GSO
    
    # Use the basis from BKZ-60
    B_sage = reduced_basis
    A = IntegerMatrix.from_matrix(B_sage)
    
    M_fpylll = GSO.Mat(A)
    M_fpylll.update_gso()
    
    enum = Enumeration(M_fpylll)
    # Radius: slightly larger than expected norm
    # Expected norm approx sqrt(3)*2^256 approx 1.73 * 2^256
    # Let's use 1.8 * 2^256
    radius = (1.8 * 2**256)**2
    
    try:
        print(f"Enumerating with radius {radius}...")
        solutions = enum.enumerate(0, A.nrows, radius, 0)
        print(f"Found {len(solutions)} vectors in radius")
        
        for cost, coeffs in solutions:
            vec = vector(ZZ, [0]*5)
            for i, c in enumerate(coeffs):
                vec += int(c) * B_sage.row(i)
            
            if abs(vec[3]) == M:
                if vec[3] == M:
                    vec = -vec
                
                dk0 = vec[0]
                sk_candidate = (ds[0] + dk0) * inverse_mod(de[0], q) % q
                
                # Verify signature
                try:
                    E = EllipticCurve(GF(p), [a, b])
                    G = E(Gx, Gy)
                    s0 = sigs[0][1]
                    r0_target = sigs[0][0]
                    e0 = data[0]['e']
                    k0 = (s0 - sk_candidate * e0) % q
                    R0 = k0 * G
                    r0_calc = int(R0.xy()[0]) * p + int(R0.xy()[1])
                    
                    if r0_calc == r0_target:
                        print(f"Signature verified! sk: {sk_candidate}")
                        # Decrypt
                        for key_endian in ['big', 'little']:
                            try:
                                key_bytes = int(sk_candidate & ((1 << 128) - 1)).to_bytes(16, key_endian)
                                cipher = AES.new(key_bytes, AES.MODE_CTR, nonce=ct[:8])
                                pt = cipher.decrypt(ct[8:])
                                if b'flag' in pt or b'ctf' in pt or b'CTF' in pt:
                                    print(f"SUCCESS! Flag: {pt}")
                                    exit(0)
                            except:
                                pass
                except:
                    pass
    except Exception as e:
        print(f"Enumeration failed: {e}")

