diff options
Diffstat (limited to 'ethosu/vela/softmax.py')
-rw-r--r-- | ethosu/vela/softmax.py | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py index eb97c792..7c23f472 100644 --- a/ethosu/vela/softmax.py +++ b/ethosu/vela/softmax.py @@ -391,7 +391,9 @@ class SoftMax: F2_one = create_const_tensor( "F2_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 29)], np.int32, quantization=no_scale_quant ) - two = create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant) + four = create_const_tensor( + "four_const", [1, 1, 1, 1], DataType.int32, [4], np.int32, quantization=no_scale_quant + ) for i in range(3): # PASS 13, 18, 23 - MUL mul_op = Operation("MulAct", self.op.name + "_mul%d" % (13 + i * 5)) @@ -416,10 +418,10 @@ class SoftMax: to_rescale.quantization = one_scale_quant.clone() to_rescale.quantization.scale_f32 = 2.0 mul_op.set_output_tensor(to_rescale) - # PASS 16, 21, 26 - SHL - shl_op = Operation("SHL", self.op.name + "_shl%d" % (16 + i * 5)) + # PASS 16, 21, 26 - MUL + shl_op = Operation("MulAct", self.op.name + "_mul%d" % (16 + i * 5)) shl_op.add_input_tensor(to_rescale) - shl_op.add_input_tensor(two) + shl_op.add_input_tensor(four) to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0") to_add.quantization = no_scale_quant shl_op.set_output_tensor(to_add) @@ -431,13 +433,15 @@ class SoftMax: nr_x.quantization = one_scale_quant add_op.set_output_tensor(nr_x) - # PASS 28 - SHL - shl28_op = Operation("SHL", self.op.name + "_shl28") - shl28_op.add_input_tensor(nr_x) - shl28_op.add_input_tensor(one) - scale_factor = Tensor(reduce_sum_shape, DataType.int32, shl28_op.name + "_0") + # PASS 28 - Multiply + mul28_op = Operation("MulAct", self.op.name + "_mul28") + mul28_op.add_input_tensor(nr_x) + mul28_op.add_input_tensor( + create_const_tensor("two_const", [1, 1, 1, 1], DataType.int32, [2], np.int32, quantization=no_scale_quant) + ) + scale_factor = Tensor(reduce_sum_shape, DataType.int32, mul28_op.name + "_0") scale_factor.quantization = one_scale_quant - shl28_op.set_output_tensor(scale_factor) + mul28_op.set_output_tensor(scale_factor) # PASS 29 - Multiply mul_op = Operation("MulAct", self.op.name + "_mul29") |