JAX for collaborative filtering

This notebook shows a basic first iteration of a music recommender system using vanilla jax to infer features from simulated music ratings. User preferences are described according to these learned features.

In the future I want to extend it to include features derived from the content. The recommendations will be based on a combination of collaborative filtering and naive bayes. I also want to explore the use of the broader JAX ecosystem - specifically rlax. With rlax I want to test how I can use reinforcement learning to tweak recommendations live based on recent user activity.