TorchTPU: Running PyTorch Natively on TPUs at Google Scale

· ai systems coding · Source ↗

TLDR

  • Google’s new TorchTPU stack lets developers run existing PyTorch workloads on TPUs by changing a single initialization line, using XLA and a PrivateUse1 integration.

Key Takeaways

  • Three eager execution modes: Debug (synchronous, for debugging), Strict (async, mirrors standard PyTorch), and Fused Eager (auto-fuses ops on the fly for 50-100%+ speedup over Strict with zero user setup).
  • Uses PyTorch’s PrivateUse1 interface at a deep level – no subclasses, no wrapper tensors – so existing distributed APIs (DDP, FSDPv2, DTensor) work unchanged.
  • XLA replaces Torch Inductor as the compiler backend; PyTorch FX graphs translate to StableHLO IR, enabling TPU-optimized binaries with full ICI-aware collective overlap.
  • MPMD support is a direct fix over PyTorch/XLA: divergent per-rank code (e.g. rank-0 logging) is handled without forcing pure SPMD, removing a major porting burden.
  • Hardware note: current TPUs peak at attention head dims of 128 or 256, not the common 64 hardcoded in many model configs – a silent efficiency leak for new adopters.

Hacker News Comment Review

  • The PyTorch/XLA predecessor had a poor reputation in practice: commenters report silent 8-hour training hangs and undocumented behavior, giving context for why a ground-up rebuild was necessary.
  • Skepticism centers on whether the “change one line” promise holds at 100k-chip scale – seen as aspirational marketing until validated in production, not a confirmed result.
  • Architecture question raised but not answered in the post: whether TorchTPU is a fork or a first-class PyTorch backend like MPS; the PrivateUse1 path suggests the latter but Google hasn’t clarified the upstreaming plan.

Notable Comments

  • @sergiopreira: frames TorchTPU as Google acknowledging PyTorch/XLA failed – “Its hard to run production ML on a toolchain engineers can’t trust.”
  • @in-silico: first-hand account of PyTorch/XLA silent hangs after 8 hours; published a workaround training pipeline at github.com/aklein4/easy-torch-tpu.
  • @yu3zhou4: confirms PrivateUse1 is a real plugin slot used independently for WebGPU backends, not a hack.

Original | Discuss on HN