aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2022-09-27 14:13:01 +0200
committerFredrik Svedberg <fredrik.svedberg@arm.com>2022-10-04 09:20:33 +0000
commit4a434cba156cdfb2613b7ebe4d4a4ec9f85ba616 (patch)
tree47ebce38221b92b8eeb34e8b5f558223dcd4d3e3
parentdda4caed56d2cd3a9d5927bf405859c1777ac909 (diff)
downloadethos-u-vela-4a434cba156cdfb2613b7ebe4d4a4ec9f85ba616.tar.gz
MLBEDSW-6969 Remove RescaleAdd and RescaleMul operators
Removed RescaleAdd and RescaleMul operators in favour of Operation.explicit_scale and removed Operation.rescale. Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com> Change-Id: Idccd8851731d4bb8d4e84970e0fd6b409d7d4e45
-rw-r--r--ethosu/vela/cascade_builder.py8
-rw-r--r--ethosu/vela/high_level_command_to_npu_op.py21
-rw-r--r--ethosu/vela/operation.py7
-rw-r--r--ethosu/vela/operation_util.py19
-rw-r--r--ethosu/vela/register_command_stream_generator.py24
-rw-r--r--ethosu/vela/softmax.py19
-rw-r--r--ethosu/vela/tflite_graph_optimiser.py13
-rw-r--r--ethosu/vela/tosa_reader.py4
-rw-r--r--ethosu/vela/tosa_supported_operators.py1
9 files changed, 42 insertions, 74 deletions
diff --git a/ethosu/vela/cascade_builder.py b/ethosu/vela/cascade_builder.py
index 3c105374..1f5dc504 100644
--- a/ethosu/vela/cascade_builder.py
+++ b/ethosu/vela/cascade_builder.py
@@ -18,6 +18,7 @@
# Groups Operators in a schedule together to form Cascades.
from collections import namedtuple
+from .high_level_command_to_npu_op import ifm_ifm2_correct_order
from .numeric_util import round_up
from .operation import NpuBlockType
from .operation import Op
@@ -169,17 +170,14 @@ class CascadeBuilder:
@staticmethod
def element_wise_cascading_conformity(sched_op):
"""Check the inputs of the op to see if it's a candidate for cascading."""
- # Cascading sub-operators of Softmax results in a crash when handling Sub and RescaleAdd ops
ifm = sched_op.parent_op.ifm
ifm2 = sched_op.parent_op.ifm2
- if sched_op.op_type in [Op.RescaleAdd]:
- return False
-
+ # Cascading elementwise operations with reverse operand order is not handled
if sched_op.parent_op.type.is_binary_elementwise_op() and ifm and ifm2:
# We cannot rule out cascadability if at least one IFM is constant
- return Op.Const in (ifm.ops[0], ifm2.ops[0])
+ return Op.Const in (ifm.ops[0].type, ifm2.ops[0].type) and ifm_ifm2_correct_order(ifm.shape, ifm2.shape)
else:
# Either one IFM is not variable or it is not a binary elementwise op - we cannot rule out cascadability
return True
diff --git a/ethosu/vela/high_level_command_to_npu_op.py b/ethosu/vela/high_level_command_to_npu_op.py
index 7923e371..974d980c 100644
--- a/ethosu/vela/high_level_command_to_npu_op.py
+++ b/ethosu/vela/high_level_command_to_npu_op.py
@@ -92,9 +92,7 @@ dtype_map = {
# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
elementwise_op_map = {
Op.Mul: NpuElementWiseOp.MUL,
- Op.RescaleMul: NpuElementWiseOp.MUL,
Op.Add: NpuElementWiseOp.ADD,
- Op.RescaleAdd: NpuElementWiseOp.ADD,
Op.Sub: NpuElementWiseOp.SUB,
Op.Minimum: NpuElementWiseOp.MIN,
Op.Maximum: NpuElementWiseOp.MAX,
@@ -312,11 +310,7 @@ def use_zero_point_0(ps, tens: Tensor, is_ifm_tensor: bool) -> bool:
(
ps.primary_op.activation is None
or forced_ofm_quantization is not None
- or (
- ps.primary_op.type.is_avgpool_op()
- and ps.primary_op.activation.op_type.is_relu_op()
- and not ps.primary_op.rescale
- )
+ or (ps.primary_op.type.is_avgpool_op() and ps.primary_op.activation.op_type.is_relu_op())
)
and (ps.primary_op.memory_function != Op.ConcatSliceWrite)
and not fused_quantize
@@ -461,7 +455,7 @@ def create_npu_activation(op: Operation) -> NpuActivation:
act = NpuActivation(act_op)
act.min = op.activation.min
act.max = op.activation.max
- if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.rescale:
+ if act_op is NpuActivationOp.NONE_OR_RELU and op.type.is_avgpool_op() and not op.explicit_scaling:
quant = op.ofm.quantization
if quant and quant.zero_point: # Zero point is not 0
scale_f32 = 1 if quant.scale_f32 is None else quant.scale_f32
@@ -544,10 +538,8 @@ def create_npu_pool_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> NpuPooling
npu_op = NpuPoolingOperation(pool_op)
set_common_op_fields(npu_op, cmd, arch)
# Pooling specific info
- npu_op.rescale = op.rescale
if op.explicit_scaling:
# Note: reuse of rescale for explicit scaling to not expose this in the external API
- assert npu_op.rescale is None
npu_op.rescale = op.explicit_scaling
return npu_op
@@ -588,7 +580,11 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu
set_common_op_fields(npu_op, cmd, arch)
# Check if output scale needs to be overridden
output_scale = None
- if op.type == Op.Add and op.original_type.is_resize_op():
+ if op.explicit_scaling is not None:
+ assert not op.explicit_scaling.per_channel
+ assert op.type in (Op.Add, Op.Mul, Op.Sub)
+ npu_op.rescale = (op.explicit_scaling.multiplier[0], op.explicit_scaling.shift[0])
+ elif op.type == Op.Add and op.original_type.is_resize_op():
# Force output scale same as the input scale for
# resizebilinear/nearestneighbor 1x1 that is converted to add
output_scale = npu_op.ifm2.quantization.scale_f32
@@ -596,9 +592,6 @@ def create_npu_elementwise_op(cmd: NpuStripe, arch: ArchitectureFeatures) -> Npu
output_scale = npu_op.ifm.quantization.scale_f32 / npu_op.ofm.quantization.scale_f32
elif op.type == Op.LeakyRelu:
output_scale = op.attrs["alpha"]
- elif op.type in (Op.RescaleAdd, Op.RescaleMul):
- assert op.rescale is not None, f"{op.type} must have rescale"
- npu_op.rescale = op.rescale
elif op.type in (Op.Add, Op.Mul, Op.Sub):
if op.activation is not None and op.activation.op_type in (Op.Sigmoid, Op.Tanh):
output_scale = 1 / 0x3000
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 8b3c88d9..8189793e 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -27,7 +27,6 @@ from typing import List
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
-from typing import Union
from .api import NpuRoundingMode
from .errors import VelaError
@@ -247,8 +246,6 @@ class Op(Enum):
ReluN1To1 = OperatorInfo(indices=NNG_IFM_INDICES)
ReluN = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
Rescale = OperatorInfo(indices=NNG_IFM_INDICES) # TOSA specific
- RescaleAdd = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
- RescaleMul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=NNG_IFM_IFM2_INDICES)
Reshape = OperatorInfo(indices=NNG_IFM_INDICES)
# resize ops map to pooling operations unless explicitly converted to other operations in the graph optimiser
ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=NNG_IFM_INDICES)
@@ -535,9 +532,6 @@ class Operation:
self._kernel = None
self.ifm_shapes: List[Shape4D] = []
self.ofm_shapes: List[Shape4D] = []
- # If not none: contains rescale to be used as output scaling
- # (which overrides the ofm tensor's scale)
- self.rescale: Optional[Union[Tuple[int, int], ExplicitScaling]] = None
self.read_offsets: List[Optional[Shape4D]] = [None, None] # offset for [ifm, ifm2]
self.read_shapes: List[Optional[Shape4D]] = [None, None] # read shape for [ifm, ifm2]
self.rounding_mode: Optional[NpuRoundingMode] = None
@@ -586,7 +580,6 @@ class Operation:
res.rounding_mode = self.rounding_mode
res.explicit_scaling = self.explicit_scaling
res.low_precision_scaling = self.low_precision_scaling
- res.rescale = self.rescale
res.ifm_resampling_mode = self.ifm_resampling_mode
res.tile_base_offsets_ifm = [_ifm.copy() for _ifm in self.tile_base_offsets_ifm]
res.tile_base_offsets_ofm = self.tile_base_offsets_ofm.copy()
diff --git a/ethosu/vela/operation_util.py b/ethosu/vela/operation_util.py
index 36a8e592..aaabddbf 100644
--- a/ethosu/vela/operation_util.py
+++ b/ethosu/vela/operation_util.py
@@ -122,25 +122,6 @@ def create_add(
)
-def create_rescale_add(
- name: str,
- ifm: Tensor,
- ifm2: Tensor,
- rescale: Tuple[int, int],
- quantization: QuantizationParameters,
- activation: Optional[ActivationFunction] = None,
- dtype: Optional[DataType] = None,
- attrs: Optional[dict] = None,
- ifm_shape: Optional[Shape4D] = None,
- ifm2_shape: Optional[Shape4D] = None,
-) -> Operation:
- op = create_binary_elementwise(
- Op.RescaleAdd, name, ifm, ifm2, quantization, activation, dtype, attrs, ifm_shape, ifm2_shape
- )
- op.rescale = rescale
- return op
-
-
def create_clz(
name: str,
ifm: Tensor,
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 5680c96f..99ac32d5 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -707,6 +707,7 @@ def generate_ofm_scaling_for_pooling(emit: CommandStreamEmitter, pool_op: NpuPoo
shift = explicit_scaling.shift[0]
else:
# for ResizeBilinear/NearestNeighbor operations with rescale
+ # Note: this is not used, but part of the public API
rescale = pool_op.rescale
rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
scale, shift = scaling.quantise_pooling_scale(kernel.height * kernel.width, rescale_bits)
@@ -759,25 +760,30 @@ def generate_scaling_for_elementwise(emit: CommandStreamEmitter, npu_op: NpuElem
else:
ofm_scale, shift = scaling.elementwise_mul_scale(input_scale, input2_scale, output_scale)
else: # Add/Sub
- opa_scale: float
- opb_scale: float
+ # Default operand scaling is no scaling
+ opa_scale = opb_scale = 1
+ opa_shift = 0
bitdepth = npu_op.ifm.data_type.size_in_bits()
use_advanced_scaling = False
- if None in (input_scale, input2_scale, output_scale):
- opa_scale = opb_scale = ofm_scale = 1
- opa_shift = shift = 0
- if npu_op.rescale is not None:
- ofm_scale, shift = npu_op.rescale
+ if npu_op.rescale is not None:
+ # Explicit ofm scaling
+ ofm_scale, shift = npu_op.rescale
+ elif None in (input_scale, input2_scale, output_scale):
+ # No ofm scaling
+ ofm_scale = 1
+ shift = 0
elif input_scale == input2_scale and bitdepth == 16:
+ # int16 same scaling
opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
input_scale, input2_scale, output_scale
)
# align the double rounding with that of advanced scaling
- opa_scale /= 2
- opb_scale /= 2
+ opa_scale //= 2
+ opb_scale //= 2
shift -= 1
opa_shift = 0 # Unused for this case
elif input_scale == input2_scale:
+ # Same scaling
opa_scale, opb_scale, ofm_scale, shift = scaling.simplified_elementwise_add_sub_scale(
input_scale, input2_scale, output_scale
)
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 9565bc5c..1655427e 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -28,6 +28,7 @@ from .api import NpuRoundingMode
from .data_type import DataType
from .debug_database import DebugDatabase
from .operation import ActivationFunction
+from .operation import ExplicitScaling
from .operation import Op
from .operation import Operation
from .operation_util import create_add
@@ -35,7 +36,6 @@ from .operation_util import create_clz
from .operation_util import create_depthwise_maxpool
from .operation_util import create_mul
from .operation_util import create_reduce_sum
-from .operation_util import create_rescale_add
from .operation_util import create_shl
from .operation_util import create_shr
from .operation_util import create_sub
@@ -351,16 +351,15 @@ class SoftMax:
f0_one_const = create_const_tensor(
"F0_one_const", [1, 1, 1, 1], DataType.int32, [(1 << 31) - 1], np.int32, quantization=no_scale_quant
)
- half_denominator = add_op_get_ofm(
- create_rescale_add(
- f"{self.op.name}_add{pass_number}",
- f0_one_const,
- shifted_sum_minus_one,
- (1, 1), # Custom rescale
- one_scale_quant,
- activation,
- )
+ add_op = create_add(
+ f"{self.op.name}_add{pass_number}",
+ f0_one_const,
+ shifted_sum_minus_one,
+ one_scale_quant,
+ activation,
)
+ add_op.explicit_scaling = ExplicitScaling(False, shift=[1], multiplier=[1]) # Custom rescale
+ half_denominator = add_op_get_ofm(add_op)
# PASS 11 - Multiply
neg_32_over_17 = create_const_tensor(
diff --git a/ethosu/vela/tflite_graph_optimiser.py b/ethosu/vela/tflite_graph_optimiser.py
index 5d6d7071..f3ca1b63 100644
--- a/ethosu/vela/tflite_graph_optimiser.py
+++ b/ethosu/vela/tflite_graph_optimiser.py
@@ -871,7 +871,7 @@ def fixup_relus_with_differing_ifm_ofm_scaling(op, arch, nng):
# Add explicit rescaling
rescale = ifm.quantization.scale_f32 / ofm.quantization.scale_f32
multiplier, shift = scaling.quantise_scale(rescale)
- relu_fused_op.rescale = ExplicitScaling(False, [shift], [multiplier])
+ relu_fused_op.explicit_scaling = ExplicitScaling(False, [shift], [multiplier])
# Tidy up and assign the ifm and ofm to the new op
ifm.consumer_list.remove(op)
@@ -991,8 +991,8 @@ def convert_prelu(op, arch, nng):
DebugDatabase.add_optimised(op, relu_op)
# Add scaled and alpha multiplied values (without scaling)
- add_op = Operation(Op.RescaleAdd, op.name + "_add")
- add_op.rescale = (1, 0) # No scale or shift
+ add_op = Operation(Op.Add, op.name + "_add")
+ add_op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
add_op.add_input_tensor(fm_alpha)
add_op.add_input_tensor(fm_scaled)
add_op.set_output_tensor(ofm)
@@ -1180,8 +1180,8 @@ def convert_lrelu_to_mul_max(op, arch):
mul_ifm.dtype = DataType.int32
min_op.set_output_tensor(mul_ifm)
min_op.set_ifm_ofm_shapes()
- new_op = Op.RescaleAdd
- op.rescale = (1, 0) # No scale or shift
+ new_op = Op.Add
+ op.explicit_scaling = ExplicitScaling(False, shift=[0], multiplier=[1]) # No scaling
DebugDatabase.add_optimised(op, min_op)
# Add multiplication with alpha
@@ -1196,8 +1196,7 @@ def convert_lrelu_to_mul_max(op, arch):
if is_converted_prelu:
# The LeakyRelu was the result from convert_prelu and the scaling is provided
scalar, alpha_scale, alpha_shift = op.attrs["alpha_scaling"]
- mul_alpha.type = Op.RescaleMul
- mul_alpha.rescale = [alpha_scale, alpha_shift]
+ mul_alpha.explicit_scaling = ExplicitScaling(False, [alpha_shift], [alpha_scale])
elif alpha == 0 or np.isinf(1 / alpha):
# Handling of alpha near or at zero
quantization.scale_f32 = np.float32(1)
diff --git a/ethosu/vela/tosa_reader.py b/ethosu/vela/tosa_reader.py
index 2bec9cf1..cd18adb2 100644
--- a/ethosu/vela/tosa_reader.py
+++ b/ethosu/vela/tosa_reader.py
@@ -23,6 +23,7 @@ import numpy as np
from .nn_graph import Graph
from .nn_graph import Subgraph
+from .operation import ExplicitScaling
from .operation import Op
from .operation import Operation
from .reader_util import align_tensor_indices_to_nng
@@ -183,8 +184,7 @@ class TosaSubgraph:
if "shift" in op.attrs and op.type == Op.Mul:
shift = op.attrs["shift"]
if shift != 0:
- op.type = Op.RescaleMul
- op.rescale = [1, shift]
+ op.explicit_scaling = ExplicitScaling(False, [shift], [1])
if op.type.is_depthwise_conv2d_op():
op.attrs["depth_multiplier"] = op.weights.shape[3]
if op.type == Op.SplitSliceRead:
diff --git a/ethosu/vela/tosa_supported_operators.py b/ethosu/vela/tosa_supported_operators.py
index 192862ef..3f3a0025 100644
--- a/ethosu/vela/tosa_supported_operators.py
+++ b/ethosu/vela/tosa_supported_operators.py
@@ -50,7 +50,6 @@ class TosaSupportedOperators:
(
Op.Add,
Op.Mul,
- Op.RescaleMul,
Op.Sub,
)
)