Skip to content

Commit

Permalink
Schedule each batch separately
Browse files Browse the repository at this point in the history
  • Loading branch information
danr committed Feb 20, 2024
1 parent 807771f commit 803b8fa
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 35 deletions.
9 changes: 6 additions & 3 deletions cellpainter/cellpainter/commandlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def AddSimDelays(cmd: commands.Command) -> commands.Command:
return program, expected_ends

def check_correspondence(command: Command, **ends: dict[int, float]):
return
# pbutils.pr(ends)

by_id: dict[int, Command] = {
i: c
for c in command.universe()
Expand All @@ -244,8 +245,10 @@ def check_correspondence(command: Command, **ends: dict[int, float]):
if end_a == -1: end_a = 'missing'
if end_b == -1: end_b = 'missing'
cmd = by_id.get(k)
mismatches += [{src_a: end_a, src_b: end_b, 'cmd': cmd}]
# if not cmd or not isinstance(cmd.peel_meta(), Checkpoint):
if src_a == 'optimizer_ends' and end_a == 'missing':
pass
else:
mismatches += [{src_a: end_a, src_b: end_b, 'cmd': cmd}]

if mismatches:
raise ValueError(f'Correspondence check failed {len(mismatches)=} ({" ".join(ends.keys())}) {mismatches=}')
Expand Down
14 changes: 10 additions & 4 deletions cellpainter/cellpainter/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def collect(self: Command) -> list[tuple[Command, Metadata]]:
(collected_cmd, collected_metadata.merge(self.metadata))
for collected_cmd, collected_metadata in self.command.collect()
]
# case OptimizeSection():
# raise ValueError(f'collect({self})')
case _:
return [(self, Metadata())]

Expand All @@ -120,7 +122,7 @@ def is_noop(self: Command) -> bool:
return float(s.offset) == 0.0
case SeqCmd():
return all(cmd.is_noop() for cmd in self.commands)
case Fork() | Meta():
case Fork() | Meta() | OptimizeSection():
return self.command.is_noop()
case _:
return False
Expand All @@ -139,7 +141,7 @@ def transform(self: Command, f: Callable[[Command], Command], reverse: bool=Fals
if reverse:
inner_commands = list(reversed(inner_commands))
return f(self.replace(commands=inner_commands))
case Fork() | Meta():
case Fork() | Meta() | OptimizeSection():
return f(self.replace(command=self.command.transform(f, reverse=reverse)))
case _:
return f(self)
Expand All @@ -154,7 +156,7 @@ def universe(self: Command) -> Iterator[Command]:
case SeqCmd():
for cmd in self.commands:
yield from cmd.universe()
case Fork() | Meta():
case Fork() | Meta() | OptimizeSection():
yield from self.command.universe()
case _:
pass
Expand Down Expand Up @@ -304,7 +306,7 @@ def assign_ids(self: Command) -> Command:
def F(cmd: Command) -> Command:
nonlocal count
match cmd:
case SeqCmd() | Fork() | Meta():
case SeqCmd() | Fork() | Meta() | OptimizeSection():
return cmd
case _:
count += 1
Expand Down Expand Up @@ -392,6 +394,10 @@ def Seq(*commands: Command) -> Command:
return flat[0]
return SeqCmd(flat)

@dataclass(frozen=True)
class OptimizeSection(Command):
command: Command

@dataclass(frozen=True)
class Idle(Command):
secs: Symbolic | float | int = 0.0
Expand Down
29 changes: 24 additions & 5 deletions cellpainter/cellpainter/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,23 @@ def optimize(cmd: Command) -> tuple[Command, dict[int, float]]:
cmd = cmd.make_resource_checkpoints()
cmd = cmd.align_forks()
cmd = cmd.assign_ids()
opt = optimal_env(cmd)
cmd = cmd.resolve(opt.env)
return cmd, opt.expected_ends

ends: dict[int, float] = {}
subst: dict[str, float] = {}

def Opt(cmd: Command) -> Command:
nonlocal ends, subst
if isinstance(cmd, OptimizeSection):
cmd_inst = cmd.command.resolve(subst)
opt = optimal_env(cmd_inst)
ends |= opt.expected_ends
subst |= opt.env
pbutils.pr(opt)
return cmd_inst.resolve(opt.env)
else:
return cmd

return OptimizeSection(cmd).transform(Opt), ends

@dataclass(frozen=True)
class Ids:
Expand All @@ -54,9 +68,14 @@ class OptimalResult:

def optimal_env(cmd: Command, unsat_core: bool=False, explain_mode: bool=False) -> OptimalResult:
if unsat_core:
pbutils.pr(cmd)
# pbutils.pr(cmd)
pass

variables = cmd.free_vars()

if not variables:
return OptimalResult({}, {})

ids = Ids()

Resolution = 4
Expand Down Expand Up @@ -227,7 +246,7 @@ def run(cmd: Command, begin: Symbolic, *, is_main: bool) -> Symbolic:
print(s.unsat_core())
raise ValueError('Impossible to schedule!')
else:
raise ValueError('Optimization says unsat, but unsat core version says sat')
raise ValueError(f'Optimization says unsat, but unsat core version says {check}')

# add the constraints with most important first (lexicographic optimization order)
for _prio, terms in sorted(maximize_terms.items(), reverse=True):
Expand Down
15 changes: 7 additions & 8 deletions cellpainter/cellpainter/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def execute(cmd: Command, runtime: Runtime, metadata: Metadata):
execute(c, runtime, metadata)

case Idle():
secs = cmd.secs
assert isinstance(secs, (float, int))
secs = cmd.secs.unwrap()
entry = entry.merge(Metadata(est=round(secs, 3)))
with runtime.timeit(entry):
runtime.sleep(secs)
Expand All @@ -50,8 +49,7 @@ def execute(cmd: Command, runtime: Runtime, metadata: Metadata):
runtime.checkpoint(cmd.name, entry)

case WaitForCheckpoint():
plus_secs = cmd.plus_secs
assert isinstance(plus_secs, (float, int))
plus_secs = cmd.plus_secs.unwrap()
t0 = runtime.wait_for_checkpoint(cmd.name)
desired_point_in_time = t0 + plus_secs
delay = desired_point_in_time - runtime.monotonic()
Expand Down Expand Up @@ -245,7 +243,7 @@ def fork():
raise ValueError(f'Unknown command {cmd}')

if effect := cmd.effect():
runtime.apply_effect(effect, entry)
runtime.apply_effect(effect, entry, fatal_errors=runtime.config.name == 'simulate')

@contextlib.contextmanager
def make_runtime(config: RuntimeConfig, program: Program) -> Iterator[Runtime]:
Expand All @@ -259,10 +257,11 @@ def make_runtime(config: RuntimeConfig, program: Program) -> Iterator[Runtime]:
def simulate_program(program: Program, sim_delays: dict[int, float] = {}, log_filename: str | None=None) -> DB:
program, expected_ends = commandlib.prepare_program(program, sim_delays=sim_delays)

with pbutils.timeit('quicksim'):
quicksim_ends, _checkpoints = commandlib.quicksim(program.command, {}, cast(Any, estimate))
if 1:
with pbutils.timeit('quicksim'):
quicksim_ends, _checkpoints = commandlib.quicksim(program.command, {}, cast(Any, estimate))

commandlib.check_correspondence(program.command, optimizer_ends=expected_ends, quicksim_ends=quicksim_ends)
commandlib.check_correspondence(program.command, optimizer_ends=expected_ends, quicksim_ends=quicksim_ends)

cmd = program.command
with pbutils.timeit('simulating'):
Expand Down
5 changes: 2 additions & 3 deletions cellpainter/cellpainter/moves.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,9 +624,8 @@ def effect(self, world: World) -> dict[str, str | None]:
effects[lid_Bi + ' get'] = PutLidOn(source=Bi, target=B21)
effects[lid_Bi + ' put'] = TakeLidOff(source=B21, target=Bi)

lid_Bi = f'lid-B{i} [base B15]'
effects[lid_Bi + ' get'] = PutLidOn(source=Bi, target=B15)
effects[lid_Bi + ' put'] = TakeLidOff(source=B15, target=Bi)
effects[lid_Bi + ' get [base B15]'] = PutLidOn(source=Bi, target=B15)
effects[lid_Bi + ' put [base B15]'] = TakeLidOff(source=B15, target=Bi)

for k in list(effects.keys()):
effects[k + ' transfer'] = effects[k]
Expand Down
18 changes: 8 additions & 10 deletions cellpainter/cellpainter/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Max,
Min,
WaitAssumption,
OptimizeSection,
)
from .commandlib import Interleaving
from .moves import World
Expand Down Expand Up @@ -567,7 +568,11 @@ def paint_batch(batch: list[Plate], protocol_config: ProtocolConfig) -> Command:
if not prev_step:
# idle_ref = WaitForCheckpoint(f'batch {batch_index}')
incu_delay = [
WaitForCheckpoint(f'batch {batch_index}') + f'{plate_desc} incu delay {ix}'
WaitForCheckpoint(f'batch {batch_index}') + f'{plate_desc} incu delay {ix}',
WaitForResource('incu', assume='no wait'),
WaitForResource('wash', assume='no wait'),
WaitForResource('disp', assume='no wait'),
WaitForResource('blue', assume='no wait'),
]
wash_delay = [
WaitForCheckpoint(f'batch {batch_index}') + f'{plate_desc} first wash delay'
Expand Down Expand Up @@ -968,21 +973,14 @@ def desc(p: Plate | None, step: str, substep: str) -> Desc | None:
def cell_paint_program(batch_sizes: list[int], protocol_config: ProtocolConfig) -> Program:
cmds: list[Command] = []
plates = define_plates(batch_sizes)
program = Seq()
for batch in plates:
batch_cmds = paint_batch(
batch,
protocol_config=protocol_config,
)
cmds += [batch_cmds]

program = OptimizeSection(program >> batch_cmds)
world0 = initial_world(pbutils.flatten(plates), protocol_config)
program = Seq(*cmds)
program = Seq(
Checkpoint('run'),
# we now do test comm at start of each batch
program,
Duration('run', OptPrio.total_time)
)
return Program(
command=program,
world0=world0,
Expand Down
11 changes: 9 additions & 2 deletions cellpainter/cellpainter/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,15 @@ def rename(self, old: str, new: str) -> Symbolic:
self.offset,
)

def resolve(self, env: dict[str, float] = {}) -> float:
return sum(env[x] for x in self.var_names) + self.offset
def resolve(self, env: dict[str, float] = {}) -> Symbolic:
var_names: list[str] = []
offset = self.offset
for x in self.var_names:
if x in env:
offset += env[x]
else:
var_names += [x]
return Symbolic(var_names, offset)

def var_set(self) -> set[str]:
return set(self.var_names)
Expand Down

0 comments on commit 803b8fa

Please sign in to comment.