Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 138 additions & 2 deletions examples/pke/iterative-ckks-bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

def main():
iterative_bootstrap_example()
iterative_bootstrap_stc_example()

def calculate_approximation_error(result,expected_result):
def calculate_approximation_error(result, expected_result):
if len(result) != len(expected_result):
raise Exception("Cannot compare vectors with different numbers of elements")
# using the infinity norm
# error is abs of the difference of real parts
max_error = max([abs(el1.real - el2.real) for (el1, el2) in zip(result, expected_result)])
# return absolute value of log base2 of the error
return abs(math.log(max_error,2))

def iterative_bootstrap_example():
# Step 1: Set CryptoContext
parameters = CCParamsCKKSRNS()
Expand Down Expand Up @@ -94,7 +96,7 @@ def iterative_bootstrap_example():

result = cryptocontext.Decrypt(ciphertext_after,key_pair.secretKey)
result.SetLength(num_slots)
precision = calculate_approximation_error(result.GetCKKSPackedValue(),ptxt.GetCKKSPackedValue())
precision = math.floor(calculate_approximation_error(result.GetCKKSPackedValue(),ptxt.GetCKKSPackedValue()))
print(f"Bootstrapping precision after 1 iteration: {precision} bits\n")

# Set the precision equal to empirically measured value after many test runs.
Expand All @@ -114,5 +116,139 @@ def iterative_bootstrap_example():
print(f"Bootstrapping precision after 2 iterations: {precision_multiple_iterations} bits\n")
print(f"Number of levels remaining after 2 bootstrappings: {depth - ciphertext_two_iterations.GetLevel()}\n")

def iterative_bootstrap_stc_example():
# Step 1: Set CryptoContext
parameters = CCParamsCKKSRNS()
secret_key_dist = SecretKeyDist.UNIFORM_TERNARY
parameters.SetSecretKeyDist(secret_key_dist)
parameters.SetSecurityLevel(SecurityLevel.HEStd_NotSet)
parameters.SetRingDim(1 << 12)

if get_native_int()==128:
rescale_tech = ScalingTechnique.FIXEDAUTO
dcrt_bits = 78
first_mod = 89
else:
rescale_tech = ScalingTechnique.FLEXIBLEAUTO
dcrt_bits = 59
first_mod = 60

parameters.SetScalingModSize(dcrt_bits)
parameters.SetScalingTechnique(rescale_tech)
parameters.SetFirstModSize(first_mod)

# Here, we specify the number of iterations to run bootstrapping.
# Note that we currently only support 1 or 2 iterations.
# Two iterations should give us approximately double the precision of one iteration.
num_iterations = 2

level_budget = [3, 3]
bsgs_dim = [0,0]

levels_available_after_bootstrap = 10 + level_budget[1]
depth = levels_available_after_bootstrap + FHECKKSRNS.GetBootstrapDepth(9, level_budget, secret_key_dist)
parameters.SetMultiplicativeDepth(depth)

# Generate crypto context
cryptocontext = GenCryptoContext(parameters)

# Enable features that you wish to use. Note, we must enable FHE to use bootstrapping.

cryptocontext.Enable(PKESchemeFeature.PKE)
cryptocontext.Enable(PKESchemeFeature.KEYSWITCH)
cryptocontext.Enable(PKESchemeFeature.LEVELEDSHE)
cryptocontext.Enable(PKESchemeFeature.ADVANCEDSHE)
cryptocontext.Enable(PKESchemeFeature.FHE)

ring_dim = cryptocontext.GetRingDimension()
print(f"CKKS is using ring dimension {ring_dim}\n\n")

# Step 2: Precomputations for bootstrapping
# We use a sparse packing
num_slots = 8
cryptocontext.EvalBootstrapSetup(level_budget, bsgs_dim, num_slots, 0, True, True)

# Step 3: Key generation
key_pair = cryptocontext.KeyGen()
cryptocontext.EvalMultKeyGen(key_pair.secretKey)
# Generate bootstrapping keys.
cryptocontext.EvalBootstrapKeyGen(key_pair.secretKey, num_slots)

# Step 4: Encoding and encryption of inputs
# Generate random input
x = [random.uniform(0, 2) for i in range(num_slots)]

""" Encoding as plaintexts
We specify the number of slots as num_slots to achieve a performance improvement.
We use the other default values of depth 1, levels 0, and no params.
Alternatively, you can also set batch size as a parameter in the CryptoContext as follows:
parameters.SetBatchSize(num_slots);
Here, we assume all ciphertexts in the cryptoContext will have num_slots slots.
We start with a depleted ciphertext that has used up all of its levels."""
ptxt = cryptocontext.MakeCKKSPackedPlaintext(x, 1, depth-1-level_budget[1], None, num_slots)
ptxt.SetLength(num_slots)
print(f"Input: {ptxt}")

# Encrypt the encoded vectors
ciph = cryptocontext.Encrypt(key_pair.publicKey, ptxt)

# Step 5: Measure the precision of a single bootstrapping operation.
ciphertext_after = cryptocontext.EvalBootstrap(ciph)

result = cryptocontext.Decrypt(ciphertext_after, key_pair.secretKey)
result.SetLength(num_slots)
precision = math.floor(calculate_approximation_error(result.GetCKKSPackedValue(), ptxt.GetCKKSPackedValue()))
print(f"Bootstrapping precision after 1 iteration: {precision} bits\n")

# Set the precision equal to empirically measured value after many test runs.
precision -= 5
print(f"Precision input to algorithm: {precision}\n")

# Step 6: Run bootstrapping with multiple iterations
ciphertext_two_iterations = cryptocontext.EvalBootstrap(ciph, num_iterations, precision)

result_two_iterations = cryptocontext.Decrypt(ciphertext_two_iterations, key_pair.secretKey)
result_two_iterations.SetLength(num_slots)
actual_result = result_two_iterations.GetCKKSPackedValue()

print(f"Output after two interations of bootstrapping: {actual_result}\n")
precision_multiple_iterations = calculate_approximation_error(actual_result, ptxt.GetCKKSPackedValue())

print(f"Bootstrapping precision after 2 iterations: {precision_multiple_iterations} bits\n")
print(f"Number of levels remaining after 2 bootstrappings: "
f"{depth - ciphertext_two_iterations.GetLevel() - ciphertext_two_iterations.GetNoiseScaleDeg() - 1}\n\n")

#---------------------------------------------------------------------------------------------------------------------
# When using EvalBootstrap for 2 iterations with STC first, it may be beneficial to scale down the default correction
# factor to achieve a higher final precision. This behavior is specifically pronounced for sparse packing. As the
# number of slots increases, the difference between the default correction factor and the best empirical correction
# factor decreases. For full packing at full security for CKKS bootstrapping, this variant of CKKS bootstrapping
# has better precision than the ModRaise-first variant without any change to the default correction factor.

cryptocontext.SetCKKSBootCorrectionFactor(cryptocontext.GetCKKSBootCorrectionFactor() - 5)
print(f"Correction factor used: {cryptocontext.GetCKKSBootCorrectionFactor()}")

ciphertext_after = cryptocontext.EvalBootstrap(ciph)
result = cryptocontext.Decrypt(key_pair.secretKey, ciphertext_after)
result.SetLength(num_slots)
precision = math.floor(calculate_approximation_error(result.GetCKKSPackedValue(), ptxt.GetCKKSPackedValue()))
print(f"Bootstrapping precision after 1 iteration: {precision}\n\n")

# Set precision equal to empirically measured value after many test runs. One could add a buffer to reduce this value as below.
precision -= 5
print(f"Precision input to 2nd iteration: {precision}\n")

ciphertext_two_iterations = cryptocontext.EvalBootstrap(ciph, num_iterations, precision)
result_two_iterations = cryptocontext.Decrypt(key_pair.secretKey, ciphertext_two_iterations)
actual_result = result_two_iterations.GetCKKSPackedValue()

print(f"Output after two iterations of bootstrapping: {actual_result}\n")
precision_multiple_iterations = calculate_approximation_error(actual_result, ptxt.GetCKKSPackedValue())

# Output the precision of bootstrapping after two iterations. It should be approximately double the original precision.
print(f"Bootstrapping precision after 2 iterations: {precision_multiple_iterations}\n")
print(f"Number of levels remaining after 2 bootstrappings: "
f"{depth - ciphertext_two_iterations.GetLevel() - (ciphertext_two_iterations.GetNoiseScaleDeg() - 1)}\n\n")

if __name__ == "__main__":
main()
68 changes: 68 additions & 0 deletions examples/pke/simple-ckks-bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,73 @@ def simple_bootstrap_example():
result.SetLength(encoded_length)
print(f"Output after bootstrapping: {result}")

def simple_bootstrap_stc_example():
parameters = CCParamsCKKSRNS()

secret_key_dist = SecretKeyDist.UNIFORM_TERNARY
parameters.SetSecretKeyDist(secret_key_dist)

parameters.SetSecurityLevel(SecurityLevel.HEStd_NotSet)
parameters.SetRingDim(1<<12)

if get_native_int()==128:
rescale_tech = ScalingTechnique.FIXEDAUTO
dcrt_bits = 78
first_mod = 89
else:
rescale_tech = ScalingTechnique.FLEXIBLEAUTO
dcrt_bits = 59
first_mod = 60

parameters.SetScalingModSize(dcrt_bits)
parameters.SetScalingTechnique(rescale_tech)
parameters.SetFirstModSize(first_mod)

level_budget = [4, 4]

levels_available_after_bootstrap = 10 + level_budget[1]

depth = levels_available_after_bootstrap + FHECKKSRNS.GetBootstrapDepth({level_budget[0], 0}, secret_key_dist)

parameters.SetMultiplicativeDepth(depth)

cryptocontext = GenCryptoContext(parameters)
cryptocontext.Enable(PKESchemeFeature.PKE)
cryptocontext.Enable(PKESchemeFeature.KEYSWITCH)
cryptocontext.Enable(PKESchemeFeature.LEVELEDSHE)
cryptocontext.Enable(PKESchemeFeature.ADVANCEDSHE)
cryptocontext.Enable(PKESchemeFeature.FHE)

ring_dim = cryptocontext.GetRingDimension()
# This is the mazimum number of slots that can be used full packing.
num_slots = int(ring_dim / 2)
print(f"CKKS is using ring dimension {ring_dim}")

cryptocontext.EvalBootstrapSetup(level_budget, [0, 0], num_slots, 0, True, True)

key_pair = cryptocontext.KeyGen()
cryptocontext.EvalMultKeyGen(key_pair.secretKey)
cryptocontext.EvalBootstrapKeyGen(key_pair.secretKey, num_slots)

x = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0]
encoded_length = len(x)

ptxt = cryptocontext.MakeCKKSPackedPlaintext(x, 1, depth-1-levelBudget[1], None, num_slots)
ptxt.SetLength(encoded_length)

print(f"Input: {ptxt}")

ciph = cryptocontext.Encrypt(key_pair.publicKey, ptxt)

print(f"Initial number of levels remaining: {depth - ciph.GetLevel()}")

ciphertext_after = cryptocontext.EvalBootstrap(ciph)

print(f"Number of levels remaining after bootstrapping: {depth - ciphertext_after.GetLevel() - (ciphertext_after.GetNoiseScaleDeg() - 1)}")

result = cryptocontext.Decrypt(ciphertext_after,key_pair.secretKey)
result.SetLength(encoded_length)
print(f"Output after bootstrapping: {result}")

if __name__ == '__main__':
main()
Loading