diff --git a/bra12.py b/bra12.py new file mode 100644 index 0000000..6fb3f80 --- /dev/null +++ b/bra12.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass + +from utils.bra import keygen, encrypt, decrypt, add, mult + + +@dataclass +class Context: + n: int + q: int + L: int + + +class Bra12: + def __init__(self, ctx): + self.ctx = ctx + pk, evks, sk = keygen(ctx.L, ctx.n, ctx.q) + self.pk = pk + self.evks = evks + self.sk = sk + + def encrypt(self, message): + ciphertext = encrypt(self.pk, message, self.ctx.n, self.ctx.q) + return Ciphertext(ciphertext, self.evks, self.ctx.q) + + def decrypt(self, ciphertext): + return decrypt(self.sk, ciphertext.inner, self.ctx.q) + + +class Ciphertext: + def __init__(self, inner, evks, q): + self.inner = inner + self.evks = evks + self.q = q + + def __add__(self, other): + evk = self.evks[-1] + return Ciphertext( + add(evk, self.inner, other.inner, self.q), + # Remove the key we've used used above (NOTE: An operation might not equal one circuit level) + self.evks[:-1], + self.q + ) + + def __mul__(self, other): + evk = self.evks[-1] + return Ciphertext( + mult(evk, self.inner, other.inner, self.q), + # Remove the key we've used used above (NOTE: An operation might not equal one circuit level) + self.evks[:-1], + self.q + ) diff --git a/main.py b/main.py new file mode 100644 index 0000000..787c61c --- /dev/null +++ b/main.py @@ -0,0 +1,64 @@ +import sys + +from bra12 import Context, Bra12 +from utils.core import tests as tests_utils_core +from utils.regev import tests as tests_utils_regev +from utils.bra import tests as tests_utils_bra + + +def main(): + n = 3 + q = 2 ** 16 + L = 1 + ctx = Context(n, q, L) + bra = Bra12(ctx) + # Addition 1 + m_1 = 0 + m_2 = 0 + c_1 = bra.encrypt(m_1) + c_2 = bra.encrypt(m_2) + res = bra.decrypt(c_1 + c_2) + print(f'{m_1} + {m_2} = {res}') + assert res == (m_1 + m_2) % 2 + # Addition 2 + m_1 = 0 + m_2 = 1 + c_1 = bra.encrypt(m_1) + c_2 = bra.encrypt(m_2) + res = bra.decrypt(c_1 + c_2) + print(f'{m_1} + {m_2} = {res}') + assert res == (m_1 + m_2) % 2 + # Multiplication 1 + m_1 = 1 + m_2 = 0 + c_1 = bra.encrypt(m_1) + c_2 = bra.encrypt(m_2) + res = bra.decrypt(c_1 * c_2) + print(f'{m_1} * {m_2} = {res}') + assert res == (m_1 * m_2) % 2 + # Multiplication 2 + m_1 = 1 + m_2 = 1 + c_1 = bra.encrypt(m_1) + c_2 = bra.encrypt(m_2) + res = bra.decrypt(c_1 * c_2) + print(f'{m_1} * {m_2} = {res}') + assert res == (m_1 * m_2) % 2 + # Multiplication 3 + m_1 = 0 + m_2 = 0 + c_1 = bra.encrypt(m_1) + c_2 = bra.encrypt(m_2) + res = bra.decrypt(c_1 * c_2) + print(f'{m_1} * {m_2} = {res}') + assert res == (m_1 * m_2) % 2 + + +if __name__ == '__main__': + if len(sys.argv) == 1: + main() + if len(sys.argv) > 1 and 'test' in sys.argv[1]: + tests_utils_core() + tests_utils_regev() + tests_utils_bra() + print('Success: 3/3')