aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/softmax.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 9e8b846d..5a5396f4 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -377,7 +377,7 @@ class SoftMax:
quantization=one_scale_quant,
),
)
- rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0")
+ rescaled = Tensor(reduce_sum_shape, DataType.int32, mul11_op.name + "_0")
rescaled.quantization = one_scale_quant.clone()
rescaled.quantization.scale_f32 = 2.0
mul11_op.set_output_tensor(rescaled)
@@ -406,7 +406,7 @@ class SoftMax:
mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5))
mul_op.add_input_tensor(nr_x)
mul_op.add_input_tensor(half_denominator)
- half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
+ half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0")
half_denominator_times_x.quantization = one_scale_quant.clone()
half_denominator_times_x.quantization.scale_f32 = 2.0
mul_op.set_output_tensor(half_denominator_times_x)
@@ -421,7 +421,7 @@ class SoftMax:
mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5))
mul_op.add_input_tensor(nr_x)
mul_op.add_input_tensor(one_minus_half_denominator_times_x)
- to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0")
+ to_rescale = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0")
to_rescale.quantization = one_scale_quant.clone()
to_rescale.quantization.scale_f32 = 2.0
mul_op.set_output_tensor(to_rescale)