Skip to content

Conversation

@georgeguimaraes
Copy link
Contributor

I was experimenting with rerankers for georgeguimaraes/arcana and found that cross-encoder models like cross-encoder/ms-marco-MiniLM-L-6-v2 weren't producing correct scores.

The issue: text_classification was setting return_token_type_ids: false, which breaks sentence-pair inputs. Cross-encoders need token_type_ids to distinguish query tokens from document tokens. Without them, scores don't match Python's sentence-transformers.

Changes:

  • Fixed text_classification to include token_type_ids (also added it to the compile template)
  • Added a new cross_encoder serving with a cleaner API for the reranking use case
serving = Bumblebee.Text.cross_encoder(model_info, tokenizer)
Nx.Serving.run(serving, {"query", "document"})
#=> %{score: 8.76}

The token_type_ids fix also benefits other sentence-pair tasks like NLI and entailment. If you don't want that change in text_classification, let me know.

Closes #251

Previously, text_classification set return_token_type_ids: false which broke
sentence-pair inputs like query-document pairs used by cross-encoder rerankers.
Now token_type_ids are included, making rerankers produce correct scores.

Closes elixir-nx#251
Copilot AI review requested due to automatic review settings January 11, 2026 13:51
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds cross-encoder support and fixes token_type_ids handling for sentence-pair classification tasks. Cross-encoder models like cross-encoder/ms-marco-MiniLM-L-6-v2 require token_type_ids to distinguish query tokens from document tokens, which were previously disabled in text_classification.

Changes:

  • Fixed text_classification to include token_type_ids for sentence-pair inputs
  • Added new cross_encoder serving with a dedicated API for reranking use cases
  • Added comprehensive test coverage for both changes

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
lib/bumblebee/text/text_classification.ex Removed return_token_type_ids: false configuration and added token_type_ids to compile template
lib/bumblebee/text/cross_encoder.ex New module implementing cross-encoder serving with pair validation and score extraction
lib/bumblebee/text.ex Added cross_encoder function documentation, type specs, and public API delegation
test/bumblebee/text/text_classification_test.exs Added test verifying correct scoring for cross-encoder sentence pairs
test/bumblebee/text/cross_encoder_test.exs New test file with comprehensive coverage for single and batch pair scoring

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one note about naming and looks good to me!

Comment on lines 423 to 424
defdelegate cross_encoder(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.CrossEncoder
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: let's rename to "cross encoding" to match all the other functions we already have (embeding, classifiction, etc):

Suggested change
defdelegate cross_encoder(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.CrossEncoder
defdelegate cross_encoding(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.CrossEncoding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@georgeguimaraes georgeguimaraes force-pushed the fix-text-classification-token-type-ids branch from f1cdc18 to 15cb657 Compare January 12, 2026 14:01
@jonatanklosko jonatanklosko merged commit 281abfc into elixir-nx:main Jan 12, 2026
2 checks passed
@georgeguimaraes georgeguimaraes deleted the fix-text-classification-token-type-ids branch January 12, 2026 14:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cross Encoder support

2 participants