Text classification with embeddings

I sometimes see people using an LLM to do basic text classification. For example, determining if something has positive or negative sentiment or classifying the type of user feedback, etc.

While this can be a powerful approach that is easy to implement it may have the drawbacks of being slow, expensive and possibly hard to control or explain. For example, it is pretty straight forward to submit a prompt that says "Tell me if the text below represents a complaint or a suggestion ...". The LLM may understand exactly what you mean by that concept but it may also have a slightly different definition that can take some time to reshape. If that is the case, you may want to continue working on that approach, but I suggest you consider a quick experiment using embeddings and n-nearest neighbors as a baseline.

An embedding can be thought of as a point in a, couple hundred to thousand, multi-dimensional space. How they are created is beyond this piece but there are functions that take text and return an that vector. The embedding vector of similar concepts end up being near each other in this embedding space. There many different kinds of embeddings functions available. Some you can run on your own, see Hugging Face's sentence transformers, and some, such as OpenAI's ada-002, are available through an easy to use API.

The idea behind the n-nearest neighbor approach is to take some text for which you have labels appropriate to your use case, your training data, create their embeddings and save them for later use. The more examples you have the better but you may find you don't need that many. Then when classifying a new piece of text, embed that text and then find the n (say 5 or 9) closest known embeddings. Inspect the labels of those and choose the (weighted) most popular one to assign to the incoming text.

Finding the closest one is a simple cosine or dot product calculation and you may even be able to do it in your database. For example Postgres has pgvector that will do that for you.

Odds are that this easy to implement, fast, low-resource approach will well work and can even be updated on the fly (just create new labeled embeddings whenever you get them). Even if you can improve performance with a fancier technique this is likely to be a good baseline to compare other approaches against.

It has worked well for me on several projects. Let me know how it works for you or if you’d like to see some code.

Want to get notified of new articles and insights?