Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions qcfdl/block_encoding/qsp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
### Mottonen Block Encoding

import numpy as np
import scipy as sp

import pennylane as qml
import pennylane.estimator as qre

Expand Down Expand Up @@ -57,12 +59,27 @@ def qsp_block_encoding_resource(
num_wires = int(np.ceil(np.log2(state_vector.shape[0])))

max_poly_deg = 2 ** (num_wires)
poly_degs = np.geomspace(2, num_wires, max_poly_deg)

xs_vals = np.sin(np.arange(max_poly_deg))
for poly_deg in np.round(poly_degs).astype(int):
cheb = np.polynomial.Chebyshev.fit(xs_vals, state_vector, poly_deg, domain=[-1, 1])
approx_on_s = cheb(xs_vals)
ts_vals = np.sin(2 * np.pi * np.linspace(0, 1, max_poly_deg))

ix_vals = np.argsort(ts_vals)
xs_vals, ys_vals = ts_vals[ix_vals], state_vector[ix_vals]
for poly_deg in np.geomspace(2, max_poly_deg, num_wires).astype(int):
if poly_deg < 1000: # 2 ** 10 is the maxium we use for the chebyshev fit
indx = np.linspace(0, len(xs_vals) - 1, min(len(xs_vals), 4 * poly_deg)).astype(int)
cheb = np.polynomial.Chebyshev.fit(
xs_vals[indx], ys_vals[indx], deg=poly_deg, domain=[-1, 1]
)
else: # Use DCT for larger poly_deg to avoid memory issues with the chebyshev fit
indx = np.linspace(0, len(xs_vals) - 1, min(2 * poly_deg, len(xs_vals))).astype(int)
max_points = min(poly_deg, int(2 ** 10))
xs_vals2 = sp.interpolate.interp1d(
xs_vals[indx], ys_vals[indx], kind='linear', bounds_error=False, fill_value="extrapolate"
)(np.cos(np.pi * (np.arange(max_points) + 0.5) / max_points))
coeffs = sp.fft.dct(xs_vals2, type=2, norm='backward') / max_points
coeffs[0] /= 2
cheb = np.polynomial.Chebyshev(coeffs, domain=[-1, 1])

approx_on_s = cheb(ts_vals)
error = np.max(np.abs(approx_on_s - state_vector))
if error <= approx_error:
break
Expand All @@ -78,5 +95,4 @@ def qsp_block_encoding_resource(
f"Poly deg: {poly_deg} \t| Linf Error: {np.round(error, 6)} \t| \n"
f"Resources: " + str(res)
)

return res.gate_counts
15 changes: 13 additions & 2 deletions qcfdl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,23 @@ def load_state_vector_resources(
return min_methods, min_errors


from time import time

if __name__ == "__main__":
state_vector = np.random.normal(loc=0, scale=1, size=1024)
state_vector = np.random.RandomState(32).normal(loc=0, scale=1, size=int(2 ** 20))
state_vector = state_vector / np.linalg.norm(state_vector)

# compute_resources(state_vector, max_wires=12, verbose=True, prob_type="block_encoding")
print("Computing resources for state preparation... \n")
time0 = time()
compute_resources(state_vector, max_wires=50, verbose=True, prob_type="state_prep")
time1 = time()
# compute_split_resources(state_vector, methods=["Fourier Serier Loader", "Matrix Product State", "Sum of Slaters"], max_wires=12, max_splits=2, verbose=True)
# min_methods, min_errors = load_state_vector_resources(state_vector, max_splits=3, max_num_wires=2, opt_param="t_gates")
# print(min_methods, min_errors)
print("Total time for state preparation:", time1 - time0, "\n")

print("Computing resources for block encoding... \n")
time2 = time()
compute_resources(state_vector, max_wires=50, verbose=True, prob_type="block_encoding")
time3 = time()
print("Total time for block encoding:", time3 - time2)