-
Notifications
You must be signed in to change notification settings - Fork 125
feat: Add cross_encoder serving and fix text_classification token_type_ids #444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add cross_encoder serving and fix text_classification token_type_ids #444
Conversation
Previously, text_classification set return_token_type_ids: false which broke sentence-pair inputs like query-document pairs used by cross-encoder rerankers. Now token_type_ids are included, making rerankers produce correct scores. Closes elixir-nx#251
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds cross-encoder support and fixes token_type_ids handling for sentence-pair classification tasks. Cross-encoder models like cross-encoder/ms-marco-MiniLM-L-6-v2 require token_type_ids to distinguish query tokens from document tokens, which were previously disabled in text_classification.
Changes:
- Fixed text_classification to include token_type_ids for sentence-pair inputs
- Added new cross_encoder serving with a dedicated API for reranking use cases
- Added comprehensive test coverage for both changes
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| lib/bumblebee/text/text_classification.ex | Removed return_token_type_ids: false configuration and added token_type_ids to compile template |
| lib/bumblebee/text/cross_encoder.ex | New module implementing cross-encoder serving with pair validation and score extraction |
| lib/bumblebee/text.ex | Added cross_encoder function documentation, type specs, and public API delegation |
| test/bumblebee/text/text_classification_test.exs | Added test verifying correct scoring for cross-encoder sentence pairs |
| test/bumblebee/text/cross_encoder_test.exs | New test file with comprehensive coverage for single and batch pair scoring |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
jonatanklosko
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one note about naming and looks good to me!
lib/bumblebee/text.ex
Outdated
| defdelegate cross_encoder(model_info, tokenizer, opts \\ []), | ||
| to: Bumblebee.Text.CrossEncoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: let's rename to "cross encoding" to match all the other functions we already have (embeding, classifiction, etc):
| defdelegate cross_encoder(model_info, tokenizer, opts \\ []), | |
| to: Bumblebee.Text.CrossEncoder | |
| defdelegate cross_encoding(model_info, tokenizer, opts \\ []), | |
| to: Bumblebee.Text.CrossEncoding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
f1cdc18 to
15cb657
Compare
I was experimenting with rerankers for georgeguimaraes/arcana and found that cross-encoder models like
cross-encoder/ms-marco-MiniLM-L-6-v2weren't producing correct scores.The issue:
text_classificationwas settingreturn_token_type_ids: false, which breaks sentence-pair inputs. Cross-encoders need token_type_ids to distinguish query tokens from document tokens. Without them, scores don't match Python's sentence-transformers.Changes:
text_classificationto include token_type_ids (also added it to the compile template)cross_encoderserving with a cleaner API for the reranking use caseThe token_type_ids fix also benefits other sentence-pair tasks like NLI and entailment. If you don't want that change in text_classification, let me know.
Closes #251