Skip to content

Conversation

@AkashKarnatak
Copy link

I came across this code snippet while reading the pipeline code:

prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
    [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)

It took me some time to understand what it was doing, and once I did, I noticed that it can be simplified to the following:

for i, seq_len in enumerate(seq_lens):
    prompt_embeds[i, seq_len:] = 0

This achieves the same result, while being more readable and more efficient (in-place operation, fewer allocations and copies).

I realize the performance gain may not be significant in most cases, but the simpler approach improves clarity and reduces unnecessary tensor operations. There are likely other similar instances in the pipelines that could benefit from the same refactor.

I'd be interested to hear your thoughts on this change and whether it would be worth applying more broadly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant