Skip to content

Improved parsing of pytorch models using torch.FX - Clean

Javier Duarte requested to merge github/fork/JanFSchulte/torchFXClean into main

Created by: JanFSchulte

Refreshed version of https://github.com/fastmachinelearning/hls4ml/pull/723 to leave behind messy git history

Current parsing of pytorch models uses a loop of the named_modules of the model (https://github.com/fastmachinelearning/hls4ml/blob/main/hls4ml/converters/pytorch_to_hls.py#L163). This has several disadvantages:

  • Captures only layers that are defined as members of the model class
  • Can't infer the correct order of models
  • Ignores other operations that are part of the forward() method of the model

In this PR, we propose to fix this by first created a graph representation of the model's forward() function using the symbolic tracing functionality of https://pytorch.org/docs/stable/fx.html. Each operation in the forward() is represented by a node in the graph. Nodes can be of these types: image

For example, for this model

class MyModuleConvRelu(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3,3,3)
        
    def forward(self, x):
        y1 = self.conv(x)
        y = torch.relu(y1)
        y = y + y1
        y = torch.relu(y)
        return y

the resulting graph representation is

graph():
    %x : [#users=1] = placeholder[target=x]
    %conv : [#users=2] = call_module[target=conv](args = (%x,), kwargs = {})
    %relu : [#users=1] = call_function[target=torch.relu](args = (%conv,), kwargs = {})
    %add : [#users=1] = call_function[target=operator.add](args = (%relu, %conv), kwargs = {})
    %relu_1 : [#users=1] = call_function[target=torch.relu](args = (%add,), kwargs = {})
    return relu_1

As the nodes in the graph follow the order of operations of the forward() function, we can then simply loop over them and parse each node into one node in the hls4ml model representation. For the parsing of the individual layers, existing code is used where available without significant changes. Functionality for more types of layers is also added by this PR.

The types of layers currently understood by the parser are

  • Linear
  • Softmax
  • Relu
  • LeakyReLU
  • ThresholdedReLU
  • ELU
  • PReLU
  • Sigmoid
  • BatchNorm2d
  • BatchNorm1d'
  • Batch_norm
  • MaxPool1d
  • MaxPool2d
  • AvgPool1d
  • AvgPool2d
  • Add
  • Subtract
  • Multiply
  • Average
  • Maximum
  • Minimum
  • Concatenate
  • Dot
  • Conv1d
  • Conv2d
  • View
  • Dropout
  • Flatten
  • Sequential

This PR also fixes https://github.com/fastmachinelearning/hls4ml/issues/409

Changes are mostly confined to the frontend, but small changes are made to the backend to the templates for pooling layers to add the option that zero-padded entries are included in average pooling operations.

One big difference between pytorch and keras is the data format of the input tensors, which is channels_first by default, instead of the channels_last used by keras. The built-in tools in pytorch to convert a model to channels_last don't work for all dimensions of the input. Therefore the functionality has been added to transpose the inputs within hls4ml so the existing channels_last implementations of layers can be used. By default the inputs are transposed for io_parrallel but not io_stream since we don't have transpose layers for all dimensions in io_stream. The outputs are not transposed by default, but this can be switched on, again only for io_parallel.

Limitations:

  • Many types of layers not supported yet
  • The same functionality is available in pytorch either as torch.nn classes or torch.functional functions in many cases. These have to be parsed differently, which I have implemented only sporadically for the functionals so far.

Type of change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change which adds functionality)

Tests

The new parsing was tested using 5-6 different pytorch model examples from around the web. In addition, I verified that the two example models for pytroch included with hls4ml get parsed successfully. A test for the API was added in the test/pytest folder, in analogy to the test for the keras parser. All tests pass successfully.

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

Merge request reports

Loading