Describe the bug
Dear Nvidia experts,
When attention mask is used, these three lines will cause very significant slow down since they are looping over all the items in the batch.
Performance bottleneck in TransformerLayer and Attention.
function get_indices
assertion the type of attention mask
similarly another assertion the type of attention mask
See below for profiling trace:
As you can see from the largest bar from the bottom rows, operations taking most of the time in each transformer is now transformer_engine/pytorch/attention/dot_product_attention/utils.py(1518): get_indices, Similarly the other two assertion is also cause significant slow down.
Steps/Code to reproduce bug
This problem should manifest in any profiling run with data that have attention mask in the input of forward()
Expected behavior
Multiple times of slow down when calling forward of TransformerLayer.