From 7a2f6d849ade2b96484cde1fdeb1a1a0d9ffa654 Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Fri, 7 Jan 2022 17:00:03 -0600
Subject: [PATCH 1/9] Remove intermediate casting in BatchNormalization

---
 .../vivado/nnet_utils/nnet_batchnorm.h        | 10 ++---
 .../vivado/nnet_utils/nnet_batchnorm_stream.h |  4 +-
 .../templates/vivado/nnet_utils/nnet_mult.h   | 27 ++++++++++++
 hls4ml/templates/vivado_template.py           |  4 +-
 test/pytest/test_batchnorm.py                 | 44 +++++++++++++++++++
 5 files changed, 80 insertions(+), 9 deletions(-)
 create mode 100644 test/pytest/test_batchnorm.py

diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h
index 9a5cff0d3..ce1cfb315 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h
@@ -43,8 +43,8 @@ struct batchnorm_config
     static const bool store_weights_in_bram = false;
     static const unsigned n_zeros = 0;
     // partitioning arrays cyclically to go with roll factors?
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::mult<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::mult_nocast<x_T, y_T>;
 };
 
 template<class data_T, class res_T, typename CONFIG_T>
@@ -71,7 +71,7 @@ void normalize(
         #pragma HLS ARRAY_PARTITION variable=bias complete
 
         int multiplier_limit  = ceil(float(CONFIG_T::n_in) / float(CONFIG_T::reuse_factor));
-        CONFIG_T::template product<data_T, typename CONFIG_T::scale_t, res_T>::limit(multiplier_limit);
+        CONFIG_T::template product<data_T, typename CONFIG_T::scale_t>::limit(multiplier_limit);
 
     } else if (CONFIG_T::io_type == io_serial) {
         #pragma HLS ARRAY_RESHAPE variable=scale complete dim=1
@@ -87,10 +87,10 @@ void normalize(
         }
         
         if (CONFIG_T::n_filt==-1) {
-            res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t, res_T>::product(data[ires], scale[ires]) + bias[ires];
+            res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t>::product(data[ires], scale[ires]) + bias[ires];
 	    } else {
             int norm_index = ires%CONFIG_T::n_filt;
-            res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t, res_T>::product(data[ires], scale[norm_index]) + bias[norm_index];
+            res[ires] = CONFIG_T::template product<data_T, typename CONFIG_T::scale_t>::product(data[ires], scale[norm_index]) + bias[norm_index];
         }
 	}
 }
diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h
index 382887fed..826bdafe9 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm_stream.h
@@ -43,7 +43,7 @@ void normalize(
 
     constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor);
     constexpr unsigned ii = CONFIG_T::n_in / multiplier_limit;
-    CONFIG_T::template product<typename data_T::value_type, typename CONFIG_T::scale_t, typename res_T::value_type>::limit(multiplier_limit);
+    CONFIG_T::template product<typename data_T::value_type, typename CONFIG_T::scale_t>::limit(multiplier_limit);
 
     BatchNormLoop: for (int i = 0; i < CONFIG_T::n_in / data_T::size; i++) {
         #pragma HLS PIPELINE II=ii
@@ -60,7 +60,7 @@ void normalize(
             } else {
                 norm_index = j % CONFIG_T::n_filt;
             }
-            out_data[j] = CONFIG_T::template product<typename data_T::value_type, typename CONFIG_T::scale_t, typename res_T::value_type>::product(in_data[j], scale[norm_index]) + bias[norm_index];
+            out_data[j] = CONFIG_T::template product<typename data_T::value_type, typename CONFIG_T::scale_t>::product(in_data[j], scale[norm_index]) + bias[norm_index];
         }
 
         res.write(out_data);
diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
index 3a597f038..29523862c 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
@@ -26,6 +26,18 @@ class Product{
     static void limit(unsigned multiplier_limit) {} // Nothing to do here
 };
 
+template<class x_T, class w_T>
+class Product_nocast{
+    public:
+    static auto product(x_T a, w_T w) -> decltype(a*w)
+    {
+        // 'Normal' product
+        #pragma HLS INLINE
+        return a * w;
+    }
+    static void limit(unsigned multiplier_limit) {} // Nothing to do here
+};
+
 template<class x_T, class w_T, class y_T>
 class both_binary : public Product<x_T, w_T, y_T>{
     public:
@@ -82,6 +94,21 @@ class mult : public Product<x_T, w_T, y_T>{
     }
 };
 
+template<class x_T, class w_T>
+class mult_nocast : public Product_nocast<x_T, w_T>{
+    public:
+    static auto product(x_T a, w_T w) -> decltype(a*w)
+    {
+        // 'Normal' product
+        #pragma HLS INLINE
+        return a * w;
+    }
+    static void limit(unsigned multiplier_limit){
+        #pragma HLS INLINE
+        #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation
+    }
+};
+
 template<class x_T, class w_T, class y_T>
 class weight_exponential : public Product<x_T, w_T, y_T>{
     public:
diff --git a/hls4ml/templates/vivado_template.py b/hls4ml/templates/vivado_template.py
index 149b52f1d..5811d7c12 100644
--- a/hls4ml/templates/vivado_template.py
+++ b/hls4ml/templates/vivado_template.py
@@ -31,8 +31,8 @@ batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{
     static const bool store_weights_in_bram = false;
     typedef {bias_t} bias_t;
     typedef {scale_t} scale_t;
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::{product_type}<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::{product_type}_nocast<x_T, y_T>;
 }};\n"""
 
 conv1d_config_template = """struct config{index} : nnet::conv1d_config {{
diff --git a/test/pytest/test_batchnorm.py b/test/pytest/test_batchnorm.py
new file mode 100644
index 000000000..25744e7f6
--- /dev/null
+++ b/test/pytest/test_batchnorm.py
@@ -0,0 +1,44 @@
+import pytest
+from tensorflow.keras.models import Sequential
+from tensorflow.keras.layers import BatchNormalization
+import numpy as np
+import hls4ml
+
+
+in_shape = 16
+atol = 5e-3
+
+@pytest.fixture(scope='module')
+def data():
+    np.random.seed(0)
+    X = np.random.rand(100, in_shape)
+    return X
+
+
+@pytest.fixture(scope='module')
+def model():
+    model = Sequential()
+    model.add(BatchNormalization(input_shape=(in_shape,)))
+    model.compile()
+    return model
+
+  
+@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
+def test_global_pool1d(model, data, io_type):
+
+    config = hls4ml.utils.config_from_keras_model(model, 
+                                                  default_precision='ap_fixed<32,1>',
+                                                  granularity='name')
+
+    hls_model = hls4ml.converters.convert_from_keras_model(model,
+                                                           hls_config=config,
+                                                           io_type=io_type,
+                                                           output_dir=f'hls4mlprj_batchnorm_{io_type}',
+                                                           part='xcvu9p-flgb2104-2-i')
+    hls_model.compile()
+    
+
+    # Predict
+    y_keras = np.squeeze(model.predict(data))
+    y_hls = hls_model.predict(data)
+    np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True)
-- 
GitLab


From ffd4bfc060b68fb57a55ee4dc614e9ef6395041f Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Tue, 11 Jan 2022 18:13:39 -0600
Subject: [PATCH 2/9] Add _nocast options for other calculations

---
 .../templates/vivado/nnet_utils/nnet_mult.h   | 128 ++++++++++++++++--
 1 file changed, 119 insertions(+), 9 deletions(-)

diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
index 29523862c..a590511cd 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
@@ -5,6 +5,7 @@
 #include "nnet_helpers.h"
 #include "hls_stream.h"
 #include <math.h>
+#include <iostream>
 
 namespace nnet {
 
@@ -12,7 +13,7 @@ namespace product{
 
 /* ---
  * 5 different methods to perform the product of input and weight, depending on the
- * types of each. 
+ * types of each.
  * --- */
 
 template<class x_T, class w_T, class y_T>
@@ -26,15 +27,8 @@ class Product{
     static void limit(unsigned multiplier_limit) {} // Nothing to do here
 };
 
-template<class x_T, class w_T>
 class Product_nocast{
     public:
-    static auto product(x_T a, w_T w) -> decltype(a*w)
-    {
-        // 'Normal' product
-        #pragma HLS INLINE
-        return a * w;
-    }
     static void limit(unsigned multiplier_limit) {} // Nothing to do here
 };
 
@@ -48,6 +42,16 @@ class both_binary : public Product<x_T, w_T, y_T>{
     }
 };
 
+template<class x_T, class w_T>
+class both_binary_nocast : public Product_nocast{
+    public:
+    static x_T product(x_T a, w_T w){
+        // specialisation for 1-bit weights and incoming data
+        #pragma HLS INLINE
+        return a == w;
+    }
+};
+
 template<class x_T, class w_T, class y_T>
 class weight_binary : public Product<x_T, w_T, y_T>{
     public:
@@ -58,6 +62,16 @@ class weight_binary : public Product<x_T, w_T, y_T>{
     }
 };
 
+template<class x_T, class w_T>
+class weight_binary_nocast : public Product_nocast{
+    public:
+    static x_T product(x_T a, w_T w){
+        // Specialisation for 1-bit weights, arbitrary data
+        #pragma HLS INLINE
+        return w == 0 ? (x_T) -a : a;
+    }
+};
+
 template<class x_T, class w_T, class y_T>
 class data_binary : public Product<x_T, w_T, y_T>{
     public:
@@ -68,6 +82,16 @@ class data_binary : public Product<x_T, w_T, y_T>{
     }
 };
 
+template<class x_T, class w_T>
+class data_binary_nocast : public Product_nocast{
+    public:
+    static w_T product(x_T a, w_T w){
+        // Specialisation for 1-bit data, arbitrary weight
+        #pragma HLS INLINE
+        return a == 0 ? (w_T) -w : w;
+    }
+};
+
 template<class x_T, class w_T, class y_T>
 class weight_ternary : public Product<x_T, w_T, y_T>{
     public:
@@ -80,6 +104,18 @@ class weight_ternary : public Product<x_T, w_T, y_T>{
     }
 };
 
+template<class x_T, class w_T>
+class weight_ternary_nocast : public Product_nocast{
+    public:
+    static x_T product(x_T a, w_T w){
+        // Specialisation for 2-bit weights, arbitrary data
+        #pragma HLS INLINE
+        if (w == 0) return (x_T) 0;
+        else if(w == -1) return (x_T) -a;
+        else return (x_T) a; // if(w == 1)
+    }
+};
+
 template<class x_T, class w_T, class y_T>
 class mult : public Product<x_T, w_T, y_T>{
     public:
@@ -95,7 +131,7 @@ class mult : public Product<x_T, w_T, y_T>{
 };
 
 template<class x_T, class w_T>
-class mult_nocast : public Product_nocast<x_T, w_T>{
+class mult_nocast : public Product_nocast{
     public:
     static auto product(x_T a, w_T w) -> decltype(a*w)
     {
@@ -122,6 +158,80 @@ class weight_exponential : public Product<x_T, w_T, y_T>{
     }
 };
 
+template<class x_T, class w_T>
+class weight_exponential_nocast : public Product_nocast{
+    public:
+    using rt = x_T;
+    static rt product(x_T a, w_T w){
+        std::cerr << "Should not match to this function" << std::endl;
+        // Shift product for exponential weights
+        #pragma HLS INLINE
+        // shift by the exponent. Negative weights shift right
+        rt y = static_cast<rt>(a) << w.weight;
+        // negate or not depending on weight sign
+        return w.sign == 1 ? y : static_cast<rt>(-y);
+    }
+};
+
+template<class w_T, int _AP_W>
+class weight_exponential_nocast<ap_int<_AP_W>, w_T> : public Product_nocast{
+    public:
+    using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_W + decltype(w_T::weight)::width>;
+    static rt product(ap_int<_AP_W> a, w_T w){
+        // Shift product for exponential weights
+        #pragma HLS INLINE
+        // shift by the exponent. Negative weights shift right
+        rt y = static_cast<rt>(a) << w.weight;
+        // negate or not depending on weight sign
+        return w.sign == 1 ? y : static_cast<rt>(-y);
+    }
+};
+
+template<class w_T, int _AP_W>
+class weight_exponential_nocast<ap_uint<_AP_W>, w_T> : public Product_nocast{
+    public:
+    using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_W + decltype(w_T::weight)::width + 1>;
+    static rt product(ap_uint<_AP_W> a, w_T w){
+        // Shift product for exponential weights
+        #pragma HLS INLINE
+        // shift by the exponent. Negative weights shift right
+        rt y = static_cast<rt>(a) << w.weight;
+        // negate or not depending on weight sign
+        return w.sign == 1 ? y : static_cast<rt>(-y);
+    }
+};
+
+template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N>
+class weight_exponential_nocast<ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product_nocast{
+    public:
+    using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_I + decltype(w_T::weight)::width,
+                        _AP_Q, _AP_O, _AP_N>;
+    static rt product(ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N> a, w_T w){
+        // Shift product for exponential weights
+        #pragma HLS INLINE
+        // shift by the exponent. Negative weights shift right
+        rt y = static_cast<rt>(a) << w.weight;
+        // negate or not depending on weight sign
+        return w.sign == 1 ? y : static_cast<rt>(-y);
+    }
+};
+
+template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N>
+class weight_exponential_nocast<ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product_nocast{
+    public:
+    using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_I + decltype(w_T::weight)::width + 1,
+                        _AP_Q, _AP_O, _AP_N>;
+    static rt product(ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N> a, w_T w){
+        // Shift product for exponential weights
+        #pragma HLS INLINE
+        // shift by the exponent. Negative weights shift right
+        // shift by the exponent. Negative weights shift right
+        rt y = static_cast<rt>(a) << w.weight;
+        // negate or not depending on weight sign
+        return w.sign == 1 ? y : static_cast<rt>(-y);
+    }
+};
+
 } // namespace product_type
 
 template<class data_T, class res_T, typename CONFIG_T>
-- 
GitLab


From c2d15c6335c8e9d47d4cdc0d47ede856c4372593 Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Mon, 17 Jan 2022 15:31:44 -0600
Subject: [PATCH 3/9] Make all the nnet::product not cast the results

---
 .../vivado/nnet_utils/nnet_batchnorm.h        |   2 +-
 .../templates/vivado/nnet_utils/nnet_dense.h  |   4 +-
 .../vivado/nnet_utils/nnet_dense_compressed.h |   2 +-
 .../vivado/nnet_utils/nnet_dense_latency.h    |   6 +-
 .../vivado/nnet_utils/nnet_dense_resource.h   |   6 +-
 .../templates/vivado/nnet_utils/nnet_merge.h  |   8 +-
 .../templates/vivado/nnet_utils/nnet_mult.h   | 102 ++----------------
 hls4ml/templates/vivado_template.py           |  14 +--
 8 files changed, 32 insertions(+), 112 deletions(-)

diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h
index ce1cfb315..edc6ff320 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_batchnorm.h
@@ -44,7 +44,7 @@ struct batchnorm_config
     static const unsigned n_zeros = 0;
     // partitioning arrays cyclically to go with roll factors?
     template<class x_T, class y_T>
-    using product = nnet::product::mult_nocast<x_T, y_T>;
+    using product = nnet::product::mult<x_T, y_T>;
 };
 
 template<class data_T, class res_T, typename CONFIG_T>
diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h
index deb1c042d..c9785335a 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_dense.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense.h
@@ -30,8 +30,8 @@ struct dense_config
     static const unsigned n_zeros = 0;
     // partitioning arrays cyclically to go with roll factors?
     // Product function to use
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::mult<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::mult<x_T, y_T>;
 };
 
 template<class data_T, class res_T, typename CONFIG_T>
diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h
index adfaa0e1b..dc803ff2b 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_compressed.h
@@ -86,7 +86,7 @@ void dense_compressed(
             auto weight_cache = weights[w].weight;
             data_T  data_cache = data[row];
             //mult[col] += weight_cache * data_cache;
-            typename CONFIG_T::accum_t prod = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::product(data_cache, weight_cache);
+            typename CONFIG_T::accum_t prod = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data_cache, weight_cache);
             fill_mult<CONFIG_T>(col, mult, prod);
         }
 
diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h
index 4a04671fd..2bbab0496 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_latency.h
@@ -54,7 +54,7 @@ void dense_latency(
         #pragma HLS ARRAY_PARTITION variable=acc complete
 
         int multiplier_limit  = ceil(float(CONFIG_T::n_in*CONFIG_T::n_out) / float(CONFIG_T::reuse_factor)) - floor(float(CONFIG_T::n_zeros) / float(CONFIG_T::reuse_factor));
-        CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::limit(multiplier_limit);
+        CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::limit(multiplier_limit);
 
     } else if (CONFIG_T::io_type == io_serial){
         // Only reduce cycle_factor if n_out is evenly divisible by reuse_factor
@@ -90,10 +90,10 @@ void dense_latency(
         Product2: for(int jj = 0; jj < CONFIG_T::n_out; jj++) {
             if (CONFIG_T::io_type == io_serial) {
                 int multiplier_limit  = ceil(float(CONFIG_T::n_out) / float(CONFIG_T::reuse_factor));
-                CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::limit(multiplier_limit);
+                CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::limit(multiplier_limit);
             }
         int index = ii*CONFIG_T::n_out+jj;
-        mult[index] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::product(cache, weights[index]);
+        mult[index] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(cache, weights[index]);
         }
     }
 
diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h
index 756a62743..c77059aa7 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h
@@ -73,7 +73,7 @@ void dense_resource_rf_leq_nin(
         for (int im = 0; im < block_factor; im++) {
             #pragma HLS UNROLL
 
-            acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::product(data[in_index], weights[w_index]);
+            acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]);
 
             // Increment w_index
             w_index += rufactor;
@@ -157,7 +157,7 @@ void dense_resource_rf_gt_nin_rem0(
         MultLoop:
         for (int im = 0; im < block_factor; im++) {
             #pragma HLS UNROLL
-            acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::product(data[in_index], weights[w_index]);
+            acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]);
 
             w_index += rufactor;
             if (w_index >= CONFIG_T::n_in * CONFIG_T::n_out) break; // check out of bounds
@@ -223,7 +223,7 @@ void dense_resource_rf_gt_nin(
             int w_index = ir + rufactor * im;
             int in_index = w_index % nin;
             if (w_index >= CONFIG_T::n_in*CONFIG_T::n_out) continue; // check out of bounds
-            tmpmult[im] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>::product(data[in_index], weights[w_index]);
+            tmpmult[im] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]);
         }
 
         typename CONFIG_T::accum_t mult[multiplier_limit];
diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_merge.h b/hls4ml/templates/vivado/nnet_utils/nnet_merge.h
index 48a5e172d..b103533e0 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_merge.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_merge.h
@@ -38,8 +38,8 @@ struct dot_config {
     static const unsigned reuse_factor = 1;
     typedef float accum_t;
     // Product function to use
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::mult<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::mult<x_T, y_T>;
 };
 
 struct concat_config {
@@ -129,7 +129,7 @@ void dot1d(
     #pragma HLS PIPELINE II=CONFIG_T::reuse_factor
 
     constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor);
-    CONFIG_T::template product<input1_T, input2_T, typename CONFIG_T::accum_t>::limit(multiplier_limit);
+    CONFIG_T::template product<input1_T, input2_T>::limit(multiplier_limit);
 
     typename CONFIG_T::accum_t mult[CONFIG_T::n_in];
     #pragma HLS ARRAY_PARTITION variable=mult complete
@@ -137,7 +137,7 @@ void dot1d(
 
     Product: for(int i_mult=0; i_mult < CONFIG_T::n_in; i_mult++) {
         #pragma HLS UNROLL
-        mult[i_mult] = CONFIG_T::template product<input1_T, input2_T, typename CONFIG_T::accum_t>::product(data1[i_mult], data2[i_mult]);
+        mult[i_mult] = CONFIG_T::template product<input1_T, input2_T>::product(data1[i_mult], data2[i_mult]);
     }
 
     Accum: for(int i_acc = 0; i_acc < CONFIG_T::n_in; i_acc++) {
diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
index a590511cd..5eb7103a7 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
@@ -12,38 +12,17 @@ namespace nnet {
 namespace product{
 
 /* ---
- * 5 different methods to perform the product of input and weight, depending on the
+ * different methods to perform the product of input and weight, depending on the
  * types of each.
  * --- */
 
-template<class x_T, class w_T, class y_T>
 class Product{
-    public:
-    static y_T product(x_T a, w_T w){
-        // 'Normal' product
-        #pragma HLS INLINE
-        return a * w;
-    }
-    static void limit(unsigned multiplier_limit) {} // Nothing to do here
-};
-
-class Product_nocast{
     public:
     static void limit(unsigned multiplier_limit) {} // Nothing to do here
 };
 
-template<class x_T, class w_T, class y_T>
-class both_binary : public Product<x_T, w_T, y_T>{
-    public:
-    static y_T product(x_T a, w_T w){
-        // specialisation for 1-bit weights and incoming data
-        #pragma HLS INLINE
-        return a == w;
-    }
-};
-
 template<class x_T, class w_T>
-class both_binary_nocast : public Product_nocast{
+class both_binary : public Product{
     public:
     static x_T product(x_T a, w_T w){
         // specialisation for 1-bit weights and incoming data
@@ -52,18 +31,8 @@ class both_binary_nocast : public Product_nocast{
     }
 };
 
-template<class x_T, class w_T, class y_T>
-class weight_binary : public Product<x_T, w_T, y_T>{
-    public:
-    static y_T product(x_T a, w_T w){
-        // Specialisation for 1-bit weights, arbitrary data
-        #pragma HLS INLINE
-        return w == 0 ? (x_T) -a : a;
-    }
-};
-
 template<class x_T, class w_T>
-class weight_binary_nocast : public Product_nocast{
+class weight_binary : public Product{
     public:
     static x_T product(x_T a, w_T w){
         // Specialisation for 1-bit weights, arbitrary data
@@ -72,18 +41,8 @@ class weight_binary_nocast : public Product_nocast{
     }
 };
 
-template<class x_T, class w_T, class y_T>
-class data_binary : public Product<x_T, w_T, y_T>{
-    public:
-    static y_T product(x_T a, w_T w){
-        // Specialisation for 1-bit data, arbitrary weight
-        #pragma HLS INLINE
-        return a == 0 ? (w_T) -w : w;
-    }
-};
-
 template<class x_T, class w_T>
-class data_binary_nocast : public Product_nocast{
+class data_binary : public Product{
     public:
     static w_T product(x_T a, w_T w){
         // Specialisation for 1-bit data, arbitrary weight
@@ -92,20 +51,8 @@ class data_binary_nocast : public Product_nocast{
     }
 };
 
-template<class x_T, class w_T, class y_T>
-class weight_ternary : public Product<x_T, w_T, y_T>{
-    public:
-    static y_T product(x_T a, w_T w){
-        // Specialisation for 2-bit weights, arbitrary data
-        #pragma HLS INLINE
-        if (w == 0) return (x_T) 0;
-        else if(w == -1) return (x_T) -a;
-        else return (x_T) a; // if(w == 1)
-    }
-};
-
 template<class x_T, class w_T>
-class weight_ternary_nocast : public Product_nocast{
+class weight_ternary : public Product{
     public:
     static x_T product(x_T a, w_T w){
         // Specialisation for 2-bit weights, arbitrary data
@@ -116,22 +63,8 @@ class weight_ternary_nocast : public Product_nocast{
     }
 };
 
-template<class x_T, class w_T, class y_T>
-class mult : public Product<x_T, w_T, y_T>{
-    public:
-    static y_T product(x_T a, w_T w){
-        // 'Normal' product
-        #pragma HLS INLINE
-        return a * w;
-    }
-    static void limit(unsigned multiplier_limit){
-        #pragma HLS INLINE
-        #pragma HLS ALLOCATION instances=mul limit=multiplier_limit operation
-    }
-};
-
 template<class x_T, class w_T>
-class mult_nocast : public Product_nocast{
+class mult : public Product{
     public:
     static auto product(x_T a, w_T w) -> decltype(a*w)
     {
@@ -145,21 +78,8 @@ class mult_nocast : public Product_nocast{
     }
 };
 
-template<class x_T, class w_T, class y_T>
-class weight_exponential : public Product<x_T, w_T, y_T>{
-    public:
-    static y_T product(x_T a, w_T w){
-        // Shift product for exponential weights
-        #pragma HLS INLINE
-        // shift by the exponent. Negative weights shift right
-        y_T y = a << w.weight;
-        // negate or not depending on weight sign
-        return w.sign == 1 ? (y_T) y : (y_T) -y;
-    }
-};
-
 template<class x_T, class w_T>
-class weight_exponential_nocast : public Product_nocast{
+class weight_exponential : public Product{
     public:
     using rt = x_T;
     static rt product(x_T a, w_T w){
@@ -174,7 +94,7 @@ class weight_exponential_nocast : public Product_nocast{
 };
 
 template<class w_T, int _AP_W>
-class weight_exponential_nocast<ap_int<_AP_W>, w_T> : public Product_nocast{
+class weight_exponential<ap_int<_AP_W>, w_T> : public Product{
     public:
     using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_W + decltype(w_T::weight)::width>;
     static rt product(ap_int<_AP_W> a, w_T w){
@@ -188,7 +108,7 @@ class weight_exponential_nocast<ap_int<_AP_W>, w_T> : public Product_nocast{
 };
 
 template<class w_T, int _AP_W>
-class weight_exponential_nocast<ap_uint<_AP_W>, w_T> : public Product_nocast{
+class weight_exponential<ap_uint<_AP_W>, w_T> : public Product{
     public:
     using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_W + decltype(w_T::weight)::width + 1>;
     static rt product(ap_uint<_AP_W> a, w_T w){
@@ -202,7 +122,7 @@ class weight_exponential_nocast<ap_uint<_AP_W>, w_T> : public Product_nocast{
 };
 
 template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N>
-class weight_exponential_nocast<ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product_nocast{
+class weight_exponential<ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product{
     public:
     using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_I + decltype(w_T::weight)::width,
                         _AP_Q, _AP_O, _AP_N>;
@@ -217,7 +137,7 @@ class weight_exponential_nocast<ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T>
 };
 
 template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N>
-class weight_exponential_nocast<ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product_nocast{
+class weight_exponential<ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product{
     public:
     using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_I + decltype(w_T::weight)::width + 1,
                         _AP_Q, _AP_O, _AP_N>;
diff --git a/hls4ml/templates/vivado_template.py b/hls4ml/templates/vivado_template.py
index 5811d7c12..5040fda0e 100644
--- a/hls4ml/templates/vivado_template.py
+++ b/hls4ml/templates/vivado_template.py
@@ -19,8 +19,8 @@ dense_config_template = """struct config{index} : nnet::dense_config {{
     typedef {bias_t} bias_t;
     typedef {weight_t} weight_t;
     typedef {index_t} index_t;
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::{product_type}<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::{product_type}<x_T, y_T>;
 }};\n"""
 
 batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{
@@ -32,7 +32,7 @@ batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{
     typedef {bias_t} bias_t;
     typedef {scale_t} scale_t;
     template<class x_T, class y_T>
-    using product = nnet::product::{product_type}_nocast<x_T, y_T>;
+    using product = nnet::product::{product_type}<x_T, y_T>;
 }};\n"""
 
 conv1d_config_template = """struct config{index} : nnet::conv1d_config {{
@@ -68,8 +68,8 @@ conv_mult_config_template = """struct config{index}_mult : nnet::dense_config {{
     typedef {accum_t} accum_t;
     typedef {bias_t} bias_t;
     typedef {weight_t} weight_t;
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::{product_type}<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::{product_type}<x_T, y_T>;
 }};\n"""
 
 conv2d_config_template = """struct config{index} : nnet::conv2d_config {{
@@ -214,8 +214,8 @@ dot_config_template = """struct config{index} : nnet::dot_config {{
     static const unsigned n_out = {n_out};
     static const unsigned reuse_factor = {reuse};
     typedef {accum_t} accum_t;
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::{product_type}<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::{product_type}<x_T, y_T>;
 }};\n"""
 
 concat_config_template = """struct config{index} : nnet::concat_config {{
-- 
GitLab


From ed94850c0cf7a8f596c5bb3aa3551291cf062673 Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Tue, 18 Jan 2022 11:06:45 -0600
Subject: [PATCH 4/9] update product based on Sioni's suggestions

---
 .../templates/vivado/nnet_utils/nnet_mult.h   | 92 ++++---------------
 1 file changed, 20 insertions(+), 72 deletions(-)

diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
index 5eb7103a7..586bc65ae 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_mult.h
@@ -34,32 +34,37 @@ class both_binary : public Product{
 template<class x_T, class w_T>
 class weight_binary : public Product{
     public:
-    static x_T product(x_T a, w_T w){
+    static auto product(x_T a, w_T w) -> decltype(-a)
+    {
         // Specialisation for 1-bit weights, arbitrary data
         #pragma HLS INLINE
-        return w == 0 ? (x_T) -a : a;
+        if (w == 0) return -a;
+        else return a;
     }
 };
 
 template<class x_T, class w_T>
 class data_binary : public Product{
     public:
-    static w_T product(x_T a, w_T w){
+    static auto product(x_T a, w_T w) -> decltype(-w)
+    {
         // Specialisation for 1-bit data, arbitrary weight
         #pragma HLS INLINE
-        return a == 0 ? (w_T) -w : w;
+        if (a == 0) return -w;
+        else return w;
     }
 };
 
 template<class x_T, class w_T>
 class weight_ternary : public Product{
     public:
-    static x_T product(x_T a, w_T w){
+    static auto product(x_T a, w_T w) -> decltype(-a)
+    {
         // Specialisation for 2-bit weights, arbitrary data
         #pragma HLS INLINE
-        if (w == 0) return (x_T) 0;
-        else if(w == -1) return (x_T) -a;
-        else return (x_T) a; // if(w == 1)
+        if (w == 0) return 0;
+        else if(w == -1) return -a;
+        else return a; // if(w == 1)
     }
 };
 
@@ -81,74 +86,17 @@ class mult : public Product{
 template<class x_T, class w_T>
 class weight_exponential : public Product{
     public:
-    using rt = x_T;
-    static rt product(x_T a, w_T w){
-        std::cerr << "Should not match to this function" << std::endl;
-        // Shift product for exponential weights
-        #pragma HLS INLINE
-        // shift by the exponent. Negative weights shift right
-        rt y = static_cast<rt>(a) << w.weight;
-        // negate or not depending on weight sign
-        return w.sign == 1 ? y : static_cast<rt>(-y);
-    }
-};
-
-template<class w_T, int _AP_W>
-class weight_exponential<ap_int<_AP_W>, w_T> : public Product{
-    public:
-    using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_W + decltype(w_T::weight)::width>;
-    static rt product(ap_int<_AP_W> a, w_T w){
-        // Shift product for exponential weights
-        #pragma HLS INLINE
-        // shift by the exponent. Negative weights shift right
-        rt y = static_cast<rt>(a) << w.weight;
-        // negate or not depending on weight sign
-        return w.sign == 1 ? y : static_cast<rt>(-y);
-    }
-};
-
-template<class w_T, int _AP_W>
-class weight_exponential<ap_uint<_AP_W>, w_T> : public Product{
-    public:
-    using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_W + decltype(w_T::weight)::width + 1>;
-    static rt product(ap_uint<_AP_W> a, w_T w){
-        // Shift product for exponential weights
-        #pragma HLS INLINE
-        // shift by the exponent. Negative weights shift right
-        rt y = static_cast<rt>(a) << w.weight;
-        // negate or not depending on weight sign
-        return w.sign == 1 ? y : static_cast<rt>(-y);
-    }
-};
-
-template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N>
-class weight_exponential<ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product{
-    public:
-    using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width, _AP_I + decltype(w_T::weight)::width,
-                        _AP_Q, _AP_O, _AP_N>;
-    static rt product(ap_fixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N> a, w_T w){
+    // Construct the return type from the multiplication equivalent to the largest shifts
+    // ap_int<pow2(decltype(w_T::weight)::width-1)-1> is the type if the multiplicand equivalent to the largest lshift <<
+    // ap_fixed<pow2(decltype(w_T::weight)::width-1)-1,0> is the type of the multiplicand equivalent to the largest rshift >>
+    using r_T = decltype(x_T(0) * (ap_int<pow2(decltype(w_T::weight)::width-1)-1>(1)+ap_fixed<pow2(decltype(w_T::weight)::width-1)-1,0>(1)));
+    static r_T product(x_T a, w_T w){
         // Shift product for exponential weights
         #pragma HLS INLINE
         // shift by the exponent. Negative weights shift right
-        rt y = static_cast<rt>(a) << w.weight;
-        // negate or not depending on weight sign
-        return w.sign == 1 ? y : static_cast<rt>(-y);
-    }
-};
-
-template<class w_T, int _AP_W, int _AP_I, ap_q_mode _AP_Q, ap_o_mode _AP_O, int _AP_N>
-class weight_exponential<ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N>, w_T> : public Product{
-    public:
-    using rt = ap_fixed<_AP_W + 2*decltype(w_T::weight)::width + 1, _AP_I + decltype(w_T::weight)::width + 1,
-                        _AP_Q, _AP_O, _AP_N>;
-    static rt product(ap_ufixed<_AP_W,_AP_I,_AP_Q, _AP_O, _AP_N> a, w_T w){
-        // Shift product for exponential weights
-        #pragma HLS INLINE
-        // shift by the exponent. Negative weights shift right
-        // shift by the exponent. Negative weights shift right
-        rt y = static_cast<rt>(a) << w.weight;
+        r_T y = static_cast<r_T>(a) << w.weight;
         // negate or not depending on weight sign
-        return w.sign == 1 ? y : static_cast<rt>(-y);
+        return w.sign == 1 ? y : static_cast<r_T>(-y);
     }
 };
 
-- 
GitLab


From 810618d90330c56e9b41ad7b726800b4d150cc00 Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Mon, 24 Jan 2022 15:36:20 -0600
Subject: [PATCH 5/9] break up accumulations in nnet_dense_resource with a cast

---
 hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h
index c77059aa7..c0e5d1759 100644
--- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h
+++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h
@@ -73,7 +73,8 @@ void dense_resource_rf_leq_nin(
         for (int im = 0; im < block_factor; im++) {
             #pragma HLS UNROLL
 
-            acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]);
+            acc[out_index] += static_cast<typename CONFIG_T::accum_t>(
+              CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]));
 
             // Increment w_index
             w_index += rufactor;
@@ -157,7 +158,8 @@ void dense_resource_rf_gt_nin_rem0(
         MultLoop:
         for (int im = 0; im < block_factor; im++) {
             #pragma HLS UNROLL
-            acc[out_index] += CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]);
+            acc[out_index] += static_cast<typename CONFIG_T::accum_t>(
+              CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[w_index]));
 
             w_index += rufactor;
             if (w_index >= CONFIG_T::n_in * CONFIG_T::n_out) break; // check out of bounds
-- 
GitLab


From 5456eb758e692e1116091d972c6ff19a9665b54b Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Mon, 24 Jan 2022 15:58:27 -0600
Subject: [PATCH 6/9] add auto_po2 qkeras test that revealed problem earlier

---
 test/pytest/test_qkeras.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/test/pytest/test_qkeras.py b/test/pytest/test_qkeras.py
index cdfc312a8..be061f1ea 100644
--- a/test/pytest/test_qkeras.py
+++ b/test/pytest/test_qkeras.py
@@ -106,8 +106,8 @@ def randX_100_16():
 # Note 4-bit test can still fail sometimes depending on random seed
 # https://github.com/fastmachinelearning/hls4ml/issues/381
 #@pytest.mark.parametrize('bits', [4, 6, 8])
-@pytest.mark.parametrize('bits', [4])
-def test_single_dense_activation_exact(randX_100_16, bits):
+@pytest.mark.parametrize('bits,alpha', [(4, 1), (4, 'auto_po2')])
+def test_single_dense_activation_exact(randX_100_16, bits, alpha):
   '''
   Test a single Dense -> Activation layer topology for
   bit exactness with number of bits parameter
@@ -115,7 +115,7 @@ def test_single_dense_activation_exact(randX_100_16, bits):
   X = randX_100_16
   model = Sequential()
   model.add(QDense(16, input_shape=(16,), name='fc1',
-                  kernel_quantizer=quantized_bits(bits,0,alpha=1), bias_quantizer=quantized_bits(bits,0,alpha=1),
+                  kernel_quantizer=quantized_bits(bits,0,alpha=alpha), bias_quantizer=quantized_bits(bits,0,alpha=1),
                   kernel_initializer='lecun_uniform'))
   model.add(QActivation(activation=quantized_relu(bits,0), name='relu1'))
   model.compile()
-- 
GitLab


From 14120601d21ff472fec1d16f82e948d98446a0ff Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Thu, 10 Feb 2022 17:18:40 -0600
Subject: [PATCH 7/9] update vivado passes

---
 hls4ml/backends/vivado/passes/convolution_templates.py | 4 ++--
 hls4ml/backends/vivado/passes/core_templates.py        | 8 ++++----
 hls4ml/backends/vivado/passes/merge_templates.py       | 4 ++--
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py
index c3e35bf1d..22aa5837e 100644
--- a/hls4ml/backends/vivado/passes/convolution_templates.py
+++ b/hls4ml/backends/vivado/passes/convolution_templates.py
@@ -13,8 +13,8 @@ conv_mult_config_template = """struct config{index}_mult : nnet::dense_config {{
     typedef {accum_t.name} accum_t;
     typedef {bias_t.name} bias_t;
     typedef {weight_t.name} weight_t;
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::{product_type}<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::{product_type}<x_T, y_T>;
 }};\n"""
 
 # Conv1D templates
diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py
index 686956735..201562f7f 100644
--- a/hls4ml/backends/vivado/passes/core_templates.py
+++ b/hls4ml/backends/vivado/passes/core_templates.py
@@ -18,8 +18,8 @@ dense_config_template = """struct config{index} : nnet::dense_config {{
     typedef {bias_t.name} bias_t;
     typedef {weight_t.name} weight_t;
     typedef {index_t.name} index_t;
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::{product_type}<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::{product_type}<x_T, y_T>;
 }};\n"""
 
 dense_function_template = 'nnet::dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
@@ -62,8 +62,8 @@ batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{
     static const bool store_weights_in_bram = false;
     typedef {bias_t.name} bias_t;
     typedef {scale_t.name} scale_t;
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::{product_type}<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::{product_type}<x_T, y_T>;
 }};\n"""
 
 batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});'
diff --git a/hls4ml/backends/vivado/passes/merge_templates.py b/hls4ml/backends/vivado/passes/merge_templates.py
index 219a7f4e2..863512c4c 100644
--- a/hls4ml/backends/vivado/passes/merge_templates.py
+++ b/hls4ml/backends/vivado/passes/merge_templates.py
@@ -50,8 +50,8 @@ dot_config_template = """struct config{index} : nnet::dot_config {{
     static const unsigned n_out = {n_out};
     static const unsigned reuse_factor = {reuse};
     typedef {accum_t.name} accum_t;
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::{product_type}<x_T, y_T, res_T>;
+    template<class x_T, class y_T>
+    using product = nnet::product::{product_type}<x_T, y_T>;
 }};\n"""
 
 class DotConfigTemplate(LayerConfigTemplate):
-- 
GitLab


From cb210210a9d28329c83f14260c6c1f633235542d Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Mon, 21 Feb 2022 16:23:02 -0600
Subject: [PATCH 8/9] update product on quartus side

---
 .../backends/quartus/passes/core_templates.py |  2 -
 .../quartus/firmware/nnet_utils/nnet_dense.h  | 44 +++++++++----------
 2 files changed, 20 insertions(+), 26 deletions(-)

diff --git a/hls4ml/backends/quartus/passes/core_templates.py b/hls4ml/backends/quartus/passes/core_templates.py
index edbfcc56b..63c3693b0 100644
--- a/hls4ml/backends/quartus/passes/core_templates.py
+++ b/hls4ml/backends/quartus/passes/core_templates.py
@@ -71,8 +71,6 @@ batchnorm_config_template = """struct config{index} : nnet::batchnorm_config {{
     static const bool store_weights_in_bram = false;
     typedef {bias_t.name} bias_t;
     typedef {scale_t.name} scale_t;
-    template<class x_T, class y_T, class res_T>
-    using product = nnet::product::{product_type}<x_T, y_T, res_T>;
 }};\n"""
 
 batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});'
diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h
index 18645042f..8075a9671 100644
--- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h
+++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_dense.h
@@ -51,38 +51,34 @@ struct dense_config
    // partitioning arrays cyclically to go with roll factors?
 };
 
-template<class data_T, class weight_T, class ret_T>
-inline typename std::enable_if<std::is_same<data_T, ac_int<1, false>>::value
-        and std::is_same<weight_T, ac_int<1, false>>::value, ac_int<1, false>>::type
-product(ac_int<1, false> a, ac_int<1, false> w){
+inline ac_int<1, false> product(ac_int<1, false> a, ac_int<1, false> w)
+{
     // specialisation for 1-bit weights and incoming data
-    return (ret_T) (a == w);
+    return (a == w);
 }
 
-template<class data_T, class weight_T, class ret_T>
-inline typename std::enable_if<(not std::is_same<data_T, ac_int<1, false>>::value)
-        and std::is_same<weight_T, ac_int<1, false>>::value, ret_T>::type
-product(data_T a, ac_int<1, false> w){
+template<class data_T>
+auto product(data_T a, ac_int<1, false> w) -> decltype(-a)
+{
     // Specialisation for 1-bit weights, arbitrary data
-    return w == 0 ? (ret_T) -a : a;
+    if (w == 0) return -a;
+    else return a;
 }
 
-template<class data_T, class weight_T, class ret_T>
-inline typename std::enable_if<(not std::is_same<data_T, ac_int<2, false>>::value)
-        and std::is_same<weight_T, ac_int<2, true>>::value, ret_T>::type
-product(data_T a, ac_int<2, true> w){
+template<class data_T>
+auto product(data_T a, ac_int<2, true> w) -> decltype(-a)
+{
     // Specialisation for 2-bit weights, arbitrary data
-    if (w == 0) return (ret_T) 0;
-    else if(w == -1) return (ret_T) -a;
-    else return (ret_T) a; // if(w == 1)
+    if (w == 0) return 0;
+    else if(w == -1) return -a;
+    else return a; // if(w == 1)
 }
 
-template<class data_T, class weight_T, class ret_T>
-inline typename std::enable_if<(not std::is_same<data_T, ac_int<1, false>>::value)
-        and (not std::is_same<weight_T, ac_int<1, false>>::value), ret_T>::type
-product(data_T a, weight_T w){
+template<class data_T, class weight_T>
+auto product(data_T a, weight_T w) -> decltype(a*w)
+{
     // 'Normal' product
-    return (ret_T)(a * w);
+    return a * w;
 }
 
 template<class data_T, class res_T, typename CONFIG_T>
@@ -138,7 +134,7 @@ void dense_rf_gt(
           uint32 w_index = ir + (CONFIG_T::reuse_factor_rounded) * im;
           if (w_index >= CONFIG_T::reuse_factor_rounded*CONFIG_T::block_factor_rounded) continue;
           int data_index = d_index[ir][im];
-          tmp_acc[im] = product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>(data[data_index], weights[w_index]);
+          tmp_acc[im] = product(data[data_index], weights[w_index]);
       }
       hls_register typename CONFIG_T::accum_t mult[CONFIG_T::multiplier_limit];
       ResetMult:
@@ -192,7 +188,7 @@ void dense_rf_lt(
        for (int im = 0, in_index = ir; im < CONFIG_T::block_factor; im++) {
             uint32 w_index = ir + (CONFIG_T::reuse_factor_rounded) * im;
             if (ir + CONFIG_T::reuse_factor * im >= CONFIG_T::n_in*CONFIG_T::n_out) continue;
-            mult[im] = product<data_T, typename CONFIG_T::weight_t, typename CONFIG_T::accum_t>(data[in_index], weights[w_index]);
+            mult[im] = product(data[in_index], weights[w_index]);
             in_index += CONFIG_T::reuse_factor;
             if (in_index >=  CONFIG_T::n_in) in_index = ir;
        }
-- 
GitLab


From aa8e4e29743f36ae34f477896eb102e66e845f38 Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Mon, 21 Feb 2022 16:25:39 -0600
Subject: [PATCH 9/9] make different backends not overwrite projects

---
 test/pytest/test_keras_api.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/test/pytest/test_keras_api.py b/test/pytest/test_keras_api.py
index 0e424f64e..820cb431d 100644
--- a/test/pytest/test_keras_api.py
+++ b/test/pytest/test_keras_api.py
@@ -37,7 +37,7 @@ def test_dense(backend):
     keras_prediction = model.predict(X_input)
 
     config = hls4ml.utils.config_from_keras_model(model)
-    output_dir = str(test_root_path / 'hls4mlprj_keras_api_dense')
+    output_dir = str(test_root_path / f'hls4mlprj_keras_api_dense_{backend}')
 
     hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend)
 
@@ -78,7 +78,7 @@ def test_activations(activation_function, backend):
     X_input = np.random.rand(100,1)
     keras_prediction = model.predict(X_input)
     config = hls4ml.utils.config_from_keras_model(model)
-    output_dir = str(test_root_path / 'hls4mlprj_keras_api_activations_{}'.format(activation_function.__class__.__name__))
+    output_dir = str(test_root_path / 'hls4mlprj_keras_api_activations_{}_{}'.format(activation_function.__class__.__name__, backend))
     hls_model = hls4ml.converters.convert_from_keras_model(model, hls_config=config, output_dir=output_dir, backend=backend)
     hls_model.compile()
     hls_prediction = hls_model.predict(X_input)
-- 
GitLab