aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrik Gustavsson <patrik.gustavsson@arm.com>2021-08-17 14:26:38 +0200
committerPatrik Gustavsson <patrik.gustavsson@arm.com>2021-09-03 13:33:01 +0200
commitc74682cfd27eb2c203ce4486e712916c45da9881 (patch)
tree82ff1cefd0ce06d6072f0b1231802e7afa803b1a
parent5e5a7847b8fc1eb261c7561f44585d2f6b524df3 (diff)
downloadethos-u-vela-c74682cfd27eb2c203ce4486e712916c45da9881.tar.gz
TOSA: Support for AVGPOOL, MAXPOOL and CONV2D
Added support for -AVGPOOL and CONV2D with TFLite correspondence -MAXPOOL -additional support for replacing RESCALE ops with avgpool. No support for breaking down tensors over the size supported by NPU. Signed-off-by: Patrik Gustavsson <patrik.gustavsson@arm.com> Change-Id: I1d2aa50ac30a26283b3e6f1fe88cba1544b7c189
-rw-r--r--ethosu/vela/api.py1
-rw-r--r--ethosu/vela/graph_optimiser_util.py17
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py6
-rw-r--r--ethosu/vela/operation_util.py2
-rw-r--r--ethosu/vela/register_command_stream_generator.py21
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py17
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py98
-rw-r--r--ethosu/vela/tosa_reader.py26
-rw-r--r--ethosu/vela/tosa_supported_operators.py6
-rw-r--r--ethosu/vela/vela.py2
10 files changed, 144 insertions, 52 deletions
diff --git a/ethosu/vela/api.py b/ethosu/vela/api.py
index 69c60406..d516b8dc 100644
--- a/ethosu/vela/api.py
+++ b/ethosu/vela/api.py
@@ -25,6 +25,7 @@ from typing import Tuple
import numpy
+
API_VERSION_MAJOR = 1
API_VERSION_MINOR = 1
API_VERSION = f"{API_VERSION_MAJOR}.{API_VERSION_MINOR}"
diff --git a/ethosu/vela/graph_optimiser_util.py b/ethosu/vela/graph_optimiser_util.py
index 5e676f18..570c7244 100644
--- a/ethosu/vela/graph_optimiser_util.py
+++ b/ethosu/vela/graph_optimiser_util.py
@@ -15,6 +15,8 @@
# limitations under the License.
# Description:
# Common functions and definitions used during the graph optimization.
+from typing import Tuple
+
from .data_type import DataType
from .debug_database import DebugDatabase
from .errors import VelaError
@@ -132,6 +134,21 @@ def check_format_restrictions(tens, arch):
tens.needs_linear_format = False
+def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
+ """
+ Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
+ that provides equivalent results.
+ """
+ total_padding = needed_total_padding(input_size, stride, filter_size)
+
+ # The bottom/right padding might need downward adjustment depending on stride/input size
+ total_minus_before = total_padding - pad_before
+ output_pad_after = pad_after
+ while output_pad_after > 0 and output_pad_after % stride != total_minus_before % stride:
+ output_pad_after -= 1
+ return pad_before, output_pad_after
+
+
def needed_total_padding(input_size, stride, filter_size):
out_size = (input_size + stride - 1) // stride
needed_input = (out_size - 1) * stride + filter_size
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index f8c9de36..c5d06465 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -204,6 +204,8 @@ def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
return True
if ps.primary_op.type not in (Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL):
return False
+ if ps.primary_op.type == Op.AvgPool and ps.primary_op.explicit_scaling:
+ return False
fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
forced_ofm_quantization = ps.primary_op.forced_output_quantization
use_0 = (
@@ -413,6 +415,10 @@ def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPooling
set_common_op_fields(npu_op, cmd, arch)
# Pooling specific info
npu_op.rescale = op.rescale
+ if op.explicit_scaling:
+ # Note: reuse of rescale for explicit scaling to not expose this in the external API
+ assert npu_op.rescale is None
+ npu_op.rescale = op.explicit_scaling
return npu_op
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index c51a6b58..4a4fd335 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -39,7 +39,7 @@ def create_avgpool_nop(name: str) -> Operation:
op.attrs["strides"] = [1, 1, 1, 1]
op.attrs["ksize"] = [1, 1, 1, 1]
op.attrs["skirt"] = [0, 0, 0, 0]
- op.attrs["explicit_padding"] = [0, 0, 0, 0]
+ op.attrs["explicit_padding"] = [0, 0, 0, 0] # [top, left, bottom, right]
op.run_on_npu = True
return op
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index d61e5717..6ee0005f 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -71,6 +71,7 @@ from .ethos_u55_regs.ethos_u55_regs import rounding
from .numeric_util import quantise_float32
from .numeric_util import round_away_zero
from .numeric_util import round_up_to_int
+from .operation import ExplicitScaling
from .operation import NpuBlockType
from .range_set import MemoryAccessSet
from .register_command_stream_util import BASE_PTR_INDEX_MEM2MEM
@@ -676,11 +677,18 @@ def generate_ofm_scaling_for_pooling(emit: CommandStreamEmitter, pool_op: NpuPoo
ofm_scale_f64 = np.double(ofm_quant.scale_f32)
scale, shift = scaling.quantise_scale(ifm_scale_f64 / ofm_scale_f64)
elif pool_op.rescale is not None:
- # for ResizeBilinear operations with rescale
- rescale = pool_op.rescale
- rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
- scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
- scale = int(round_away_zero(scale * rescale))
+ if type(pool_op.rescale) == ExplicitScaling:
+ # Note: reuse of rescale for explicit scaling to not expose this in the external API
+ explicit_scaling = pool_op.rescale
+ assert explicit_scaling.per_channel is False
+ scale = explicit_scaling.multiplier[0]
+ shift = explicit_scaling.shift[0]
+ else:
+ # for ResizeBilinear operations with rescale
+ rescale = pool_op.rescale
+ rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
+ scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
+ scale = int(round_away_zero(scale * rescale))
else:
# In case avg pool fused with concat or other memory operation, rescaling might be needed.
# kernel height == kernel width == 1 is always true in this case
@@ -896,6 +904,9 @@ def generate_pooling_op(emit: CommandStreamEmitter, npu_op: NpuPoolingOperation,
use_global_scale = (
npu_op.sub_op_type in (NpuPoolingOp.AVERAGE, NpuPoolingOp.REDUCE_SUM) and sum(npu_op.padding) == 0
)
+ # Note: reuse of rescale for explicit scaling to not expose this in the external API
+ if npu_op.rescale is not None and type(npu_op.rescale) == ExplicitScaling:
+ use_global_scale = not npu_op.rescale.per_channel
generate_common(emit, npu_op, NpuBlockTraversal.DEPTH_FIRST, arch, use_global_scale=use_global_scale)
# Pooling op specific
if use_global_scale:
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index ff2f5a08..3f743e43 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -18,7 +18,6 @@
# to do the traversal of the graph.
import math
import uuid
-from typing import Tuple
import numpy as np
@@ -31,6 +30,7 @@ from .data_type import DataType
from .debug_database import DebugDatabase
from .errors import UnsupportedFeatureError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .graph_optimiser_util import calc_explicit_padding
from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
@@ -270,21 +270,6 @@ def fix_sg_input_output(op, arch, nng):
return op
-def calc_explicit_padding(input_size, stride, filter_size, pad_before, pad_after) -> Tuple[int, int]:
- """
- Based on explicit padding provided in a PAD operation, returns the corresponding hardware padding
- that provides equivalent results.
- """
- total_padding = needed_total_padding(input_size, stride, filter_size)
- # The top/left padding can be taken as is from the PAD
- output_pad_before = pad_before
- # The bottom/right padding might need downward adjustment depending on stride/input size
- output_pad_after = pad_after
- while output_pad_after > 0 and output_pad_after % stride != (total_padding - pad_before) % stride:
- output_pad_after -= 1
- return output_pad_before, output_pad_after
-
-
def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
k_w, k_h = kernel.dilated_wh()
s_x, s_y = kernel.stride
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index fe18ce35..44e0f8ec 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -19,21 +19,38 @@ from . import rewrite_graph
from .api import NpuRoundingMode
from .data_type import DataType
from .debug_database import DebugDatabase
+from .graph_optimiser_util import calc_explicit_padding
from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
from .operation import ExplicitScaling
from .operation import NpuBlockType
from .operation import Op
-from .operation import Padding
+from .operation_util import create_avgpool_nop
-def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
+def replace_rescale_with_avg_pool(rescale_op):
+ assert rescale_op.type == Op.Rescale
+
+ avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
+ rescale_op_clone = rescale_op.clone()
+ op = rescale_op
+ op.attrs = avgpool_op.attrs.copy()
+ op.type = Op.AvgPool
+ DebugDatabase.add_optimised(rescale_op_clone, op)
+
+ return op
+
+
+def calc_skirt(kernel, input_shape, explicit_padding):
k_w, k_h = kernel.dilated_wh()
s_x, s_y = kernel.stride
ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
- left_pad, right_pad, top_pad, bottom_pad = explicit_padding
+
+ top, left, bottom, right = explicit_padding
+ top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
+ left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
padding = (top_pad, left_pad, bottom_pad, right_pad)
skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
@@ -42,16 +59,14 @@ def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
def add_padding_fields(op, arch, nng):
if op.run_on_npu:
- if "padding" in op.attrs:
+ if "explicit_padding" in op.attrs:
input_shape = op.ifm_shapes[0]
if op.type == Op.Conv2DBackpropInputSwitchedBias:
# TODO not yet supported, but there will be need for separate handling
assert False
else:
- padding, skirt = calc_padding_and_skirt(
- Padding.EXPLICIT, op.kernel, input_shape, op.attrs.get("padding"),
- )
+ padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
op.attrs["explicit_padding"] = padding
op.attrs["skirt"] = skirt
@@ -104,7 +119,6 @@ def rewrite_rescale(op, arch, nng):
prev_op = ifm.ops[0]
# TODO currently not supported
- assert prev_op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const)
assert len(ifm.consumer_list) == 1
input_zp = op.attrs["input_zp"]
@@ -126,27 +140,26 @@ def rewrite_rescale(op, arch, nng):
print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
assert False
ifm.quantization.zero_point = input_zp
-
- if not scale32:
- double_round = False
+ ofm.quantization.zero_point = output_zp
+ for s, m in zip(shift, multiplier):
+ # TODO these are the TOSA limitations
+ assert m >= 0
+ assert 2 <= s <= 62
+ # TODO these are the HW limitations
+ assert 0 <= s < (1 << 6)
+ explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
+
+ if double_round and scale32:
+ rounding_mode = NpuRoundingMode.TFL
+ else:
+ rounding_mode = NpuRoundingMode.NATURAL
if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
assert len(multiplier) == len(shift) == len(prev_op.bias.values)
if ifm.dtype == DataType.int32 and per_channel:
- for s, m in zip(shift, multiplier):
- # TODO these are the TOSA limitations
- assert m >= 0
- assert 2 <= s <= 62
- # TODO these are the HW limitations
- assert 0 <= s < (1 << 6)
- prev_op.explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
- ofm.quantization.zero_point = output_zp
-
- if double_round:
- prev_op.rounding_mode = NpuRoundingMode.TFL
- else:
- prev_op.rounding_mode = NpuRoundingMode.NATURAL
+ prev_op.explicit_scaling = explicit_scaling
+ prev_op.rounding_mode = rounding_mode
# Bypass op
prev_op.set_output_tensor(ofm)
@@ -155,13 +168,42 @@ def rewrite_rescale(op, arch, nng):
else:
print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
assert False
-
+ # TODO which are the cases we need to and can do standalone Rescale?
+ # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
+ # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
+ # limited to these at the moment:
+ elif (
+ (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
+ or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
+ or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
+ ):
+ # Create NOP performing the RESCALE
+ avgpool_op = replace_rescale_with_avg_pool(op)
+ avgpool_op.rounding_mode = rounding_mode
+
+ if per_channel:
+ # TODO
+ avgpool_op.explicit_scaling = explicit_scaling
+ print("Warning, unsupported TOSA Rescale")
+ assert False
+ else:
+ avgpool_op.explicit_scaling = explicit_scaling
else:
print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
assert False
return op
+def fixup_quantization(op, arch, nng):
+ if op.ifm and op.ifm.quantization.zero_point is None:
+ op.ifm.quantization.zero_point = 0
+ if op.ifm2 and op.ifm2.quantization.zero_point is None:
+ op.ifm.quantization.zero_point = 0
+ if op.ofm and op.ofm.quantization.zero_point is None:
+ op.ofm.quantization.zero_point = 0
+ return op
+
+
def supported_operator_check(op, arch, nng):
op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
return op
@@ -187,10 +229,14 @@ def tosa_optimise_graph(nng, arch):
nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
)
- # Post-processing step
+ # Post-processing step 1
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
nng, sg, arch, [], [rewrite_activation, add_padding_fields],
)
+ # Post-processing step 2
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
+
return nng
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index dfed035d..eb317169 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -30,6 +30,7 @@ from .reader_util import clone_and_reshape_tensor
from .reader_util import decode_str
from .reader_util import fixup_tensors
from .tensor import QuantizationParameters
+from .tensor import shape_num_elements
from .tensor import Tensor
from .tflite_mapping import DataType
from .tosa.TosaGraph import TosaGraph as TG
@@ -135,6 +136,22 @@ class TosaSubgraph:
if attr_serializer is not None:
op.attrs = attr_serializer.deserialize(op_data)
+ if "padding" in op.attrs:
+ padding = op.attrs["padding"] # [top, bottom, left, right]
+ op.attrs["explicit_padding"] = (
+ padding[0],
+ padding[2],
+ padding[1],
+ padding[3],
+ ) # [top, left, bottom, right]
+ if "stride" in op.attrs:
+ stride = op.attrs["stride"]
+ if len(stride) == 2:
+ op.attrs["strides"] = (1, stride[0], stride[1], 1)
+ else:
+ # TODO CONV3D more to be done....
+ print("Unsupported kernel dimensions: ", len(stride))
+ assert False
if "dilation" in op.attrs:
dilation = op.attrs["dilation"]
if len(dilation) == 2:
@@ -160,7 +177,7 @@ class TosaSubgraph:
self.set_tensor_zp(op.ifm, quant_info["input_zp"])
if "weight_zp" in quant_info:
self.set_tensor_zp(op.weights, quant_info["weight_zp"])
- if "ouput_zp" in quant_info:
+ if "output_zp" in quant_info:
self.set_tensor_zp(op.ofm, quant_info["output_zp"])
if "a_zp" in quant_info:
self.set_tensor_zp(op.ifm, quant_info["a_zp"])
@@ -194,7 +211,12 @@ class TosaSubgraph:
data_as_numpy = tens_data.DataAsNumpy()
if tens_dtype in datatype_map_numpy:
np_dtype = datatype_map_numpy[tens_dtype]
- tens.values = np.array(data_as_numpy.view(np_dtype).reshape(shape))
+
+ # TOSA pads the tensor data
+ shape_elements = shape_num_elements(shape)
+ values = np.array(data_as_numpy.view(np_dtype))
+ values = values[0:shape_elements]
+ tens.values = values.reshape(shape)
else:
# int48 is only expected as an accumulated data/output format, int4 not supported
print(f"Error: unsupported/unexpected Tensor type {dtype}, with data")
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 51f80ebd..3b0e6b39 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -29,7 +29,11 @@ class TosaSupportedOperators:
# Categorised lists of supported operators
convolution_ops = set((Op.Conv2DBias,))
convolution_like_ops = convolution_ops
- mac_main_ops = convolution_like_ops
+ max_pooling_ops = Op.op_set(Op.is_maxpool_op)
+ avg_pooling_ops = Op.op_set(Op.is_avgpool_op)
+ pooling_ops = set((Op.ReduceSum,)) | max_pooling_ops | avg_pooling_ops
+
+ mac_main_ops = convolution_like_ops | pooling_ops
type_conversion_ops = set((Op.Rescale,))
relu_ops = set((Op.Clamp, Op.ReluN,))
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 7400b8e9..94487499 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -87,7 +87,7 @@ def process(input_name, enable_debug_db, arch, model_reader_options, compiler_op
output_tfl_filename = output_basename + "_vela.tflite"
if input_name.endswith(".tflite"):
tflite_writer.write_tflite(nng, output_tfl_filename)
- elif input_name.endswith(".tosa"):
+ if input_name.endswith(".tosa"):
rawdata_writer.write_rawdata_output(nng, arch, output_basename)
if enable_debug_db: