Skip to content

Commit

Permalink
add Dllib python style check (#3654)
Browse files Browse the repository at this point in the history
  • Loading branch information
Le-Zheng authored Dec 6, 2021
1 parent fd8b82f commit e0ecd2e
Show file tree
Hide file tree
Showing 50 changed files with 1,914 additions and 1,353 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/style-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ jobs:
run: bash python/nano/test/run-nano-codestyle-test.sh
env:
ANALYTICS_ZOO_ROOT: ${{ github.workspace }}

- name: Dllib style checking
run: bash python/dllib/dev/lint-python

- name: Orca style checking
run: bash python/orca/dev/test/lint-python
Expand Down
2 changes: 1 addition & 1 deletion python/dllib/dev/lint-python
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
PYTHON_ROOT_DIR="$SCRIPT_DIR/.."
echo $PYTHON_ROOT_DIR
PATHS_TO_CHECK="."
PATHS_TO_CHECK="$SCRIPT_DIR/../src"
PEP8_REPORT_PATH="$PYTHON_ROOT_DIR/dev/pep8-report.txt"
PYLINT_REPORT_PATH="$PYTHON_ROOT_DIR/dev/pylint-report.txt"
PYLINT_INSTALL_INFO="$PYTHON_ROOT_DIR/dev/pylint-info.txt"
Expand Down
2 changes: 1 addition & 1 deletion python/dllib/dev/pep8-1.7.0.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
RERAISE_COMMA_REGEX = re.compile(r'raise\s+\w+\s*,.*,\s*\w+\s*$')
ERRORCODE_REGEX = re.compile(r'\b[A-Z]\d{3}\b')
DOCSTRING_REGEX = re.compile(r'u?r?["\']')
EXTRANEOUS_WHITESPACE_REGEX = re.compile(r'[[({] | []}),;:]')
EXTRANEOUS_WHITESPACE_REGEX = re.compile(r'[\[({] | [\]}),;:]')
WHITESPACE_AFTER_COMMA_REGEX = re.compile(r'[,;:]\s*(?: |\t)')
COMPARE_SINGLETON_REGEX = re.compile(r'(\bNone|\bFalse|\bTrue)?\s*([=!]=)'
r'\s*(?(1)|(None|False|True))\b')
Expand Down
1 change: 0 additions & 1 deletion python/dllib/src/bigdl/dllib/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

77 changes: 39 additions & 38 deletions python/dllib/src/bigdl/dllib/contrib/onnx/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,52 +19,53 @@


def calc_output_shape(input, kernel, padding=0, stride=1, dilation=1, ceil_mode=False):
def dilated_kernel_size(kernel, dilation):
return kernel + (kernel - 1) * (dilation - 1)
rounding = math.ceil if ceil_mode else math.floor
out = (input + 2 * padding - dilated_kernel_size(kernel, dilation)) / stride + 1
out = int(rounding(out))
return out
def dilated_kernel_size(kernel, dilation):
return kernel + (kernel - 1) * (dilation - 1)

rounding = math.ceil if ceil_mode else math.floor
out = (input + 2 * padding - dilated_kernel_size(kernel, dilation)) / stride + 1
out = int(rounding(out))
return out


def parse_node_attr(node_proto):
attrs = {}
attr_proto = node_proto.attribute
attrs = {}
attr_proto = node_proto.attribute

for attr in attr_proto:
for field in ['f', 'i', 's']:
if attr.HasField(field):
attrs[attr.name] = getattr(attr, field)
for attr in attr_proto:
for field in ['f', 'i', 's']:
if attr.HasField(field):
attrs[attr.name] = getattr(attr, field)

# Needed for supporting python version > 3.5
if isinstance(attrs[attr.name], bytes):
attrs[attr.name] = attrs[attr.name].decode(encoding='utf-8')
# Needed for supporting python version > 3.5
if isinstance(attrs[attr.name], bytes):
attrs[attr.name] = attrs[attr.name].decode(encoding='utf-8')

for field in ['floats', 'ints', 'strings']:
if list(getattr(attr, field)):
assert attr.name not in attrs, "Only one type of attr is allowed"
attrs[attr.name] = tuple(getattr(attr, field))
for field in ['floats', 'ints', 'strings']:
if list(getattr(attr, field)):
assert attr.name not in attrs, "Only one type of attr is allowed"
attrs[attr.name] = tuple(getattr(attr, field))

for field in ['t', 'g']:
if attr.HasField(field):
attrs[attr.name] = getattr(attr, field)
for field in ['tensors', 'graphs']:
if list(getattr(attr, field)):
raise NotImplementedError()
if attr.name not in attrs:
raise ValueError("Cannot parse attribute: \n{}\n.".format(attr))
for field in ['t', 'g']:
if attr.HasField(field):
attrs[attr.name] = getattr(attr, field)
for field in ['tensors', 'graphs']:
if list(getattr(attr, field)):
raise NotImplementedError()
if attr.name not in attrs:
raise ValueError("Cannot parse attribute: \n{}\n.".format(attr))

return attrs
return attrs


def parse_tensor_data(tensor_proto):
try:
from onnx.numpy_helper import to_array
except ImportError:
raise ImportError("Onnx and protobuf need to be installed.")
if len(tuple(tensor_proto.dims)) > 0:
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
else:
# If it is a scalar tensor
np_array = np.array([to_array(tensor_proto)])
return np_array
try:
from onnx.numpy_helper import to_array
except ImportError:
raise ImportError("Onnx and protobuf need to be installed.")
if len(tuple(tensor_proto.dims)) > 0:
np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
else:
# If it is a scalar tensor
np_array = np.array([to_array(tensor_proto)])
return np_array
6 changes: 4 additions & 2 deletions python/dllib/src/bigdl/dllib/contrib/onnx/onnx_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def load_graph(self, graph_proto):
root_nodes.append((name, op_type))
prev_modules = [dummy_root]

bigdl_module, outputs_shape = self._make_module_from_onnx_node(op_type, inputs, prev_modules, attrs, outputs)
bigdl_module, outputs_shape = self._make_module_from_onnx_node(op_type, inputs,
prev_modules, attrs,
outputs)

assert len(outputs) == len(outputs_shape)

Expand Down Expand Up @@ -108,4 +110,4 @@ def load(model_path):

def load_model_proto(model_proto):
loader = OnnxLoader()
return loader.load_graph(model_proto.graph)
return loader.load_graph(model_proto.graph)
Loading

0 comments on commit e0ecd2e

Please sign in to comment.