aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-18 17:22:21 +0100
committerEric Kunze <eric.kunze@arm.com>2023-11-02 23:22:09 +0000
commitd1a08ce27ef8d0f6cf77e1b864610aade06edc5c (patch)
tree777992f45d240361f898b1d21902c2a46c58235f /reference_model/src/generate
parentb0b9e33c3500bd8dc9b12ef012d4234b1245247a (diff)
downloadreference_model-d1a08ce27ef8d0f6cf77e1b864610aade06edc5c.tar.gz
Compliance mode testing for CONV2D
Added CONV2D data generation. Updated verify dot product check to latest specification. Updated test generator and python datagenerator library to create const files during test generation. Add support for compliance test sets to conformance test_select. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I5be3b761a1e3ef259c058e493877cd5a89d5778b
Diffstat (limited to 'reference_model/src/generate')
-rw-r--r--reference_model/src/generate/generate_dot_product.cc115
-rw-r--r--reference_model/src/generate/generate_dot_product_states.cc2
-rw-r--r--reference_model/src/generate/generate_utils.cc1
-rw-r--r--reference_model/src/generate/generate_utils.h2
4 files changed, 118 insertions, 2 deletions
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc
index cbfac4b..e6815ad 100644
--- a/reference_model/src/generate/generate_dot_product.cc
+++ b/reference_model/src/generate/generate_dot_product.cc
@@ -76,6 +76,119 @@ bool generateMatMul(const TosaReference::GenerateConfig& cfg,
return true;
}
+//---------------------------------------------------------------------------//
+// Conv2D //
+//---------------------------------------------------------------------------//
+
+bool generateConv2DInput(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.dotProductInfo.kernel.size() != 2 || cfg.dotProductInfo.kernel[0] <= 0 || cfg.dotProductInfo.kernel[1] <= 0)
+ {
+ WARNING("[Generator][DP][Conv2D][Input] Missing or incorrect kernel size information.");
+ return false;
+ }
+ if (cfg.shape.size() != 4)
+ {
+ WARNING("[Generator][DP][Conv2D][Input] Tensor shape expected 4 dimensions.");
+ return false;
+ }
+
+ float* input = reinterpret_cast<float*>(data);
+ const int64_t T = TosaReference::numElementsFromShape(cfg.shape);
+ const uint32_t IH = cfg.shape[1];
+ const uint32_t IW = cfg.shape[2];
+ const uint32_t IC = cfg.shape[3];
+ const uint32_t KH = cfg.dotProductInfo.kernel[0];
+ const uint32_t KW = cfg.dotProductInfo.kernel[1];
+
+ for (int64_t t = 0; t < T; ++t)
+ {
+ uint32_t ic = t % IC;
+ uint32_t ix = (t / IC) % IW;
+ uint32_t iy = ((t / IC) / IW) % IH;
+ uint32_t k = ((iy % KH) * KW + (ix % KW)) * IC + ic;
+
+ input[t] = generator(k);
+ }
+ return true;
+}
+
+bool generateConv2DWeight(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.shape.size() != 4)
+ {
+ WARNING("[Generator][DP][Conv2D][Weight] Tensor shape expected 4 dimensions.");
+ return false;
+ }
+
+ float* weight = reinterpret_cast<float*>(data);
+ const int64_t T = TosaReference::numElementsFromShape(cfg.shape);
+ const uint32_t KH = cfg.shape[1];
+ const uint32_t KW = cfg.shape[2];
+ const uint32_t IC = cfg.shape[3];
+
+ for (int64_t t = 0; t < T; ++t)
+ {
+ uint32_t ic = t % IC;
+ uint32_t kx = (t / IC) % KW;
+ uint32_t ky = ((t / IC) / KW) % KH;
+ uint32_t k = (ky + KW * kx) * IC + ic;
+
+ weight[t] = generator(k);
+ }
+ return true;
+}
+
+bool generateConv2DBias(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.shape.size() != 1)
+ {
+ WARNING("[Generator][DP][Conv2D][Bias] Tensor shape expected 1 dimension.");
+ return false;
+ }
+
+ float* bias = reinterpret_cast<float*>(data);
+ const uint32_t T = cfg.shape[0];
+
+ for (uint32_t t = 0; t < T; ++t)
+ {
+ bias[t] = generator(2);
+ }
+ return true;
+}
+
+bool generateConv2D(const TosaReference::GenerateConfig& cfg,
+ TosaReference::IDotProductGenerator& generator,
+ void* data,
+ size_t size)
+{
+ if (cfg.dataType != DType::DType_FP32)
+ {
+ WARNING("[Generator][DP][Conv2D] Only supports FP32.");
+ return false;
+ }
+ switch (cfg.inputPos)
+ {
+ case 0:
+ return generateConv2DInput(cfg, generator, data, size);
+ case 1:
+ return generateConv2DWeight(cfg, generator, data, size);
+ case 2:
+ return generateConv2DBias(cfg, generator, data, size);
+ default:
+ WARNING("[Generator][DP][Conv2D] Invalid input tensor slot position to operator.");
+ return false;
+ }
+}
} // namespace
namespace TosaReference
@@ -95,6 +208,8 @@ bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size)
{
case tosa::Op_MATMUL:
return generateMatMul(cfg, *generator, data, size);
+ case tosa::Op_CONV2D:
+ return generateConv2D(cfg, *generator, data, size);
default:
WARNING("[Generator][DP] Unsupported operator.");
return false;
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc
index 649e55e..53bef3a 100644
--- a/reference_model/src/generate/generate_dot_product_states.cc
+++ b/reference_model/src/generate/generate_dot_product_states.cc
@@ -242,7 +242,7 @@ public:
if (_p != P2)
return (_B / std::sqrt(_KS + 1)) * s;
else
- return (_B * _B / (_KS + 1)) * s;
+ return 0.f;
}
private:
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index bcbf9d7..d3bb076 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -41,6 +41,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op,
{ Op::Op_MATMUL, "MATMUL" },
{ Op::Op_MAX_POOL2D, "MAX_POOL2D" },
{ Op::Op_PAD, "PAD" },
+ { Op::Op_CONV2D, "CONV2D" },
})
} // namespace tosa
diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h
index 0239e98..7c55f1d 100644
--- a/reference_model/src/generate/generate_utils.h
+++ b/reference_model/src/generate/generate_utils.h
@@ -52,7 +52,7 @@ struct DotProductInfo
int32_t ks;
DType accType;
int32_t axis;
- std::array<int32_t, 2> kernel;
+ std::vector<int32_t> kernel;
};
/// \brief Pseudo random generator meta-data