diff options
-rw-r--r-- | ethosu/vela/softmax.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py index 9e8b846d..5a5396f4 100644 --- a/ethosu/vela/softmax.py +++ b/ethosu/vela/softmax.py @@ -377,7 +377,7 @@ class SoftMax: quantization=one_scale_quant, ), ) - rescaled = Tensor(ifm_exp.shape, DataType.int32, mul11_op.name + "_0") + rescaled = Tensor(reduce_sum_shape, DataType.int32, mul11_op.name + "_0") rescaled.quantization = one_scale_quant.clone() rescaled.quantization.scale_f32 = 2.0 mul11_op.set_output_tensor(rescaled) @@ -406,7 +406,7 @@ class SoftMax: mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5)) mul_op.add_input_tensor(nr_x) mul_op.add_input_tensor(half_denominator) - half_denominator_times_x = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0") + half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0") half_denominator_times_x.quantization = one_scale_quant.clone() half_denominator_times_x.quantization.scale_f32 = 2.0 mul_op.set_output_tensor(half_denominator_times_x) @@ -421,7 +421,7 @@ class SoftMax: mul_op = Operation("MulAct", self.op.name + "_mul%d" % (15 + i * 5)) mul_op.add_input_tensor(nr_x) mul_op.add_input_tensor(one_minus_half_denominator_times_x) - to_rescale = Tensor(ifm_exp.shape, DataType.int32, mul_op.name + "_0") + to_rescale = Tensor(reduce_sum_shape, DataType.int32, mul_op.name + "_0") to_rescale.quantization = one_scale_quant.clone() to_rescale.quantization.scale_f32 = 2.0 mul_op.set_output_tensor(to_rescale) |