Contents

# Define variables
n = 3
N = 2^n
X = var(['X{}'.format(i) for i in range(n)]) # Define the variables X_0, X-1,.....
u = var(['u{}'.format(i) for i in range(n)]) # Define the variables u_0, u-1,.....

# Function to get binary representation as a list of bits
def bits(i,n):
    return list(map(int, format(i,'0{}b'.format(n))))

def bits_reverse(i, n):
    bits_list = bits(i, n)
    return list(reversed(bits_list))

# Define the eq_tilde function
def eq_tilde(bits_i, u_vector):
    result=1
    for bit,u in zip(bits_i,u_vector):
        result *= (1-bit)*(1-u) + bit*u
    return result

# Coefficients of the polynomial
a = [var('a{}'.format(i)) for i in range(N)] # Coefficients a_0, a_1, ...., a_(N-1)

# MLE polynomial
f_tilde = sum(a[i]*eq_tilde(bits(i,n), u) for i in range(N))
show(f_tilde)

# Generate all combinations of (1-u[i]) and u[i] based on binary representation
def generate_c_vector(n, u):
    c_vector = []
    for i in range(2^n):  # Loop over all binary numbers from 0 to 2^n - 1
        binary = list(map(int, format(i, f'0{n}b')))  # Binary representation of i
        binary_reverse = list(reversed(binary))  # Reverse the binary representation
        product = 1
        for j, bit in enumerate(binary_reverse):
            if bit == 0:
                product *= (1 - u[j])  # Use (1 - u[j]) for 0
            else:
                product *= u[j]  # Use u[j] for 1
        c_vector.append(product)
    return c_vector

# Compute the c vector
c = generate_c_vector(n, u)

# Display the c vector
show(c)

# SageMath Implementation for Polynomial Encoding using Lagrange Basis

# Define the finite field and subgroup H
p = 17  # Example prime (p must be chosen appropriately)
F = GF(p)  # Finite field F_p
omega = F(3)  # Example primitive 8th root of unity in F_p
H = [omega^i for i in range(8)]  # Multiplicative subgroup H of size 8

# Define Lagrange Basis Polynomials
def lagrange_basis(H, X):
    basis = []
    for i in range(len(H)):
        Li = 1
        for j in range(len(H)):
            if i != j:
                Li *= (X - H[j]) / (H[i] - H[j])  # Lagrange basis polynomial
        basis.append(Li)
    return basis

# Symbolic variable for the polynomial
X = polygen(F, 'X')

# Compute the Lagrange basis polynomials
L = lagrange_basis(H, X)
show(L)

# Vector c as computed earlier
c = [F(c_val) for c_val in [1, 2, 3, 4, 5, 6, 7, 8]]  # Example values for c vector
show(c)
# Compute the polynomial encoding c(X)
c_X = sum(c[i] * L[i] for i in range(len(c)))

# Compute the polynomial encoding c(X) step by step
# c_X = 0  # Start with an empty polynomial
# print("Building c(X) step by step:")
# for i in range(len(c)):
#     show(c[i])
#     show(L[i])
#     term = c[i] * L[i]  # Compute the current term
#     show(term)
#     c_X += term  # Add the current term to the polynomial
#     show(term)
#     '\n'
# #     print(f"Step {i+1}: Adding term {term}")
# #     print(f"Partial sum: {c_X}\n")

# # Final polynomial
# print("Final polynomial c(X):")
# show(c_X)

# Display the resulting polynomial c(X)
show(c_X)

# Verification: Check that c(omega^i) = c_i
verification = [c_X(H[i]) == c[i] for i in range(len(H))]
show(verification)


# Define subsets H0, H1, H2 based on the Group Tower relationship
H0 = [H[0]]  # {1}
H1 = [H[0], H[4]]  # {1, ω^4}
H2 = [H[0], H[2], H[4], H[6]]  # {1, ω^2, ω^4, ω^6}
H3 = H  # Full set {1, ω, ω^2, ..., ω^7}

# Define the vanishing polynomials v_H(X) and v_Hi(X)
X = polygen(F, 'X')  # Define X as a polynomial variable

def vanishing_polynomial(domain):
    """Compute the vanishing polynomial for a given domain."""
    v = 1
    for alpha in domain:
        v *= (X - alpha)
    return v

# Compute vanishing polynomials
v_H = vanishing_polynomial(H3)  # Full vanishing polynomial for H
v_H0 = vanishing_polynomial(H0)
v_H1 = vanishing_polynomial(H1)
v_H2 = vanishing_polynomial(H2)

# Compute the selector polynomials s_0(X), s_1(X), s_2(X)
s0 =  v_H/v_H0
s1 = v_H/v_H1
s2 =  v_H/v_H2

# Display the selector polynomials
show(s0)
show(s1)
show(s2)
\(\displaystyle -a_{0} {\left(u_{0} - 1\right)} {\left(u_{1} - 1\right)} {\left(u_{2} - 1\right)} + a_{4} u_{0} {\left(u_{1} - 1\right)} {\left(u_{2} - 1\right)} + a_{2} {\left(u_{0} - 1\right)} u_{1} {\left(u_{2} - 1\right)} - a_{6} u_{0} u_{1} {\left(u_{2} - 1\right)} + a_{1} {\left(u_{0} - 1\right)} {\left(u_{1} - 1\right)} u_{2} - a_{5} u_{0} {\left(u_{1} - 1\right)} u_{2} - a_{3} {\left(u_{0} - 1\right)} u_{1} u_{2} + a_{7} u_{0} u_{1} u_{2}\)
\(\displaystyle \left[-{\left(u_{0} - 1\right)} {\left(u_{1} - 1\right)} {\left(u_{2} - 1\right)}, u_{0} {\left(u_{1} - 1\right)} {\left(u_{2} - 1\right)}, {\left(u_{0} - 1\right)} u_{1} {\left(u_{2} - 1\right)}, -u_{0} u_{1} {\left(u_{2} - 1\right)}, {\left(u_{0} - 1\right)} {\left(u_{1} - 1\right)} u_{2}, -u_{0} {\left(u_{1} - 1\right)} u_{2}, -{\left(u_{0} - 1\right)} u_{1} u_{2}, u_{0} u_{1} u_{2}\right]\)
\(\displaystyle \left[14 X^{7} + 11 X^{6} + X^{5} + 5 X^{4} + 3 X^{3} + 13 X^{2} + 10 X + 12, X^{7} + 4 X^{6} + 4 X^{5} + 5 X^{4} + 10 X^{3} + 4 X^{2} + 13 X + 10, 9 X^{7} + 5 X^{6} + 7 X^{5} + 6 X^{3} + 7 X^{2} + 4 X + 13, X^{7} + 11 X^{6} + 10 X^{4} + 10 X^{3} + 6 X^{2} + 10 X + 3, 5 X^{7} + 2 X^{6} + 3 X^{5} + 4 X^{4} + 10 X^{3} + 5 X + 5, 3 X^{7} + X^{6} + 15 X^{5} + 3 X^{4} + 7 X^{2} + 4 X + 1, 14 X^{7} + 3 X^{6} + X^{5} + 2 X^{4} + 11 X^{3} + 5 X^{2} + 4 X + 11, 4 X^{7} + 14 X^{6} + 3 X^{5} + 5 X^{4} + X^{3} + 9 X^{2} + X + 14\right]\)
\(\displaystyle \left[1, 2, 3, 4, 5, 6, 7, 8\right]\)
\(\displaystyle 16 X^{7} + 6 X^{6} + 13 X^{5} + 11 X^{4} + 12 X^{3} + 11 X^{2} + 3 X + 14\)
\(\displaystyle \left[\mathrm{True}, \mathrm{True}, \mathrm{True}, \mathrm{True}, \mathrm{True}, \mathrm{True}, \mathrm{True}, \mathrm{True}\right]\)
\(\displaystyle X^{7} + 2 X^{6} + 11 X^{5} + 4 X^{4} + 16 X^{3} + 7 X^{2} + 8 X + 13\)
\(\displaystyle X^{6} + 15 X^{5} + 2 X^{4} + 13 X^{3} + 15 X^{2} + 15 X + 16\)
\(\displaystyle X^{4} + 5 X^{3} + 4 X^{2} + 12 X + 1\)
# Define the finite field and its elements
p = 17  # Prime modulus
F = GF(p)  # Finite field F_p
generator = F.multiplicative_generator() # the root unity of F_p
print("the root unity of F_p: ", generator)

# find the order of the generator of F_p
def find_order(generator, field):
    order = 1
    power = generator
    while power != 1:
        power *= generator
        order += 1
    return order

# return the root of unity in subgroup H, where |H| = N
def root_of_H(generator, generator_order, N):
    return generator**(generator_order / N)

generator_order = find_order(generator, F)
omega = root_of_H(generator, generator_order, 8)
print("the root unity of subgroup H: ", omega)
H = [omega^i for i in range(8)]  # Subgroup H of size 8 (2^3)
show(H)

# Define variables
X = polygen(F, 'X')  # Polynomial variable
N = len(H)  # Size of the group H

# Define vanishing polynomial of H
def vanishing_polynomial(H):
    v_H = F(1)
    for h in H:
        v_H *= (X - h)
    return v_H

v_H = vanishing_polynomial(H)

# Define polynomials a(X) and c(X)
a_coeffs = [F.random_element() for _ in range(N)]  # Random coefficients for a(X)
c_coeffs = [F.random_element() for _ in range(N)]  # Random coefficients for c(X)

a_X = sum(a_coeffs[i] * X^i for i in range(N))  # a(X)
c_X = sum(c_coeffs[i] * X^i for i in range(N))  # c(X)

# Compute P(X) = a(X) * c(X)
P_X = a_X * c_X

print("P_X: ", P_X)

# Compute the decomposition of P(X)
q_X = P_X // v_H  # Quotient
r_X = P_X % v_H   # Remainder
v = sum(P_X(h) for h in H)  # v = sum of P(ω) over H
g_X = (r_X - v / N) // X    # g(X) 

# Verify the decomposition
decomposed_P_X = q_X * v_H + X * g_X + (v / N)
assert P_X == decomposed_P_X  # Ensure the decomposition holds

# Verification at a challenge ζ
zeta = F.random_element()  # Random challenge ζ
lhs = a_X(zeta) * c_X(zeta)  # a(ζ) * c(ζ)
rhs = q_X(zeta) * v_H(zeta) + zeta * g_X(zeta) + (v / N)  # q(ζ) * v_H(ζ) + ζ * g(ζ) + v/N

# Output results
print("a(X):", a_X)
print("c(X):", c_X)
print("P(X):", P_X)
print("q(X):", q_X)
print("g(X):", g_X)
print("v:", v)
print("Verification at ζ:")
print("LHS (a(ζ) * c(ζ)):", lhs)
print("RHS (q(ζ) * v_H(ζ) + ζ * g(ζ) + v/N):", rhs)
assert lhs == rhs, "Verification failed!"
the root unity of F_p:  3
the root unity of subgroup H:  9
\(\displaystyle \left[1, 9, 13, 15, 16, 8, 4, 2\right]\)
P_X:  15*X^14 + 4*X^13 + 15*X^12 + 11*X^11 + X^10 + 15*X^9 + 11*X^8 + 3*X^7 + 14*X^6 + 14*X^5 + 3*X^4 + 6*X^3 + 5*X^2 + 7*X + 13
a(X): 10*X^7 + 7*X^6 + 11*X^5 + 15*X^4 + 9*X^3 + 10*X^2 + 15*X + 4
c(X): 10*X^7 + 7*X^6 + 6*X^5 + 15*X^4 + 6*X^3 + 15*X^2 + 14*X + 16
P(X): 15*X^14 + 4*X^13 + 15*X^12 + 11*X^11 + X^10 + 15*X^9 + 11*X^8 + 3*X^7 + 14*X^6 + 14*X^5 + 3*X^4 + 6*X^3 + 5*X^2 + 7*X + 13
q(X): 15*X^6 + 4*X^5 + 15*X^4 + 11*X^3 + X^2 + 15*X + 11
g(X): 3*X^6 + 12*X^5 + X^4 + X^3 + 6*X + 5
v: 5
Verification at ζ:
LHS (a(ζ) * c(ζ)): 15
RHS (q(ζ) * v_H(ζ) + ζ * g(ζ) + v/N): 15
# Define the finite field and its elements
p = 17  # Prime modulus
F = GF(p)  # Finite field F_p
omega = F(9)  # Primitive 8th root of unity in F_p
H = [omega^i for i in range(8)]  # Subgroup H of size 8 (2^3)

# Define the vanishing polynomial for H
def vanishing_polynomial(H):
    v_H = F(1)
    for h in H:
        v_H *= (X - h)
    return v_H

# Lagrange basis polynomials
def lagrange_basis(H, X):
    basis = []
    for i in range(len(H)):
        Li = F(1)
        for j in range(len(H)):
            if i != j:
                Li *= (X - H[j]) / (H[i] - H[j])
        basis.append(Li)
    return basis

# Define variables
X = polygen(F, 'X')  # Polynomial variable
N = len(H)

# Random coefficients for a(X) and c(X)
a_coeffs = [F.random_element() for _ in range(N)]
c_coeffs = [F.random_element() for _ in range(N)]


# Construct a(X) and c(X)
L = lagrange_basis(H, X)
a_X = sum(a_coeffs[i] * L[i] for i in range(N))
c_X = sum(c_coeffs[i] * L[i] for i in range(N))

# Compute the auxiliary vector z_i
z = [a_coeffs[0] * c_coeffs[0]]  # z_0
for i in range(1, N):
    z.append(z[i-1] + a_coeffs[i] * c_coeffs[i])

# Compute z(X) using Lagrange basis polynomials
z_X = sum(z[i] * L[i] for i in range(N))

# Compute the final sum v
v = z[-1]

# Polynomial constraints
# Constraint 1: L0(X) * (z(X) - a(X) * c_0)
constraint1 = L[0] * (z_X - a_X * c_coeffs[0])

# Constraint 2: (X - 1) * (z(X) - z(ω^{-1} * X) - a(X) * c(X))
omega_inv = omega^(-1)
z_shifted = z_X.subs(X=omega_inv * X)
constraint2 = (X - 1) * (z_X - z_shifted - a_X * c_X)

# Constraint 3: LN-1(X) * (z(X) - v)
constraint3 = L[-1] * (z_X - v)

# Display the constraints
print("Constraint 1:")
show(constraint1)

print("Constraint 2:")
show(constraint2)

print("Constraint 3:")
show(constraint3)

# Evaluate constraints point-by-point in H to verify correctness
verification_points = H  # Points in the subgroup H

def verify_constraint_at_points(constraint, points):
    """Check if a constraint is satisfied at all points in the given domain."""
    return all(constraint.subs(X=point) == 0 for point in points)

# Check if constraints are satisfied at all points in H
constraint1_valid = verify_constraint_at_points(constraint1, verification_points)
constraint2_valid = verify_constraint_at_points(constraint2, verification_points)
constraint3_valid = verify_constraint_at_points(constraint3, verification_points)

# Display results
if constraint1_valid:
    print("Constraint 1 is satisfied!")
else:
    print("Constraint 1 failed!")

if constraint2_valid:
    print("Constraint 2 is satisfied!")
else:
    print("Constraint 2 failed!")

if constraint3_valid:
    print("Constraint 3 is satisfied!")
else:
    print("Constraint 3 failed!")

# Assert that all constraints are satisfied
assert constraint1_valid, "Constraint 1 failed!"
assert constraint2_valid, "Constraint 2 failed!"
assert constraint3_valid, "Constraint 3 failed!"

print("All constraints are satisfied!")
Constraint 1:
\(\displaystyle 13 X^{14} + 10 X^{13} + 14 X^{12} + 15 X^{11} + 2 X^{10} + 2 X^{9} + 3 X^{8} + 4 X^{6} + 7 X^{5} + 3 X^{4} + 2 X^{3} + 15 X^{2} + 15 X + 14\)
Constraint 2:
\(\displaystyle 7 X^{15} + 6 X^{14} + 11 X^{13} + 9 X^{12} + 7 X^{11} + 6 X^{10} + 6 X^{9} + 6 X^{8} + 10 X^{7} + 11 X^{6} + 6 X^{5} + 8 X^{4} + 10 X^{3} + 11 X^{2} + 11 X + 11\)
Constraint 3:
\(\displaystyle 5 X^{14} + 14 X^{13} + 11 X^{12} + 16 X^{11} + 12 X^{10} + 16 X^{9} + 2 X^{8} + 12 X^{6} + 3 X^{5} + 6 X^{4} + X^{3} + 5 X^{2} + X + 15\)
Constraint 1 is satisfied!
Constraint 2 is satisfied!
Constraint 3 is satisfied!
All constraints are satisfied!
# Define the finite field and its elements
p = 17  # Prime modulus
F = GF(p)  # Finite field F_p
omega = F(9)  # Primitive 8th root of unity in F_p
H = [omega^i for i in range(8)]  # Subgroup H of size 8 (2^3)
N = len(H)
n = log(len(H), 2)


# Define the vanishing polynomial for H
def vanishing_polynomial(H):
    v_H = F(1)
    for h in H:
        v_H *= (X - h)
    return v_H

# Polynomial variable
X = polygen(F, 'X')
v_H = vanishing_polynomial(H)

# Define the lagrange basis polynomials
def lagrange_basis(H, X):
    basis = []
    for i in range(len(H)):
        Li = F(1)
        for j in range(len(H)):
            if i != j:
                Li *= (X - H[j]) / (H[i] - H[j])
        basis.append(Li)
    return basis

L = lagrange_basis(H, X)  # Lagrange basis polynomials

u = [F.random_element() for _ in range(n)]
print("u: ", u)

# Random coefficients for a(X) and c(X)
a_coeffs = [F.random_element() for _ in range(len(H))]
c_coeffs = generate_c_vector(n, u)
# c_coeffs = [F.random_element() for _ in range(len(H))]
print("c_coeffs: ", c_coeffs)

# Construct a(X) and c(X)
a_X = sum(a_coeffs[i] * L[i] for i in range(len(H)))
c_X = sum(c_coeffs[i] * L[i] for i in range(len(H)))

# Recursive auxiliary polynomial z(X)
z = [a_coeffs[0] * c_coeffs[0]]
for i in range(1, len(H)):
    z.append(z[i-1] + a_coeffs[i] * c_coeffs[i])

z_X = sum(z[i] * L[i] for i in range(len(H)))
v = z[-1]  # Final sum

def compute_s_polynomial(H, X):
    s_polynomials = []
    N = len(H)
    n = log(N, 2)
    s_polynomials.append(X**(2**(n - 1)) + F(1))
    for i in range(n - 2, -1, -1):
        temp = X**(2**i) + F(1)
        s_polynomials.append(s_polynomials[-1] * temp)
    s_polynomials.reverse()
    return s_polynomials

# # Define the polynomials p_i(X)
p = []
s_polys = compute_s_polynomial(H, X)

for i in range(n + 1):
    if i == 0:
        p.append(s_polys[0] * (c_X - prod([(F(1) - u_i) for u_i in u])))
    else:
        temp_X = c_X.subs(X=(omega^(2^(n - i)) * X))
        p.append(s_polys[i - 1] * (c_X * u[n - i] - temp_X * (F(1) - u[n - i])))

# Define h_1(X), h_2(X), h_3(X)
h1_X = L[0] * (z_X - a_X * c_coeffs[0])
h2_X = (X - 1) * (z_X - z_X.subs(X=omega^(-1) * X) - a_X * c_X)
h3_X = L[-1] * (z_X - v)

# Aggregate all polynomials into h(X)
alpha = F.random_element()  # Random alpha
h_X = sum(alpha^i * p[i] for i in range(len(p))) + alpha^(len(p)) * h1_X + alpha^(len(p) + 1) * h2_X + alpha^(len(p) + 2) * h3_X

# Verify correctness of h(X)
t_X = h_X // v_H  # Quotient polynomial
remainder = h_X % v_H  # Remainder

# Display results
print("Polynomial h(X):")
show(h_X)

print("Quotient polynomial t(X):")
show(t_X)

print("Remainder:")
show(remainder)

# Verify that h(X) is divisible by v_H(X)
if remainder != 0: print("h(X) is not divisible by v_H(X)!")
else :
    print("Verification successful: h(X) is divisible by v_H(X).")
u:  [16, 14, 3]
c_coeffs:  [1, 8, 12, 11, 7, 5, 16, 9]
Polynomial h(X):
\(\displaystyle 8 X^{15} + 10 X^{14} + 13 X^{13} + 12 X^{12} + 10 X^{11} + 11 X^{10} + 11 X^{9} + 3 X^{8} + 9 X^{7} + 7 X^{6} + 4 X^{5} + 5 X^{4} + 7 X^{3} + 6 X^{2} + 6 X + 14\)
Quotient polynomial t(X):
\(\displaystyle 8 X^{7} + 10 X^{6} + 13 X^{5} + 12 X^{4} + 10 X^{3} + 11 X^{2} + 11 X + 3\)
Remainder:
\(\displaystyle 0\)
Verification successful: h(X) is divisible by v_H(X).
# Generate random challenge ζ ensuring ζ is not in H
while True:
    zeta = F.random_element()
    if zeta not in H:
        break

# Compute values of polynomials at ζ
f_zeta = a_X(zeta) * c_X(zeta)
c_zeta = c_X(zeta)
z_zeta = z_X(zeta)

# Compute quotient t(ζ) = h(ζ) // v_H(ζ)
v_H_zeta = v_H(zeta)
if v_H_zeta == 0:
    raise ValueError("Vanishing polynomial evaluates to zero at ζ, choose a different ζ.")

# Compute t(ζ)
t_zeta = h_X(zeta) // v_H_zeta

# Compute s_i(ζ) values
s = [v_H(zeta) / vanishing_polynomial(H[:2**i])(zeta) for i in range(len(H))]

# Compute L_0(ζ) and L_{N-1}(ζ)
L0_zeta = v_H(zeta) / (N * (zeta - 1))
LN_minus1_zeta = v_H(zeta) / (N * (omega * zeta - 1))

# Compute p_i(ζ) values
p_zeta = []
for i in range(n + 1):
    if i == 0:
        p_zeta.append(s[0] * (c_zeta - prod([(F(1) - u_i) for u_i in u])))
    else:
        p_zeta.append(s[i] * (c_zeta * u[n - i] - c_X.subs(X=(omega^(2^(n - i)) * zeta)) * (F(1) - u[n - i])))

# Compute h_1(ζ), h_2(ζ), h_3(ζ)
h1_zeta = L0_zeta * (z_zeta - a_X(zeta) * c_coeffs[0])
h2_zeta = (zeta - 1) * (z_zeta - z_X.subs(X=omega^(-1) * zeta) - a_X(zeta) * c_X(zeta))
h3_zeta = LN_minus1_zeta * (z_zeta - v)

# Aggregate polynomial evaluations into h(ζ)
alpha = F.random_element()
h_zeta = sum(alpha^i * p_zeta[i] for i in range(len(p_zeta))) + alpha^len(p_zeta) * h1_zeta + alpha^(len(p_zeta) + 1) * h2_zeta + alpha^(len(p_zeta) + 2) * h3_zeta

# Verify correctness
t_zeta = h_zeta // v_H_zeta
# remainder = h_zeta % v_H_zeta

# Display results
print("h(ζ):", h_zeta)
print("t(ζ):", t_zeta)
print("Remainder:", remainder)

# Ensure h(ζ) is divisible by v_H(ζ)
if remainder != 0:
    print("Verification failed: h(ζ) is not divisible by v_H(ζ)!")
else :
    print("Verification successful: h(ζ) is divisible by v_H(ζ).")
h(ζ): 4
t(ζ): 15
Remainder: 0
Verification successful: h(ζ) is divisible by v_H(ζ).
# Define the finite field and subgroup H
p = 17  # Prime modulus
F = GF(p)  # Finite field F_p
omega = F(9)  # Primitive N-th root of unity in F_p
N = 8  # Size of H, N = 2^n
H = [omega^i for i in range(N)]  # Multiplicative subgroup H of size N

# Polynomial variable
X = polygen(F, 'X')

# Define the vanishing polynomial v_H(X)
def vanishing_polynomial(H):
    v_H = F(1)
    for h in H:
        v_H *= (X - h)
    return v_H

v_H = vanishing_polynomial(H)

# Define the Lagrange basis polynomials L_i(X)
def lagrange_basis(H, v_H):
    basis = []
    for i in range(len(H)):
        Li = (v_H / (X - H[i])) * (1 / v_H.derivative()(H[i]))
        basis.append(Li)
    return basis

L = lagrange_basis(H, v_H)

# Define the MLE Polynomial coefficients
a_coeffs = [F.random_element() for _ in range(N)]  # Random coefficients a_0, ..., a_{N-1}

# Construct the univariate polynomial a(X)
a_X = sum(a_coeffs[i] * L[i] for i in range(N))

# Display the univariate polynomial a(X)
print("Univariate Polynomial a(X):")
show(a_X)

# Simulate the SRS of KZG10 (Commitment scheme)
# Example SRS: {g^1, g^2, ..., g^N}, where g is a generator
g = F.random_element()  # Random generator in F_p
SRS = [g^i for i in range(N)]  # SRS elements

# Commit a(X)
commitment = sum(a_coeffs[i] * SRS[i] for i in range(N))
print("Commitment of a(X):")
show(commitment)
Univariate Polynomial a(X):
\(\displaystyle 10 X^{7} + 14 X^{6} + 9 X^{5} + 5 X^{4} + 16 X^{3} + 3 X^{2} + 8 X + 3\)
Commitment of a(X):
\(\displaystyle 10\)
# Step 1: Define Common Input Struct
class CommonInput:
    def __init__(self, commitment_a, evaluation_point_u, polynomial_value_v):
        """
        Initialize the common input for the protocol.
        :param commitment_a: Commitment of a(X) (Ca)
        :param evaluation_point_u: Evaluation point (u_0, u_1, ..., u_{n-1})
        :param polynomial_value_v: Value of MLE polynomial at the evaluation point
        """
        self.commitment_a = commitment_a
        self.evaluation_point_u = evaluation_point_u
        self.polynomial_value_v = polynomial_value_v


# Step 2: Define the Prover Class
class Prover:
    def __init__(self, a_coeffs, lagrange_basis, subgroup_H, vanishing_poly, finite_field, common_input: CommonInput):
        """
        Initialize the prover with the polynomial coefficients and other required inputs.
        :param a_coeffs: Coefficients of a(X)
        :param lagrange_basis: Lagrange basis polynomials
        :param subgroup_H: Subgroup H
        :param vanishing_poly: Vanishing polynomial v_H(X)
        """
        self.a_coeffs = a_coeffs
        self.lagrange_basis = lagrange_basis
        self.subgroup_H = subgroup_H
        self.vanishing_poly = vanishing_poly
        self.polynomial_value = common_input.polynomial_value_v # v = f(u)
        self.finite_field = finite_field 
    
    def compute_selector_polynomials(self):
        """
        Construct selector polynomials s_i(X).
        """
        s_polynomials = []
        N = len(self.subgroup_H)
        n = log(N, 2)
        s_polynomials.append(X**(2**(n - 1)) + F(1))
        for i in range(n - 2, -1, -1):
            temp = X**(2**i) + F(1)
            s_polynomials.append(s_polynomials[-1] * temp)
        s_polynomials.reverse()
        print("s_polys: ", s_polys)
        return s_polynomials
    
    def compute_constraint_polynomials(self, c_X, evaluation_point_u):
        """
        Construct constraint polynomials p_0(X), ..., p_n(X).
        """
        p_polys = []
        s_polys = self.compute_selector_polynomials()
        n = len(evaluation_point_u)

        # Construct p_0(X)
        p_0 = s_polys[0] * (c_X - prod([(1 - u) for u in evaluation_point_u]))
        p_polys.append(p_0)

        # Construct p_k(X) for k = 1 to n
        omega = self.subgroup_H[1]
        for k in range(1, n + 1):
            # omega_k = self.subgroup_H[k]
            term = s_polys[k - 1] * ((evaluation_point_u[n - k] * c_X) - (1 - evaluation_point_u[n - k]) * c_X.subs(X=(omega^(2^(n - k))) * X))
            p_polys.append(term)
        return p_polys
    
    def compute_accumulation_polynomial(self, c_X):
        """
        Construct accumulation polynomial z(X).
        """
        # Ensure c_X is a proper polynomial
        c_poly = c_X.numerator()  # Extract the numerator as a polynomial
        c_coeffs = [c_poly.subs(X=x) for x in self.subgroup_H]
    
        # Compute z(X) coefficients
        z_coeffs = [self.a_coeffs[0] * c_coeffs[0]]  # First term
        for i in range(1, len(self.a_coeffs)):
            z_coeffs.append(z_coeffs[-1] + self.a_coeffs[i] * c_coeffs[i])

        # Construct z(X) as a linear combination of Lagrange basis
        z_X = sum(z_coeffs[i] * self.lagrange_basis[i] for i in range(len(self.a_coeffs)))
        return z_X
    
    def compute_constraint_h_polynomials(self, z_X, a_X, c_X):
        """
        Construct constraint polynomials h_0(X), h_1(X), h_2(X).
        """
        # Ensure c_X is a proper polynomial
        c_poly = c_X.numerator()  # Extract the numerator as a polynomial
        
        # Construct h_0(X)
        h0_X = self.lagrange_basis[0] * (z_X - self.a_coeffs[0] * c_poly.subs(X=1))
    
        # Construct h_1(X)
        z_prev = z_X.subs(X=self.subgroup_H[-1] * X)  # z(ω^-1 * X)
        h1_X = (X - 1) * (z_X - z_prev - a_X * c_X)
    
        # Construct h_2(X)
        h2_X = self.lagrange_basis[-1] * (z_X - self.polynomial_value)

        return h0_X, h1_X, h2_X 

    
    def compute_aggregation_polynomial(self, constraint_polys, h_polys, alpha):
        """
        Construct aggregation polynomial h(X).
        """
        n = len(constraint_polys)
        h_X = sum(alpha**i * constraint_polys[i] for i in range(n))
        h_X += alpha**n * h_polys[0] + alpha**(n + 1) * h_polys[1] + alpha**(n + 2) * h_polys[2]
        return h_X

    def compute_t_polynomial(self, h_X):
        """
        Compute quotient polynomial t(X) satisfying h(X) = t(X) * v_H(X).
        """
        if self.vanishing_poly == 0:
            raise ZeroDivisionError("Vanishing polynomial v_H(X) evaluated to 0.")
        t_X = h_X // self.vanishing_poly
        return t_X
    
    def compute_c_polynomial(self, evaluation_point_u):
        """
        Construct the polynomial c(X) using the provided evaluation point.
        :param evaluation_point_u: The evaluation point (u_0, u_1, ..., u_{n-1})
        :return: c(X)
        """
        def eq_tilde(bits_i, u_vector):
            result = F(1)
            for bit, u in zip(bits_i, u_vector):
                result *= (1 - bit) * (1 - u) + bit * u
            return result

        # Compute coefficients c_i
        c_coeffs = [eq_tilde(list(reversed(list(map(int, f"{i:03b}")))), evaluation_point_u) for i in range(len(self.subgroup_H))]

        # Construct c(X)
        c_X = sum(c_coeffs[i] * self.lagrange_basis[i] for i in range(len(self.subgroup_H)))
        return c_X

    def commit_c_polynomial(self, c_X):
        """
        Compute the commitment of c(X).
        :param c_X: The polynomial c(X)
        :return: Commitment of c(X)
        """
        # Convert c_X to a polynomial
        c_poly = c_X.numerator()  # Get the numerator to ensure it's a polynomial
        g = F.random_element()  # Random generator for SRS
        # TODO: should set SRS be a parameter
        SRS = [g**i for i in range(len(self.subgroup_H))]

        # Get the coefficients of the polynomial
        c_coeffs = c_poly.list()  # Retrieve the list of coefficients
        # Pad coefficients with zeros to match the length of the SRS
        c_coeffs += [F(0)] * (len(self.subgroup_H) - len(c_coeffs))

        # Compute the commitment as a linear combination with the SRS
        commitment_c = sum(c_coeffs[i] * SRS[i] for i in range(len(self.subgroup_H)))
        return commitment_c

    def round_2(self, c_X, a_X, evaluation_point_u):
        """
        Perform all steps in Round 2.
        """
        # Step 1: Compute selector polynomials
        s_polys = self.compute_selector_polynomials()

        # Step 2: Compute constraint polynomials
        constraint_polys = self.compute_constraint_polynomials(c_X, evaluation_point_u)

        # Step 3: Compute accumulation polynomial z(X)
        z_X = self.compute_accumulation_polynomial(c_X)
        z_value = [z_X.subs(X=x) for x in self.subgroup_H]

        # Step 4: Compute constraint h polynomials
        h_polys = self.compute_constraint_h_polynomials(z_X, a_X, c_X)

        # Step 5: Compute aggregation polynomial h(X)
        alpha = F.random_element()  # Random challenge
        h_X = self.compute_aggregation_polynomial(constraint_polys, h_polys, alpha)

        # Step 6: Compute quotient polynomial t(X)
        t_X = self.compute_t_polynomial(h_X)

        return t_X, h_X, z_X
    
    def round_3(self, verifier, c_X, z_X, t_X, a_X):
        """
        Perform Round 3 of the protocol.
        :param verifier: Verifier instance providing ζ
        :param c_X: Polynomial c(X)
        :param z_X: Accumulation polynomial z(X)
        :param t_X: Quotient polynomial t(X)
        :param a_X: Original polynomial a(X)
        :return: Evaluations and KZG10 proofs
        """
        # Step 1: Verifier sends random evaluation point ζ
        zeta = verifier.generate_valid_zeta()  # Verifier provides ζ

        # Step 2: Calculate values of s_i(X) at ζ
        s_polys = self.compute_selector_polynomials()
        s_values = [s(zeta) for s in s_polys]

        # Step 3: Define new domain D and coset D'
        D = self.subgroup_H
        D_prime = [zeta * d for d in D]
#         for d in D_prime:
#             if self.vanishing_poly(d) == 0:
#                 raise ZeroDivisionError(f"v_H({d}) = 0, cannot compute t(X).")


        # Step 4: Evaluate c(X), z(X), t(X), a(X) at D'
        c_values = [c_X(d) for d in D_prime]
        z_values = [z_X(d) for d in D_prime]
        t_values = [t_X(d) for d in D_prime]
        a_values = [a_X(d) for d in D_prime]

        # Step 5: Send evaluations and KZG10 proofs
        evaluations = {
            "c(X)": c_values,
            "z(X)": z_values,
            "t(X)": t_values,
            "a(X)": a_values
        }

        # Generate KZG10 proofs for the evaluations
        kzg_proofs = {
            "c(X)": self.generate_kzg_proof(c_X, D_prime),
            "z(X)": self.generate_kzg_proof(z_X, [zeta, zeta / self.subgroup_H[-1]]),
            "t(X)": self.generate_kzg_proof(t_X, D_prime),
            "a(X)": self.generate_kzg_proof(a_X, D_prime),
        }
        
        # Return evaluations and KZG proofs
        evaluations = {
            "c(X)": c_values,
            "z(X)": z_values,
            "t(X)": t_values,
            "a(X)": a_values
        }
        return evaluations, kzg_proofs
    
    def generate_kzg_proof(self, polynomial, points):
        """
        Generate KZG10 proofs for the given polynomial at the specified points.
        :param polynomial: Polynomial for which the proof is generated
        :param points: Points where the polynomial is evaluated
        :return: List of proofs
        """
        # Ensure the polynomial is in the proper polynomial ring
        poly_ring = PolynomialRing(self.finite_field, 'X')
        polynomial = poly_ring(polynomial)

        # Generate SRS large enough for the polynomial's degree
        max_degree = polynomial.degree()
        SRS = [self.finite_field.random_element() ** i for i in range(max_degree + 1)]

        proofs = []
        for point in points:
            # Compute the divisor
            divisor = polynomial - polynomial(point)

            # Convert the numerator into the finite field polynomial ring directly
            numerator = divisor.numerator()
            numerator_in_ring = poly_ring(numerator)

            # Generate the proof as a sum of SRS coefficients
            proof = sum(numerator_in_ring[i] * SRS[i] for i in range(numerator_in_ring.degree() + 1))
            proofs.append(proof)

        return proofs

    
class Verifier:
    def __init__(self, finite_field, vanishing_poly, subgroup_H):
        self.finite_field = finite_field  # Finite field F_p
        self.vanishing_poly = vanishing_poly  # Vanishing polynomial v_H(X)
        self.subgroup_H = subgroup_H  # Subgroup H of size 2^n
        self.random_zeta = None  # Random evaluation point ζ
    
    def generate_random_alpha(self):
        """
        Generate and send random scalar α for aggregation.
        """
        self.random_alpha = self.finite_field.random_element()
        return self.random_alpha
    
    
    def generate_valid_zeta(self):
        while True:
            zeta = self.finite_field.random_element()
            if all(self.vanishing_poly(zeta * d) != 0 for d in self.subgroup_H):
                self.random_zeta = zeta
                return zeta

    def verify_commitment(self, commitment, evaluation_point, polynomial, proof):
        """
        Verify a single KZG commitment.
        :param commitment: Commitment of the polynomial
        :param evaluation_point: Evaluation point ζ
        :param polynomial: Polynomial to verify
        :param proof: KZG proof
        :return: True if verification succeeds, False otherwise
        """
        # Simulate KZG verification (replace with actual verification in implementation)
        return hash((commitment, evaluation_point, polynomial, proof)) % 2 == 1

    def verify_constraint_equation(self, t_zeta, h_polys, p_polys, alpha, v_H_zeta):
        """
        Verify the final constraint equation:
        t(ζ) * v_H(ζ) = sum of constraint evaluations.
        """
        # Compute the left-hand side
        lhs = t_zeta * v_H_zeta

        # Compute the right-hand side
        rhs = sum(alpha**i * p_polys[i] for i in range(len(p_polys)))
        rhs += alpha**len(p_polys) * h_polys[0]
        rhs += alpha**(len(p_polys) + 1) * h_polys[1]
        rhs += alpha**(len(p_polys) + 2) * h_polys[2]

        return lhs == rhs

    def verify_proof(self, proof, zeta, alpha, v_H, lagrange_basis):
        """
        Perform the verification of the proof π.
        :param proof: Proof π containing all elements
        :param zeta: Evaluation point ζ
        :param alpha: Aggregation scalar α
        :param v_H: Vanishing polynomial v_H(X)
        :param lagrange_basis: Lagrange basis polynomials
        :return: True if all verifications succeed, False otherwise
        """
        # Parse the proof
        C_t, C_z, C_c = proof["C_t"], proof["C_z"], proof["C_c"]
        a_zeta, z_zeta, z_omega_zeta, t_zeta = proof["a(ζ)"], proof["z(ζ)"], proof["z(ζ/ω)"], proof["t(ζ)"]
        c_values = proof["c_values"]
        kzg_proofs = proof["kzg_proofs"]

        # Verify individual KZG commitments
        if not self.verify_commitment(C_c, zeta, "c(X)", kzg_proofs["c(X)"]):
            return False
        if not self.verify_commitment(C_z, zeta, "z(X)", kzg_proofs["z(X)"]):
            return False
        if not self.verify_commitment(C_t, zeta, "t(X)", kzg_proofs["t(X)"]):
            return False

        # Compute constraint polynomials at ζ
        p_polys = [
            lagrange_basis[0] * (c_values[0] - (1 - zeta)),
            lagrange_basis[1] * (zeta - z_omega_zeta),
            lagrange_basis[-1] * (z_zeta - a_zeta)
        ]

        # Compute h polynomials at ζ
        h_polys = [
            lagrange_basis[0] * (z_zeta - a_zeta),
            (z_zeta - z_omega_zeta) - t_zeta,
            lagrange_basis[-1] * (t_zeta - a_zeta)
        ]

        # Verify the final constraint equation
        return self.verify_constraint_equation(t_zeta, h_polys, p_polys, alpha, v_H(zeta))


# Initialize Parameters
p = 17  # Prime modulus
F = GF(p)  # Finite field
omega = F(9)  # Primitive root of unity
N = 8  # Size of the subgroup H
H = [omega^i for i in range(N)]  # Subgroup H
X = polygen(F, 'X')

# Compute vanishing polynomial and Lagrange basis
v_H = vanishing_polynomial(H)
L = lagrange_basis(H, v_H)

# Random coefficients for a(X)
a_coeffs = [F.random_element() for _ in range(N)]
a_X = sum(a_coeffs[i] * L[i] for i in range(N))

# Commitment of a(X)
g = F.random_element()  # Random generator for SRS
SRS = [g^i for i in range(N)]
commitment_a = sum(a_coeffs[i] * SRS[i] for i in range(N))

# Define common inputs
evaluation_point_u = [F.random_element() for _ in range(3)]  # Random evaluation point
c_coeffs = [eq_tilde(list(reversed(list(map(int, f"{i:03b}")))), evaluation_point_u) for i in range(len(H))]
print("a_coeffs: ", a_coeffs)
print("c_coeffs: ", c_coeffs)
polynomial_value_v = sum(a_coeffs[i] * c_coeffs[i] for i in range(N))
print("polynomial_value_v: ", polynomial_value_v)
# polynomial_value_v = a_X(evaluation_point_u[0])  # Example: value of a(X) at evaluation point
common_input = CommonInput(commitment_a, evaluation_point_u, polynomial_value_v)

# Prover Round 1: Compute c(X) and its commitment
prover = Prover(a_coeffs, L, H, v_H, F, common_input)
c_X = prover.compute_c_polynomial(evaluation_point_u)
test_c_values = [c_X.subs(X=e) for e in H]
commitment_c = prover.commit_c_polynomial(c_X)

# Display results
print("c(X):")
show(c_X)
print("Commitment of c(X):")
show(commitment_c)

# Perform Round 2
t_X, h_X, z_X = prover.round_2(c_X, a_X, evaluation_point_u)

# Display results
print("Quotient Polynomial t(X):")
show(t_X)
print("Aggregation Polynomial h(X):")
show(h_X)
print("Accumulation Polynomial z(X):")
show(z_X)

# Initialize Verifier with subgroup H and vanishing polynomial
verifier = Verifier(F, vanishing_poly=prover.vanishing_poly, subgroup_H=prover.subgroup_H)

# Validate v_H(X) at D'
# v_H_values = [prover.vanishing_poly(zeta * d) for d in prover.subgroup_H]
# if any(v == 0 for v in v_H_values):
#     print("Error: v_H(X) evaluates to 0 at some points in D'.")

# Perform Round 3 with valid zeta
try:
    evaluations, kzg_proofs = prover.round_3(verifier, c_X, z_X, t_X, a_X)
    print("Evaluations at coset D':", evaluations)
    print("KZG10 Proofs:", kzg_proofs)
except ZeroDivisionError as e:
    print("Error during Round 3:", e)
except ValueError as e:
    print("Validation Error:", e)
a_coeffs:  [1, 10, 9, 7, 12, 6, 11, 12]
c_coeffs:  [2, 6, 11, 16, 0, 0, 0, 0]
polynomial_value_v:  1
c(X):
\(\displaystyle 6 X^{7} + 6 X^{6} + X^{5} + X^{4} + 9 X^{3} + 13 X^{2} + 2 X + 15\)
Commitment of c(X):
\(\displaystyle 10\)
s_polys:  [X^7 + X^6 + X^5 + X^4 + X^3 + X^2 + X + 1, X^6 + X^4 + X^2 + 1, X^4 + 1]
s_polys:  [X^7 + X^6 + X^5 + X^4 + X^3 + X^2 + X + 1, X^6 + X^4 + X^2 + 1, X^4 + 1]
Quotient Polynomial t(X):
\(\displaystyle 5 X^{7} + 10 X^{6} + 5 X^{5} + 13 X^{4} + 13 X^{3} + 6 X^{2} + 4 X + 12\)
Aggregation Polynomial h(X):
\(\displaystyle 5 X^{15} + 10 X^{14} + 5 X^{13} + 13 X^{12} + 13 X^{11} + 6 X^{10} + 4 X^{9} + 12 X^{8} + 12 X^{7} + 7 X^{6} + 12 X^{5} + 4 X^{4} + 4 X^{3} + 11 X^{2} + 13 X + 5\)
Accumulation Polynomial z(X):
\(\displaystyle 10 X^{7} + 7 X^{6} + 16 X^{5} + 4 X^{4} + 13 X^{3} + 4 X + 16\)
s_polys:  [X^7 + X^6 + X^5 + X^4 + X^3 + X^2 + X + 1, X^6 + X^4 + X^2 + 1, X^4 + 1]
Evaluations at coset D': {'c(X)': [10, 1, 0, 1, 6, 15, 11, 8], 'z(X)': [11, 13, 11, 16, 0, 10, 12, 4], 't(X)': [8, 7, 10, 1, 14, 16, 12, 11], 'a(X)': [6, 1, 0, 16, 3, 6, 12, 7]}
KZG10 Proofs: {'c(X)': [5, 14, 15, 14, 9, 0, 4, 7], 'z(X)': [2, 0], 't(X)': [3, 4, 1, 10, 14, 12, 16, 0], 'a(X)': [4, 9, 10, 11, 7, 4, 15, 3]}