From d9e38fe2bc0458fdca83dd4932abee6554fe2eb2 Mon Sep 17 00:00:00 2001 From: Fredrik Svedberg Date: Mon, 21 Sep 2020 10:34:48 +0200 Subject: Fix int8/uint8 softmax mul shape Fixed incorrect ofm shape for some of the intermediate mul operations in softmax int8/uint8. Change-Id: I82351c1eb6a66b93280752f4cc00e2d0744d33b2 Signed-off-by: Fredrik Svedberg --- ethosu/vela/softmax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py index 9e8b846d..5a5396f4 100644 --- a/ethosu/vela/softmax.py +++ b/ethosu/vela/softmax.py @@ -377,7 +377,7 @@ class SoftMax: quantization=one_scale_quant, ), ) - rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0") + rescaled = Tensor(reduce_sum_shape, DataType.int32, mul11_op.name + "_0") rescaled.quantization = one_scale_quant.clone() rescaled.quantization.scale_f32 = 2.0 mul11_op.set_output_tensor(rescaled) @@ -406,7 +406,7 @@ class SoftMax: mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5)) mul_op.add_input_tensor(nr_x) mul_op.add_input_tensor(half_denominator) - half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0") + half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0") half_denominator_times_x.quantization = one_scale_quant.clone() half_denominator_times_x.quantization.scale_f32 = 2.0 mul_op.set_output_tensor(half_denominator_times_x) @@ -421,7 +421,7 @@ class SoftMax: mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5)) mul_op.add_input_tensor(nr_x) mul_op.add_input_tensor(one_minus_half_denominator_times_x) - to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0") + to_rescale = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0") to_rescale.quantization = one_scale_quant.clone() to_rescale.quantization.scale_f32 = 2.0 mul_op.set_output_tensor(to_rescale) -- cgit v1.2.1