aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/weight_compressor.py
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2023-03-09 13:22:40 +0100
committerFredrik Svedberg <fredrik.svedberg@arm.com>2023-03-13 15:44:32 +0000
commitbb9885190f5f7ea959f171b38ee1dd44d3e1e75e (patch)
treead87c79350f14e56760903f6da2dc1ca107928b3 /ethosu/vela/weight_compressor.py
parent6e281afe19ea0cd9dba2cecfb73050c18f29d242 (diff)
downloadethos-u-vela-bb9885190f5f7ea959f171b38ee1dd44d3e1e75e.tar.gz
MLBEDSW-7427 Fix scale calculations for FullyConnected
Fixed scale calculations for FullyConnected to match the reference. Also removed unused low_precision_scaling. Change-Id: I4b766febff4a0010acd3de708bb49be458d22bf3 Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu/vela/weight_compressor.py')
-rw-r--r--ethosu/vela/weight_compressor.py14
1 files changed, 4 insertions, 10 deletions
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index e56cc5e5..ab22e94f 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -266,17 +266,11 @@ def _prepare_scale_and_bias(arch, tens, rescale_for_faf, explicit_scaling):
# Convert scales to np.double (from np.float32) to conform to TensorFlow Lite which
# uses double during scaling calculations
- # TensorFlow Lite casts the scales slightly differently for uint8 and int8
+ # TensorFlow Lite casts the scales slightly differently for uint8 and int8 as well as
+ # for FullyConnected operators
if not rescale_for_faf:
- if ifm_dtype == DataType.uint8:
- # for some cases of the Mean operator, the scale must be calculated differently to match reference
- if first_consumer_op.low_precision_scaling:
- scales = [
- np.double(np.single(ifm_scale) / (np.single(weight_scale) * np.single(ofm_scale)))
- for weight_scale in weight_scales
- ]
- else:
- scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
+ if ifm_dtype == DataType.uint8 or first_consumer_op.type == Op.FullyConnected:
+ scales = [np.double(ifm_scale * weight_scale) / np.double(ofm_scale) for weight_scale in weight_scales]
elif ifm_dtype == DataType.int8 or ifm_dtype == DataType.int16:
scales = [
(np.double(ifm_scale) * np.double(weight_scale)) / np.double(ofm_scale)