aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/operation.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r--ethosu/vela/operation.py15
1 files changed, 14 insertions, 1 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 967d30b..d2b08b5 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -196,7 +196,7 @@ class Op(Enum):
Max = OperatorInfo()
MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
- Mean = OperatorInfo()
+ Mean = OperatorInfo(indices=IFM_INDICES)
Min = OperatorInfo()
Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
MirrorPad = OperatorInfo()
@@ -414,6 +414,7 @@ class Operation:
"run_on_npu",
"activation",
"memory_function",
+ "forced_input_quantization",
"forced_output_quantization",
"activation_lut",
"_kernel",
@@ -422,6 +423,7 @@ class Operation:
"rescale",
"read_offsets",
"rounding_mode",
+ "low_precision_scaling",
)
def __init__(self, op_type: Op, name: str):
@@ -439,6 +441,7 @@ class Operation:
self.memory_function = None
# If not none: contains QuantizationParameters to be used as output quantization
# (which overrides the ofm tensor's quantization), used in LUT
+ self.forced_input_quantization = None
self.forced_output_quantization = None
self.scheduled_pass = None
self.op_index = None # input network operator index
@@ -451,6 +454,9 @@ class Operation:
self.rescale = None
self.read_offsets: List[Shape4D] = [None, None] # offset for [ifm, ifm2]
self.rounding_mode: Optional[NpuRoundingMode] = None
+ # The Mean operator (implemented as a depthwise convolution) requires scaling
+ # to be calculated differently in one case. In that case, this is set to True.
+ self.low_precision_scaling = False
def clone(self, suffix="_clone"):
res = Operation(self.type, self.name + suffix)
@@ -463,11 +469,13 @@ class Operation:
res.run_on_npu = self.run_on_npu
res.activation = None if self.activation is None else self.activation.clone()
res.memory_function = self.memory_function
+ res.forced_input_quantization = self.forced_input_quantization
res.forced_output_quantization = self.forced_output_quantization
res.scheduled_pass = self.scheduled_pass
res.op_index = None # not relevant as not part of input network
res.read_offsets = list(self.read_offsets)
res.rounding_mode = self.rounding_mode
+ res.low_precision_scaling = self.low_precision_scaling
return res
@@ -692,6 +700,11 @@ class Operation:
if self not in tens.consumer_list:
tens.consumer_list.append(self)
+ def get_input_quantization(self):
+ if self.forced_input_quantization is not None:
+ return self.forced_input_quantization
+ return self.ifm.quantization
+
def set_output_tensor(self, tens):
tens.ops = [self]
self.outputs = [tens]