Speed up Keras profiling
Created by: AdrianAlan
Description
Changes to get_ymodel_keras
speeding up (Q)Keras profiling. Instead of compiling a new model on every step the output can be an array as suggested in Keras FAQ.
Type of change
-
Other: (non-breaking enhancement)
Tests
- I added new tests in
test_trace
. - I tried it with ResNet and I didn't find any issues.
- I have tried it on a simple example of LeNet on T4 on
lxplus705
:
def get_model():
model = keras.Sequential()
model.add(layers.Conv2D(filters=6, kernel_size=(3, 3), activation='relu', input_shape=(32,32,1)))
model.add(layers.AveragePooling2D())
model.add(layers.Conv2D(filters=16, kernel_size=(3, 3), activation='relu'))
model.add(layers.AveragePooling2D())
model.add(layers.Flatten())
model.add(layers.Dense(units=120))
model.add(layers.Dense(units=84, activation='linear'))
model.add(layers.Dense(units=10, activation='softmax'))
model.compile()
return model
warmup = np.random.random((1000, 32, 32, 1))
model = get_model()
for _ in range(10):
_ = hls4ml.model.profiling.get_ymodel_keras_old(model, warmup)
new, old = [], []
for _ in range(10):
X = np.random.random((1000, 32, 32, 1))
start = time.time()
_trace_new = hls4ml.model.profiling.get_ymodel_keras(model, X)
end = time.time()
new.append(end-start)
start = time.time()
_trace_old = hls4ml.model.profiling.get_ymodel_keras_old(model, X)
end = time.time()
old.append(end-start)
assert _trace_old.keys() == _trace_new.keys()
for l in _trace_old.keys():
assert np.all(_trace_old[l] == _trace_new[l])
print("New implementation: {}".format(np.mean(new)))
print("Old implementation: {}".format(np.mean(old)))
and I got
New implementation: 0.7009074449539184
Old implementation: 1.334557008743286
Checklist
-
I have read the guidelines for contributing. -
I have commented my code, particularly in hard-to-understand areas. -
I have made corresponding changes to the documentation. -
My changes generate no new warnings. -
I have installed and run pre-commit
on the files I edited or added. -
I have added tests that prove my fix is effective or that my feature works.