aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael McGeagh <michael.mcgeagh@arm.com>2020-12-14 15:51:20 +0000
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-12-17 17:16:02 +0000
commit168954814fb6a1cc5e7b2d44784b24402ef30199 (patch)
tree35693aeee7c291695ba83f27db7f8d81272b787c
parentf842b69d007e70d70fc5cef3b6f1f50b4cabbd90 (diff)
downloadethos-u-vela-168954814fb6a1cc5e7b2d44784b24402ef30199.tar.gz
MLBEDSW-3694 Replace padding with enum
Use an Enum instead of a bytestring to specify VALID or SAME padding Signed-off-by: Michael McGeagh <michael.mcgeagh@arm.com> Change-Id: I4e87f8c32b3bfac176d822a68de061e85a558fce
-rw-r--r--ethosu/vela/graph_optimiser.py22
-rw-r--r--ethosu/vela/operation.py5
-rw-r--r--ethosu/vela/operation_util.py7
-rw-r--r--ethosu/vela/supported_operators.py11
-rw-r--r--ethosu/vela/test/test_supported_operators.py31
-rw-r--r--ethosu/vela/tflite_mapping.py5
6 files changed, 45 insertions, 36 deletions
diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py
index 15d13522..4806001f 100644
--- a/ethosu/vela/graph_optimiser.py
+++ b/ethosu/vela/graph_optimiser.py
@@ -35,6 +35,7 @@ from .operation import create_activation_function
from .operation import NpuBlockType
from .operation import Op
from .operation import Operation
+from .operation import Padding
from .operation_util import create_avgpool_nop
from .softmax import SoftMax
from .tensor import check_quantized_tens_scaling_equal
@@ -147,18 +148,18 @@ def needed_total_padding(input_size, stride, filter_size):
def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims):
ypad = needed_total_padding(int(input_dims[1]), int(stride[1]), int(kernel_size[0]))
xpad = needed_total_padding(int(input_dims[2]), int(stride[2]), int(kernel_size[1]))
- if padding_type == b"SAME":
+ if padding_type == Padding.SAME:
left_pad = (xpad + 0) // 2
right_pad = (xpad + 1) // 2
top_pad = (ypad + 0) // 2
bottom_pad = (ypad + 1) // 2
- elif padding_type == b"VALID":
+ elif padding_type == Padding.VALID:
left_pad = 0
right_pad = 0
top_pad = 0
bottom_pad = 0
else:
- raise UnsupportedFeatureError(f"Unknown padding {padding_type.decode('utf-8')}")
+ raise UnsupportedFeatureError(f"Unknown padding")
padding = (top_pad, left_pad, bottom_pad, right_pad)
skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
return padding, skirt
@@ -166,21 +167,20 @@ def calc_padding_and_skirt(padding_type, kernel_size, stride, input_dims):
def calc_upscaled_padding_and_skirt(padding_type, kernel_size, stride, input_dims, upscaling_factor):
kernel_height, kernel_width = kernel_size[0], kernel_size[1]
- if padding_type == b"SAME":
+ if padding_type == Padding.SAME:
ypad = needed_total_padding(int(input_dims[1]) * upscaling_factor, int(stride[1]), int(kernel_height))
xpad = needed_total_padding(int(input_dims[2]) * upscaling_factor, int(stride[2]), int(kernel_width))
right_pad = max(((xpad + 1) // upscaling_factor) - 1, 0)
bottom_pad = max(((ypad + 1) // upscaling_factor) - 1, 0)
left_pad = max(kernel_width - 1 - right_pad, 0)
top_pad = max(kernel_height - 1 - bottom_pad, 0)
- elif padding_type == b"VALID":
+ elif padding_type == Padding.VALID:
right_pad = max(kernel_width - 2, 0)
bottom_pad = max(kernel_height - 2, 0)
left_pad = kernel_width - 1
top_pad = kernel_height - 1
else:
- raise UnsupportedFeatureError(f"Unknown padding {padding_type.decode('utf-8')}")
-
+ raise UnsupportedFeatureError(f"Unknown padding")
padding = (top_pad, left_pad, bottom_pad, right_pad)
skirt = padding
return padding, skirt
@@ -230,10 +230,10 @@ def convert_resizebilinear_to_2x2_pool(op):
op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
if op.attrs["align_corners"]:
shape_modifier = 1
- op.attrs["padding"] = b"VALID"
+ op.attrs["padding"] = Padding.VALID
else:
shape_modifier = 0
- op.attrs["padding"] = b"SAME"
+ op.attrs["padding"] = Padding.SAME
op.inputs[0].resampling_mode = resampling_mode.NEAREST
upscaled_shape = np.array(op.inputs[0].shape[1:3])
@@ -1034,11 +1034,11 @@ def add_attrs_to_resizebilinear(op, arch, nng):
if not op.attrs["align_corners"] and out_shape == upscaled_shape:
# this means the output is supposed to be a x2 upscale,
# so we need to do SAME padding
- op.attrs["padding"] = b"SAME"
+ op.attrs["padding"] = Padding.SAME
elif op.attrs["align_corners"] and out_shape == [upscaled_shape[0] - 1, upscaled_shape[1] - 1]:
# here we can just run the avg pool without padding and
# produce a (M * 2 - 1, N * 2 - 1) sized output
- op.attrs["padding"] = b"VALID"
+ op.attrs["padding"] = Padding.VALID
else:
return op
input_tensor.resampling_mode = resampling_mode.NEAREST
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 32cba365..afc02d41 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -326,6 +326,11 @@ class Op(Enum):
return self.value.id < other.value.id
+class Padding(Enum):
+ SAME = 0
+ VALID = 1
+
+
class ActivationFunction:
"""Fused activation function"""
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 2fc7622c..a267b2ad 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -22,6 +22,7 @@ from .high_level_command_to_npu_op import ifm_ifm2_correct_order
from .operation import ActivationFunction
from .operation import Op
from .operation import Operation
+from .operation import Padding
from .tensor import create_reshape_tensor
from .tensor import QuantizationParameters
from .tensor import Tensor
@@ -29,7 +30,7 @@ from .tensor import Tensor
def create_avgpool_nop(name: str) -> Operation:
op = Operation(Op.AvgPool, name)
- op.attrs["padding"] = b"VALID"
+ op.attrs["padding"] = Padding.VALID
op.attrs["stride_w"] = 1
op.attrs["stride_h"] = 1
op.attrs["filter_width"] = 1
@@ -48,7 +49,7 @@ def create_depthwise_maxpool(
height = ifm.shape[1] * ifm.shape[2]
width = ifm.shape[3]
ifm_shape = [1, height, width, 1]
- op.attrs["padding"] = b"VALID"
+ op.attrs["padding"] = Padding.VALID
op.attrs["stride_w"] = 1
op.attrs["stride_h"] = 1
op.attrs["filter_width"] = width
@@ -67,7 +68,7 @@ def create_reduce_sum(
name: str, ifm: Tensor, quantization: QuantizationParameters, activation: Optional[ActivationFunction] = None
) -> Operation:
op = Operation(Op.ReduceSum, name)
- op.attrs["padding"] = b"VALID"
+ op.attrs["padding"] = Padding.VALID
op.attrs["stride_w"] = 1
op.attrs["stride_h"] = 1
op.attrs["filter_width"] = 1
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index a1fcf6ab..f2c2eb9e 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -24,6 +24,7 @@ from .data_type import DataType
from .numeric_util import is_integer
from .operation import get_slice_offsets
from .operation import Op
+from .operation import Padding
from .tensor import check_quantized_tens_scaling_equal
from .tflite_mapping import BUILTIN_OPERATOR_UNKNOWN
from .tflite_mapping import optype_to_builtintype
@@ -569,7 +570,7 @@ class SupportedOperators:
@staticmethod
def constraint_tconv_same(op):
"SAME padding: OFM dimensions must equal IFM dimensions multiplied by stride"
- if op.attrs["padding"] == b"SAME":
+ if op.attrs["padding"] == Padding.SAME:
w = op.kernel.stride.x
h = op.kernel.stride.y
ifm_shape = op.ifm.shape
@@ -582,7 +583,7 @@ class SupportedOperators:
def constraint_tconv_valid(op):
"""VALID padding: OFM dimensions must equal IFM dimensions multiplied by stride,
minus difference between kernel size and stride"""
- if op.attrs["padding"] == b"VALID":
+ if op.attrs["padding"] == Padding.VALID:
s_w = op.kernel.stride.x
s_h = op.kernel.stride.y
k_w = op.kernel.width
@@ -626,7 +627,7 @@ class SupportedOperators:
@docstring_format_args(filter_range)
def constraint_filter_range(cls, op):
"Kernel filter values for both width and height must be in the range [{}, {}]"
- if op.attrs["padding"] == b"SAME":
+ if op.attrs["padding"] == Padding.SAME:
w = op.kernel.width
h = op.kernel.height
filter_min, filter_max = cls.filter_range
@@ -656,7 +657,7 @@ class SupportedOperators:
@docstring_format_args(filter_height_range)
def constraint_filter_height_range_valid_pad(op):
"VALID padding: Kernel filter height must be in the range [{}, {}]"
- if op.attrs["padding"] == b"VALID":
+ if op.attrs["padding"] == Padding.VALID:
return SupportedOperators.constraint_filter_height_range(op)
return True, "Op has padding=SAME"
@@ -664,7 +665,7 @@ class SupportedOperators:
@docstring_format_args(filter_product_range)
def constraint_filter_product_range_valid_pad(op):
"VALID padding: Product of kernel filter width and height must be in the range [{}, {}]"
- if op.attrs["padding"] == b"VALID":
+ if op.attrs["padding"] == Padding.VALID:
return SupportedOperators.constraint_filter_product_range(op)
return True, "Op has padding=SAME"
diff --git a/ethosu/vela/test/test_supported_operators.py b/ethosu/vela/test/test_supported_operators.py
index f132eef7..583821a2 100644
--- a/ethosu/vela/test/test_supported_operators.py
+++ b/ethosu/vela/test/test_supported_operators.py
@@ -21,6 +21,7 @@ import numpy as np
from ethosu.vela.data_type import DataType
from ethosu.vela.operation import ActivationFunction
from ethosu.vela.operation import Op
+from ethosu.vela.operation import Padding
from ethosu.vela.supported_operators import SupportedOperators
from ethosu.vela.tensor import create_const_tensor
from ethosu.vela.tensor import QuantizationParameters
@@ -276,7 +277,7 @@ def test_constraint_depth_multiplier():
def test_constraint_tconv_stride():
# Strides must be 2
op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
- op.attrs = {"stride_w": 1, "stride_h": 1, "padding": b"SAME"}
+ op.attrs = {"stride_w": 1, "stride_h": 1, "padding": Padding.SAME}
ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
ifm.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm)
@@ -286,14 +287,14 @@ def test_constraint_tconv_stride():
def test_constraint_tconv_same():
# Valid
op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 2, 2, 1], weights_shape=[1, 1, 1, 1])
- op.attrs = {"stride_w": 2, "stride_h": 2, "padding": b"SAME"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
ifm.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm)
assert support.is_operator_supported(op)
# Invalid
op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[1, 1, 1, 1])
- op.attrs = {"stride_w": 2, "stride_h": 2, "padding": b"SAME"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.SAME}
ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
ifm.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm)
@@ -303,14 +304,14 @@ def test_constraint_tconv_same():
def test_constraint_tconv_valid():
# Valid
op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[4, 4, 1, 1])
- op.attrs = {"stride_w": 2, "stride_h": 2, "padding": b"VALID"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.VALID}
ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
ifm.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm)
assert support.is_operator_supported(op)
# Invalid
op = testutil.create_op_with_quant_tensors(Op.Conv2DBackpropInput, [0], [1, 4, 4, 1], weights_shape=[2, 2, 1, 1])
- op.attrs = {"stride_w": 2, "stride_h": 2, "padding": b"VALID"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "padding": Padding.VALID}
ifm = Tensor([1, 1, 1, 1], DataType.uint8, "ifm")
ifm.quantization = testutil.default_quant_params()
op.add_input_tensor(ifm)
@@ -320,7 +321,7 @@ def test_constraint_tconv_valid():
def test_constraint_matching_in_out_types():
# Valid
op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": b"SAME"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 2, "padding": Padding.SAME}
assert support.is_operator_supported(op)
# Invalid. datatypes for ifm and ofm must match (default uint8)
op.ifm.dtype = DataType.int8
@@ -330,7 +331,7 @@ def test_constraint_matching_in_out_types():
def test_constraint_filter_type():
# Filter width/height must be integers
op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": b"SAME"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2.5, "filter_height": "2", "padding": Padding.SAME}
assert not support.is_operator_supported(op)
@@ -338,17 +339,17 @@ def test_constraint_filter_range():
# Avg pool restrictions are dependent on padding:
# SAME padding restricts both W and H to max 8
op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 20, "filter_height": 20, "padding": b"SAME"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 20, "filter_height": 20, "padding": Padding.SAME}
assert not support.is_operator_supported(op)
# VALID padding limits are much larger
- op.attrs["padding"] = b"VALID"
+ op.attrs["padding"] = Padding.VALID
assert support.is_operator_supported(op)
def test_constraint_filter_height_range_valid_pad():
# Avg pool restrictions are dependent on padding:
op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": b"VALID"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": Padding.VALID}
assert support.is_operator_supported(op)
# VALID padding restricts to 256 in filter height
op.attrs["filter_height"] = 257
@@ -358,7 +359,7 @@ def test_constraint_filter_height_range_valid_pad():
def test_constraint_filter_product_height_range_valid_pad():
# Avg pool restrictions are dependent on padding:
op = testutil.create_op_with_quant_tensors(Op.AvgPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": b"VALID"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": Padding.VALID}
assert support.is_operator_supported(op)
# VALID padding restricts filter W x H to 256x256
op.attrs["filter_width"] = 257
@@ -368,26 +369,26 @@ def test_constraint_filter_product_height_range_valid_pad():
def test_constraint_filter_height_range():
# Max pool restrictions arent dependent on padding
op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": b"SAME"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 2, "filter_height": 256, "padding": Padding.SAME}
assert support.is_operator_supported(op)
# Restricts to 256 in filter height
op.attrs["filter_height"] = 257
assert not support.is_operator_supported(op)
# Doesnt matter if SAME or VALID
- op.attrs["padding"] = b"VALID"
+ op.attrs["padding"] = Padding.VALID
assert not support.is_operator_supported(op)
def test_constraint_filter_product_height_range():
# Max pool restrictions arent dependent on padding
op = testutil.create_op_with_quant_tensors(Op.MaxPool, [1, 8, 8, 8], [1, 8, 8, 8])
- op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": b"SAME"}
+ op.attrs = {"stride_w": 2, "stride_h": 2, "filter_width": 256, "filter_height": 256, "padding": Padding.SAME}
assert support.is_operator_supported(op)
# Restricts filter W x H to 256x256
op.attrs["filter_width"] = 257
assert not support.is_operator_supported(op)
# Doesnt matter if SAME or VALID
- op.attrs["padding"] = b"VALID"
+ op.attrs["padding"] = Padding.VALID
assert not support.is_operator_supported(op)
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index fe582614..cc6053c0 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -24,6 +24,7 @@ import numpy as np
from .data_type import DataType
from .operation import CustomType
from .operation import Op
+from .operation import Padding as opPad
from .tflite import AbsOptions
from .tflite import AddNOptions
from .tflite import AddOptions
@@ -425,8 +426,8 @@ class CustomOptionsSerializer:
padding_map = {
- Padding.SAME: b"SAME",
- Padding.VALID: b"VALID",
+ Padding.SAME: opPad.SAME,
+ Padding.VALID: opPad.VALID,
}
padding_inv_map = inverse_map(padding_map)