From e82be7c1a000277b44da7e85c527229a1d5eab2a Mon Sep 17 00:00:00 2001 From: Fredrik Svedberg Date: Mon, 18 Jan 2021 15:21:03 +0100 Subject: [MLBEDSW-2787] Remove op.attrs["rescale"] in softmax.py Added RescaleAdd operation to avoid non-standard attribute "rescale" for Add operation. Also changed ResizeBilinear in the same way. Signed-off-by: Fredrik Svedberg Change-Id: I1d286f63890585c06b8a161df1ff77e3f844a4b9 --- ethosu/vela/high_level_command_to_npu_op.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'ethosu/vela/high_level_command_to_npu_op.py') diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py index 07117025..8e4d33a5 100644 --- a/ethosu/vela/high_level_command_to_npu_op.py +++ b/ethosu/vela/high_level_command_to_npu_op.py @@ -91,6 +91,7 @@ block_traversal_map = { elementwise_op_map = { Op.Mul: NpuElementWiseOp.MUL, Op.Add: NpuElementWiseOp.ADD, + Op.RescaleAdd: NpuElementWiseOp.ADD, Op.Sub: NpuElementWiseOp.SUB, Op.Minimum: NpuElementWiseOp.MIN, Op.Maximum: NpuElementWiseOp.MAX, @@ -386,8 +387,8 @@ def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPooling npu_op = NpuPoolingOperation(pool_op) set_common_op_fields(npu_op, cmd, arch) # Pooling specific info - if op.type == Op.ResizeBilinear and "rescale" in op.attrs: - npu_op.rescale = op.attrs["rescale"] + if op.type == Op.ResizeBilinear: + npu_op.rescale = op.rescale return npu_op @@ -426,8 +427,9 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu output_scale = npu_op.ifm2.quantization.scale_f32 if op.type == Op.LeakyRelu: output_scale = op.attrs["alpha"] - if op.type in (Op.Add, Op.Sub) and "rescale" in op.attrs: - npu_op.rescale = op.attrs.get("rescale") + if op.type == Op.RescaleAdd: + assert op.rescale is not None, f"{op.type} must have rescale" + npu_op.rescale = op.rescale if op.type in (Op.Add, Op.Mul, Op.Sub): if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh): output_scale = 1 / 0x3000 -- cgit v1.2.1