Skip to content

Commit

Permalink
Fix maxOS CI test failures (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli committed Jan 15, 2024
1 parent 8659a03 commit d457c6c
Showing 1 changed file with 73 additions and 47 deletions.
120 changes: 73 additions & 47 deletions test/dynamics/solvers/test_solver_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,12 @@ def test_hamiltonian_model(self):
t_span = [0.0, 1.42]
signals = [Signal(3.0)]

res1 = ham_solver.solve(t_span=t_span, y0=y0, signals=signals)
res1 = ham_solver.solve(t_span=t_span, y0=y0, signals=signals, atol=1e-11, rtol=1e-11)

self.ham_model.signals = signals
res2 = solve_lmde(generator=self.ham_model, t_span=t_span, y0=y0)
res2 = solve_lmde(generator=self.ham_model, t_span=t_span, y0=y0, atol=1e-11, rtol=1e-11)

self.assertAllClose(res1.y, res2.y)
self.assertAllClose(res1.y[-1], res2.y[-1], atol=1e-8, rtol=1e-8)

def test_lindblad_model(self):
"""Test for solver with only static lindblad information."""
Expand All @@ -317,12 +317,14 @@ def test_lindblad_model(self):
t_span = [0.0, 1.42]
signals = [Signal(3.0)]

res1 = lindblad_solver.solve(t_span=t_span, y0=y0, signals=signals)
res1 = lindblad_solver.solve(t_span=t_span, y0=y0, signals=signals, atol=1e-12, rtol=1e-12)

self.lindblad_model.signals = (signals, None)
res2 = solve_lmde(generator=self.lindblad_model, t_span=t_span, y0=y0)
res2 = solve_lmde(
generator=self.lindblad_model, t_span=t_span, y0=y0, atol=1e-12, rtol=1e-12
)

self.assertAllClose(res1.y, res2.y)
self.assertAllClose(res1.y[-1], res2.y[-1], atol=1e-8, rtol=1e-8)

def test_td_lindblad_model(self):
"""Test for solver with time-dependent lindblad information."""
Expand All @@ -339,12 +341,16 @@ def test_td_lindblad_model(self):
t_span = [0.0, 1.42]
signals = ([Signal(3.0)], [Signal(2.0)])

res1 = td_lindblad_solver.solve(t_span=t_span, y0=y0, signals=signals)
res1 = td_lindblad_solver.solve(
t_span=t_span, y0=y0, signals=signals, atol=1e-12, rtol=1e-12
)

self.td_lindblad_model.signals = signals
res2 = solve_lmde(generator=self.td_lindblad_model, t_span=t_span, y0=y0)
res2 = solve_lmde(
generator=self.td_lindblad_model, t_span=t_span, y0=y0, atol=1e-12, rtol=1e-12
)

self.assertAllClose(res1.y, res2.y)
self.assertAllClose(res1.y[-1], res2.y[-1], atol=1e-8, rtol=1e-8)

def test_rwa_ham_model(self):
"""Test correct handling of RWA for a Hamiltonian model."""
Expand All @@ -361,13 +367,13 @@ def test_rwa_ham_model(self):
t_span = [0.0, 1.0]
signals = [Signal(1.0, carrier_freq=5.0)]

res1 = rwa_ham_solver.solve(t_span=t_span, y0=y0, signals=signals)
res1 = rwa_ham_solver.solve(t_span=t_span, y0=y0, signals=signals, atol=1e-12, rtol=1e-12)

self.ham_model.signals = signals
rwa_ham_model = rotating_wave_approximation(self.ham_model, cutoff_freq=5.0)
res2 = solve_lmde(generator=rwa_ham_model, t_span=t_span, y0=y0)
res2 = solve_lmde(generator=rwa_ham_model, t_span=t_span, y0=y0, atol=1e-12, rtol=1e-12)

self.assertAllClose(res1.y, res2.y)
self.assertAllClose(res1.y[-1], res2.y[-1], atol=1e-8, rtol=1e-8)

def test_rwa_lindblad_model(self):
"""Test correct handling of RWA for Lindblad model without
Expand All @@ -387,13 +393,17 @@ def test_rwa_lindblad_model(self):
t_span = [0.0, 1.0]
signals = [Signal(1.0, carrier_freq=5.0)]

res1 = rwa_lindblad_solver.solve(t_span=t_span, y0=y0, signals=signals)
res1 = rwa_lindblad_solver.solve(
t_span=t_span, y0=y0, signals=signals, atol=1e-12, rtol=1e-12
)

self.lindblad_model.signals = (signals, None)
rwa_lindblad_model = rotating_wave_approximation(self.lindblad_model, cutoff_freq=5.0)
res2 = solve_lmde(generator=rwa_lindblad_model, t_span=t_span, y0=y0)
res2 = solve_lmde(
generator=rwa_lindblad_model, t_span=t_span, y0=y0, atol=1e-12, rtol=1e-12
)

self.assertAllClose(res1.y, res2.y)
self.assertAllClose(res1.y[-1], res2.y[-1], atol=1e-8, rtol=1e-8)

def test_rwa_td_lindblad_model(self):
"""Test correct handling of RWA for Lindblad model with
Expand All @@ -414,13 +424,17 @@ def test_rwa_td_lindblad_model(self):
t_span = [0.0, 1.0]
signals = ([Signal(1.0, carrier_freq=5.0)], [Signal(1.0, carrier_freq=5.0)])

res1 = rwa_td_lindblad_solver.solve(t_span=t_span, y0=y0, signals=signals)
res1 = rwa_td_lindblad_solver.solve(
t_span=t_span, y0=y0, signals=signals, atol=1e-12, rtol=1e-12
)

self.td_lindblad_model.signals = signals
rwa_td_lindblad_model = rotating_wave_approximation(self.td_lindblad_model, cutoff_freq=5.0)
res2 = solve_lmde(generator=rwa_td_lindblad_model, t_span=t_span, y0=y0)
res2 = solve_lmde(
generator=rwa_td_lindblad_model, t_span=t_span, y0=y0, atol=1e-12, rtol=1e-12
)

self.assertAllClose(res1.y, res2.y)
self.assertAllClose(res1.y[-1], res2.y[-1], atol=1e-8, rtol=1e-8)

def test_signals_are_None(self):
"""Test the model signals return to being None after simulation."""
Expand Down Expand Up @@ -554,15 +568,19 @@ def test_vec_lindblad_statevector(self):
y0=Statevector([0.0, 1.0]),
signals=[Signal(1.0, 5.0)],
method=self.method,
atol=1e-12,
rtol=1e-12,
)
results2 = self.lindblad_solver.solve(
t_span=[0.0, 1.0],
y0=Statevector([0.0, 1.0]),
signals=[Signal(1.0, 5.0)],
method=self.method,
atol=1e-12,
rtol=1e-12,
)
self.assertTrue(isinstance(results.y[-1], DensityMatrix))
self.assertAllClose(results.y[-1].data, results2.y[-1].data)
self.assertAllClose(results.y[-1].data, results2.y[-1].data, atol=1e-8, rtol=1e-8)

def test_array_vectorized_lindblad(self):
"""Test Lindblad solver is array-vectorized."""
Expand Down Expand Up @@ -634,18 +652,18 @@ def test_hamiltonian_lindblad_SuperOp_consistency(self):
t_span=[0.0, 0.432],
y0=SuperOp(np.eye(4)),
signals=[Signal(1.0, 5.0)],
atol=1e-10,
rtol=1e-10,
atol=1e-12,
rtol=1e-12,
method=self.method,
)
results2 = self.vec_lindblad_solver_no_diss.solve(
t_span=[0.0, 0.432],
y0=SuperOp(np.eye(4)),
signals=[Signal(1.0, 5.0)],
atol=1e-10,
rtol=1e-10,
atol=1e-12,
rtol=1e-12,
)
self.assertAllClose(results.y[-1].data, results2.y[-1].data)
self.assertAllClose(results.y[-1].data, results2.y[-1].data, atol=1e-8, rtol=1e-8)

def test_lindblad_solver_consistency(self):
"""Test consistency of lindblad solver with dissipators specified
Expand Down Expand Up @@ -812,9 +830,9 @@ def test_static_simulation(self, model):
y0=Statevector([0.0, 1.0]),
schedules=sched,
signals=None,
test_tol=1e-9,
atol=1e-12,
rtol=1e-12,
test_tol=1e-8,
atol=1e-11,
rtol=1e-11,
)

@unpack
Expand All @@ -841,9 +859,9 @@ def test_one_channel_simulation(self, model):
y0=Statevector([1.0, 0.0]),
schedules=sched,
signals=[sig],
test_tol=1e-9,
atol=1e-12,
rtol=1e-12,
test_tol=1e-8,
atol=1e-11,
rtol=1e-11,
)

@unpack
Expand Down Expand Up @@ -899,7 +917,7 @@ def test_two_channel_list_simulation(self, model):
y0=[Statevector([1.0, 0.0]), DensityMatrix([0.0, 1.0])],
schedules=[sched0, sched1],
signals=signals,
test_tol=1e-9,
test_tol=1e-8,
atol=1e-12,
rtol=1e-12,
)
Expand Down Expand Up @@ -939,8 +957,8 @@ def test_two_channel_SuperOp_simulation(self, model):
schedules=sched,
signals=signals,
test_tol=1e-8,
atol=1e-12,
rtol=1e-12,
atol=1e-11,
rtol=1e-11,
)

def test_4_channel_schedule(self):
Expand Down Expand Up @@ -993,9 +1011,9 @@ def test_4_channel_schedule(self):
y0=Statevector([1.0, 0.0]),
schedules=schedule,
signals=signals,
test_tol=1e-9,
atol=1e-13,
rtol=1e-13,
test_tol=1e-8,
atol=1e-12,
rtol=1e-12,
)

def test_rwa_ham_solver(self):
Expand Down Expand Up @@ -1025,14 +1043,23 @@ def test_rwa_ham_solver(self):
sig = Signal(0.9, carrier_freq=5.0)

res_pulse = ham_pulse_solver.solve(
t_span=[0, 0.4], y0=Statevector([0.0, 1.0]), signals=schedule, method=self.method
t_span=[0, 0.4],
y0=Statevector([0.0, 1.0]),
signals=schedule,
method=self.method,
atol=1e-12,
rtol=1e-12,
)
res_signal = ham_solver.solve(
t_span=[0, 0.4], y0=Statevector([0.0, 1.0]), signals=[sig], method=self.method
t_span=[0, 0.4],
y0=Statevector([0.0, 1.0]),
signals=[sig],
method=self.method,
atol=1e-12,
rtol=1e-12,
)

self.assertAllClose(res_pulse.t, res_signal.t, atol=1e-14, rtol=1e-14)
self.assertAllClose(res_pulse.y, res_signal.y, atol=1e-14, rtol=1e-14)
self.assertAllClose(res_pulse.y[-1], res_signal.y[-1], atol=1e-8, rtol=1e-8)

def test_rwa_lindblad_solver(self):
"""Test RWA for pulse solver with Lindblad information."""
Expand Down Expand Up @@ -1081,8 +1108,7 @@ def test_rwa_lindblad_solver(self):
rtol=1e-12,
)

self.assertAllClose(res_pulse.t, res_signal.t, atol=1e-14, rtol=1e-14)
self.assertAllClose(res_pulse.y, res_signal.y, atol=1e-14, rtol=1e-14)
self.assertAllClose(res_pulse.y[-1], res_signal.y[-1], atol=1e-8, rtol=1e-8)

def test_list_simulation_mixing_types(self):
"""Test correct formatting when input states have the same shape.
Expand Down Expand Up @@ -1133,9 +1159,9 @@ def test_list_simulation_mixing_types(self):
y0=[np.eye(2, dtype=complex), DensityMatrix([0.0, 1.0])],
schedules=[sched0, sched1],
signals=signals,
test_tol=1e-9,
atol=1e-12,
rtol=1e-12,
test_tol=1e-8,
atol=1e-11,
rtol=1e-11,
)

def test_channel_without_instructions(self):
Expand All @@ -1158,8 +1184,8 @@ def test_channel_without_instructions(self):
schedules=sched,
signals=signals,
test_tol=1e-8,
atol=1e-12,
rtol=1e-12,
atol=1e-11,
rtol=1e-11,
)

def _compare_schedule_to_signals(
Expand Down

0 comments on commit d457c6c

Please sign in to comment.