Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpointing Large Runs #239

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions bqskit/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@
from bqskit.passes.group import PassGroup
from bqskit.passes.io.checkpoint import LoadCheckpointPass
from bqskit.passes.io.checkpoint import SaveCheckpointPass
from bqskit.passes.io.intermediate import CheckpointRestartPass
from bqskit.passes.io.intermediate import RestoreIntermediatePass
from bqskit.passes.io.intermediate import SaveIntermediatePass
from bqskit.passes.mapping.apply import ApplyPlacement
Expand Down Expand Up @@ -344,6 +345,7 @@
'ParallelDo',
'LoadCheckpointPass',
'SaveCheckpointPass',
'CheckpointRestartPass',
'SaveIntermediatePass',
'RestoreIntermediatePass',
'GroupSingleQuditGatePass',
Expand Down
30 changes: 26 additions & 4 deletions bqskit/passes/control/foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
collection_filter: Callable[[Operation], bool] | None = None,
replace_filter: ReplaceFilterFn | str = 'always',
batch_size: int | None = None,
blocks_to_run: list[int] = [],
) -> None:
"""
Construct a ForEachBlockPass.
Expand Down Expand Up @@ -127,6 +128,11 @@ def __init__(
Defaults to 'always'. #TODO: address importability

batch_size (int): (Deprecated).

blocks_to_run (List[int]):
A list of blocks to run the ForEachBlockPass body on. By default
you run on all blocks. This is mainly used with checkpointing,
where some blocks have already finished while others have not.
"""
if batch_size is not None:
import warnings
Expand All @@ -140,7 +146,7 @@ def __init__(
self.collection_filter = collection_filter or default_collection_filter
self.replace_filter = replace_filter or default_replace_filter
self.workflow = Workflow(loop_body)

self.blocks_to_run = sorted(blocks_to_run)
if not callable(self.collection_filter):
raise TypeError(
'Expected callable method that maps Operations to booleans for'
Expand Down Expand Up @@ -171,9 +177,20 @@ async def run(self, circuit: Circuit, data: PassData) -> None:

# Collect blocks
blocks: list[tuple[int, Operation]] = []
for cycle, op in circuit.operations_with_cycles():
if self.collection_filter(op):
blocks.append((cycle, op))
if (len(self.blocks_to_run) == 0):
self.blocks_to_run = list(range(circuit.num_operations))

block_ids = self.blocks_to_run.copy()
next_id = block_ids.pop(0)
for i, (cycle, op) in enumerate(circuit.operations_with_cycles()):
if i == next_id:
if self.collection_filter(op):
blocks.append((cycle, op))
if len(block_ids) > 0:
next_id = block_ids.pop(0)
else:
# No more blocks to run on
break

# No blocks, no work
if len(blocks) == 0:
Expand Down Expand Up @@ -212,6 +229,11 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
block_data['model'] = submodel
block_data['point'] = CircuitPoint(cycle, op.location[0])
block_data['calculate_error_bound'] = self.calculate_error_bound
# Need to zero pad block ids for consistency
num_digits = len(str(circuit.num_operations))
block_data['block_num'] = str(
self.blocks_to_run[i],
).zfill(num_digits)
for key in data:
if key.startswith(self.pass_down_key_prefix):
block_data[key] = data[key]
Expand Down
2 changes: 2 additions & 0 deletions bqskit/passes/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

from bqskit.passes.io.checkpoint import LoadCheckpointPass
from bqskit.passes.io.checkpoint import SaveCheckpointPass
from bqskit.passes.io.intermediate import CheckpointRestartPass
from bqskit.passes.io.intermediate import RestoreIntermediatePass
from bqskit.passes.io.intermediate import SaveIntermediatePass

__all__ = [
'CheckpointRestartPass',
'LoadCheckpointPass',
'SaveCheckpointPass',
'SaveIntermediatePass',
Expand Down
170 changes: 124 additions & 46 deletions bqskit/passes/io/intermediate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@

import logging
import pickle
import shutil
from os import listdir
from os import mkdir
from os.path import exists
from os.path import join
from re import findall
from typing import cast
from typing import Sequence

from bqskit.compiler.basepass import BasePass
from bqskit.compiler.passdata import PassData
from bqskit.ir.circuit import Circuit
from bqskit.ir.gates.circuitgate import CircuitGate
from bqskit.ir.lang.qasm2.qasm2 import OPENQASM2Language
from bqskit.ir.operation import Operation
from bqskit.passes.alias import PassAlias
from bqskit.passes.util.converttou3 import ToU3Pass
from bqskit.utils.typing import is_sequence

_logger = logging.getLogger(__name__)

Expand All @@ -32,6 +38,7 @@ def __init__(
path_to_save_dir: str,
project_name: str | None = None,
save_as_qasm: bool = True,
overwrite: bool = False,
) -> None:
"""
Constructor for the SaveIntermediatePass.
Expand All @@ -57,15 +64,18 @@ def __init__(
else 'unnamed_project'

enum = 1
if exists(self.pathdir + self.projname):
while exists(self.pathdir + self.projname + f'_{enum}'):
enum += 1
self.projname += f'_{enum}'
_logger.warning(
f'Path {path_to_save_dir} already exists, '
f'saving to {self.pathdir + self.projname} '
'instead.',
)
if exists(join(self.pathdir, self.projname)):
if overwrite:
shutil.rmtree(join(self.pathdir, self.projname))
else:
while exists(join(self.pathdir, self.projname + f'_{enum}')):
enum += 1
self.projname += f'_{enum}'
_logger.warning(
f'Path {path_to_save_dir} already exists, '
f'saving to {self.pathdir + self.projname} '
'instead.',
)

mkdir(self.pathdir + self.projname)

Expand Down Expand Up @@ -102,8 +112,8 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
block.params,
)
subcircuit.unfold((0, 0))
await ToU3Pass().run(subcircuit, PassData(subcircuit))
if self.as_qasm:
await ToU3Pass().run(subcircuit, PassData(subcircuit))
with open(block_skeleton + f'{enum}.qasm', 'w') as f:
f.write(OPENQASM2Language().encode(subcircuit))
else:
Expand All @@ -117,7 +127,10 @@ async def run(self, circuit: Circuit, data: PassData) -> None:


class RestoreIntermediatePass(BasePass):
def __init__(self, project_directory: str, load_blocks: bool = True):
def __init__(
self, project_directory: str, load_blocks: bool = True,
as_circuit_gate: bool = False,
):
"""
Constructor for the RestoreIntermediatePass.

Expand All @@ -130,30 +143,20 @@ def __init__(self, project_directory: str, load_blocks: bool = True):
the user must explicitly call load_blocks() themselves. Defaults
to True.

as_circuit_gate (bool): If True, blocks are reloaded as a circuit
gate rather than a circuit.

Raises:
ValueError: If `project_directory` does not exist or if
`structure.pickle` is invalid.
"""
self.proj_dir = project_directory
if not exists(self.proj_dir):
raise TypeError(
f"Project directory '{self.proj_dir}' does not exist.",
)
if not exists(self.proj_dir + '/structure.pickle'):
raise TypeError(
f'Project directory `{self.proj_dir}` does not '
'contain `structure.pickle`.',
)

with open(self.proj_dir + '/structure.pickle', 'rb') as f:
self.structure = pickle.load(f)

if not isinstance(self.structure, list):
raise TypeError('The provided `structure.pickle` is not a list.')

self.block_list: list[str] = []
if load_blocks:
self.reload_blocks()
self.as_circuit_gate = as_circuit_gate
# We will detect automatically if blocks are saved as qasm or pickle
self.saved_as_qasm = False

self.load_blocks = load_blocks

def reload_blocks(self) -> None:
"""
Expand All @@ -164,11 +167,18 @@ def reload_blocks(self) -> None:
ValueError: if there are more block files than indices in the
`structure.pickle`.
"""
files = listdir(self.proj_dir)
files = sorted(listdir(self.proj_dir))
# Files are of the form block_*.pickle or block_*.qasm
self.block_list = [f for f in files if 'block_' in f]
pickle_list = [f for f in self.block_list if '.pickle' in f]
if len(pickle_list) == 0:
self.saved_as_qasm = True
self.block_list = [f for f in self.block_list if '.qasm' in f]
else:
self.block_list = pickle_list
if len(self.block_list) > len(self.structure):
raise ValueError(
'More block files than indicies in `structure.pickle`',
'More block files than indices in `structure.pickle`',
)

async def run(self, circuit: Circuit, data: PassData) -> None:
Expand All @@ -179,21 +189,89 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
ValueError: if a block file and the corresponding index in
`structure.pickle` are differnt lengths.
"""
# If the circuit is empty, just append blocks in order
if circuit.depth == 0:
for block in self.block_list:
# Get block
block_num = int(findall(r'\d+', block)[0])
with open(self.proj_dir + '/' + block) as f:

if not exists(self.proj_dir):
raise TypeError(
f"Project directory '{self.proj_dir}' does not exist.",
)
if not exists(self.proj_dir + '/structure.pickle'):
raise TypeError(
f'Project directory `{self.proj_dir}` does not '
'contain `structure.pickle`.',
)

with open(self.proj_dir + '/structure.pickle', 'rb') as f:
self.structure = pickle.load(f)

if not isinstance(self.structure, list):
raise TypeError('The provided `structure.pickle` is not a list.')

if self.load_blocks:
self.reload_blocks()

# Get circuit from checkpoint, ignore previous circuit
new_circuit = Circuit(circuit.num_qudits, circuit.radixes)
for block in self.block_list:
# Get block
block_num = int(findall(r'\d+', block)[0])
if self.saved_as_qasm:
with open(join(self.proj_dir, block)) as f:
block_circ = OPENQASM2Language().decode(f.read())
# Get location
block_location = self.structure[block_num]
if block_circ.num_qudits != len(block_location):
raise ValueError(
f'{block} and `structure.pickle` locations are '
'different sizes.',
)
# Append to circuit
circuit.append_circuit(block_circ, block_location)
else:
with open(join(self.proj_dir, block), 'rb') as f:
block_circ = pickle.load(f)
# Get location
block_location = self.structure[block_num]
if block_circ.num_qudits != len(block_location):
raise ValueError(
f'{block} and `structure.pickle` locations are '
'different sizes.',
)
# Append to circuit
new_circuit.append_circuit(
block_circ, block_location,
as_circuit_gate=self.as_circuit_gate,
)

circuit.become(new_circuit)
# Check if the circuit has been partitioned, if so, try to replace
# blocks


class CheckpointRestartPass(PassAlias):
def __init__(
self, base_checkpoint_dir: str,
project_name: str,
default_passes: BasePass | Sequence[BasePass],
save_as_qasm: bool = True,
) -> None:
"""Group together one or more `passes`."""
if not is_sequence(default_passes):
default_passes = [cast(BasePass, default_passes)]

if not isinstance(default_passes, list):
default_passes = list(default_passes)

full_checkpoint_dir = join(base_checkpoint_dir, project_name)

# Check if checkpoint files exist
if not exists(join(full_checkpoint_dir, 'structure.pickle')):
_logger.info('Checkpoint does not exist!')
save_pass = SaveIntermediatePass(
base_checkpoint_dir, project_name,
save_as_qasm=save_as_qasm, overwrite=True,
)
default_passes.append(save_pass)
self.passes = default_passes
else:
# Already checkpointed, restore
_logger.info('Restoring from Checkpoint!')
self.passes = [
RestoreIntermediatePass(
full_checkpoint_dir, as_circuit_gate=True,
),
]

def get_passes(self) -> list[BasePass]:
"""Return the passes to be run, see :class:`PassAlias` for more."""
return self.passes
Loading
Loading