aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKevin Cheng <kevin.cheng@arm.com>2021-04-28 16:29:44 -0700
committerKevin Cheng <kevin.cheng@arm.com>2021-04-28 17:37:03 -0700
commit989cb050228b47085189b1c5cb0d9b705e1060e7 (patch)
tree317bfa75a3d3b79341f2d0e3c994bcef834ff179
parent550ccc52de231621c0bf0c05ae2a398eec37ff51 (diff)
downloadreference_model-989cb050228b47085189b1c5cb0d9b705e1060e7.tar.gz
Support mixed-precision input tensors for TOSA unit test.
Bring CONV2D/DEPTHWISE_CONV2D/TRANSPOSE_CONV2D/FULLY_CONNECTED up running. Other minor fixes: - reference model should bail out if shape is invalid, along with "goto done" cleanup. - cleanup typos/duplicate in tosa_test_gen.py/tosa_serializer.py. - wrong input_zp/output_zp being generated for RESCALE. Signed-off-by: Kevin Cheng <kevin.cheng@arm.com> Change-Id: Ic1f3fe0090482bdee8a61508be7c738714191e19
-rw-r--r--reference_model/src/tensor.h45
-rw-r--r--verif/tosa_serializer.py3
-rw-r--r--verif/tosa_test_gen.py171
3 files changed, 113 insertions, 106 deletions
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index d39cc7c..6c0622e 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -33,9 +33,7 @@ class GraphNode;
class Tensor
{
public:
- Tensor(std::string tensorName_,
- DType tensorDtype__,
- std::vector<int> shape_);
+ Tensor(std::string tensorName_, DType tensorDtype__, std::vector<int> shape_);
virtual ~Tensor();
@@ -240,9 +238,7 @@ template <class T>
class TensorTemplate : public Tensor
{
public:
- TensorTemplate(std::string tensorName_,
- DType tensorDtype_,
- std::vector<int> shape_)
+ TensorTemplate(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_)
: Tensor(tensorName_, tensorDtype_, shape_)
{
tensor = nullptr;
@@ -606,11 +602,15 @@ int Tensor6<bool>::dumpTensor(FILE* out) const;
class TensorFactory
{
public:
- static Tensor* newTensor(std::string tensorName_,
- DType tensorDtype_,
- std::vector<int> shape_,
- const uint32_t rank)
+ static Tensor* newTensor(std::string tensorName_, DType tensorDtype_, std::vector<int> shape_, const uint32_t rank)
{
+ // Bail out if any dimension is invalid.
+ for (auto& dim : shape_)
+ {
+ if (dim <= 0)
+ goto done;
+ }
+
switch (tensorDtype_)
{
case DType_FLOAT:
@@ -630,9 +630,8 @@ public:
return new Tensor5<float>(tensorName_, tensorDtype_, shape_);
case 6:
return new Tensor6<float>(tensorName_, tensorDtype_, shape_);
- default:
- goto done;
}
+ break;
case DType_INT32:
case DType_UINT8:
case DType_INT4:
@@ -654,9 +653,8 @@ public:
return new Tensor5<int32_t>(tensorName_, tensorDtype_, shape_);
case 6:
return new Tensor6<int32_t>(tensorName_, tensorDtype_, shape_);
- default:
- goto done;
}
+ break;
case DType_INT48:
switch (rank)
{
@@ -674,9 +672,8 @@ public:
return new Tensor5<int64_t>(tensorName_, tensorDtype_, shape_);
case 6:
return new Tensor6<int64_t>(tensorName_, tensorDtype_, shape_);
- default:
- goto done;
}
+ break;
case DType_BOOL:
switch (rank)
{
@@ -694,16 +691,22 @@ public:
return new Tensor5<bool>(tensorName_, tensorDtype_, shape_);
case 6:
return new Tensor6<bool>(tensorName_, tensorDtype_, shape_);
- default:
- goto done;
}
+ break;
default:
- goto done;
+ break;
}
done:
- FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d", tensorName_.c_str(), EnumNamesDType()[tensorDtype_],
- rank);
+ std::string shape_str("[");
+ for (auto& dim : shape_)
+ {
+ shape_str += (std::to_string(dim) + ", ");
+ }
+ shape_str.append("]");
+
+ FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d, shape=%s", tensorName_.c_str(),
+ EnumNamesDType()[tensorDtype_], rank, shape_str.c_str());
}
static Tensor* newTensor(DType type, const std::vector<int> shape);
diff --git a/verif/tosa_serializer.py b/verif/tosa_serializer.py
index fa1fdcb..726ffc4 100644
--- a/verif/tosa_serializer.py
+++ b/verif/tosa_serializer.py
@@ -616,10 +616,7 @@ class TosaSerializer:
)
def setExpectedFailure(self, desc="", val=True):
- self.expectedFailure = val
- self.expectedFailureDesc = desc
- def setExpectedFailure(self, desc="", val=True):
self.expectedFailure = val
self.expectedFailureDesc = desc
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index b059ef5..134f569 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -259,7 +259,10 @@ class TosaTensorGen:
# The filter dimensions are OHWI
filter_shape = np.asarray([ofm_depth, filter_hw[0], filter_hw[1], ifm_shape[3]])
- return [ifm_shape, filter_shape]
+ # The bias is OC
+ bias_shape = np.asarray([ofm_depth])
+
+ return [ifm_shape, filter_shape, bias_shape]
@staticmethod
def tgDepthwiseConv2D(testGen, op, rank):
@@ -298,7 +301,6 @@ class TosaTensorGen:
pl, const = op["operands"]
assert rank == 2
- assert pl == 2 and const == 0
input_shape = testGen.makeShape(rank)
filter_oc = testGen.makeShape(1)[0]
@@ -914,21 +916,25 @@ class TosaTestGen:
else:
raise Exception("Unrecognized Dtype: {}".format(dtype))
- def buildPlaceholderTensors(self, shape_list, dtype):
+ def buildPlaceholderTensors(self, shape_list, dtype_list):
placeholders = []
- for shape in shape_list:
- arr = self.getRandTensor(shape, dtype)
- placeholders.append(self.ser.addPlaceholder(shape, dtype, arr))
+ assert len(shape_list) == len(dtype_list)
+
+ for idx, shape in enumerate(shape_list):
+ arr = self.getRandTensor(shape, dtype_list[idx])
+ placeholders.append(self.ser.addPlaceholder(shape, dtype_list[idx], arr))
return placeholders
- def buildConstTensors(self, shape_list, dtype):
+ def buildConstTensors(self, shape_list, dtype_list):
consts = []
- for shape in shape_list:
- arr = self.getRandTensor(shape, dtype)
- consts.append(self.ser.addConst(shape, dtype, arr))
+ assert len(shape_list) == len(dtype_list)
+
+ for idx, shape in enumerate(shape_list):
+ arr = self.getRandTensor(shape, dtype_list[idx])
+ consts.append(self.ser.addConst(shape, dtype_list[idx], arr))
return consts
@@ -981,24 +987,28 @@ class TosaTestGen:
return "x".join(sStr)
def typeStr(self, t):
- if t == DType.BOOL:
- return "b"
- elif t == DType.INT4:
- return "i4"
- elif t == DType.INT8:
- return "i8"
- elif t == DType.UINT8:
- return "u8"
- elif t == DType.INT16:
- return "i16"
- elif t == DType.INT32:
- return "i32"
- elif t == DType.INT48:
- return "i48"
- elif t == DType.FLOAT:
- return "float"
+ if isinstance(t, list):
+ assert len(t) >= 2
+ return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
else:
- raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
+ if t == DType.BOOL:
+ return "b"
+ elif t == DType.INT4:
+ return "i4"
+ elif t == DType.INT8:
+ return "i8"
+ elif t == DType.UINT8:
+ return "u8"
+ elif t == DType.INT16:
+ return "i16"
+ elif t == DType.INT32:
+ return "i32"
+ elif t == DType.INT48:
+ return "i48"
+ elif t == DType.FLOAT:
+ return "float"
+ else:
+ raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
def typeWidth(self, t):
""" Get the datatype width for integer types"""
@@ -1075,7 +1085,7 @@ class TosaTestGen:
# Replace the cond tensor with a boolean tensor since it probably
# has the wrong dtype
- t = self.buildPlaceholderTensors([cond.shape], DType.BOOL)
+ t = self.buildPlaceholderTensors([cond.shape], [DType.BOOL])
cond = t[0]
result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
@@ -1121,7 +1131,7 @@ class TosaTestGen:
return result_tens
def build_transpose_conv2d(
- self, op, ifm, filter, stride, outpad, dilation, output_shape, qinfo
+ self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
):
assert len(outpad) == 2
result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
@@ -1129,24 +1139,8 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.TransposeConv2DAttribute(outpad, stride, dilation, output_shape)
- # Create bias here since the acc_t depends on (but isn't the same as) the input dtype
- # The bias is OC
- if ifm.dtype == DType.INT8:
- bias_type = DType.INT32
- elif ifm.dtype == DType.INT16:
- bias_type = DType.INT48
- elif ifm.dtype == DType.FLOAT:
- bias_type = DType.FLOAT
- else:
- raise Exception(
- "Unsupported dtype for transpose_conv2d: {}".format(ifm.dtype)
- )
-
- bias_arr = self.getRandTensor([filter.shape[0]], bias_type)
- bias_tens = self.ser.addConst([filter.shape[0]], bias_type, bias_arr)
-
self.ser.addOperator(
- op, [ifm.name, filter.name, bias_tens.name], [result_tens.name], attr, qinfo
+ op, [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
)
return result_tens
@@ -1414,13 +1408,13 @@ class TosaTestGen:
out_type_width = self.typeWidth(out_dtype)
if val.dtype == DType.INT8:
- input_zp = self.randInt()
+ input_zp = self.randInt(-128, 127)
in_type_width = in_type_width + 1
else:
input_zp = 0
if out_dtype == DType.INT8:
- output_zp = self.randInt()
+ output_zp = self.randInt(-128, 127)
out_type_width = out_type_width + 1
else:
output_zp = 0
@@ -1661,7 +1655,7 @@ class TosaTestGen:
return testList
- def serializeTest(self, opName, testStr, dtype, shapeList, testArgs):
+ def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
try:
op = self.TOSA_OP_LIST[opName]
except KeyError as e:
@@ -1672,6 +1666,23 @@ class TosaTestGen:
build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+
+ if isinstance(dtype_or_dtypeList, list):
+ dtypeList = dtype_or_dtypeList
+ else:
+ dtypeList = [dtype_or_dtypeList] * (num_operands)
+
+ assert (
+ len(shapeList) == num_operands
+ ), "shapeList length {} must match number of operands {}".format(
+ len(shapeList), num_operands
+ )
+ assert (
+ len(dtypeList) == num_operands
+ ), "dtypeList length {} must match number of operands {}".format(
+ len(dtypeList), num_operands
+ )
try:
qgen = op["qgen"]
@@ -1690,25 +1701,27 @@ class TosaTestGen:
placeholders = []
for idx, shape in enumerate(shapeList[:]):
if idx == 1:
- if dtype == DType.INT8:
+ if dtypeList[idx] == DType.INT8:
arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
- elif dtype == DType.INT16:
+ elif dtypeList[idx] == DType.INT16:
arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
- elif dtype == DType.INT32:
+ elif dtypeList[idx] == DType.INT32:
arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
else:
raise Exception("OpArithmeticRightShift: invalid input dtype")
else:
- arr = self.getRandTensor(shapeList[0], dtype)
- placeholders.append(self.ser.addPlaceholder(shape, dtype, arr))
+ arr = self.getRandTensor(shapeList[0], dtypeList[idx])
+ placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
tens.extend(placeholders)
else:
- tens.extend(self.buildPlaceholderTensors(shapeList[0:pCount], dtype))
- tens.extend(self.buildConstTensors(shapeList[pCount:], dtype))
+ tens.extend(
+ self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])
+ )
+ tens.extend(self.buildConstTensors(shapeList[pCount:], dtypeList[pCount:]))
if qgen is not None:
- qinfo = qgen(self, op, dtype)
+ qinfo = qgen(self, op, dtypeList[0])
else:
qinfo = None
@@ -1829,6 +1842,12 @@ class TosaTestGen:
TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FLOAT]
+ TYPE_CONV2D = [
+ [DType.INT8, DType.INT8, DType.INT32],
+ [DType.INT16, DType.INT8, DType.INT48],
+ DType.FLOAT,
+ ]
+
DEFAULT_RANK_RANGE = (1, 4)
TOSA_OP_LIST = {
@@ -1949,7 +1968,7 @@ class TosaTestGen:
"rank": (4, 4),
"build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
"qgen": TosaQuantGen.qgConv,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_CONV2D,
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
@@ -1964,13 +1983,13 @@ class TosaTestGen:
TosaArgGen.agConv2D,
),
"qgen": TosaQuantGen.qgConv,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_CONV2D,
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
"transpose_conv2d_TEMPLATE": {
"op": Op.TRANSPOSE_CONV2D,
- "operands": (1, 1),
+ "operands": (1, 2),
"rank": (4, 4),
"build_fcn": (
build_transpose_conv2d,
@@ -1978,16 +1997,16 @@ class TosaTestGen:
TosaArgGen.agTransposeConv2D,
),
"qgen": TosaQuantGen.qgConv,
- "types": TYPE_FP,
+ "types": TYPE_CONV2D,
"template": True,
},
"fully_connected": {
"op": Op.FULLY_CONNECTED,
- "operands": (2, 0),
+ "operands": (1, 2),
"rank": (2, 2),
"build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
"qgen": TosaQuantGen.qgConv,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_CONV2D,
},
"matmul": {
"op": Op.MATMUL,
@@ -2453,9 +2472,6 @@ class OutputShaper:
else:
raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
- if ifm.dtype == DType.INT16:
- ser.setExpectedFailure(True, "INT16 support is in progress")
-
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
@@ -2496,9 +2512,6 @@ class OutputShaper:
else:
raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
- if ifm.dtype == DType.INT16:
- ser.setExpectedFailure(True, "INT16 support is in progress")
-
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
@@ -2533,9 +2546,6 @@ class OutputShaper:
else:
raise Exception("Unsupported input dtype: {}".format(input.dtype))
- if input.dtype == DType.INT16:
- ser.setExpectedFailure(True, "INT16 support is in progress")
-
return ser.addOutput(output_shape, out_dtype)
@staticmethod
@@ -2681,12 +2691,12 @@ class OutputShaper:
ser.setExpectedFailure(True, "Invalid output data type")
elif input_dtype == DType.INT16:
if output_dtype != DType.INT48:
- ser.setexpectedfailure(true, "Invalid output data type")
+ ser.setExpectedFailure(true, "Invalid output data type")
elif input_dtype == DType.FLOAT:
if output_dtype != DType.FLOAT:
- ser.setexpectedfailure(true, "Invalid output data type")
+ ser.setExpectedFailure(true, "Invalid output data type")
else:
- ser.setexpectedfailure(true, "Invalid input data type")
+ ser.setExpectedFailure(true, "Invalid input data type")
elif mode == ResizeMode.NEAREST:
if input_dtype == DType.INT8:
@@ -2694,15 +2704,15 @@ class OutputShaper:
ser.setExpectedFailure(True, "Invalid output data type")
elif input_dtype == DType.INT16:
if output_dtype != DType.INT16:
- ser.setexpectedfailure(true, "Invalid output data type")
+ ser.setExpectedFailure(true, "Invalid output data type")
elif input_dtype == DType.FLOAT:
if output_dtype != DType.FLOAT:
- ser.setexpectedfailure(true, "Invalid output data type")
+ ser.setExpectedFailure(true, "Invalid output data type")
else:
- ser.setexpectedfailure(true, "Invalid input data type")
+ ser.setExpectedFailure(true, "Invalid input data type")
else:
- ser.setexpectedfailure(true, "Invalid resize mode")
+ ser.setExpectedFailure(true, "Invalid resize mode")
return ser.addOutput(output_dims, output_dtype)
@@ -2724,7 +2734,4 @@ class OutputShaper:
if output_shape[1] <= 0 or output_shape[2] <= 0:
ser.setExpectedFailure(True, "Negative output shape")
- if ifm.dtype == DType.INT16:
- ser.setExpectedFailure(True, "INT16 support is in progress")
-
return ser.addOutput(output_shape, out_dtype)