From c30f495dc013a73e371dd8053a0381e4707ab309 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Mon, 15 Jun 2020 20:47:35 +0100 Subject: Code clean-up using black and flake8 - No functional change Signed-off-by: Tim Hall Change-Id: I5ab1198b9d092cd041fa9b85b2dee9900d299bfc --- ethosu/vela/supported_operators.py | 52 +++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 29 deletions(-) (limited to 'ethosu/vela/supported_operators.py') diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 5676ba1c..55e718e9 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -15,7 +15,8 @@ # limitations under the License. # Description: # The SupportedOperators class which is a collection of all supported operators and parameter checks. -from .data_type import BaseType, DataType +from .data_type import BaseType +from .data_type import DataType class SupportedOperators: @@ -51,17 +52,7 @@ class SupportedOperators: self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs")) self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum")) self.binary_elem_wise_add_mul_sub = set( - ( - "AddAct", - "MulAct", - "SubAct", - "QuantizedAdd", - "QuantizedSub", - "QuantizedMul", - "Mul", - "Add", - "Sub", - ) + ("AddAct", "MulAct", "SubAct", "QuantizedAdd", "QuantizedSub", "QuantizedMul", "Mul", "Add", "Sub",) ) self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops @@ -201,13 +192,13 @@ class SupportedOperators: return False elif op.attrs["padding"] == b"VALID": kernel_h, kernel_w = weight_tensor.shape[0], weight_tensor.shape[1] - if ((ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) - or (ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0))): + if (ofm_h != (ifm_h) * stride_h + max(kernel_h - stride_h, 0)) or ( + ofm_w != (ifm_w) * stride_w + max(kernel_w - stride_w, 0) + ): return False return self.check_convolution_restrictions(op) - def check_pooling_restrictions(self, op): # check stride if op.attrs["stride_w"] > 3 or op.attrs["stride_h"] > 3: @@ -226,8 +217,9 @@ class SupportedOperators: # check kernel size if op.attrs["padding"] == b"SAME" and (op.attrs["filter_width"] > 8 or op.attrs["filter_height"] > 8): return False - if (op.attrs["padding"] == b"VALID" and - (op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256)): + if op.attrs["padding"] == b"VALID" and ( + op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256 + ): return False if op.type in self.max_pooling_ops: @@ -259,31 +251,33 @@ class SupportedOperators: # check data type ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm() # input and output datatype must match for these operators - if (op.type in self.binary_elem_wise_min_max_ops | self.unary_elem_wise_main_ops and - ifm_tensor.dtype != ofm_tensor.dtype): + if ( + op.type in self.binary_elem_wise_min_max_ops | self.unary_elem_wise_main_ops + and ifm_tensor.dtype != ofm_tensor.dtype + ): return False - if (op.type in self.binary_elem_wise_add_mul_sub): + if op.type in self.binary_elem_wise_add_mul_sub: # both inputs must have same type - if (ifm_tensor.dtype != ifm2_tensor.dtype): + if ifm_tensor.dtype != ifm2_tensor.dtype: return False # signed input check - if (ifm_tensor.dtype.type & BaseType.Signed): + if ifm_tensor.dtype.type & BaseType.Signed: # output must be signed - if (ofm_tensor.dtype.type & BaseType.Unsigned): + if ofm_tensor.dtype.type & BaseType.Unsigned: return False # and 8, 16 or 32-bit - if (ofm_tensor.element_size() not in (1, 2, 4)): + if ofm_tensor.element_size() not in (1, 2, 4): return False # unsigned input check, output must be same type or int32 - if (ifm_tensor.dtype.type & BaseType.Unsigned and not - (ifm_tensor.dtype == ofm_tensor.dtype or - ofm_tensor.dtype == DataType.int32)): + if ifm_tensor.dtype.type & BaseType.Unsigned and not ( + ifm_tensor.dtype == ofm_tensor.dtype or ofm_tensor.dtype == DataType.int32 + ): return False # check batch size if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1: - return False - if op.type in self.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None + return False + if op.type in self.binary_elem_wise_main_ops: # if op type is unary, ifm2_tensor is None if len(ifm2_tensor.shape) > 2 and ifm2_tensor.shape[0] != 1: return False -- cgit v1.2.1