Skip to content

QKeras support for RNN layers

Javier Duarte requested to merge github/fork/laurilaatu/qrnn into main

Created by: laurilaatu

Description

📝 This pull request adds QKeras support for RNN layers

The quantizers for kernel, recurrent kernel and bias are retrieved from the QKeras model. State quantizer corresponds to the output precision of the hidden state which is set by altering the default precision (discussed in issue #825 (closed)). Currently the quantized activations for the RNN layers are not supported.

Type of change

  • New feature (non-breaking change which adds functionality)

Tests

📝 Please describe the tests that you ran to verify your changes.

  • Tested with several different network configurations
  • Pytests implemented
  • All layers produce expected output except for the Vivado implementation of the SimpleRNN (excluded from pytests)

Test Configuration: Quartus 21.1.0.169.pro

Example usage:

qmodel = Sequential()
qmodel.add(QSimpleRNN(4, input_shape=(5,1), name='rnn',
                      kernel_quantizer=quantized_bits(bits=9, integer=0, symmetric=False, alpha=1.0),
                      recurrent_quantizer=quantized_bits(bits=9, integer=0, symmetric=False, alpha=1.0),
                      bias_quantizer=quantized_bits(bits=9, integer=0, symmetric=False, alpha=1.0),
                      state_quantizer=quantized_bits(bits=9, integer=0, symmetric=False, alpha=1.0),
                      activation='relu'))
qmodel.summary()
qmodel.compile(loss='mse', optimizer='adam')


config = hls4ml.utils.config_from_keras_model(qmodel,
                                           default_precision="ap_fixed<9,1>",
                                           granularity='name')

hls_model = hls4ml.converters.convert_from_keras_model(qmodel,
                                           hls_config=config,
                                           output_dir='model_qrnn/hls4ml_prj',
                                           part='AGFB014R24A2E2VR0',
                                           backend='Quartus')
hls_model.compile()

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

Merge request reports