aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py98
1 files changed, 72 insertions, 26 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index fe18ce35..44e0f8ec 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -19,21 +19,38 @@ from . import rewrite_graph
from .api import NpuRoundingMode
from .data_type import DataType
from .debug_database import DebugDatabase
+from .graph_optimiser_util import calc_explicit_padding
from .graph_optimiser_util import needed_total_padding
from .graph_optimiser_util import set_ifm_ofm_op_shapes
from .graph_optimiser_util import set_tensor_equivalence
from .operation import ExplicitScaling
from .operation import NpuBlockType
from .operation import Op
-from .operation import Padding
+from .operation_util import create_avgpool_nop
-def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
+def replace_rescale_with_avg_pool(rescale_op):
+ assert rescale_op.type == Op.Rescale
+
+ avgpool_op = create_avgpool_nop(rescale_op.name + "_avgpool")
+ rescale_op_clone = rescale_op.clone()
+ op = rescale_op
+ op.attrs = avgpool_op.attrs.copy()
+ op.type = Op.AvgPool
+ DebugDatabase.add_optimised(rescale_op_clone, op)
+
+ return op
+
+
+def calc_skirt(kernel, input_shape, explicit_padding):
k_w, k_h = kernel.dilated_wh()
s_x, s_y = kernel.stride
ypad = needed_total_padding(int(input_shape.height), int(s_y), int(k_h))
xpad = needed_total_padding(int(input_shape.width), int(s_x), int(k_w))
- left_pad, right_pad, top_pad, bottom_pad = explicit_padding
+
+ top, left, bottom, right = explicit_padding
+ top_pad, bottom_pad = calc_explicit_padding(int(input_shape.height), int(s_y), int(k_h), int(top), int(bottom))
+ left_pad, right_pad = calc_explicit_padding(int(input_shape.width), int(s_x), int(k_w), int(left), int(right))
padding = (top_pad, left_pad, bottom_pad, right_pad)
skirt = (top_pad, left_pad, ypad - top_pad, xpad - left_pad)
@@ -42,16 +59,14 @@ def calc_padding_and_skirt(padding_type, kernel, input_shape, explicit_padding):
def add_padding_fields(op, arch, nng):
if op.run_on_npu:
- if "padding" in op.attrs:
+ if "explicit_padding" in op.attrs:
input_shape = op.ifm_shapes[0]
if op.type == Op.Conv2DBackpropInputSwitchedBias:
# TODO not yet supported, but there will be need for separate handling
assert False
else:
- padding, skirt = calc_padding_and_skirt(
- Padding.EXPLICIT, op.kernel, input_shape, op.attrs.get("padding"),
- )
+ padding, skirt = calc_skirt(op.kernel, input_shape, op.attrs.get("explicit_padding"))
op.attrs["explicit_padding"] = padding
op.attrs["skirt"] = skirt
@@ -104,7 +119,6 @@ def rewrite_rescale(op, arch, nng):
prev_op = ifm.ops[0]
# TODO currently not supported
- assert prev_op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const)
assert len(ifm.consumer_list) == 1
input_zp = op.attrs["input_zp"]
@@ -126,27 +140,26 @@ def rewrite_rescale(op, arch, nng):
print("Error (fuse_rescale): zp of tensors producer/consumer differs unexpectedidly ")
assert False
ifm.quantization.zero_point = input_zp
-
- if not scale32:
- double_round = False
+ ofm.quantization.zero_point = output_zp
+ for s, m in zip(shift, multiplier):
+ # TODO these are the TOSA limitations
+ assert m >= 0
+ assert 2 <= s <= 62
+ # TODO these are the HW limitations
+ assert 0 <= s < (1 << 6)
+ explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
+
+ if double_round and scale32:
+ rounding_mode = NpuRoundingMode.TFL
+ else:
+ rounding_mode = NpuRoundingMode.NATURAL
if prev_op.type.is_depthwise_conv2d_op() or prev_op.type.is_conv2d_op() or prev_op.type == Op.FullyConnected:
assert len(multiplier) == len(shift) == len(prev_op.bias.values)
if ifm.dtype == DataType.int32 and per_channel:
- for s, m in zip(shift, multiplier):
- # TODO these are the TOSA limitations
- assert m >= 0
- assert 2 <= s <= 62
- # TODO these are the HW limitations
- assert 0 <= s < (1 << 6)
- prev_op.explicit_scaling = ExplicitScaling(per_channel, shift, multiplier)
- ofm.quantization.zero_point = output_zp
-
- if double_round:
- prev_op.rounding_mode = NpuRoundingMode.TFL
- else:
- prev_op.rounding_mode = NpuRoundingMode.NATURAL
+ prev_op.explicit_scaling = explicit_scaling
+ prev_op.rounding_mode = rounding_mode
# Bypass op
prev_op.set_output_tensor(ofm)
@@ -155,13 +168,42 @@ def rewrite_rescale(op, arch, nng):
else:
print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
assert False
-
+ # TODO which are the cases we need to and can do standalone Rescale?
+ # TODO should we try to identify a conversion uint8<->int8 accomplished by 2 RESCALE ops?
+ # origin might be TFLite op QUANTIZE, should we look to see if they can be translated to QUANTIZE?
+ # limited to these at the moment:
+ elif (
+ (ifm.dtype == DataType.int8 and ofm.dtype == DataType.int8)
+ or (ifm.dtype == DataType.uint8 and ofm.dtype == DataType.int8)
+ or (ifm.dtype == DataType.int8 and ofm.dtype == DataType.uint8)
+ ):
+ # Create NOP performing the RESCALE
+ avgpool_op = replace_rescale_with_avg_pool(op)
+ avgpool_op.rounding_mode = rounding_mode
+
+ if per_channel:
+ # TODO
+ avgpool_op.explicit_scaling = explicit_scaling
+ print("Warning, unsupported TOSA Rescale")
+ assert False
+ else:
+ avgpool_op.explicit_scaling = explicit_scaling
else:
print("Warning, unsupported fusing of TOSA Rescale previous operator is of type:", prev_op.type)
assert False
return op
+def fixup_quantization(op, arch, nng):
+ if op.ifm and op.ifm.quantization.zero_point is None:
+ op.ifm.quantization.zero_point = 0
+ if op.ifm2 and op.ifm2.quantization.zero_point is None:
+ op.ifm.quantization.zero_point = 0
+ if op.ofm and op.ofm.quantization.zero_point is None:
+ op.ofm.quantization.zero_point = 0
+ return op
+
+
def supported_operator_check(op, arch, nng):
op.run_on_npu = arch.tosa_supported_operators.is_operator_supported(op)
return op
@@ -187,10 +229,14 @@ def tosa_optimise_graph(nng, arch):
nng, sg, arch, [], op_rewrite_list, rewrite_unsupported=False,
)
- # Post-processing step
+ # Post-processing step 1
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
nng, sg, arch, [], [rewrite_activation, add_padding_fields],
)
+ # Post-processing step 2
+ for idx, sg in enumerate(nng.subgraphs):
+ nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(nng, sg, arch, [], [fixup_quantization],)
+
return nng