Hmm, 3% market share framework with barely any ecosystem and single vendor accelerators (Jax on TPU) vs a 60% market share framework with insanely rich ecosystem and ability to debug code on your own workstation (PyTorch on GPU)? In my informed opinion most people should use the latter unless they like wasting time on shiny things
JAX is used by almost every large genAI player (Anthropic, Cohere, DeepMind, Midjourney, Character.ai, XAi, Apple, etc.). Its actual market share in foundation models development is something like 80%.
Are there any resources going into detail about why the big players prefer JAX? I've heard this before but have never seen explanations of why/how this happened.
It's all about cost and performance. If you can train a foundation model 2x faster with JAX on the same hardware, you are effectively slashing your training costs by 2x, which is significant for a multi-million dollar training run.
The current SOTA models (GPT4, DALL-E, Sora) were trained on GPUs. The next one (GPT5) will be, too. And the one after that. Besides, only very few people train models that need more than a few hundred H100s at a time, and PyTorch works well at that scale. And when you train large scale stuff the scaling problems are demonstrably surmountable, unlike, say, capacity problems which you will run into if you need a ton of modern TPU quota, because Google itself is pretty compute starved at the moment. Also, gone are the days when TPUs were significantly faster. GPUs have “TPUs” inside them, too, nowadays
No, I am saying, with JAX you train on G.P.U., with a G, and your training runs are >2x faster, so your training costs are 2x lower, which matters whether your training spend is $1k or $100M. You're not interested in that? That's ok, but most people are.
Have you actually tried that or are you just regurgitating Google’s marketing? I’ve seen Jax perform _slower_ than PyTorch on practical GPU workloads on the exact same machine, and not by a little, by something like 20%. I too thought I’d be getting great performance and “saving money”, but reality turned out to be a bit more complicated than that - you have to benchmark and tune.
The parts of your comment that have any truth in them could have been said of PyTorch when it came out. People wasting time on shiny things is how we get better tools.
Nope. When PyTorch came out it was the only option that was easy to use and debug. Your alternative was TF1 which sucked so bad people dropped it like it has syphilis, and Google had to add eager mode in TF2, ruining performance in the process, later. I would know, I was one of those people. It really was a watershed moment in AI research productivity
I used PyTorch since it came out, and at the time it outperformed TF1 on computer vision workloads quite handily, in addition to being much easier to work with. I literally watched the entire departments in one of the major research labs where I worked at the time switch to PyTorch in the span of a couple of months.
Yes, but its performance on GPU leaves much to be desired, and 20 times as much research comes out on PyTorch. Would you rather just build on that or laboriously port and debug the models and their weights, losses, dataset readers, training regimes etc etc?