From 8002669e07013188e6675fb2b49429bafaaf6203 Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Wed, 22 Dec 2021 18:55:11 -0600
Subject: [PATCH 1/3] first attempt trying to implement scaling and unscaling

---
 hls4ml/model/hls_layers.py                 |  22 +-
 hls4ml/model/optimizer/__init__.py         |   8 +-
 hls4ml/model/optimizer/passes/qkeras.py    |  22 --
 hls4ml/model/optimizer/passes/quant_opt.py | 297 +++++++++++++++------
 hls4ml/templates/vivado_template.py        |   1 +
 5 files changed, 238 insertions(+), 112 deletions(-)

diff --git a/hls4ml/model/hls_layers.py b/hls4ml/model/hls_layers.py
index bcc606104..969f680c5 100644
--- a/hls4ml/model/hls_layers.py
+++ b/hls4ml/model/hls_layers.py
@@ -573,7 +573,7 @@ class Constant(Layer):
             shape = (1,)
             self.value = np.array([self.value])
         dims = [f'{self.name}_{i}' for i in range(len(shape))]
-        self.add_output_variable(shape, dims, var_name=self.name, precision=self.get_attr("quant_precision"))
+        self.add_output_variable(shape, dims, var_name=self.name, precision=self.get_attr("precision"))
 
     def function_cpp(self):
         return None
@@ -1524,6 +1524,21 @@ class BatchNormalization(Layer):
 
         return self._config_template.format(**params)
 
+class ApplyAlpha(BatchNormalization):
+    ''' A custom layer to scale the output of a QDense layer which used 'alpha != 1'
+        Inference computation uses BatchNormalization methods'''
+
+    def initialize(self):
+        inp = self.get_input_variable()
+        shape = inp.shape
+        dims = inp.dim_names
+        self.add_output_variable(shape, dims)
+
+    def add_weights(self, scale, quantizer=None):
+        self.add_weights_variable(name='scale', var_name='s{index}', data=scale, quantizer=quantizer)
+
+    def add_bias(self, bias, quantizer=None):
+        self.add_weights_variable(name='bias', var_name='b{index}', data=bias, quantizer=quantizer)
 
 class Merge(Layer):
     def initialize(self):
@@ -2046,8 +2061,9 @@ layer_map = {
     'GarNetStack'            : GarNetStack,
     # TensorFlow-specific layers:
     'BiasAdd'                : BiasAdd,
-    # QONNX quantization layter
-    'Quant'                  : Quant
+    # QONNX quantization layer
+    'Quant'                  : Quant,
+    'ApplyAlpha'             : ApplyAlpha
 }
 
 def register_layer(name, clazz):
diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py
index 377078386..54ad72576 100644
--- a/hls4ml/model/optimizer/__init__.py
+++ b/hls4ml/model/optimizer/__init__.py
@@ -16,7 +16,8 @@ from hls4ml.model.optimizer.passes.repack_stream import ReshapeStream, Broadcast
 from hls4ml.model.optimizer.passes.transpose_opt import RemoveUselessTranspose
 from hls4ml.model.optimizer.passes.multi_dense import ReplaceMultidimensionalDenseWithConv
 from hls4ml.model.optimizer.passes.reshape_const import ReshapeConstant
-from hls4ml.model.optimizer.passes.quant_opt import QuantConstantParameters, QuantFactorizeScale, QuantToActivation, QuantToConstant
+from hls4ml.model.optimizer.passes.quant_opt import (
+    QuantConstantParameters, QuantToActivation, FuseQuantWithConstant, QuantToAlphaActivationAlpha, ConstQuantToConstAlpha)
 from hls4ml.model.optimizer.passes.batchnorm_opt import BatchNormConstantParameters, ConstantBatchNormMerging, FuseConsecutiveBatchNormalization
 from hls4ml.model.optimizer.passes.merge_const import MergeTwoConstant, MergeToBatchNormalization, MergeToBatchNormalizationDiv
 from hls4ml.model.optimizer.passes.matmul_const_to_dense import MatmulConstToDense
@@ -40,9 +41,10 @@ if __qkeras_optimizers__:
 
 register_pass('reshape_constant', ReshapeConstant)
 register_pass('quant_constant_params', QuantConstantParameters)
-register_pass('quant_factorize_scale', QuantFactorizeScale)
 register_pass('quant_to_activation', QuantToActivation)
-register_pass('quant_to_constant', QuantToConstant)
+register_pass('fuse_quant_with_constant', FuseQuantWithConstant)
+register_pass('quant_to_alph_activation_alpha', QuantToAlphaActivationAlpha)
+register_pass('const_quant_to_const_alpha', ConstQuantToConstAlpha)
 register_pass('batch_norm_constant_parameters', BatchNormConstantParameters)
 register_pass('fuse_consecutive_base_batch_normalizations', FuseConsecutiveBatchNormalization)
 register_pass('constant_batch_norm_fusion', ConstantBatchNormMerging)
diff --git a/hls4ml/model/optimizer/passes/qkeras.py b/hls4ml/model/optimizer/passes/qkeras.py
index 0abb5303d..b9ad01b8c 100644
--- a/hls4ml/model/optimizer/passes/qkeras.py
+++ b/hls4ml/model/optimizer/passes/qkeras.py
@@ -77,28 +77,6 @@ class OutputRoundingSaturationMode(OptimizerPass):
         pstr = pstr.replace('>', mode)
         return pstr
 
-class ApplyAlpha(BatchNormalization):
-    ''' A custom layer to scale the output of a QDense layer which used 'alpha != 1'
-        Inference computation uses BatchNormalization methods'''
-
-    def initialize(self):
-        inp = self.get_input_variable()
-        shape = inp.shape
-        dims = inp.dim_names
-        self.add_output_variable(shape, dims)
-
-    def add_weights(self, scale, quantizer=None):
-        self.add_weights_variable(name='scale', var_name='s{index}', data=scale, quantizer=quantizer)
-
-    def add_bias(self, bias, quantizer=None):
-        self.add_weights_variable(name='bias', var_name='b{index}', data=bias, quantizer=quantizer)
-
-# register the layer and its templates
-register_layer('ApplyAlpha', ApplyAlpha)
-# TODO ideally: for backend in backends
-for backend in ['Vivado', 'VivadoAccelerator']:
-    temps = templates.get_backend(backend)
-    temps.register_templates('ApplyAlpha', temps.get_function_template('BatchNormalization'), temps.get_config_template('BatchNormalization'), temps.get_include_list('BatchNormalization'))
 
 class QKerasFactorizeAlpha(OptimizerPass):
     '''OptimizerPass for extracting alpha "scale" from QKeras quantized layer.
diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py
index d48c52915..7e8b4ec1e 100644
--- a/hls4ml/model/optimizer/passes/quant_opt.py
+++ b/hls4ml/model/optimizer/passes/quant_opt.py
@@ -1,8 +1,15 @@
+'''
+This file includes optimizations related to quant nodes.
+
+As a first step, QuantConstantParameters converts the extra inputs to attributes. It is always the first step
+
+The next category is to check if scale and
+'''
+from copy import deepcopy
 import numpy as np
 from hls4ml.model.hls_layers import FixedPrecisionType, Constant
 from hls4ml.converters.onnx.quantizer import QuantNodeQuantizer
 from hls4ml.model.optimizer import OptimizerPass
-from hls4ml.model.optimizer.passes.qkeras import ApplyAlpha
 
 class QuantConstantParameters(OptimizerPass):
     """ Remove Constant from the Qaunt node parameters (but not input[0]) """
@@ -46,77 +53,77 @@ class QuantConstantParameters(OptimizerPass):
 
         return True
 
-class QuantFactorizeScale(OptimizerPass):
+
+class QuantToActivation(OptimizerPass):
     '''
-    Extract scale and zero-point from Quant Node
+    This is for the case when scale is 1 and zeropt is 0. It is a a 1:1 transformation of
+    a Quant to an Activation.
+
+    As an optimization, this is not called when the input is constant.
     '''
     def match(self, node):
         # only matches after the other inputs are already folded
-
         is_match = (node.__class__.__name__ == 'Quant'
+                    and not isinstance(node.get_input_node(), Constant)
                     and not node.get_input_node(node.inputs[1])
                     and not node.get_input_node(node.inputs[2])
                     and not node.get_input_node(node.inputs[3]))
-        
-        # Only match if the scale is not 1s and the zero-point is not 0s
-        if is_match and node.get_input_variable() is not None: # to make sure this is a quant node with inputs
+
+        # Only match if the scale is 1s and the zero-point is 0s
+        if is_match: # to make sure this is a quant node with inputs
             input_shape = node.get_input_variable().shape
-            scale = np.broadcast_to(1/node.get_attr("scale"), input_shape)
+            scale = np.broadcast_to(node.get_attr("scale"), input_shape)
             bias = np.broadcast_to(node.get_attr("zeropt"), input_shape)
-            is_match = is_match and (scale != np.ones_like(scale)).any()
-            is_match = is_match and (bias != np.zeros_like(bias)).any()
+            is_match = is_match and (scale == np.ones_like(scale)).all()
+            is_match = is_match and (bias == np.zeros_like(bias)).all()
         return is_match
 
     def transform(self, model, node):
         '''
-        Insert an ApplyAlpha layer to factorize the scales
+        Change quant node to Activation
         '''
         input_shape = node.get_input_variable().shape
 
-        scale = np.broadcast_to(1/node.get_attr('scale'), input_shape)
-        bias = np.broadcast_to(node.get_attr('zeropt'), input_shape)
-        # Unset the scale and zero-point so we don't try to factorize again
-        node.set_attr('scale', 1)
-        node.set_attr('zeropt', 0)
-
-        # TODO derive these
-        scale_precision = FixedPrecisionType()
-        scale_quantizer = QuantNodeQuantizer(scale_precision)
-        bias_precision = FixedPrecisionType()
-
-        attrs = {
-            'name' : node.get_attr('name') + '_alpha',
-            'class_name' : 'Alpha',
-            'inputs' : node.outputs,
-            'n_in' : node.get_attr('n_out'),
-            'n_filt' : node.get_attr('n_filt', -1),
-            'reuse_factor' : node.get_attr('reuse_factor'),
-            'bias_t' : bias_precision, 
-            'scale_t' : scale_precision,
-            'Trace' : node.get_attr('Trace', False) 
+        n_in = np.prod(input_shape)
+
+        rounding_mode = node.get_attr("rounding_mode")
+        narrow = node.get_attr("narrow")
+        signed = node.get_attr("signed")
+        bitwidth = node.get_attr("bitwidth")
+
+        precision, quantizer = _calculate_precision_quantizer(bitwidth, signed, narrow, rounding_mode)
+
+        attributes = {
+            'activation' : 'linear',
+            'precision'  : precision,
+            'quantizer'  : quantizer,
+            'n_in'       : n_in
         }
-        alpha_layer = model.make_node('ApplyAlpha', node.name + '_alpha', attrs, node.outputs)
 
-        alpha_layer.add_weights(scale, quantizer=scale_quantizer)
-        alpha_layer.add_bias(bias, quantizer=None)
-        model.insert_node(alpha_layer)
- 
+        new_node = model.make_node('Activation', f'{node.name}_act',
+                                   attributes, [node.inputs[0]], node.outputs)
+        new_node.get_output_variable().type.precision = precision
+        model.replace_node(node, new_node)
+
         return True
 
-class QuantToActivation(OptimizerPass):
-    ''' Change Quant node to Activation input[0]'''
+
+class FuseQuantWithConstant(OptimizerPass):
+    '''
+    This is for the case when scale is 1 and zeropt is 0. It directly applies the quantization to a constant.
+    '''
     def match(self, node):
         # only matches after the other inputs are already folded
         is_match = (node.__class__.__name__ == 'Quant'
-                    and not isinstance(node.get_input_node(), Constant)
+                    and isinstance(node.get_input_node(), Constant)
                     and not node.get_input_node(node.inputs[1])
                     and not node.get_input_node(node.inputs[2])
                     and not node.get_input_node(node.inputs[3]))
-        
+
         # Only match if the scale is 1s and the zero-point is 0s
         if is_match: # to make sure this is a quant node with inputs
             input_shape = node.get_input_variable().shape
-            scale = np.broadcast_to(1/node.get_attr("scale"), input_shape)
+            scale = np.broadcast_to(node.get_attr("scale"), input_shape)
             bias = np.broadcast_to(node.get_attr("zeropt"), input_shape)
             is_match = is_match and (scale == np.ones_like(scale)).all()
             is_match = is_match and (bias == np.zeros_like(bias)).all()
@@ -124,42 +131,74 @@ class QuantToActivation(OptimizerPass):
 
     def transform(self, model, node):
         '''
-        Change quant node to Activation
+        Fuse Quant with Constant.
         '''
-        input_shape = node.get_input_variable().shape
-
-        n_in = np.prod(input_shape)
 
         rounding_mode = node.get_attr("rounding_mode")
-        if rounding_mode == "ROUND":
-            bn_round = "AP_RND_CONV"
-        elif rounding_mode == "FLOOR":
-            bn_round =  "AP_TRN"
-        else:
-            raise NotImplementedError(f"Rounding mode {rounding_mode} not supported in Quant node. Only ROUND and FLOOR supported.")
+        narrow = node.get_attr("narrow")
+        signed = node.get_attr("signed")
+        bitwidth = node.get_attr("bitwidth")
+
+        precision, quantizer = _calculate_precision_quantizer(bitwidth, signed, narrow, rounding_mode)
+
+        const_node = node.get_input_node(node.inputs[0])
+        const_node.set_attr("precsion", precision)
+        const_node.set_attr("quantizer", quantizer)
+
+        # reinitialize (which also runs quantization if quantizer exists)
+        const_node.initialize()
+
+        # remove the Quant node
+        model.remove_node(node, rewire=True)
+
+        return True
+
+
+class QuantToAlphaActivationAlpha(OptimizerPass):
+    '''
+    This is for the case when scale is not 1 or zeropt is not 0. It is a a 1:3 transformation of
+    a Quant to an ApplyAlpha (to scale), Activatio, ApplyAlpho (to rescale).
+
+    As an optimization, this is not called when the input is constant.
+    '''
+    def match(self, node):
+        # only matches after the other inputs are already folded
+        is_match = (node.__class__.__name__ == 'Quant'
+                    and not isinstance(node.get_input_node(), Constant)
+                    and not node.get_input_node(node.inputs[1])
+                    and not node.get_input_node(node.inputs[2])
+                    and not node.get_input_node(node.inputs[3]))
 
-        if node.get_attr("narrow") and not node.get_attr("signed"):
-            raise NotImplementedError("Narrow mode is only supported for singed numbers.")
+        if is_match: # to make sure this is a quant node with inputs
+            input_shape = node.get_input_variable().shape
+            scale = np.broadcast_to(node.get_attr("scale"), input_shape)
+            bias = np.broadcast_to(node.get_attr("zeropt"), input_shape)
+            is_match = is_match and ((scale != np.ones_like(scale)).any() or (bias != np.zeros_like(bias)).any())
+        return is_match
 
-        if node.get_attr("narrow"):
-            bn_sat = "AP_SAT_SYM"
-        else:
-            bn_sat = "AP_SAT"
+    def transform(self, model, node):
+        '''
+        Change quant node to ApplyAlhpa, Activation, ApplyAlpha
+        '''
+
+        # Do the Activation as in the simple case
+
+        input_shape = node.get_input_variable().shape
 
+        n_in = np.prod(input_shape)
+
+        rounding_mode = node.get_attr("rounding_mode")
+        narrow = node.get_attr("narrow")
+        signed = node.get_attr("signed")
         bitwidth = node.get_attr("bitwidth")
-        if np.squeeze(bitwidth).shape:
-            raise RuntimeError("Only scalar bitwidth values are supporeted by the Quant node")
-        bitwidth = int(bitwidth)
 
-        precision = FixedPrecisionType(bitwidth, bitwidth, node.get_attr("signed"), bn_round, bn_sat)
-        quantizer = QuantNodeQuantizer(precision)
+        precision, quantizer = _calculate_precision_quantizer(bitwidth, signed, narrow, rounding_mode)
 
         attributes = {
             'activation' : 'linear',
             'precision'  : precision,
-            'n_in'       : n_in,
-            'n_out'      : n_in,
-            'n_filt'     : -1
+            'quantizer'  : quantizer,
+            'n_in'       : n_in
         }
 
         new_node = model.make_node('Activation', f'{node.name}_act',
@@ -167,36 +206,126 @@ class QuantToActivation(OptimizerPass):
         new_node.get_output_variable().type.precision = precision
         model.replace_node(node, new_node)
 
+        # but now add the ApplyAlhpas before and after
+
+        scale = np.broadcast_to(node.get_attr("scale"), input_shape)
+        bias = np.broadcast_to(node.get_attr("zeropt"), input_shape)
+
+        attributes_scale = {
+            'n_in': n_in,
+            'n_out': n_in,
+            'n_filt': -1,
+            'reuse_factor': node.get_attr("reuse_factor"),
+            'target_cycles': node.get_attr("target_cycles"),
+            'Trace'      : False
+        }
+
+        attributes_rescale = deepcopy(attributes_scale)
+
+        scale_node = model.make_node('ApplyAlpha', node.name + '_scale', attributes_scale, node.inputs)
+        scale_node.add_weights(1/scale)
+        scale_node.add_bias(bias)
+        model.insert_node(scale_node)
+
+        rescale_node = model.make_node('ApplyAlpha', node.name + '_rescale', attributes_rescale, new_node.outputs)
+        rescale_node.add_weights(scale)
+        rescale_node.add_bias(-bias*scale)
+        model.insert_node(rescale_node)
+
         return True
 
-class QuantToConstant(OptimizerPass):
+
+class ConstQuantToConstAlpha(OptimizerPass):
     '''
-    Remove a Quant node that is quantizing a constant.
-    Update the attributes of the constant according to the quantization.
+    This is for the case when scale is not 1 or zeropt is not 0. It is a a 1:3 transformation of
+    a Quant to an ApplyAlpha (to scale), Activation, ApplyAlpho (to unscale), but an input
+    consts allows for optimization, so the ApplyAlpha (to scale), Activation are
+    optimized away right away.
     '''
-
     def match(self, node):
+        # only matches after the other inputs are already folded
         is_match = (node.__class__.__name__ == 'Quant'
-                    and isinstance(node.get_input_node(node.inputs[0]), Constant))
+                    and isinstance(node.get_input_node(), Constant)
+                    and not node.get_input_node(node.inputs[1])
+                    and not node.get_input_node(node.inputs[2])
+                    and not node.get_input_node(node.inputs[3]))
+
+        if is_match: # to make sure this is a quant node with inputs
+            input_shape = node.get_input_variable().shape
+            scale = np.broadcast_to(node.get_attr("scale"), input_shape)
+            bias = np.broadcast_to(node.get_attr("zeropt"), input_shape)
+            is_match = is_match and ((scale != np.ones_like(scale)).any() or (bias != np.zeros_like(bias)).any())
         return is_match
 
     def transform(self, model, node):
+        '''
+        Change Constant + Quant node to Constant, ApplyAlpha
+        '''
+
+        # Do the Activation as in the simple case
+
+        input_shape = node.get_input_variable().shape
+
+        n_in = np.prod(input_shape)
+
+        rounding_mode = node.get_attr("rounding_mode")
+        narrow = node.get_attr("narrow")
+        signed = node.get_attr("signed")
+        bitwidth = node.get_attr("bitwidth")
+
+        precision, quantizer = _calculate_precision_quantizer(bitwidth, signed, narrow, rounding_mode)
+
         const_node = node.get_input_node(node.inputs[0])
 
-        new_val = const_node.value * node.get_attr('scale') + node.get_attr('zeropt')
-        quantizer = node.get_attr('quantizer')  # None if not defined
-        if quantizer:
-            const_node.set_attr('quantizer', quantizer)
+        scale = np.broadcast_to(node.get_attr("scale"), input_shape)
+        bias = np.broadcast_to(node.get_attr("zeropt"), input_shape)
+
+        # caclucate the new value
+        new_val = const_node.value / scale + bias
         const_node.set_attr('value', new_val)
+        const_node.set_attr("precsion", precision)
+        const_node.set_attr("quantizer", quantizer)
 
-        quant_precision = node.get_attr('quant_precision')
-        if quant_precision:
-            const_node.set_attr('quant_precision', quant_precision)
+        attributes_rescale = {
+            'n_in': n_in,
+            'n_out': n_in,
+            'n_filt': -1,
+            'reuse_factor': node.get_attr("reuse_factor"),
+            'target_cycles': node.get_attr("target_cycles"),
+            'Trace'      : False
+        }
 
-        # reinitialize (which also runs quantization if quantizer exists)
-        const_node.initialize()
+        rescale_node = model.make_node('ApplyAlpha', node.name + '_rescale', attributes_rescale, node.inputs)
+        rescale_node.add_weights(scale)
+        rescale_node.add_bias(-bias*scale)
+        model.replace_node(node, rescale_node)
 
-        # remove the Quant node
-        model.remove_node(node, rewire=True)
-       
-        return True
\ No newline at end of file
+        return True
+
+
+def _calculate_precision_quantizer(bitwidth, signed, narrow, rounding_mode):
+    '''
+    A function to determine the precision and quantizer
+    '''
+    if rounding_mode == "ROUND":
+        bn_round = "AP_RND_CONV"
+    elif rounding_mode == "FLOOR":
+        bn_round =  "AP_TRN"
+    else:
+        raise NotImplementedError(f"Rounding mode {rounding_mode} not supported in Quant node. Only ROUND and FLOOR supported.")
+
+    if narrow and not signed:
+        raise NotImplementedError("Narrow mode is only supported for singed numbers.")
+
+    if narrow:
+        bn_sat = "AP_SAT_SYM"
+    else:
+        bn_sat = "AP_SAT"
+
+    if np.squeeze(bitwidth).shape:
+        raise RuntimeError("Only scalar bitwidth values are supporeted by the Quant node")
+    bitwidth = int(bitwidth)
+
+    precision = FixedPrecisionType(bitwidth, bitwidth, signed, bn_round, bn_sat)
+    quantizer = QuantNodeQuantizer(precision)
+    return (precision, quantizer)
diff --git a/hls4ml/templates/vivado_template.py b/hls4ml/templates/vivado_template.py
index 149b52f1d..6bc449b31 100644
--- a/hls4ml/templates/vivado_template.py
+++ b/hls4ml/templates/vivado_template.py
@@ -387,6 +387,7 @@ class VivadoBackend(Backend):
         self.register_templates('Dense', dense_function_template, dense_config_template, dense_include_list)
         self.register_templates('BinaryDense'            , dense_function_template,       dense_config_template, dense_include_list)
         self.register_templates('BatchNormalization'     , batchnorm_function_template,   batchnorm_config_template, batchnorm_include_list)
+        self.register_templates('ApplyAlpha'             , batchnorm_function_template,   batchnorm_config_template, batchnorm_include_list)
         self.register_templates('Conv1D'                 , conv1d_function_template,      [conv1d_config_template, conv_mult_config_template], conv1d_include_list)
         self.register_templates('Conv2D'                 , conv2d_function_template,      [conv2d_config_template, conv_mult_config_template], conv2d_include_list)
         self.register_templates('Conv2DBatchnorm'        , conv2d_function_template,      [conv2d_config_template, conv_mult_config_template], conv2d_include_list)
-- 
GitLab


From bde26297495b94bc6cfde6225abe80741d759c2b Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Thu, 23 Dec 2021 17:38:53 -0600
Subject: [PATCH 2/3] Checkpoint, TFC working, but CNV failing

---
 hls4ml/model/hls_layers.py                    |  5 ++
 hls4ml/model/hls_model.py                     |  9 ++--
 hls4ml/model/optimizer/__init__.py            |  3 +-
 .../model/optimizer/passes/batchnorm_opt.py   | 46 +++++++++----------
 hls4ml/model/optimizer/passes/nop.py          | 23 ++++++++++
 hls4ml/model/optimizer/passes/quant_opt.py    | 28 +++++------
 6 files changed, 71 insertions(+), 43 deletions(-)

diff --git a/hls4ml/model/hls_layers.py b/hls4ml/model/hls_layers.py
index 969f680c5..0ceed192e 100644
--- a/hls4ml/model/hls_layers.py
+++ b/hls4ml/model/hls_layers.py
@@ -418,6 +418,11 @@ class Layer(object):
 
         self.precision[out.type.name] = out.type
 
+    def update_output_precision(self, precision, output_name=None):
+        if output_name is None:
+            output_name = self.outputs[0]
+        self.variables[output_name].type.precision = precision
+
     def make_array_variable(self, shape, dim_names, var_name='layer{index}_out', type_name='layer{index}_t', precision=None, pragma='auto'):
         if pragma == 'auto':
             if self.model.config.get_config_value('IOType') == 'io_serial':
diff --git a/hls4ml/model/hls_model.py b/hls4ml/model/hls_model.py
index d957095a5..60e9bd95e 100644
--- a/hls4ml/model/hls_model.py
+++ b/hls4ml/model/hls_model.py
@@ -384,13 +384,14 @@ class HLSModel(object):
                 `before` does not specify a correct node in sequence.
 
         """
-        if len(node.inputs) > 1:
+        # string comprehension is to remove empty inputs
+        if len([x for x in node.inputs if x]) > 1:
             raise Exception('Cannot insert a node with more than one input (for now).')
 
         prev_node = node.get_input_node(node.inputs[0])
-        next_nodes = [x for x in self.graph.values() if x.inputs[0] in prev_node.outputs]
+        next_nodes = [x for x in self.graph.values() if len(x.inputs) > 0 and x.inputs[0] in prev_node.outputs]
         if before is None:
-            next_node = next((x for x in self.graph.values() if x.inputs[0] in prev_node.outputs), None)
+            next_node = next((x for x in self.graph.values() if len(x.inputs) and x.inputs[0] in prev_node.outputs), None)
         else:
             if before not in next_nodes:
                 raise Exception('Cannot insert a node {} before {} (candidates: {}).'.format(node.name, before.name, ','.join([n.name for n in next_nodes])))
@@ -556,7 +557,7 @@ class HLSModel(object):
             xlist = [x]
         else:
             xlist = x
-        
+
         for xi in xlist:
             if not isinstance(xi, np.ndarray):
                 raise Exception('Expected numpy.ndarray, but got {}'.format(type(x)))
diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py
index 54ad72576..4da802815 100644
--- a/hls4ml/model/optimizer/__init__.py
+++ b/hls4ml/model/optimizer/__init__.py
@@ -3,7 +3,7 @@ from __future__ import absolute_import
 from hls4ml.model.optimizer.optimizer import OptimizerPass, register_pass, get_optimizer, optimize_model, get_available_passes
 
 
-from hls4ml.model.optimizer.passes.nop import EliminateLinearActivation
+from hls4ml.model.optimizer.passes.nop import EliminateLinearActivation, EliminateLinearActivationQuant
 from hls4ml.model.optimizer.passes.bn_quant import MergeBatchNormAndQuantizedTanh
 from hls4ml.model.optimizer.passes.bn_quant import QuantizeDenseOutput
 from hls4ml.model.optimizer.passes.bn_fuse import FuseBatchNormalization
@@ -68,3 +68,4 @@ register_pass('reshape_stream', ReshapeStream)
 register_pass('remove_useless_transpose', RemoveUselessTranspose)
 register_pass('replace_multidense_conv', ReplaceMultidimensionalDenseWithConv)
 register_pass('broadcast_stream', BroadcastStream)
+register_pass('eliminate_linear_activation_quant', EliminateLinearActivationQuant)
diff --git a/hls4ml/model/optimizer/passes/batchnorm_opt.py b/hls4ml/model/optimizer/passes/batchnorm_opt.py
index e6f627a39..6c9085f70 100644
--- a/hls4ml/model/optimizer/passes/batchnorm_opt.py
+++ b/hls4ml/model/optimizer/passes/batchnorm_opt.py
@@ -1,10 +1,11 @@
 import numpy as np
 from hls4ml.model.optimizer import OptimizerPass
+from hls4ml.model.hls_layers import BatchNormalization, Constant
 
 class BatchNormConstantParameters(OptimizerPass):
     """ Remove Constant from the BatchNormalization node parameters (but not input[0]) """
     def match(self, node):
-        is_match = (node.__class__.__name__ == 'BatchNormalization'
+        is_match = (isinstance(node, BatchNormalization)
                     and any(node.inputs[1:]))
 
         return is_match
@@ -16,9 +17,9 @@ class BatchNormConstantParameters(OptimizerPass):
 
         if not (len(node.inputs) == 5 and all(node.inputs)):
             raise ValueError(f"All {len.node.inputs} BatchNormalization inputs need to be defined")
-        
+
         gamma_node = node.get_input_node(node.inputs[1])
-        if gamma_node.__class__.__name__ != 'Constant':
+        if not isinstance(gamma_node, Constant):
             raise TypeError("Only consant gammas supported")
         gamma = gamma_node.value
         node.set_attr('gamma', gamma)
@@ -34,7 +35,7 @@ class BatchNormConstantParameters(OptimizerPass):
         model.remove_node(beta_node, rewire=False)
 
         moving_mean_node = node.get_input_node(node.inputs[3])
-        if moving_mean_node.__class__.__name__ != 'Constant':
+        if not isinstance(moving_mean_node, Constant):
             raise TypeError("Only consant moving_means supported")
         moving_mean = moving_mean_node.value
         node.set_attr('moving_mean', moving_mean)
@@ -42,7 +43,7 @@ class BatchNormConstantParameters(OptimizerPass):
         model.remove_node(moving_mean_node, rewire=False)
 
         moving_variance_node = node.get_input_node(node.inputs[4])
-        if moving_variance_node.__class__.__name__ != 'Constant':
+        if not isinstance(moving_variance_node, Constant):
             raise TypeError("Only consant moving_variances supported")
         moving_variance = moving_variance_node.value
         node.set_attr('moving_variance', moving_variance)
@@ -52,8 +53,8 @@ class BatchNormConstantParameters(OptimizerPass):
         scale = gamma / np.sqrt(moving_variance + node.get_attr('epsilon'))
         bias = beta - gamma * moving_mean / np.sqrt(moving_variance + node.get_attr('epsilon'))
 
-        node.add_weights_variable("scale", data=scale, precision=node.get_attr("weight_precision"), quantizer=node.get_attr("weight_quantizer"))
-        node.add_weights_variable("bias", data=bias, precision=node.get_attr("weight_precision"), quantizer=node.get_attr("weight_quantizer"))
+        node.add_weights_variable("scale", data=scale, precision=node.get_attr("scale_precision"), quantizer=node.get_attr("bias_quantizer"))
+        node.add_weights_variable("bias", data=bias, precision=node.get_attr("bias_precision"), quantizer=node.get_attr("bias_quantizer"))
 
         return True
 
@@ -63,12 +64,12 @@ class ConstantBatchNormMerging(OptimizerPass):
     Merge BatchNorm into Const (after parameters have already been merged in BatchNormalization)
     """
     def match(self, node):
-        is_match = (node.__class__.__name__ == 'BatchNormalization'
+        is_match = (isinstance(node, BatchNormalization)
                     and not any(node.inputs[1:])
-                    and node.get_input_node(node.inputs[0]).__class__.__name__ == 'Constant')
-
+                    and isinstance(node.get_input_node(node.inputs[0]), Constant)
+                    and not node.get_input_node(node.inputs[0]).get_attr("quant_precision"))
         return is_match
-    
+
     def transform(self, model, node):
         """
         Remove the batch norm
@@ -76,21 +77,16 @@ class ConstantBatchNormMerging(OptimizerPass):
         const_node = node.get_input_node(node.inputs[0])
 
         new_val = const_node.value * node.weights["scale"].data_unquantized + node.weights["bias"].data_unquantized
-        quantizer = node.get_attr("quantizer")  # None if not defined
-        if quantizer:
-            const_node.set_attr("quantizer", quantizer)
         const_node.set_attr("value", new_val)
-
-        quant_precision = node.get_attr("quant_precision")
-        if quant_precision:
-            const_node.set_attr("quant_precision", quant_precision)
+        const_node.set_attr("quantizer", node.get_attr("quantizer"))  # None if not defined
+        const_node.set_attr("quant_precision",  node.get_attr("quant_precision"))
 
         # reinitialize (which also runs quantization if quantizer exists)
         const_node.initialize()
 
         # remove the batch norm node
         model.remove_node(node, rewire=True)
-       
+
         return True
 
 
@@ -101,10 +97,10 @@ class FuseConsecutiveBatchNormalization(OptimizerPass):
     '''
 
     def match(self, node):
-        return (node.__class__.__name__ == 'BatchNormalization'
-                and node.get_input_node(node.inputs[0]).__class__.__name__ == 'BatchNormalization'
+        return (isinstance(node, BatchNormalization)
+                and isinstance(node.get_input_node(node.inputs[0]), BatchNormalization)
                 and not node.get_input_node(node.inputs[0]).get_attr("quant_precision"))
- 
+
 
     def transform(self, model, node):
         prev_node = node.get_input_node(node.inputs[0])
@@ -118,8 +114,8 @@ class FuseConsecutiveBatchNormalization(OptimizerPass):
         bias_new = s1 * b0 + b1
 
         # call function so that quantizer would be called if needed
-        node.add_weights_variable(name='scale', data=scale_new, precision=node.get_attr("weight_precision"), quantizer=node.get_attr("weight_quantizer"))
-        node.add_weights_variable(name='bias', data=bias_new, precision=node.get_attr("weight_precision"), quantizer=node.get_attr("weight_quantizer"))
- 
+        node.add_weights_variable(name='scale', data=scale_new, precision=node.get_attr("scale_precision"), quantizer=node.get_attr("scale_quantizer"))
+        node.add_weights_variable(name='bias', data=bias_new, precision=node.get_attr("bias_precision"), quantizer=node.get_attr("bias_quantizer"))
+
         model.remove_node(prev_node, rewire=True)
         return True
diff --git a/hls4ml/model/optimizer/passes/nop.py b/hls4ml/model/optimizer/passes/nop.py
index c727a1179..05104b1e2 100644
--- a/hls4ml/model/optimizer/passes/nop.py
+++ b/hls4ml/model/optimizer/passes/nop.py
@@ -10,3 +10,26 @@ class EliminateLinearActivation(OptimizerPass):
     def transform(self, model, node):
         model.remove_node(node)
         return True
+
+class EliminateLinearActivationQuant(OptimizerPass):
+    '''
+    This is to optimize away lots of linear qantizations in QONNX. May have to restrict it
+    more if it causes problems.
+    '''
+    def match(self, node):
+        '''
+        Only match if this activation is from quant node and previous node precision is not set  by a quant node already.
+        '''
+        is_match = (node.__class__.__name__ == 'Activation' and node.get_attr('activation') == 'linear'
+                    and node.get_attr("quant_precision")
+                    and not node.get_input_node(node.inputs[0]).get_attr("quant_precision"))
+        return is_match
+
+    def transform(self, model, node):
+        prev_node = node.get_input_node(node.inputs[0]);
+        quant_precision = node.get_attr("quant_precision")
+        prev_node.set_attr("quant_precision", quant_precision)
+        prev_node.set_attr("quantizer", node.get_attr("quantizer"))
+        prev_node.update_output_precision(quant_precision)
+        model.remove_node(node)
+        return True
diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py
index 7e8b4ec1e..d92408a45 100644
--- a/hls4ml/model/optimizer/passes/quant_opt.py
+++ b/hls4ml/model/optimizer/passes/quant_opt.py
@@ -30,23 +30,21 @@ class QuantConstantParameters(OptimizerPass):
         """
         if node.get_input_node(node.inputs[1]):
             scale_node = node.get_input_node(node.inputs[1])
-            if scale_node.__class__.__name__ == 'Constant':
+            if isinstance(scale_node, Constant):
                 node.set_attr('scale', scale_node.value)
                 node.inputs[1] = ''
-                node.attributes["scale_precision"] = scale_node.get_attr("quant_precision")
                 model.remove_node(scale_node, rewire=False)
 
         if node.get_input_node(node.inputs[2]):
             zeropt_node = node.get_input_node(node.inputs[2])
-            if zeropt_node.__class__.__name__ == 'Constant':
+            if isinstance(zeropt_node, Constant):
                 node.set_attr('zeropt', zeropt_node.value)
                 node.inputs[2] = ''
-                node.attributes["bias_precision"] = zeropt_node.get_attr("quant_precision")
                 model.remove_node(zeropt_node, rewire=False)
 
         if node.get_input_node(node.inputs[3]):
             bitwidth_node = node.get_input_node(node.inputs[3])
-            if bitwidth_node.__class__.__name__ == 'Constant':
+            if isinstance(bitwidth_node, Constant):
                 node.set_attr('bitwidth', bitwidth_node.value)
                 node.inputs[3] = ''
                 model.remove_node(bitwidth_node, rewire=False)
@@ -64,7 +62,7 @@ class QuantToActivation(OptimizerPass):
     def match(self, node):
         # only matches after the other inputs are already folded
         is_match = (node.__class__.__name__ == 'Quant'
-                    and not isinstance(node.get_input_node(), Constant)
+                    and not isinstance(node.get_input_node(node.inputs[0]), Constant)
                     and not node.get_input_node(node.inputs[1])
                     and not node.get_input_node(node.inputs[2])
                     and not node.get_input_node(node.inputs[3]))
@@ -95,7 +93,7 @@ class QuantToActivation(OptimizerPass):
 
         attributes = {
             'activation' : 'linear',
-            'precision'  : precision,
+            'quant_precision'  : precision,
             'quantizer'  : quantizer,
             'n_in'       : n_in
         }
@@ -115,7 +113,7 @@ class FuseQuantWithConstant(OptimizerPass):
     def match(self, node):
         # only matches after the other inputs are already folded
         is_match = (node.__class__.__name__ == 'Quant'
-                    and isinstance(node.get_input_node(), Constant)
+                    and isinstance(node.get_input_node(node.inputs[0]), Constant)
                     and not node.get_input_node(node.inputs[1])
                     and not node.get_input_node(node.inputs[2])
                     and not node.get_input_node(node.inputs[3]))
@@ -142,7 +140,7 @@ class FuseQuantWithConstant(OptimizerPass):
         precision, quantizer = _calculate_precision_quantizer(bitwidth, signed, narrow, rounding_mode)
 
         const_node = node.get_input_node(node.inputs[0])
-        const_node.set_attr("precsion", precision)
+        const_node.set_attr("quant_precision", precision)
         const_node.set_attr("quantizer", quantizer)
 
         # reinitialize (which also runs quantization if quantizer exists)
@@ -164,7 +162,7 @@ class QuantToAlphaActivationAlpha(OptimizerPass):
     def match(self, node):
         # only matches after the other inputs are already folded
         is_match = (node.__class__.__name__ == 'Quant'
-                    and not isinstance(node.get_input_node(), Constant)
+                    and not isinstance(node.get_input_node(node.inputs[0]), Constant)
                     and not node.get_input_node(node.inputs[1])
                     and not node.get_input_node(node.inputs[2])
                     and not node.get_input_node(node.inputs[3]))
@@ -196,7 +194,7 @@ class QuantToAlphaActivationAlpha(OptimizerPass):
 
         attributes = {
             'activation' : 'linear',
-            'precision'  : precision,
+            'quant_precision'  : precision,
             'quantizer'  : quantizer,
             'n_in'       : n_in
         }
@@ -245,7 +243,7 @@ class ConstQuantToConstAlpha(OptimizerPass):
     def match(self, node):
         # only matches after the other inputs are already folded
         is_match = (node.__class__.__name__ == 'Quant'
-                    and isinstance(node.get_input_node(), Constant)
+                    and isinstance(node.get_input_node(node.inputs[0]), Constant)
                     and not node.get_input_node(node.inputs[1])
                     and not node.get_input_node(node.inputs[2])
                     and not node.get_input_node(node.inputs[3]))
@@ -283,9 +281,12 @@ class ConstQuantToConstAlpha(OptimizerPass):
         # caclucate the new value
         new_val = const_node.value / scale + bias
         const_node.set_attr('value', new_val)
-        const_node.set_attr("precsion", precision)
+        const_node.set_attr("quant_precision", precision)
         const_node.set_attr("quantizer", quantizer)
 
+        # reinitialize (which also runs quantization if quantizer exists)
+        const_node.initialize()
+
         attributes_rescale = {
             'n_in': n_in,
             'n_out': n_in,
@@ -329,3 +330,4 @@ def _calculate_precision_quantizer(bitwidth, signed, narrow, rounding_mode):
     precision = FixedPrecisionType(bitwidth, bitwidth, signed, bn_round, bn_sat)
     quantizer = QuantNodeQuantizer(precision)
     return (precision, quantizer)
+
-- 
GitLab


From 026cdcb1aed37eb275850dc44479ca1075573fbf Mon Sep 17 00:00:00 2001
From: Jovan Mitrevski <jmitrevs@fnal.gov>
Date: Thu, 23 Dec 2021 18:45:19 -0600
Subject: [PATCH 3/3] fix sharing of lists

---
 hls4ml/model/optimizer/passes/quant_opt.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py
index d92408a45..6cb2e4778 100644
--- a/hls4ml/model/optimizer/passes/quant_opt.py
+++ b/hls4ml/model/optimizer/passes/quant_opt.py
@@ -220,12 +220,12 @@ class QuantToAlphaActivationAlpha(OptimizerPass):
 
         attributes_rescale = deepcopy(attributes_scale)
 
-        scale_node = model.make_node('ApplyAlpha', node.name + '_scale', attributes_scale, node.inputs)
+        scale_node = model.make_node('ApplyAlpha', node.name + '_scale', attributes_scale, [x for x in node.inputs])
         scale_node.add_weights(1/scale)
         scale_node.add_bias(bias)
         model.insert_node(scale_node)
 
-        rescale_node = model.make_node('ApplyAlpha', node.name + '_rescale', attributes_rescale, new_node.outputs)
+        rescale_node = model.make_node('ApplyAlpha', node.name + '_rescale', attributes_rescale, [x for x in new_node.outputs])
         rescale_node.add_weights(scale)
         rescale_node.add_bias(-bias*scale)
         model.insert_node(rescale_node)
@@ -296,7 +296,7 @@ class ConstQuantToConstAlpha(OptimizerPass):
             'Trace'      : False
         }
 
-        rescale_node = model.make_node('ApplyAlpha', node.name + '_rescale', attributes_rescale, node.inputs)
+        rescale_node = model.make_node('ApplyAlpha', node.name + '_rescale', attributes_rescale, [x for x in node.inputs])
         rescale_node.add_weights(scale)
         rescale_node.add_bias(-bias*scale)
         model.replace_node(node, rescale_node)
-- 
GitLab