aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-08 22:19:41 +0000
committerTai Ly <tai.ly@arm.com>2024-03-17 19:56:21 -0700
commit60dc48c4ddf30f2a76d4cfcf1b40ca57b6f3bf95 (patch)
treee3d229a2d596e1a0788dfd75d77b996263055496
parente67115ef82bcba0718dcbd75cc8411985001b7cc (diff)
downloadreference_model-60dc48c4ddf30f2a76d4cfcf1b40ca57b6f3bf95.tar.gz
[ref model] Change Clamp and Pad attribute fields
This implements changes due to ClampAttribute and PadAttribute field changes. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ide01e2a27fe3c1ea7794e7a4b6780b7eae436caf
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosabin1504 -> 1484 bytes
-rw-r--r--reference_model/src/ops/activation_funcs.cc29
-rw-r--r--reference_model/src/ops/data_layout.cc16
m---------thirdparty/serialization_lib0
-rw-r--r--verif/generator/tosa_arg_gen.py16
-rw-r--r--verif/generator/tosa_test_gen.py24
-rw-r--r--verif/generator/tosa_utils.py5
7 files changed, 56 insertions, 34 deletions
diff --git a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
index 87bafd1..01d8375 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 1f4c3b3..de7d8be 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -17,6 +17,7 @@
#include "arith_util.h"
#include "quant_util.h"
#include "template_types.h"
+#include "tosa_serialization_handler.h"
#include <cmath>
using namespace TosaReference;
@@ -35,8 +36,11 @@ int OpClamp<Rank, Dtype>::register_fcn()
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32: {
- InEigenType min = (InEigenType)attribute->min_fp();
- InEigenType max = (InEigenType)attribute->max_fp();
+ std::vector<float> min_float_data, max_float_data;
+ TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data);
+ TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data);
+ InEigenType min = (InEigenType)min_float_data[0];
+ InEigenType max = (InEigenType)max_float_data[0];
ERROR_IF(max < min, "OpClamp: max smaller than min");
this->fcn = [min, max](InEigenType a) -> OutEigenType {
@@ -45,23 +49,32 @@ int OpClamp<Rank, Dtype>::register_fcn()
}
break;
case TOSA_REF_TYPE_FP64: {
- InEigenType min = (InEigenType)attribute->min_fp();
- InEigenType max = (InEigenType)attribute->max_fp();
+ std::vector<float> min_float_data, max_float_data;
+ TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data);
+ TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data);
+ InEigenType min = (InEigenType)min_float_data[0];
+ InEigenType max = (InEigenType)max_float_data[0];
ERROR_IF(max < min, "OpClamp: max smaller than min");
this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); };
}
break;
case TOSA_REF_TYPE_INT8: {
- int8_t min = (int8_t)attribute->min_int();
- int8_t max = (int8_t)attribute->max_int();
+ std::vector<int32_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data);
+ TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data);
+ int8_t min = (int8_t)min_int_data[0];
+ int8_t max = (int8_t)max_int_data[0];
ERROR_IF(max < min, "OpClamp: max smaller than min");
this->fcn = [min, max](int8_t a) -> int8_t { return a <= min ? min : a >= max ? max : a; };
}
case TOSA_REF_TYPE_INT16: {
- int16_t min = (int16_t)attribute->min_int();
- int16_t max = (int16_t)attribute->max_int();
+ std::vector<int32_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data);
+ TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data);
+ int16_t min = (int16_t)min_int_data[0];
+ int16_t max = (int16_t)max_int_data[0];
ERROR_IF(max < min, "OpClamp: max smaller than min");
this->fcn = [min, max](int16_t a) -> int16_t { return a <= min ? min : a >= max ? max : a; };
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index b6ad704..4c17e78 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -176,17 +176,25 @@ int OpPad<Rank, Dtype>::eval()
case TOSA_REF_TYPE_BOOL:
case TOSA_REF_TYPE_INT8:
case TOSA_REF_TYPE_INT16:
- case TOSA_REF_TYPE_INT32:
- pad_value = (InEigenType)attribute->pad_const_int();
+ case TOSA_REF_TYPE_INT32: {
+ std::vector<int32_t> int32_data;
+ TosaSerializationHandler::ConvertU8toI32(attribute->pad_const(),
+ /* size = */ 1, int32_data);
+ pad_value = (InEigenType)int32_data[0];
break;
+ }
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_FP64:
case TOSA_REF_TYPE_FP8E4M3:
- case TOSA_REF_TYPE_FP8E5M2:
- pad_value = (InEigenType)attribute->pad_const_fp();
+ case TOSA_REF_TYPE_FP8E5M2: {
+ std::vector<float> float_data;
+ TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(),
+ /* size = */ 1, float_data);
+ pad_value = (InEigenType)float_data[0];
break;
+ }
default:
ASSERT_MSG(false, "TOSA_REF_TYPE %s is not supported.", EnumNameTOSAREFTYPE(Dtype));
break;
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 758e73e117c5cef17f8f0b1c543efc1df953b2f
+Subproject 0b6d7c271af1e6593e6a2cf14b32acea765f4b6
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 20572e8..a2ef5bf 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1813,13 +1813,7 @@ class TosaArgGen:
and "data_gen" in testGen.TOSA_OP_LIST[opName]
and gtu.dtypeIsSupportedByCompliance(dtype)
):
- if dtype in [
- DType.FP16,
- DType.FP32,
- DType.BF16,
- DType.FP8E4M3,
- DType.FP8E5M2,
- ]:
+ if gtu.dtypeIsFloat(dtype):
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
else:
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
@@ -2462,13 +2456,7 @@ class TosaArgGen:
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype in (
- DType.FP16,
- DType.BF16,
- DType.FP32,
- DType.FP8E4M3,
- DType.FP8E5M2,
- ):
+ elif gtu.dtypeIsFloat(dtype):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index e7704f1..3173906 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -3,6 +3,7 @@
import json
import logging
import os
+import struct
from copy import deepcopy
from datetime import datetime
from pathlib import Path
@@ -1428,13 +1429,17 @@ class TosaTestGen:
# Non-tensor fp16 ops take fp16 values as fp32 in reference_model
min_val = min_val.astype(np.float32)
max_val = max_val.astype(np.float32)
-
- attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
+ min_val_as_bytes = struct.pack("<f", min_val)
+ max_val_as_bytes = struct.pack("<f", max_val)
elif a.dtype in (DType.INT8, DType.INT16):
- attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
+ min_val_as_bytes = struct.pack("<i", min_val)
+ max_val_as_bytes = struct.pack("<i", max_val)
else:
# to avoid internal error for incorrect input types
- attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
+ min_val_as_bytes = struct.pack("<i", 0)
+ max_val_as_bytes = struct.pack("<i", 0)
+
+ attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -1578,9 +1583,14 @@ class TosaTestGen:
result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
- # write empty padding into PadAttribute to ensure inputs[1] is used
+ # get pad_const_val_as_bytes from either pad_const_float or pad_const_int
+ if gtu.dtypeIsFloat(a.dtype):
+ pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
+ else:
+ pad_const_val_as_bytes = struct.pack("<i", pad_const_int)
+
attr = ts.TosaSerializerAttribute()
- attr.PadAttribute(self.ser.builder, [], pad_const_int, pad_const_float)
+ attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)
# Invalidate Input/Output list for error if checks.
input_list = [a.name, pad_input.name]
@@ -2271,8 +2281,6 @@ class TosaTestGen:
attr.RescaleAttribute(
input_zp,
output_zp,
- [],
- [],
scale32,
double_round,
per_channel,
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 6558bf8..cfe7cc6 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -64,6 +64,11 @@ def dtypeWidth(dtype):
raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
+def dtypeIsFloat(dtype):
+ """Is floating point data type"""
+ return dtype in (DType.BF16, DType.FP16, DType.FP32, DType.FP8E4M3, DType.FP8E5M2)
+
+
def dtypeIsSupportedByCompliance(dtype):
"""Types supported by the new data generation and compliance flow."""
if isinstance(dtype, list) or isinstance(dtype, tuple):