aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2022-07-21 11:46:03 +0100
committertim.hall <tim.hall@arm.com>2022-07-23 16:56:07 +0000
commit885033b5bf2f6513b438f273b2bc71964f0c6c59 (patch)
treec52a6c5bbe1c6f4295aa94206b80a37a60fcf182
parent47c7636586be265eed9e352e6ad4c090a02fb31f (diff)
downloadethos-u-vela-885033b5bf2f6513b438f273b2bc71964f0c6c59.tar.gz
MLBEDSW-4157: Add RESIZE_NEAREST_NEIGHBOR support
- Changed ResizeBilinear to support ResizeNearestNeighbor as well for 1x1 IFM, IFM equal OFM, and non-align corners - Added support for ResizeNearestNeighbor with align corners by converting to a DepthwiseConv - Updated supported operator unit tests - Added is_resize() helper function and some associated refactoring Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: Id5bdf2a25e8aa6a4f28b7236250abf768141ce37
-rw-r--r--ethosu/vela/api.py2
-rw-r--r--ethosu/vela/high_level_command_stream_generator.py2
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py30
-rw-r--r--ethosu/vela/operation.py8
-rw-r--r--ethosu/vela/pass_packing.py5
-rw-r--r--ethosu/vela/register_command_stream_generator.py2
-rw-r--r--ethosu/vela/test/test_tflite_supported_operators.py160
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py172
-rw-r--r--ethosu/vela/tflite_mapping.py2
-rw-r--r--ethosu/vela/tflite_supported_operators.py18
10 files changed, 238 insertions, 163 deletions
diff --git a/ethosu/vela/api.py b/ethosu/vela/api.py
index 399fd46d..26ca291d 100644
--- a/ethosu/vela/api.py
+++ b/ethosu/vela/api.py
@@ -374,7 +374,7 @@ class NpuPoolingOperation(NpuBlockOperation):
def __init__(self, pooling_op_type: NpuPoolingOp):
super().__init__(NpuOperationType.Pooling)
self.sub_op_type: NpuPoolingOp = pooling_op_type
- # Set to a float value for ResizeBilinear operations (affects scaling), else to None
+ # Set to a float value for ResizeBilinear/NearestNeighbor operations (affects scaling), else to None
self.rescale: Optional[float] = None
diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py
index a52bdc37..7e13b62f 100644
--- a/ethosu/vela/high_level_command_stream_generator.py
+++ b/ethosu/vela/high_level_command_stream_generator.py
@@ -85,7 +85,7 @@ def generate_high_level_commands_for_sched_op(sched_op, schedule):
upscaling = 1
if sched_op.op_type == Op.Conv2DBackpropInputSwitchedBias:
upscaling = ofm_shape.height // ifm.shape.height
- elif sched_op.op_type == Op.ResizeBilinear:
+ elif sched_op.op_type.is_resize_op():
upscaling = round_up_divide(ofm_shape.height, ifm.shape.height)
# Get kernel height and height dilation
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index e6bfc1c4..2ce150fc 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -129,7 +129,7 @@ def ifm_ifm2_correct_order(ifm_shape: List[int], ifm2_shape: List[int]) -> bool:
def get_rounding_mode(op: Operation, fused_quantize: bool) -> NpuRoundingMode:
"""Specifies type of rounding to be used"""
rounding_mode = NpuRoundingMode.TFL
- if op.type == Op.ResizeBilinear:
+ if op.type.is_resize_op():
rounding_mode = NpuRoundingMode.NATURAL
elif (
op.type.npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise)
@@ -201,17 +201,6 @@ def get_mem_limits_for_regions(arch: ArchitectureFeatures) -> Dict[int, int]:
return mem_limits
-def get_upscale(op: Operation) -> NpuResamplingMode:
- upscale = NpuResamplingMode.NONE
- if op.type == Op.ResizeBilinear:
- # perform nearest neighbor upscale
- upscale = NpuResamplingMode.NEAREST
- elif op.type == Op.Conv2DBackpropInputSwitchedBias:
- # perform insert zero upscale
- upscale = NpuResamplingMode.TRANSPOSE
- return upscale
-
-
def get_ifm_depth(npu_block_type: NpuBlockType, ifm_box: Box, ofm_box: Box) -> int:
if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct, NpuBlockType.ReduceSum):
block = ifm_box.get_block()
@@ -224,7 +213,7 @@ def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
"""Checks if quantization should use 0 as zero point"""
if tens.dtype == DataType.int32 and is_ifm_tensor:
return True
- if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
+ if ps.primary_op.type not in (Op.AvgPool, Op.CLZ, Op.SHL) and not ps.primary_op.type.is_resize_op():
return False
if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
return False
@@ -435,10 +424,9 @@ def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPooling
"""Converts the command to NpuPoolingOperation"""
ps = cmd.ps
op = ps.primary_op
- pool_op = NpuPoolingOp.AVERAGE
if op.type.is_maxpool_op():
pool_op = NpuPoolingOp.MAX
- elif op.type.is_avgpool_op() or op.type == Op.ResizeBilinear:
+ elif op.type.is_avgpool_op() or op.type.is_resize_op():
pool_op = NpuPoolingOp.AVERAGE
elif op.type == Op.ReduceSum:
pool_op = NpuPoolingOp.REDUCE_SUM
@@ -485,18 +473,18 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu
set_common_op_fields(npu_op, cmd, arch)
# Check if output scale needs to be overridden
output_scale = None
- if op.type == Op.Add and "resizebilinear" in op.attrs:
+ if op.type == Op.Add and op.original_type.is_resize_op():
# Force output scale same as the input scale for
- # resizebilinear 1x1 that is converted to add
+ # resizebilinear/nearestneighbor 1x1 that is converted to add
output_scale = npu_op.ifm2.quantization.scale_f32
- if op.type == Op.Abs:
+ elif op.type == Op.Abs:
output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
- if op.type == Op.LeakyRelu:
+ elif op.type == Op.LeakyRelu:
output_scale = op.attrs["alpha"]
- if op.type in (Op.RescaleAdd, Op.RescaleMul):
+ elif op.type in (Op.RescaleAdd, Op.RescaleMul):
assert op.rescale is not None, f"{op.type} must have rescale"
npu_op.rescale = op.rescale
- if op.type in (Op.Add, Op.Mul, Op.Sub):
+ elif op.type in (Op.Add, Op.Mul, Op.Sub):
if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
output_scale = 1 / 0x3000
if output_scale is not None:
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index f3eace7e..1a34d0e1 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -248,8 +248,9 @@ class Op(Enum):
RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
RescaleMul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Reshape = OperatorInfo(indices=NNG_IFM_INDICES)
+ # resize ops map to pooling operations unless explicitly converted to other operations in the graph optimiser
ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
- ResizeNearestNeighbor = OperatorInfo()
+ ResizeNearestNeighbor = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
ReverseSequence = OperatorInfo()
ReverseV2 = OperatorInfo()
Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=NNG_IFM_WEIGHTS_INDICES)
@@ -364,6 +365,9 @@ class Op(Enum):
def is_concat_op(self):
return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped, Op.Pack)
+ def is_resize_op(self):
+ return self in (Op.ResizeBilinear, Op.ResizeNearestNeighbor)
+
def needs_bias(self):
return bool(self.info.indices.biases)
@@ -467,6 +471,7 @@ class Operation:
__slots__ = (
"type",
+ "original_type",
"name",
"op_index",
"attrs",
@@ -497,6 +502,7 @@ class Operation:
def __init__(self, op_type: Op, name: str):
self.type = op_type
+ self.original_type = op_type
self.name = name
self.attrs: Dict[str, Any] = {}
self.inputs: List[Optional[Tensor]] = []
diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py
index 050b0965..988e52e6 100644
--- a/ethosu/vela/pass_packing.py
+++ b/ethosu/vela/pass_packing.py
@@ -61,10 +61,9 @@ mac_main_ops = set(
Op.AvgPool,
Op.MaxPool,
Op.ReduceSum,
- # deconvolution
- Op.ResizeBilinear,
)
-)
+ # resize ops use pooling operations unless explicitly converted to other operations prior to pass packing
+) | Op.op_set(Op.is_resize_op)
binary_elem_wise_main_ops = Op.op_set(Op.is_binary_elementwise_op)
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 12a36caf..a8d1ddff 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -706,7 +706,7 @@ def generate_ofm_scaling_for_pooling(emit: CommandStreamEmitter, pool_op: NpuPoo
scale = explicit_scaling.multiplier[0]
shift = explicit_scaling.shift[0]
else:
- # for ResizeBilinear operations with rescale
+ # for ResizeBilinear/NearestNeighbor operations with rescale
rescale = pool_op.rescale
rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
diff --git a/ethosu/vela/test/test_tflite_supported_operators.py b/ethosu/vela/test/test_tflite_supported_operators.py
index ab12e417..89c27997 100644
--- a/ethosu/vela/test/test_tflite_supported_operators.py
+++ b/ethosu/vela/test/test_tflite_supported_operators.py
@@ -306,84 +306,88 @@ def test_constraint_filter_product_height_range():
assert not support.is_operator_supported(op)
-def test_constraint_bilinear_resize():
- # IFM W and H == 1
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 1, 1, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
- assert support.is_operator_supported(op)
-
- # IFM == OFM
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 8, 8, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
- assert support.is_operator_supported(op)
-
- # IFM x2 == OFM ; align_corners = False
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
- assert support.is_operator_supported(op)
-
- # IFM x4 == OFM ; align_corners = False
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 16, 16, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [16, 16], np.int32))
- assert support.is_operator_supported(op)
-
- # IFM x8 == OFM ; align_corners = False
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 32, 32, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [32, 32], np.int32))
- assert support.is_operator_supported(op)
-
- # IFM -1 x2 == OFM -1 ; align_corners = True
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 7, 7, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
- op.attrs["align_corners"] = True
- assert support.is_operator_supported(op)
-
- # IFM -1 x4 == OFM -1 ; align_corners = True
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 13, 13, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [13, 13], np.int32))
- op.attrs["align_corners"] = True
- assert support.is_operator_supported(op)
-
- # IFM -1 x8 == OFM -1 ; align_corners = True
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 25, 25, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [25, 25], np.int32))
- op.attrs["align_corners"] = True
- assert support.is_operator_supported(op)
-
- # Invalid case - upscale size
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 17, 17, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [17, 17], np.int32))
- assert not support.is_operator_supported(op)
-
- # Invalid case - upscale size with align corners
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 15, 15, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [15, 15], np.int32))
- op.attrs["align_corners"] = True
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_bilinear_resize_size():
- # Invalid case - size != ofm size
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_bilinear_resize_attrs():
- # Invalid case - both align corners and half-pixel centers
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
- op.attrs["align_corners"] = True
- op.attrs["half_pixel_centers"] = True
- assert not support.is_operator_supported(op)
-
-
-def test_constraint_bilinear_resize_hpc():
- # Invalid case - half-pixel centers (not supported)
- op = testutil.create_op_with_quant_tensors(Op.ResizeBilinear, [1, 4, 4, 8], [1, 8, 8, 8])
- op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
- op.attrs["half_pixel_centers"] = True
- assert not support.is_operator_supported(op)
+def test_constraint_resize():
+ for resize_op in Op.op_set(Op.is_resize_op):
+ # IFM W and H == 1
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 1, 1, 8], [1, 8, 8, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+ assert support.is_operator_supported(op)
+
+ # IFM == OFM
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 8, 8, 8], [1, 8, 8, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+ assert support.is_operator_supported(op)
+
+ # IFM x2 == OFM ; align_corners = False
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+ assert support.is_operator_supported(op)
+
+ # IFM x4 == OFM ; align_corners = False
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 16, 16, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [16, 16], np.int32))
+ assert support.is_operator_supported(op)
+
+ # IFM x8 == OFM ; align_corners = False
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 32, 32, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [32, 32], np.int32))
+ assert support.is_operator_supported(op)
+
+ # IFM -1 x2 == OFM -1 ; align_corners = True
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 7, 7, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
+ op.attrs["align_corners"] = True
+ assert support.is_operator_supported(op)
+
+ # IFM -1 x4 == OFM -1 ; align_corners = True
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 13, 13, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [13, 13], np.int32))
+ op.attrs["align_corners"] = True
+ assert support.is_operator_supported(op)
+
+ # IFM -1 x8 == OFM -1 ; align_corners = True
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 25, 25, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [25, 25], np.int32))
+ op.attrs["align_corners"] = True
+ assert support.is_operator_supported(op)
+
+ # Invalid case - upscale size
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 17, 17, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [17, 17], np.int32))
+ assert not support.is_operator_supported(op)
+
+ # Invalid case - upscale size with align corners
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 15, 15, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [15, 15], np.int32))
+ op.attrs["align_corners"] = True
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_resize_size():
+ for resize_op in Op.op_set(Op.is_resize_op):
+ # Invalid case - size != ofm size
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [7, 7], np.int32))
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_resize_attrs():
+ for resize_op in Op.op_set(Op.is_resize_op):
+ # Invalid case - both align corners and half-pixel centers
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+ op.attrs["align_corners"] = True
+ op.attrs["half_pixel_centers"] = True
+ assert not support.is_operator_supported(op)
+
+
+def test_constraint_resize_half_pixel_centers():
+ for resize_op in Op.op_set(Op.is_resize_op):
+ # Invalid case - half-pixel centers (not supported)
+ op = testutil.create_op_with_quant_tensors(resize_op, [1, 4, 4, 8], [1, 8, 8, 8])
+ op.add_input_tensor(create_const_tensor("size", [2], DataType.int32, [8, 8], np.int32))
+ op.attrs["half_pixel_centers"] = True
+ assert not support.is_operator_supported(op)
def test_constraint_concat_pass():
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index d2899c4c..ed8fa1e3 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -279,10 +279,9 @@ def fixup_conv2d_backprop(op, arch, nng):
# Convert the op to an elementwise add
-def convert_resizebilinear_1x1_to_add(op):
- op.type = Op.Add
+def convert_resize_1x1_to_add(op):
+ op.type = Op.Add # original_type will stay as Op.ResizeBilinear or Op.ResizeNearestNeighbor
op.name = op.name + "_add"
- op.attrs["resizebilinear"] = True
# Create an input tensor filled with zeros
shape = op.ofm_shapes[0].as_list()
tens = Tensor(shape, op.inputs[0].dtype, op.inputs[1].name + "_add")
@@ -301,12 +300,103 @@ def convert_resizebilinear_1x1_to_add(op):
return op
-# Convert ResizeBilinear to a number of 2x2 nearest neighbor upscaling and one avgpool op with kernel size dependent
-# on the upscaling factor. Avgpool kernel limit of 8x8 when padding is applied limits upscaling to 8x8.
-def convert_resizebilinear_to_upscale_and_average_pool(op):
+# Convert ResizeNearestNeightbor with align corners to a depthwise convolution. The IFM will already have been upscaled
+# apart from the final x2 scaling which will be done as part of this operation. The kernel contains a single coefficient
+# to select the appropriate nearest neighbor value
+def convert_resizenn_ac_to_depthwise_conv(op, upscale_factor):
+ ifm = op.ifm
+ ofm = op.ofm
+ output_depth = ofm.shape[-1]
+ dw_op_attrs = {
+ "padding": Padding.VALID,
+ "stride_h": 1,
+ "stride_w": 1,
+ "strides": (1, 1, 1, 1),
+ "depth_multiplier": 1,
+ "channel_multiplier": 1,
+ "dilation_h_factor": 1,
+ "dilation_w_factor": 1,
+ "dilation": (1, 1, 1, 1),
+ }
+
+ # change resizebilinear to depthwise
+ op.type = Op.DepthwiseConv2DBias
+ op.attrs.update(dw_op_attrs)
+ op.set_input_tensor(ifm, 0) # ifm tensor index
+ op.activation = None
+
+ # add input resample to resize by x2
+ op.ifm_resampling_mode = resampling_mode.NEAREST
+
+ # don't care about the rounding mode as it is nearest neighbor
+
+ # setup weight tensor
+ weight_quant = QuantizationParameters()
+ weight_quant.scale_f32 = 1.0 # no scaling as only a single non-zero coeff to select the desired value
+ weight_quant.zero_point = 0
+ weight_quant.quant_dim = 0
+ ofm_dtype = ofm.dtype
+ if ofm_dtype == DataType.uint8:
+ weight_value_dtype = np.uint8
+ weight_quant.quant_min = 0
+ weight_quant.quant_max = (1 << ofm_dtype.bits) - 1
+ else:
+ if ofm_dtype == DataType.int8:
+ weight_value_dtype = np.int8
+ else:
+ assert ofm_dtype == DataType.int16
+ weight_value_dtype = np.int16
+
+ weight_quant.quant_min = -(1 << (ofm_dtype.bits - 1))
+ weight_quant.quant_max = (1 << (ofm_dtype.bits - 1)) - 1
+
+ weight_shape = [upscale_factor, upscale_factor, output_depth, output_depth] # HWIO
+
+ # the single non-zero coefficient used to select the desired value needs to be placed in the 'centre value', which
+ # is calculated by finding the 'centre position' ('*' in the diagram below) and then choosing the 'value' that is
+ # below-and-right (i.e. next) to it (D).
+ # 0---1---2
+ # | A | B |
+ # 1---*---+
+ # | C | D |
+ # 2---+---+
+ weight_values = [0] * (upscale_factor * upscale_factor)
+ centre_coeff = (upscale_factor // 2) * upscale_factor + (upscale_factor // 2)
+ weight_values[centre_coeff] = 1
+
+ # add weight tensor, this will discard the size tensor of the resize op
+ op.set_input_tensor(
+ create_const_tensor(
+ "weights",
+ weight_shape,
+ ofm.dtype,
+ np.array(weight_values).reshape(weight_shape),
+ value_dtype=weight_value_dtype,
+ quantization=weight_quant,
+ ),
+ 1, # inputs tensor weight index
+ )
+
+ # setup bias tensor by assign None and then call the fix-up function to create a suitable tensor.
+ # need to append the bias tensor as resize ops only have 2 inputs
+ assert len(op.inputs) == 2
+ op.inputs.append(None)
+ fixup_bias_tensors(op, None, None)
+
+ # finally update the shape incase we've change the tensor shapes or connections
+ op.set_ifm_ofm_shapes()
+
+ return op
+
+
+# Convert ResizeBilinear/NearestNeighbor to a number of 1x1 average pools with nearest neighbor x2 upscaling and one
+# final average pool with a kernel size that depends upon the resize ops upscaling factor (x2, x4 or x8). The maximum
+# upscale factor is limited to x8 because of the limit 8x8 kernel size limit for average pool with padding.
+def convert_resize_to_upscale_and_average_pool(op):
pre_op = op
outputs = op.outputs
dtype = op.ifm.dtype
+
op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 1, 1, 1)})
op.attrs["padding"] = Padding.SAME # doesn't really matter as the kernel is 1x1
op.ifm_resampling_mode = resampling_mode.NEAREST
@@ -321,14 +411,14 @@ def convert_resizebilinear_to_upscale_and_average_pool(op):
# between different versions of numpy. This consistency ensures that the kernel dimensions are kept integral
n = int(np.log2(upscale_factor))
- # Perform 2x2 upscaling n-1 times
+ # Perform x2 upscaling n-1 times
scaled_op = pre_op
for count in range(n - 1):
if count > 0:
scaled_op = op.clone(f"_{count}")
scaled_op.inputs[0] = pre_op.outputs[0]
- # Nearest neighbor 2x2 upscaling
+ # Nearest neighbor x2 upscaling
upscaled_shape = upscaled_shape * 2
shape = op.ofm_shapes[0].as_list()
shape[1:3] = upscaled_shape
@@ -339,17 +429,30 @@ def convert_resizebilinear_to_upscale_and_average_pool(op):
scaled_op.set_ifm_ofm_shapes()
- # Last 2x2 upscaling also applies avgpool with kernel size dependent on the upscaling factor and adds
- # padding to the right and bottom.
+ # Last x2 upscaling
if n > 1:
scaled_op = op.clone(f"_{n-1}")
scaled_op.inputs[0] = pre_op.outputs[0]
- if op.attrs["align_corners"]:
- scaled_op.attrs["padding"] = Padding.VALID
- else:
- scaled_op.attrs["padding"] = Padding.EXPLICIT
- scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
- scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
+
+ if scaled_op.original_type == Op.ResizeBilinear:
+ if scaled_op.attrs["align_corners"]:
+ # no padding
+ scaled_op.attrs["padding"] = Padding.VALID
+ else:
+ # padding to the right and bottom (limits average pool to 8x8 kernel)
+ scaled_op.attrs["padding"] = Padding.EXPLICIT
+ scaled_op.attrs["explicit_padding"] = [0, 0, upscale_factor - 1, upscale_factor - 1]
+
+ # kernal size dependent on the upscaling factor
+ scaled_op.attrs.update({"ksize": (1, upscale_factor, upscale_factor, 1)})
+ else: # Op.ResizeNearestNeighbor
+ if scaled_op.attrs["align_corners"]:
+ # use depthwise conv to select the correct value
+ scaled_op = convert_resizenn_ac_to_depthwise_conv(scaled_op, upscale_factor)
+ else:
+ # keep 1x1 kernel and average pool
+ pass
+
scaled_op.outputs = outputs
scaled_op.outputs[0].ops = [scaled_op]
scaled_op.set_ifm_ofm_shapes()
@@ -357,16 +460,16 @@ def convert_resizebilinear_to_upscale_and_average_pool(op):
return op
-def fixup_resizebilinear(op, arch, nng):
- if op.type == Op.ResizeBilinear and op.run_on_npu:
+def fixup_resize(op, arch, nng):
+ if op.type.is_resize_op() and op.run_on_npu:
if op.ifm_shapes[0] == op.ofm_shapes[0]:
- # Bypass nop resizebilinear
+ # Bypass the resize op which is essentially a NOP
op.inputs = op.inputs[:1]
op.type = Op.Identity
elif op.ifm_shapes[0].height == 1 and op.ifm_shapes[0].width == 1:
- convert_resizebilinear_1x1_to_add(op)
+ convert_resize_1x1_to_add(op)
else:
- convert_resizebilinear_to_upscale_and_average_pool(op)
+ convert_resize_to_upscale_and_average_pool(op)
return op
@@ -1130,31 +1233,6 @@ def convert_pad(op: Operation, arch, nng):
return avgpool_op
-def add_attrs_to_resizebilinear(op, arch, nng):
- if op.type == Op.ResizeBilinear and op.run_on_npu:
- input_shape = op.ifm_shapes[0]
- upscaled_height = input_shape.height * 2
- upscaled_width = input_shape.width * 2
- out_shape = op.ofm_shapes[0]
- if not op.attrs["align_corners"] and out_shape.height == upscaled_height and out_shape.width == upscaled_width:
- # this means the output is supposed to be a x2 upscale,
- # so we need to do SAME padding
- op.attrs["padding"] = Padding.SAME
- elif (
- op.attrs["align_corners"]
- and out_shape.height == (upscaled_height - 1)
- and out_shape.width == (upscaled_width - 1)
- ):
- # here we can just run the avg pool without padding and
- # produce a (M * 2 - 1, N * 2 - 1) sized output
- op.attrs["padding"] = Padding.VALID
- else:
- return op
- op.ifm_resampling_mode = resampling_mode.NEAREST
- op.attrs.update({"strides": (1, 1, 1, 1), "ksize": (1, 2, 2, 1)})
- return op
-
-
def fixup_bias_tensors(op, arch, nng):
if op.type.needs_bias() and op.bias is None:
# Op has no bias, add bias tensor filled with zeros
@@ -1577,7 +1655,7 @@ def tflite_optimise_graph(nng, arch):
fixup_conv2d_backprop,
fixup_relus_with_differing_ifm_ofm_scaling,
reorder_depthwise_weights,
- fixup_resizebilinear,
+ fixup_resize,
fixup_bias_tensors,
fixup_asymmetric_weights,
convert_mul_max_to_abs_or_lrelu,
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index bf155b9c..39b08b9e 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -799,7 +799,7 @@ builtin_operator_map = {
BuiltinOperator.RESIZE_NEAREST_NEIGHBOR: (
Op.ResizeNearestNeighbor,
OptionsSerializer("ResizeNearestNeighborOptions", ("align_corners", "half_pixel_centers")),
- TFLITE_NO_INDICES,
+ TFLITE_IFM_INDICES,
),
BuiltinOperator.LEAKY_RELU: (Op.LeakyRelu, OptionsSerializer("LeakyReluOptions", ("alpha",)), TFLITE_IFM_INDICES),
BuiltinOperator.SQUARED_DIFFERENCE: (
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index 01d2e61f..90d93d0f 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -58,7 +58,7 @@ class TFLiteSupportedOperators:
max_pooling_ops = Op.op_set(Op.is_maxpool_op)
avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
- resizing_ops = set((Op.ResizeBilinear,))
+ resizing_ops = Op.op_set(Op.is_resize_op)
fc_vector_products = set(
(
Op.QuantizedMatMul,
@@ -242,10 +242,10 @@ class TFLiteSupportedOperators:
# Resizing specific checks:
for op_type in TFLiteSupportedOperators.resizing_ops:
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bilinear_resize)
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bilinear_resize_size)
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bilinear_resize_attrs)
- self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_bilinear_resize_hpc)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_size)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_attrs)
+ self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_half_pixel_centers)
# Vector Product specific checks:
for op_type in TFLiteSupportedOperators.fc_vector_products:
@@ -589,7 +589,7 @@ class TFLiteSupportedOperators:
return True, "Op has padding=SAME"
@staticmethod
- def constraint_bilinear_resize(op):
+ def constraint_resize(op):
"""The width and height of the IFM and OFM must match one of the following criteria:
IFM W and H must both be 1
IFM must match OFM
@@ -625,7 +625,7 @@ class TFLiteSupportedOperators:
return valid, f"Op has ifm_shape={ifm_shape}, ofm_shape={ofm_shape} and align_corners={align_corners}"
@staticmethod
- def constraint_bilinear_resize_size(op):
+ def constraint_resize_size(op):
"The size tensor must match the output tensor shape"
valid = False
ofm_shape = op.ofm.shape
@@ -640,7 +640,7 @@ class TFLiteSupportedOperators:
return valid, f"Op has size={size_h}x{size_w} and ofm_shape={ofm_shape}."
@staticmethod
- def constraint_bilinear_resize_attrs(op):
+ def constraint_resize_attrs(op):
"Both align_corners and half_pixel_centers can't be True"
valid = True
align_corners = op.attrs.get("align_corners", False)
@@ -651,7 +651,7 @@ class TFLiteSupportedOperators:
return valid, "Op has both align_corners and half_pixel_centers set to True."
@staticmethod
- def constraint_bilinear_resize_hpc(op):
+ def constraint_resize_half_pixel_centers(op):
"half_pixel_centers are not supported"
valid = True
if op.attrs.get("half_pixel_centers", False):