aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/register_command_stream_generator.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-09-30 09:01:52 +0200
committerLouis Verhaard <louis.verhaard@arm.com>2020-10-08 16:29:29 +0200
commitaee5d7537ff81ffda5ba222721b72f914ce50fb8 (patch)
tree495b9dfff2a188c6916f8ca2e390ee88f7da8ccc /ethosu/vela/register_command_stream_generator.py
parent36ad73a0fb46d3f844845c97c56d92de2a7a9b3d (diff)
downloadethos-u-vela-aee5d7537ff81ffda5ba222721b72f914ce50fb8.tar.gz
MLBEDSW-3148: Refactor Operation
- op.type is now an enum instead of a string - Removed unused operator codes - Refactored some attributes like npu_block_type, fused_activation_function - Refactored operator index calculation - Refactored a number of operator sets Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/register_command_stream_generator.py')
-rw-r--r--ethosu/vela/register_command_stream_generator.py85
1 files changed, 40 insertions, 45 deletions
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index da9be668..073b50fb 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -50,6 +50,7 @@ from .numeric_util import quantise_float32
from .numeric_util import round_away_zero
from .numeric_util import round_up_to_int
from .operation import NpuBlockType
+from .operation import Op
from .tensor import MemType
from .tensor import TensorBlockTraversal
from .tensor import TensorFormat
@@ -357,16 +358,16 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
# Maps an elementwise op type to an elementwise_mode enum value used by NPU_OP_ELEMENTWISE
elementwise_mode_map = {
- "MulAct": elementwise_mode.MUL.value,
- "AddAct": elementwise_mode.ADD.value,
- "SubAct": elementwise_mode.SUB.value,
- "Minimum": elementwise_mode.MIN.value,
- "Maximum": elementwise_mode.MAX.value,
- "LeakyRelu": elementwise_mode.LRELU.value,
- "Abs": elementwise_mode.ABS.value,
- "CLZ": elementwise_mode.CLZ.value,
- "SHR": elementwise_mode.SHR.value,
- "SHL": elementwise_mode.SHL.value,
+ Op.Mul: elementwise_mode.MUL.value,
+ Op.Add: elementwise_mode.ADD.value,
+ Op.Sub: elementwise_mode.SUB.value,
+ Op.Minimum: elementwise_mode.MIN.value,
+ Op.Maximum: elementwise_mode.MAX.value,
+ Op.LeakyRelu: elementwise_mode.LRELU.value,
+ Op.Abs: elementwise_mode.ABS.value,
+ Op.CLZ: elementwise_mode.CLZ.value,
+ Op.SHR: elementwise_mode.SHR.value,
+ Op.SHL: elementwise_mode.SHL.value,
}
cmd_stream = []
@@ -439,15 +440,15 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
rounding_mode = (
rounding.NATURAL if primary_op.attrs.get("rounding_mode", "") == b"NATURAL" else rounding.TFL
)
- if primary_op.type == "ResizeBilinear":
+ if primary_op.type == Op.ResizeBilinear:
rounding_mode = rounding.TRUNCATE
- fmf = primary_op.attrs.get("fused_memory_function", None)
- faf = primary_op.attrs.get("fused_activation_function", None)
- fused_quantize = any(op.type == "Quantize" for op in ps.ops)
+ fmf = primary_op.memory_function
+ faf = primary_op.activation
+ fused_quantize = any(op.type == Op.Quantize for op in ps.ops)
# Force output scale, used in operations with fused LUT
# Note: with current LUT support, forced_ofm_quantization is always equal to cmd.ofm_tensor.quantization
# except when primary_op is AddAct + 0 (no-op) + LUT
- forced_ofm_quantization = primary_op.attrs.get("forced_output_quantization", None)
+ forced_ofm_quantization = primary_op.forced_output_quantization
ofm_quant = cmd.ofm_tensor.quantization
if forced_ofm_quantization is not None:
ofm_quant = forced_ofm_quantization
@@ -482,16 +483,16 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
ifm2_broadcast |= IFM2Broadcast.ReverseOperandOrder
# Calculate scales needed for arithmetic elementwise operators
- if primary_op.type in set(("AddAct", "MulAct", "SubAct",)):
+ if primary_op.type in set((Op.Add, Op.Mul, Op.Sub,)):
input_scale = cmd.ifm_tensor.quantization.scale_f32 if cmd.ifm_tensor.quantization else None
input2_scale = cmd.ifm2_tensor.quantization.scale_f32 if cmd.ifm2_tensor.quantization else None
output_scale = ofm_quant.scale_f32 if ofm_quant else None
use_global_scale = True
- if output_scale is not None and faf in ("Sigmoid", "Tanh"):
+ if output_scale is not None and faf in (Op.Sigmoid, Op.Tanh):
output_scale = 1 / 0x3000
- if primary_op.type == "MulAct":
+ if primary_op.type == Op.Mul:
if None in (input_scale, input2_scale, output_scale):
ofm_scale = 1
shift = 0
@@ -537,11 +538,11 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
emit.cmd1_with_offset(cmd1.NPU_SET_OPB_SCALE, opb_scale)
emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, ofm_scale, shift)
- elif primary_op.type in set(("LeakyRelu", "Abs",)):
+ elif primary_op.type in set((Op.LeakyRelu, Op.Abs,)):
output_scale = ofm_quant.scale_f32
use_global_scale = True
- if primary_op.type == "LeakyRelu":
+ if primary_op.type == Op.LeakyRelu:
output_scale = primary_op.attrs["alpha"]
ofm_scale, shift = scaling.quantise_scale(output_scale)
@@ -599,10 +600,10 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
emit.cmd0_with_param(cmd0.NPU_SET_ACC_FORMAT, acc_format_map[shared_buffer.use_accumulator_element])
- if primary_op.type == "ResizeBilinear":
+ if primary_op.type == Op.ResizeBilinear:
# perform nearest neighbor upscale
emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, resampling_mode.NEAREST)
- elif primary_op.type == "Conv2DBackpropInputSwitchedBias":
+ elif primary_op.type == Op.Conv2DBackpropInputSwitchedBias:
# perform insert zero upscale
emit.cmd0_with_param(cmd0.NPU_SET_IFM_UPSCALE, resampling_mode.TRANSPOSE)
else:
@@ -651,12 +652,9 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
valid_padding = sum(explicit_padding) == 0
- if (
- primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear", "ReduceSum"))
- and valid_padding
- ):
+ if primary_op.type in set((Op.AvgPool, Op.ResizeBilinear, Op.ReduceSum)) and valid_padding:
# For valid padding vela has to output scaling values
- if faf == "Sigmoid" or faf == "Tanh":
+ if faf == Op.Sigmoid or faf == Op.Tanh:
rescale = 0x3000 * cmd.ifm_tensor.quantization.scale_f32
if cmd.ifm_tensor.dtype == DataType.int16:
# Calculate scale and shift for the output scale of 1/(3*4096)
@@ -675,7 +673,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
ifm_scale_f64 = np.double(cmd.ifm_tensor.quantization.scale_f32)
ofm_scale_f64 = np.double(ofm_quant.scale_f32)
scale, shift = scaling.quantise_scale(ifm_scale_f64 / ofm_scale_f64)
- elif primary_op.type == "ResizeBilinear" and "rescale" in primary_op.attrs:
+ elif primary_op.type == Op.ResizeBilinear and "rescale" in primary_op.attrs:
rescale = primary_op.attrs["rescale"]
rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
scale, shift = scaling.quantise_pooling_scale(k_height * k_width, rescale_bits)
@@ -689,7 +687,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
rescale = cmd.ifm_tensor.quantization.scale_f32 / ofm_quant.scale_f32
rescale_bits = 0
if k_height == k_width == 1:
- if fmf == "ConcatSliceWrite":
+ if fmf == Op.ConcatSliceWrite:
rounding_mode = rounding.NATURAL
if rescale > 1:
rescale_bits = len(bin(round_up_to_int(rescale))) - 2 + 1
@@ -814,35 +812,35 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
# Even if no activation function, values need to be set to override previous values
faf_min = ofm_quant_qmin
faf_max = ofm_quant_qmax
- elif faf == "Relu":
+ elif faf == Op.Relu:
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
faf_min = quantise_float32(0.0, ofm_quant.scale_f32, ofm_quant.zero_point)
faf_max = ofm_quant_qmax
- elif faf == "Relu6":
+ elif faf == Op.Relu6:
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
faf_min = quantise_float32(0.0, ofm_quant.scale_f32, ofm_quant.zero_point)
faf_max = quantise_float32(6.0, ofm_quant.scale_f32, ofm_quant.zero_point)
- elif faf == "ReluN1To1":
+ elif faf == Op.ReluN1To1:
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.NONE)
faf_min = quantise_float32(-1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
- elif faf == "Tanh":
+ elif faf == Op.Tanh:
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.TANH)
- if primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear")):
+ if primary_op.type in set((Op.AvgPool, Op.ResizeBilinear)):
faf_min = quantise_float32(-1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
else:
faf_min = quantise_float32(clamp_tanh(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
faf_max = quantise_float32(clamp_tanh(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
- elif faf == "Sigmoid":
+ elif faf == Op.Sigmoid:
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, activation.SIGMOID)
- if primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear")):
+ if primary_op.type in set((Op.AvgPool, Op.ResizeBilinear)):
faf_min = quantise_float32(0, ofm_quant.scale_f32, ofm_quant.zero_point)
faf_max = quantise_float32(1.0, ofm_quant.scale_f32, ofm_quant.zero_point)
else:
faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point)
faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point)
- elif faf == "LUT":
+ elif faf == Op.LUT:
lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", -1)
assert activation.LUT_START.value <= lut_index <= activation.LUT_END.value, "LUT index out of range."
if cmd.ofm_tensor.dtype == DataType.int32:
@@ -851,7 +849,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
faf_min = ofm_quant_qmin
faf_max = ofm_quant_qmax
else:
- raise Exception("Unsupported fused_activation_function = " + faf)
+ raise Exception("Unsupported fused_activation_function = " + faf.name)
# Activation range needs to be set based upon the quantisation range and the fused activation range
emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION_MIN, max(ofm_quant_qmin, faf_min))
@@ -911,14 +909,11 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
need_zero_point = (
(faf is not None and forced_ofm_quantization is None)
- or (fmf == "ConcatSliceWrite")
+ or (fmf == Op.ConcatSliceWrite)
or fused_quantize
)
if (
- (
- primary_op.type in set(("AvgPool", "AvgPoolAct", "ResizeBilinear", "CLZ", "SHL"))
- and not need_zero_point
- )
+ (primary_op.type in set((Op.AvgPool, Op.ResizeBilinear, Op.CLZ, Op.SHL)) and not need_zero_point)
or (
tens.dtype == DataType.int32
and zero_point_op in (cmd0.NPU_SET_IFM_ZERO_POINT, cmd0.NPU_SET_IFM2_ZERO_POINT)
@@ -933,7 +928,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
zero_point = forced_ofm_quantization.zero_point
elif (
"resizebilinear" in primary_op.attrs
- and primary_op.type == "AddAct"
+ and primary_op.type == Op.Add
and cmd0.NPU_SET_OFM_ZERO_POINT == zero_point_op
):
# Force output zero point same as the input zero point
@@ -1108,7 +1103,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
# Vector product is implemented using a 1x1 convolution
emit.cmd_do_operation(cmd0.NPU_OP_CONV)
elif npu_block_type == NpuBlockType.Pooling:
- param = pooling_mode.MAX.value if "Max" in primary_op.type else pooling_mode.AVERAGE.value
+ param = pooling_mode.MAX.value if primary_op.type.is_maxpool_op() else pooling_mode.AVERAGE.value
emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=param)
elif npu_block_type == NpuBlockType.ReduceSum:
emit.cmd_do_operation(cmd0.NPU_OP_POOL, param=pooling_mode.REDUCE_SUM.value)