aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/softmax.py
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2020-09-21 10:34:48 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-09-21 14:03:36 +0000
commitd9e38fe2bc0458fdca83dd4932abee6554fe2eb2 (patch)
treeb625411e1eee8a3e83b94f4f6f3d827a5fe8a6b7 /ethosu/vela/softmax.py
parentcb33704fcd7859b1c334f996445bba2f4efea5f9 (diff)
downloadethos-u-vela-d9e38fe2bc0458fdca83dd4932abee6554fe2eb2.tar.gz
Fix int8/uint8 softmax mul shape
Fixed incorrect ofm shape for some of the intermediate mul operations in softmax int8/uint8. Change-Id: I82351c1eb6a66b93280752f4cc00e2d0744d33b2 Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu/vela/softmax.py')
-rw-r--r--ethosu/vela/softmax.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 9e8b846d..5a5396f4 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -377,7 +377,7 @@ class SoftMax:
quantization=one_scale_quant,
),
)
- rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0")
+ rescaled = Tensor(reduce_sum_shape, DataType.int32, mul11_op.name + "_0")
rescaled.quantization = one_scale_quant.clone()
rescaled.quantization.scale_f32 = 2.0
mul11_op.set_output_tensor(rescaled)
@@ -406,7 +406,7 @@ class SoftMax:
mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
mul_op.add_input_tensor(nr_x)
mul_op.add_input_tensor(half_denominator)
- half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
+ half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0")
half_denominator_times_x.quantization = one_scale_quant.clone()
half_denominator_times_x.quantization.scale_f32 = 2.0
mul_op.set_output_tensor(half_denominator_times_x)
@@ -421,7 +421,7 @@ class SoftMax:
mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5))
mul_op.add_input_tensor(nr_x)
mul_op.add_input_tensor(one_minus_half_denominator_times_x)
- to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
+ to_rescale = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0")
to_rescale.quantization = one_scale_quant.clone()
to_rescale.quantization.scale_f32 = 2.0
mul_op.set_output_tensor(to_rescale)