diff --git a/native/src/seal/evaluator.cpp b/native/src/seal/evaluator.cpp index ba6d19f8..a17a46a9 100644 --- a/native/src/seal/evaluator.cpp +++ b/native/src/seal/evaluator.cpp @@ -1729,9 +1729,59 @@ namespace seal return; } - // Create a vector of copies of encrypted - vector exp_vector(static_cast(exponent), encrypted); - multiply_many(exp_vector, relin_keys, encrypted, std::move(pool)); + // Precomputed-squares + tree-reduce over set-bit powers. + // For an exponent e, this uses floor(log2(e)) + popcount(e) - 1 + // multiplications at depth floor(log2(e)) + ceil(log2(popcount(e))), + // versus the prior e - 1 multiplications at depth ceil(log2(e)). + int top_bit = 63; + while (top_bit > 0 && !((exponent >> top_bit) & uint64_t(1))) + { + --top_bit; + } + + // Build pow2[k] = encrypted^{2^k} for k = 0..top_bit. + vector pow2; + pow2.reserve(static_cast(top_bit) + 1); + pow2.emplace_back(encrypted); + for (int k = 1; k <= top_bit; k++) + { + Ciphertext next = pow2.back(); + square_inplace(next, pool); + relinearize_inplace(next, relin_keys, pool); + pow2.emplace_back(std::move(next)); + } + + // Collect the powers whose bits are set in exponent. + vector terms; + terms.reserve(pow2.size()); + for (int k = 0; k <= top_bit; k++) + { + if ((exponent >> k) & uint64_t(1)) + { + terms.emplace_back(pow2[static_cast(k)]); + } + } + + // Tree-reduce the selected powers into a single product. + while (terms.size() > 1) + { + vector next_level; + next_level.reserve((terms.size() + 1) / 2); + for (size_t i = 0; i + 1 < terms.size(); i += 2) + { + Ciphertext prod; + multiply(terms[i], terms[i + 1], prod, pool); + relinearize_inplace(prod, relin_keys, pool); + next_level.emplace_back(std::move(prod)); + } + if (terms.size() & size_t(1)) + { + next_level.emplace_back(std::move(terms.back())); + } + terms = std::move(next_level); + } + + encrypted = std::move(terms.front()); } void Evaluator::add_plain_inplace(Ciphertext &encrypted, const Plaintext &plain, MemoryPoolHandle pool) const