Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

Or alternatively, do you want faster training runs (and thus lower training costs)? Then JAX is a good choice for you.


The current SOTA models (GPT4, DALL-E, Sora) were trained on GPUs. The next one (GPT5) will be, too. And the one after that. Besides, only very few people train models that need more than a few hundred H100s at a time, and PyTorch works well at that scale. And when you train large scale stuff the scaling problems are demonstrably surmountable, unlike, say, capacity problems which you will run into if you need a ton of modern TPU quota, because Google itself is pretty compute starved at the moment. Also, gone are the days when TPUs were significantly faster. GPUs have “TPUs” inside them, too, nowadays


No, I am saying, with JAX you train on G.P.U., with a G, and your training runs are >2x faster, so your training costs are 2x lower, which matters whether your training spend is $1k or $100M. You're not interested in that? That's ok, but most people are.


Have you actually tried that or are you just regurgitating Google’s marketing? I’ve seen Jax perform _slower_ than PyTorch on practical GPU workloads on the exact same machine, and not by a little, by something like 20%. I too thought I’d be getting great performance and “saving money”, but reality turned out to be a bit more complicated than that - you have to benchmark and tune.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: