aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/register_command_stream_generator.py
diff options
context:
space:
mode:
authorJacob Bohlin <jacob.bohlin@arm.com>2020-06-29 11:58:50 +0200
committertim.hall <tim.hall@arm.com>2020-07-10 14:20:20 +0000
commit9fbc4913fe5dd0e6090532963f8612449936d994 (patch)
tree01fa3f202116cc4b767077f5e0a3977cf2185559 /ethosu/vela/register_command_stream_generator.py
parent0b9ca78dbde3f5c29f368577de1f621845e711a2 (diff)
downloadethos-u-vela-9fbc4913fe5dd0e6090532963f8612449936d994.tar.gz
MLBEDSW-1497: Add Quantize operator support
Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com> Change-Id: Iaf4d7ab9c32b0d783072c5f131a61bfebe77cc16
Diffstat (limited to 'ethosu/vela/register_command_stream_generator.py')
-rw-r--r--ethosu/vela/register_command_stream_generator.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 4bbea01e..28bc6b79 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -430,6 +430,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
rounding_mode = rounding.TRUNCATE
fmf = primary_op.attrs.get("fused_memory_function", None)
faf = primary_op.attrs.get("fused_activation_function", None)
+ fused_quantize = any(op.type == "Quantize" for op in ps.ops)
# Specifies which operand to apply scaling to in bitexact elementwise ADD/SUB
op_to_scale = 0
@@ -628,6 +629,11 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
scale = (1 << shift) * 3 * multiplier
else:
scale = int(round_away_zero(scale * rescale))
+ elif fused_quantize:
+ # Quantize op requires different scaling
+ ifm_scale_f64 = np.double(cmd.ifm_tensor.quantization.scale_f32)
+ ofm_scale_f64 = np.double(cmd.ofm_tensor.quantization.scale_f32)
+ scale, shift = scaling.quantise_scale(ifm_scale_f64 / ofm_scale_f64)
else:
# In case avg pool fused with concat or other memory operation, rescaling might be needed.
# k_height == k_width == 1 is allways true in this case
@@ -846,7 +852,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
if tens is None:
continue
- need_zero_point = (faf is not None) or (fmf == "ConcatSliceWrite")
+ need_zero_point = (faf is not None) or (fmf == "ConcatSliceWrite") or fused_quantize
if (
primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear")) and not need_zero_point
) or tens.quantization is None: