Skip to content

Commit

Permalink
feat: can save run notebooks inplace
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloalcain committed Nov 6, 2023
1 parent cd73ebc commit a8ccb4c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 14 deletions.
22 changes: 16 additions & 6 deletions firenze/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ def convert(self, value, param, ctx):

@click.command()
@click.argument("notebook-path", type=PathOrS3(exists=True))
@click.option("--output-html-path", type=PathOrS3(), default="output.html")
@click.option("-o", "--output-html-path", type=PathOrS3(), default="output.html")
@click.option("-q", "--quiet", count=True, help="Decrease verbosity.")
@click.option(
"-i", "--in-place", is_flag=True, help="Overwrite the notebook file with the execution."
)
@click.argument("parameters", nargs=-1)
def execute_notebook(notebook_path, output_html_path, quiet, parameters):
def execute_notebook(notebook_path, output_html_path, quiet, in_place, parameters):
parsed_options = parse_options(parameters)
notebook = Notebook.from_path(notebook_path)
notebook.clean()
Expand All @@ -39,13 +42,20 @@ async def execute():
await notebook.async_execute()
done_event.set()

async def write_html():
async def write_while_running():
while not done_event.is_set():
notebook.write_html(output_html_path)
await write()
await asyncio.sleep(5)

await asyncio.gather(asyncio.create_task(execute()), asyncio.create_task(write_html()))
notebook.write_html(output_html_path)
async def write():
if in_place:
notebook.save_notebook(notebook_path)
notebook.write_html(output_html_path)

await asyncio.gather(
asyncio.create_task(execute()), asyncio.create_task(write_while_running())
)
await write()

asyncio.run(execute_and_write())

Expand Down
24 changes: 17 additions & 7 deletions firenze/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,28 @@ def from_s3(cls, s3_path, s3_client=None):
jupyter_notebook = nbformat.reads(data["Body"].read().decode("utf_8"), as_version=4)
return cls(jupyter_notebook)

def write_html(self, file_path):
def save(self, file_path, content):
if file_path.startswith("s3://"):
self.write_html_to_s3(file_path)
self._save_to_s3(file_path, content)
else:
self.write_html_to_local(file_path)
self.save_to_local(file_path, content)

def write_html(self, file_path):
self.save(file_path, self.html)

def save_notebook(self, file_path):
# Serialize the notebook to a string
notebook_str = nbformat.writes(self.jupyter_notebook)
self.save(file_path, notebook_str)

def write_html_to_local(self, file_path):
@staticmethod
def save_to_local(file_path, content):
pathlib.Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(self.html)
f.write(content)

def write_html_to_s3(self, s3_path):
@staticmethod
def _save_to_s3(s3_path, content):
bucket, key = s3_path.replace("s3://", "").split("/", 1)
s3_client = boto3.client("s3")
s3_client.put_object(Bucket=bucket, Key=key, Body=self.html)
s3_client.put_object(Bucket=bucket, Key=key, Body=content)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "firenze"
version = "0.1.3"
version = "0.1.4"
description = "A lean executor for jupyter notebooks."
authors = ["Pablo Alcain <[email protected]>"]
license = "MIT"
Expand Down
13 changes: 13 additions & 0 deletions tests/test_firenze.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,19 @@ def test_can_write_notebook_html_to_local_file(one_cell_notebook_path):
assert "Dummy text" in tmp.read().decode("utf-8")


@pytest.mark.slow
def test_can_write_notebook_ipynb_to_local_file(one_cell_notebook_path):
with open(one_cell_notebook_path) as f:
jupyter_notebook = nbformat.read(f, as_version=4)
notebook = Notebook(jupyter_notebook, DummyClient(jupyter_notebook))
notebook.execute()

with tempfile.NamedTemporaryFile(delete=True) as tmp:
notebook.save_notebook(tmp.name)
new_notebook = Notebook.from_path(tmp.name)
assert new_notebook.jupyter_notebook == notebook.jupyter_notebook


@pytest.mark.slow
def test_can_write_notebook_html_to_s3_path(mock_bucket, one_cell_notebook_path):
with open(one_cell_notebook_path) as f:
Expand Down

0 comments on commit a8ccb4c

Please sign in to comment.