Skip to content

Commit

Permalink
Allow for directories to be given as args
Browse files Browse the repository at this point in the history
Will recursively format all files in directories.
  • Loading branch information
bsamseth committed Jul 15, 2020
1 parent f0cdfce commit 03cab97
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
2 changes: 1 addition & 1 deletion blackbricks/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.0"
__version__ = "0.4.1"
38 changes: 27 additions & 11 deletions blackbricks/blackbricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import itertools
import os
from typing import List
from typing import List, Tuple

import black
import sqlparse
Expand Down Expand Up @@ -91,21 +91,34 @@ def version_callback(version_requested: bool):
raise typer.Exit()


def filenames_callback(filenames: List[str]):
# Validate file paths:
for filename in filenames:
try:
with open(filename) as f:
pass
except FileNotFoundError:
def filenames_callback(paths: Tuple[str]):
"""Resolve the paths given into valid file names
Directories are recursively added, similarly to how black operates.
"""
paths = list(paths)
file_paths = []
while paths:
path = os.path.abspath(paths.pop())

if not os.path.exists(path):
typer.echo(
typer.style("Error:", fg=typer.colors.RED)
+ " No such file or directory: "
+ typer.style(filename, fg=typer.colors.CYAN)
+ typer.style(path, fg=typer.colors.CYAN)
)
raise typer.Exit(1)

return filenames
if os.path.isdir(path):

# Recursively add all the files/dirs in path to the paths to be consumed.
paths.extend([os.path.join(path, f) for f in os.listdir(path)])

else:

file_paths.append(path)

return file_paths


def mutually_exclusive(names, values):
Expand Down Expand Up @@ -205,7 +218,7 @@ def main(
)

output = (
f"{HEADER}\n\n"
f"{HEADER}\n"
+ f"\n\n{COMMAND}\n\n".join(
"".join(line.rstrip() + "\n" for line in cell.splitlines()).rstrip()
for cell in output_cells
Expand All @@ -229,6 +242,9 @@ def main(
with open(filename, "w") as f:
for line in output.splitlines():
f.write(line.rstrip() + "\n")

if output != content:
typer.secho(f"reformatted {filename}", bold=True)
elif check and output != content:
typer.secho(f"would reformat {filename}", bold=True)

Expand Down
1 change: 0 additions & 1 deletion test_notebooks/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Databricks notebook source

from pyspark.sql import SQLContext

sqlContext = SQLContext(spark)
Expand Down

0 comments on commit 03cab97

Please sign in to comment.