#!/usr/bin/env python3

project='ntruprime'

primitives = {
  'sntrup653': {'name': 'sntrup653', 'pkbytes': 994, 'skbytes': 1518, 'cbytes': 897},
  'sntrup761': {'name': 'sntrup761', 'pkbytes': 1158, 'skbytes': 1763, 'cbytes': 1039},
  'sntrup857': {'name': 'sntrup857', 'pkbytes': 1322, 'skbytes': 1999, 'cbytes': 1184},
  'sntrup953': {'name': 'sntrup953', 'pkbytes': 1505, 'skbytes': 2254, 'cbytes': 1349},
  'sntrup1013': {'name': 'sntrup1013', 'pkbytes': 1623, 'skbytes': 2417, 'cbytes': 1455},
  'sntrup1277': {'name': 'sntrup1277', 'pkbytes': 2067, 'skbytes': 3059, 'cbytes': 1847},
}

kem="""from typing import Tuple as _Tuple
import ctypes as _ct
from ._lib import _lib, _check_input


class _KEM:
    def __init__(self) -> None:
        '''
        '''
        self._c_keypair = getattr(_lib, '%s_keypair' % self._prefix)
        self._c_keypair.argtypes = [_ct.c_char_p, _ct.c_char_p]
        self._c_keypair.restype = None
        self._c_enc = getattr(_lib, '%s_enc' % self._prefix)
        self._c_enc.argtypes = [_ct.c_char_p, _ct.c_char_p, _ct.c_char_p]
        self._c_enc.restype = None
        self._c_dec = getattr(_lib, '%s_dec' % self._prefix)
        self._c_dec.argtypes = [_ct.c_char_p, _ct.c_char_p, _ct.c_char_p]
        self._c_dec.restype = None

    def keypair(self) -> _Tuple[bytes, bytes]:
        '''
        Keypair - randomly generates secret key 'sk' and corresponding public key 'pk'.
        Returns:
            pk (bytes): public key
            sk (bytes): secret key
        '''
        pk = _ct.create_string_buffer(self.PUBLICKEYBYTES)
        sk = _ct.create_string_buffer(self.SECRETKEYBYTES)
        self._c_keypair(pk, sk)
        return pk.raw, sk.raw

    def enc(self, pk: bytes) -> _Tuple[bytes, bytes]:
        '''
        Encapsulation - randomly generates a ciphertext 'c' and the corresponding session key 'k' given Alice's public key 'pk'.
        Parameters:
            pk (bytes): public key
        Returns:
            c (bytes): ciphertext
            k (bytes): session key
        '''
        _check_input(pk, self.PUBLICKEYBYTES, 'pk')
        c = _ct.create_string_buffer(self.CIPHERTEXTBYTES)
        k = _ct.create_string_buffer(self.BYTES)
        pk = _ct.create_string_buffer(pk)
        self._c_enc(c, k, pk)
        return c.raw, k.raw

    def dec(self, c: bytes, sk: bytes) -> bytes:
        '''
        Decapsulation - given Alice's secret key 'sk' computes the session key 'k' corresponding to a ciphertext 'c'.
        Parameters:
            c (bytes): ciphertext
            sk (bytes): secret key
        Returns:
            k (bytes): session key
        '''
        _check_input(c, self.CIPHERTEXTBYTES, 'c')
        _check_input(sk, self.SECRETKEYBYTES, 'sk')
        k = _ct.create_string_buffer(self.BYTES)
        c = _ct.create_string_buffer(c)
        sk = _ct.create_string_buffer(sk)
        self._c_dec(k, c, sk)
        return k.raw
"""


with open(f'src/{project}/kem.py', 'w') as f:
    f.write(kem)
    for primitive in primitives:
        f.write('\n\n')
        f.write(f'class {primitive}(_KEM):\n')
        f.write(f'    PUBLICKEYBYTES = {primitives[primitive]["pkbytes"]}\n')
        f.write(f'    SECRETKEYBYTES = {primitives[primitive]["skbytes"]}\n')
        f.write(f'    CIPHERTEXTBYTES = {primitives[primitive]["cbytes"]}\n')
        f.write(f'    BYTES = 32\n')
        f.write(f'    _prefix = "{project}_kem_{primitives[primitive]["name"]}"\n\n\n')
        f.write(f'{primitive} = {primitive}()\n')
