aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/weight_compressor.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/weight_compressor.py')
-rw-r--r--ethosu/vela/weight_compressor.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index b291dce..bb7cd67 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -426,13 +426,13 @@ def calc_scales_and_pack_biases(tens, arch, ofm_depth_step, rescale_for_faf=Fals
first_consumer_op = tens.consumer_list[0]
ifm_dtype = first_consumer_op.inputs[0].dtype
- ifm_scale = first_consumer_op.inputs[0].quantization.scale_f32
+ ifm_scale = first_consumer_op.get_input_quantization().scale_f32
ofm_scale = first_consumer_op.get_output_quantization().scale_f32
weight_scales = first_consumer_op.inputs[1].quantization.scale_f32
# biases can have multiple consumers for rnn cells. if so, then check that they are all the same
for op in tens.consumer_list[1:]:
- assert ifm_scale == op.inputs[0].quantization.scale_f32
+ assert ifm_scale == op.get_input_quantization().scale_f32
assert ofm_scale == op.get_output_quantization().scale_f32
assert weight_scales == op.inputs[1].quantization.scale_f32
@@ -445,7 +445,14 @@ def calc_scales_and_pack_biases(tens, arch, ofm_depth_step, rescale_for_faf=Fals
# TensorFlow Lite casts the scales slightly differently for uint8 and int8
if not rescale_for_faf:
if ifm_dtype == DataType.uint8:
- scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
+ # for some cases of the Mean operator, the scale must be calculated differently to match reference
+ if first_consumer_op.low_precision_scaling:
+ scales = [
+ np.double(np.single(ifm_scale) / (np.single(weight_scale) * np.single(ofm_scale)))
+ for weight_scale in weight_scales
+ ]
+ else:
+ scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
scales = [
(np.double(ifm_scale) * np.double(weight_scale)) / np.double(ofm_scale)