diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2020-09-21 10:34:48 +0200 |
---|---|---|
committer | patrik.gustavsson <patrik.gustavsson@arm.com> | 2020-09-21 14:03:36 +0000 |
commit | d9e38fe2bc0458fdca83dd4932abee6554fe2eb2 (patch) | |
tree | b625411e1eee8a3e83b94f4f6f3d827a5fe8a6b7 /ethosu | |
parent | cb33704fcd7859b1c334f996445bba2f4efea5f9 (diff) | |
download | ethos-u-vela-d9e38fe2bc0458fdca83dd4932abee6554fe2eb2.tar.gz |
Fix int8/uint8 softmax mul shape
Fixed incorrect ofm shape for some of the intermediate mul
operations in softmax int8/uint8.
Change-Id: I82351c1eb6a66b93280752f4cc00e2d0744d33b2
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu')
-rw-r--r-- | ethosu/vela/softmax.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py index 9e8b846d..5a5396f4 100644 --- a/ethosu/vela/softmax.py +++ b/ethosu/vela/softmax.py @@ -377,7 +377,7 @@ class SoftMax: quantization=one_scale_quant, ), ) - rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0") + rescaled = Tensor(reduce_sum_shape, DataType.int32, mul11_op.name + "_0") rescaled.quantization = one_scale_quant.clone() rescaled.quantization.scale_f32 = 2.0 mul11_op.set_output_tensor(rescaled) @@ -406,7 +406,7 @@ class SoftMax: mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5)) mul_op.add_input_tensor(nr_x) mul_op.add_input_tensor(half_denominator) - half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0") + half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0") half_denominator_times_x.quantization = one_scale_quant.clone() half_denominator_times_x.quantization.scale_f32 = 2.0 mul_op.set_output_tensor(half_denominator_times_x) @@ -421,7 +421,7 @@ class SoftMax: mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5)) mul_op.add_input_tensor(nr_x) mul_op.add_input_tensor(one_minus_half_denominator_times_x) - to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0") + to_rescale = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0") to_rescale.quantization = one_scale_quant.clone() to_rescale.quantization.scale_f32 = 2.0 mul_op.set_output_tensor(to_rescale) |