from ideal_one_way import *

# PRNG
def make_prng(key):
  state = 1
  def next():
      nonlocal state
      state = state + 1
      return one_way(state, key)
  return next

def a_stream_cipher_out(key):
    next = make_prng(key)
    plaintexts = [ 5, 6, 7, 8 ]

    broadcast("alice", "bob", xor(len(plaintexts), next()))
    for plaintext in plaintexts:
        ciphertext = xor(xor(plaintext, next()), next())
        broadcast("alice", "bob", ciphertext)


def b_stream_cipher_in(key):
    next = make_prng(key)
    plaintexts = []
    amt = xor(receive("alice", "bob"), next())
    for i in range(0, amt):
        ciphertext = receive("alice", "bob")
        plaintext = xor(xor(ciphertext, next()), next())
        plaintexts.append(plaintext)
    print(plaintexts)

run(lambda: a_stream_cipher_out(10),
    lambda: b_stream_cipher_in(10))

def a_ecb_out(key):
    def enc(n):
        (a, b) = inverse_pair(key)
        return one_way(n, a)
    plaintexts = [ 5, 6, 7, 8 ]

    broadcast("alice", "bob", enc(len(plaintexts)))
    for plaintext in plaintexts:
        ciphertext = enc(plaintext)
        broadcast("alice", "bob", ciphertext)

def b_ecb_in(key):
    def dec(m):
        (a, b) = inverse_pair(key)
        return one_way(m, b)
    plaintexts = []
    amt = dec(receive("alice", "bob"))
    for i in range(0, amt):
        ciphertext = receive("alice", "bob")
        plaintext = dec(ciphertext)
        plaintexts.append(plaintext)
    print(plaintexts)

run(lambda: a_ecb_out(42),
    lambda: b_ecb_in(42))
