From 99e9be48783b4fcac0fe81bfe47956c43dde1af0 Mon Sep 17 00:00:00 2001 From: Jo Bovy Date: Thu, 4 Jun 2026 13:18:46 -0400 Subject: [PATCH 1/3] P2.3 ellipsoidal: namespace-swap to backend-agnostic compute Co-Authored-By: Claude Opus 4.8 (1M context) --- galpy/potential/EllipsoidalPotential.py | 340 +++++++++++-------- galpy/potential/TriaxialGaussianPotential.py | 10 +- galpy/potential/TwoPowerTriaxialPotential.py | 4 +- tests/test_backend_ellipsoidal.py | 257 ++++++++++++++ 4 files changed, 457 insertions(+), 154 deletions(-) create mode 100644 tests/test_backend_ellipsoidal.py diff --git a/galpy/potential/EllipsoidalPotential.py b/galpy/potential/EllipsoidalPotential.py index 3179b6997..76f24005e 100644 --- a/galpy/potential/EllipsoidalPotential.py +++ b/galpy/potential/EllipsoidalPotential.py @@ -9,11 +9,13 @@ # ############################################################################### import hashlib +import math import numpy from scipy import integrate -from ..util import _rotate_to_arbitrary_vector, conversion, coords +from ..backend import get_namespace +from ..util import _rotate_to_arbitrary_vector, conversion from .Potential import Potential @@ -132,228 +134,263 @@ def _setup_gl(self, glorder): self._glw *= 0.5 return None + def _rotate_to_aligned(self, x, y, z, xp): + """Rotate (x,y,z) into the aligned density frame using ``self._rot``. + + Written as explicit (constant) component products so that it is pure + arithmetic and backend-agnostic (numpy values unchanged from the former + ``numpy.dot(self._rot, numpy.array([x, y, z]))``).""" + rot = self._rot + xp_ = rot[0, 0] * x + rot[0, 1] * y + rot[0, 2] * z + yp_ = rot[1, 0] * x + rot[1, 1] * y + rot[1, 2] * z + zp_ = rot[2, 0] * x + rot[2, 1] * y + rot[2, 2] * z + return xp_, yp_, zp_ + + def _rotate_force_back(self, Fx, Fy, Fz, xp): + """Rotate a force triple from the aligned frame back to the data frame + using ``self._rot.T`` (pure arithmetic, backend-agnostic).""" + rot = self._rot + Fx_ = rot[0, 0] * Fx + rot[1, 0] * Fy + rot[2, 0] * Fz + Fy_ = rot[0, 1] * Fx + rot[1, 1] * Fy + rot[2, 1] * Fz + Fz_ = rot[0, 2] * Fx + rot[1, 2] * Fy + rot[2, 2] * Fz + return Fx_, Fy_, Fz_ + def _evaluate(self, R, z, phi=0.0, t=0.0): + xp = get_namespace(R, z) if not self.isNonAxi: phi = 0.0 - x, y, z = coords.cyl_to_rect(R, phi, z) - if numpy.ndim(R) == 0: - if numpy.isinf(R): - y = 0.0 - else: - y = numpy.where(numpy.isinf(R), 0.0, y) + x, y = R * xp.cos(phi), R * xp.sin(phi) + # When R is infinite, y = R*sin(phi) is inf/nan; force it to 0 (the + # potential along the axis). Shape-polymorphic so it works on scalars + # and arrays in every backend. + y = xp.where(xp.isinf(R), 0.0, y) if self._aligned: - return self._evaluate_xyz(x, y, z) + return self._evaluate_xyz(x, y, z, xp) else: - xyzp = numpy.dot(self._rot, numpy.array([x, y, z])) - return self._evaluate_xyz(xyzp[0], xyzp[1], xyzp[2]) + xp_, yp_, zp_ = self._rotate_to_aligned(x, y, z, xp) + return self._evaluate_xyz(xp_, yp_, zp_, xp) - def _evaluate_xyz(self, x, y, z): + def _evaluate_xyz(self, x, y, z, xp=None): """Evaluation of the potential as a function of (x,y,z) in the aligned coordinate frame""" + if xp is None: + xp = get_namespace(x, y, z) return ( 2.0 - * numpy.pi + * math.pi * self._b * self._c * _potInt( - x, y, z, self._psi, self._b2, self._c2, glx=self._glx, glw=self._glw + x, y, z, self._psi, self._b2, self._c2, xp, glx=self._glx, glw=self._glw ) ) - def _compute_forces(self, x, y, z): - """Compute and cache all three force components in the aligned frame""" - new_hash = hashlib.md5(numpy.array([x, y, z])).hexdigest() - if new_hash != self._force_hash: - if self._aligned: - xp, yp, zp = x, y, z - else: - xyzp = numpy.dot(self._rot, numpy.array([x, y, z])) - xp, yp, zp = xyzp[0], xyzp[1], xyzp[2] - prefac = -4.0 * numpy.pi * self._b * self._c - Fx, Fy, Fz = _forceInt_all( - xp, - yp, - zp, - lambda m: self._mdens(m), - self._b2, - self._c2, - glx=self._glx, - glw=self._glw, - ) - self._cached_Fx = prefac * Fx - self._cached_Fy = prefac * Fy - self._cached_Fz = prefac * Fz + def _compute_forces(self, x, y, z, xp): + """Compute all three force components in the aligned frame. + + Returns ``(Fx, Fy, Fz)`` already rotated back into the data frame. The + shared force integral is computed as a local; for the numpy backend the + result is also stored in a per-instance cache (keyed on the input hash) + so the three public force methods evaluated at the same point reuse a + single quadrature, exactly as before. The traced (jax/torch) path never + touches ``self``-state.""" + if xp is numpy: + new_hash = hashlib.md5(numpy.array([x, y, z])).hexdigest() + if new_hash == self._force_hash: + return self._cached_Fx, self._cached_Fy, self._cached_Fz + if self._aligned: + xa, ya, za = x, y, z + else: + xa, ya, za = self._rotate_to_aligned(x, y, z, xp) + prefac = -4.0 * math.pi * self._b * self._c + Fx, Fy, Fz = _forceInt_all( + xa, + ya, + za, + lambda m: self._mdens(m), + self._b2, + self._c2, + xp, + glx=self._glx, + glw=self._glw, + ) + Fx = prefac * Fx + Fy = prefac * Fy + Fz = prefac * Fz + if not self._aligned: + Fx, Fy, Fz = self._rotate_force_back(Fx, Fy, Fz, xp) + if xp is numpy: + self._cached_Fx, self._cached_Fy, self._cached_Fz = Fx, Fy, Fz self._force_hash = new_hash + return Fx, Fy, Fz def _Rforce(self, R, z, phi=0.0, t=0.0): + xp = get_namespace(R, z) if not self.isNonAxi: phi = 0.0 - x, y, z = coords.cyl_to_rect(R, phi, z) - self._compute_forces(x, y, z) - Fx = self._cached_Fx - Fy = self._cached_Fy - if not self._aligned: - Fxyz = numpy.dot(self._rot.T, numpy.array([Fx, Fy, self._cached_Fz])) - Fx, Fy = Fxyz[0], Fxyz[1] - return numpy.cos(phi) * Fx + numpy.sin(phi) * Fy + x, y = R * xp.cos(phi), R * xp.sin(phi) + Fx, Fy, _ = self._compute_forces(x, y, z, xp) + return xp.cos(phi) * Fx + xp.sin(phi) * Fy def _phitorque(self, R, z, phi=0.0, t=0.0): + xp = get_namespace(R, z) if not self.isNonAxi: phi = 0.0 - x, y, z = coords.cyl_to_rect(R, phi, z) - self._compute_forces(x, y, z) - Fx = self._cached_Fx - Fy = self._cached_Fy - if not self._aligned: - Fxyz = numpy.dot(self._rot.T, numpy.array([Fx, Fy, self._cached_Fz])) - Fx, Fy = Fxyz[0], Fxyz[1] - return R * (-numpy.sin(phi) * Fx + numpy.cos(phi) * Fy) + x, y = R * xp.cos(phi), R * xp.sin(phi) + Fx, Fy, _ = self._compute_forces(x, y, z, xp) + return R * (-xp.sin(phi) * Fx + xp.cos(phi) * Fy) def _zforce(self, R, z, phi=0.0, t=0.0): + xp = get_namespace(R, z) if not self.isNonAxi: phi = 0.0 - x, y, z = coords.cyl_to_rect(R, phi, z) - self._compute_forces(x, y, z) - Fz = self._cached_Fz - if not self._aligned: - Fxyz = numpy.dot( - self._rot.T, - numpy.array([self._cached_Fx, self._cached_Fy, Fz]), - ) - Fz = Fxyz[2] + x, y = R * xp.cos(phi), R * xp.sin(phi) + _, _, Fz = self._compute_forces(x, y, z, xp) return Fz - def _compute_2ndderivs(self, x, y, z): - """Compute and cache all six unique 2nd-derivative components in the - aligned frame""" - new_hash = hashlib.md5(numpy.array([x, y, z])).hexdigest() - if new_hash != self._2ndderiv_hash: - prefac = 4.0 * numpy.pi * self._b * self._c - xx, xy, xz, yy, yz, zz = _2ndDerivInt_all( - x, - y, - z, - lambda m: self._mdens(m), - lambda m: self._mdens_deriv(m), - self._b2, - self._c2, - glx=self._glx, - glw=self._glw, - ) - self._cached_2nd_xx = prefac * xx - self._cached_2nd_xy = prefac * xy - self._cached_2nd_xz = prefac * xz - self._cached_2nd_yy = prefac * yy - self._cached_2nd_yz = prefac * yz - self._cached_2nd_zz = prefac * zz + def _compute_2ndderivs(self, x, y, z, xp): + """Compute all six unique 2nd-derivative components in the aligned frame. + + Returns ``(xx, xy, xz, yy, yz, zz)``. The shared quadrature is computed + as a local; for the numpy backend the result is also cached on the + instance (keyed on the input hash) so methods sharing a point reuse it. + The traced (jax/torch) path never touches ``self``-state. Only used for + the aligned case (the public methods raise for rotated frames).""" + if xp is numpy: + new_hash = hashlib.md5(numpy.array([x, y, z])).hexdigest() + if new_hash == self._2ndderiv_hash: + return ( + self._cached_2nd_xx, + self._cached_2nd_xy, + self._cached_2nd_xz, + self._cached_2nd_yy, + self._cached_2nd_yz, + self._cached_2nd_zz, + ) + prefac = 4.0 * math.pi * self._b * self._c + xx, xy, xz, yy, yz, zz = _2ndDerivInt_all( + x, + y, + z, + lambda m: self._mdens(m), + lambda m: self._mdens_deriv(m), + self._b2, + self._c2, + xp, + glx=self._glx, + glw=self._glw, + ) + xx = prefac * xx + xy = prefac * xy + xz = prefac * xz + yy = prefac * yy + yz = prefac * yz + zz = prefac * zz + if xp is numpy: + self._cached_2nd_xx = xx + self._cached_2nd_xy = xy + self._cached_2nd_xz = xz + self._cached_2nd_yy = yy + self._cached_2nd_yz = yz + self._cached_2nd_zz = zz self._2ndderiv_hash = new_hash + return xx, xy, xz, yy, yz, zz def _R2deriv(self, R, z, phi=0.0, t=0.0): + xp = get_namespace(R, z) if not self.isNonAxi: phi = 0.0 - x, y, z = coords.cyl_to_rect(R, phi, z) + x, y = R * xp.cos(phi), R * xp.sin(phi) if not self._aligned: raise NotImplementedError( "2nd potential derivatives of TwoPowerTriaxialPotential not implemented for rotated coordinated frames (non-trivial zvec and pa); use RotateAndTiltWrapperPotential for this functionality instead" ) - self._compute_2ndderivs(x, y, z) - phixx = self._cached_2nd_xx - phixy = self._cached_2nd_xy - phiyy = self._cached_2nd_yy + phixx, phixy, _, phiyy, _, _ = self._compute_2ndderivs(x, y, z, xp) return ( - numpy.cos(phi) ** 2.0 * phixx - + numpy.sin(phi) ** 2.0 * phiyy - + 2.0 * numpy.cos(phi) * numpy.sin(phi) * phixy + xp.cos(phi) ** 2.0 * phixx + + xp.sin(phi) ** 2.0 * phiyy + + 2.0 * xp.cos(phi) * xp.sin(phi) * phixy ) def _Rzderiv(self, R, z, phi=0.0, t=0.0): + xp = get_namespace(R, z) if not self.isNonAxi: phi = 0.0 - x, y, z = coords.cyl_to_rect(R, phi, z) + x, y = R * xp.cos(phi), R * xp.sin(phi) if not self._aligned: raise NotImplementedError( "2nd potential derivatives of TwoPowerTriaxialPotential not implemented for rotated coordinated frames (non-trivial zvec and pa); use RotateAndTiltWrapperPotential for this functionality instead" ) - self._compute_2ndderivs(x, y, z) - phixz = self._cached_2nd_xz - phiyz = self._cached_2nd_yz - return numpy.cos(phi) * phixz + numpy.sin(phi) * phiyz + _, _, phixz, _, phiyz, _ = self._compute_2ndderivs(x, y, z, xp) + return xp.cos(phi) * phixz + xp.sin(phi) * phiyz def _z2deriv(self, R, z, phi=0.0, t=0.0): + xp = get_namespace(R, z) if not self.isNonAxi: phi = 0.0 - x, y, z = coords.cyl_to_rect(R, phi, z) + x, y = R * xp.cos(phi), R * xp.sin(phi) if not self._aligned: raise NotImplementedError( "2nd potential derivatives of TwoPowerTriaxialPotential not implemented for rotated coordinated frames (non-trivial zvec and pa); use RotateAndTiltWrapperPotential for this functionality instead" ) - self._compute_2ndderivs(x, y, z) - return self._cached_2nd_zz + _, _, _, _, _, phizz = self._compute_2ndderivs(x, y, z, xp) + return phizz def _phi2deriv(self, R, z, phi=0.0, t=0.0): + xp = get_namespace(R, z) if not self.isNonAxi: phi = 0.0 - x, y, z = coords.cyl_to_rect(R, phi, z) + x, y = R * xp.cos(phi), R * xp.sin(phi) if not self._aligned: raise NotImplementedError( "2nd potential derivatives of TwoPowerTriaxialPotential not implemented for rotated coordinated frames (non-trivial zvec and pa); use RotateAndTiltWrapperPotential for this functionality instead" ) - self._compute_forces(x, y, z) - Fx = self._cached_Fx - Fy = self._cached_Fy - self._compute_2ndderivs(x, y, z) - phixx = self._cached_2nd_xx - phixy = self._cached_2nd_xy - phiyy = self._cached_2nd_yy + Fx, Fy, _ = self._compute_forces(x, y, z, xp) + phixx, phixy, _, phiyy, _, _ = self._compute_2ndderivs(x, y, z, xp) return R**2.0 * ( - numpy.sin(phi) ** 2.0 * phixx - + numpy.cos(phi) ** 2.0 * phiyy - - 2.0 * numpy.cos(phi) * numpy.sin(phi) * phixy - ) + R * (numpy.cos(phi) * Fx + numpy.sin(phi) * Fy) + xp.sin(phi) ** 2.0 * phixx + + xp.cos(phi) ** 2.0 * phiyy + - 2.0 * xp.cos(phi) * xp.sin(phi) * phixy + ) + R * (xp.cos(phi) * Fx + xp.sin(phi) * Fy) def _Rphideriv(self, R, z, phi=0.0, t=0.0): + xp = get_namespace(R, z) if not self.isNonAxi: phi = 0.0 - x, y, z = coords.cyl_to_rect(R, phi, z) + x, y = R * xp.cos(phi), R * xp.sin(phi) if not self._aligned: raise NotImplementedError( "2nd potential derivatives of TwoPowerTriaxialPotential not implemented for rotated coordinated frames (non-trivial zvec and pa); use RotateAndTiltWrapperPotential for this functionality instead" ) - self._compute_forces(x, y, z) - Fx = self._cached_Fx - Fy = self._cached_Fy - self._compute_2ndderivs(x, y, z) - phixx = self._cached_2nd_xx - phixy = self._cached_2nd_xy - phiyy = self._cached_2nd_yy + Fx, Fy, _ = self._compute_forces(x, y, z, xp) + phixx, phixy, _, phiyy, _, _ = self._compute_2ndderivs(x, y, z, xp) return ( - R * numpy.cos(phi) * numpy.sin(phi) * (phiyy - phixx) - + R * numpy.cos(2.0 * phi) * phixy - + numpy.sin(phi) * Fx - - numpy.cos(phi) * Fy + R * xp.cos(phi) * xp.sin(phi) * (phiyy - phixx) + + R * xp.cos(2.0 * phi) * phixy + + xp.sin(phi) * Fx + - xp.cos(phi) * Fy ) def _phizderiv(self, R, z, phi=0.0, t=0.0): + xp = get_namespace(R, z) if not self.isNonAxi: phi = 0.0 - x, y, z = coords.cyl_to_rect(R, phi, z) + x, y = R * xp.cos(phi), R * xp.sin(phi) if not self._aligned: raise NotImplementedError( "2nd potential derivatives of TwoPowerTriaxialPotential not implemented for rotated coordinated frames (non-trivial zvec and pa); use RotateAndTiltWrapperPotential for this functionality instead" ) - self._compute_2ndderivs(x, y, z) - phixz = self._cached_2nd_xz - phiyz = self._cached_2nd_yz - return R * (numpy.cos(phi) * phiyz - numpy.sin(phi) * phixz) + _, _, phixz, _, phiyz, _ = self._compute_2ndderivs(x, y, z, xp) + return R * (xp.cos(phi) * phiyz - xp.sin(phi) * phixz) def _dens(self, R, z, phi=0.0, t=0.0): - x, y, z = coords.cyl_to_rect(R, phi, z) + xp = get_namespace(R, z) + x, y = R * xp.cos(phi), R * xp.sin(phi) if self._aligned: - xp, yp, zp = x, y, z + xa, ya, za = x, y, z else: - xyzp = numpy.dot(self._rot, numpy.array([x, y, z])) - xp, yp, zp = xyzp[0], xyzp[1], xyzp[2] - m = numpy.sqrt(xp**2.0 + yp**2.0 / self._b2 + zp**2.0 / self._c2) + xa, ya, za = self._rotate_to_aligned(x, y, z, xp) + m = xp.sqrt(xa**2.0 + ya**2.0 / self._b2 + za**2.0 / self._c2) return self._mdens(m) def _mass(self, R, z=None, t=0.0): @@ -371,7 +408,7 @@ def OmegaP(self): return 0.0 -def _potInt(x, y, z, psi, b2, c2, glx=None, glw=None): +def _potInt(x, y, z, psi, b2, c2, xp=numpy, glx=None, glw=None): r"""int_0^\infty [psi(m)-psi(\infy)]/sqrt([1+tau]x[b^2+tau]x[c^2+tau])dtau""" def integrand(s): @@ -381,6 +418,7 @@ def integrand(s): ) / numpy.sqrt((1.0 + (b2 - 1.0) * s**2.0) * (1.0 + (c2 - 1.0) * s**2.0)) if glx is None: + # scipy.integrate fallback (glorder=None): numpy-only, deferred to Pspecial return integrate.quad(integrand, 0.0, 1.0)[0] result = 0.0 x2 = x**2 @@ -390,8 +428,8 @@ def integrand(s): s = glx[k] t = 1.0 / s**2 - 1.0 denom = numpy.sqrt((1.0 + (b2 - 1.0) * s**2) * (1.0 + (c2 - 1.0) * s**2)) - m = numpy.sqrt(x2 / (1.0 + t) + y2 / (b2 + t) + z2 / (c2 + t)) - result += glw[k] * psi(m) / denom + m = xp.sqrt(x2 / (1.0 + t) + y2 / (b2 + t) + z2 / (c2 + t)) + result = result + glw[k] * psi(m) / denom return result @@ -414,9 +452,10 @@ def integrand(s): return integrate.quad(integrand, 0.0, 1.0)[0] -def _forceInt_all(x, y, z, dens, b2, c2, glx=None, glw=None): +def _forceInt_all(x, y, z, dens, b2, c2, xp=numpy, glx=None, glw=None): """Compute all three force integral components in a single pass.""" if glx is None: + # scipy.integrate fallback (glorder=None): numpy-only, deferred to Pspecial return ( _forceInt(x, y, z, dens, b2, c2, 0), _forceInt(x, y, z, dens, b2, c2, 1), @@ -433,12 +472,12 @@ def _forceInt_all(x, y, z, dens, b2, c2, glx=None, glw=None): inv1t = 1.0 / (1.0 + t) invb2t = 1.0 / (b2 + t) invc2t = 1.0 / (c2 + t) - m = numpy.sqrt(x2 * inv1t + y2 * invb2t + z2 * invc2t) + m = xp.sqrt(x2 * inv1t + y2 * invb2t + z2 * invc2t) denom = numpy.sqrt((1.0 + (b2 - 1.0) * s**2) * (1.0 + (c2 - 1.0) * s**2)) common = w * dens(m) / denom - Fx += common * x * inv1t - Fy += common * y * invb2t - Fz += common * z * invc2t + Fx = Fx + common * x * inv1t + Fy = Fy + common * y * invb2t + Fz = Fz + common * z * invc2t return Fx, Fy, Fz @@ -474,10 +513,11 @@ def integrand(s): return integrate.quad(integrand, 0.0, 1.0)[0] -def _2ndDerivInt_all(x, y, z, dens, densDeriv, b2, c2, glx=None, glw=None): +def _2ndDerivInt_all(x, y, z, dens, densDeriv, b2, c2, xp=numpy, glx=None, glw=None): """Compute all six unique 2nd derivative integrals in a single pass. Returns (xx, xy, xz, yy, yz, zz).""" if glx is None: + # scipy.integrate fallback (glorder=None): numpy-only, deferred to Pspecial return ( _2ndDerivInt(x, y, z, dens, densDeriv, b2, c2, 0, 0), _2ndDerivInt(x, y, z, dens, densDeriv, b2, c2, 0, 1), @@ -497,7 +537,7 @@ def _2ndDerivInt_all(x, y, z, dens, densDeriv, b2, c2, glx=None, glw=None): inv1t = 1.0 / (1.0 + t) invb2t = 1.0 / (b2 + t) invc2t = 1.0 / (c2 + t) - m = numpy.sqrt(x2 * inv1t + y2 * invb2t + z2 * invc2t) + m = xp.sqrt(x2 * inv1t + y2 * invb2t + z2 * invc2t) denom = numpy.sqrt((1.0 + (b2 - 1.0) * s**2) * (1.0 + (c2 - 1.0) * s**2)) w_over_denom = w / denom dens_val = dens(m) @@ -508,10 +548,10 @@ def _2ndDerivInt_all(x, y, z, dens, densDeriv, b2, c2, glx=None, glw=None): dd_xi = w_over_denom * dderiv_over_m * xi dd_yi = w_over_denom * dderiv_over_m * yi dd_zi = w_over_denom * dderiv_over_m * zi - xx += dd_xi * xi + w_over_denom * dens_val * inv1t - xy += dd_xi * yi - xz += dd_xi * zi - yy += dd_yi * yi + w_over_denom * dens_val * invb2t - yz += dd_yi * zi - zz += dd_zi * zi + w_over_denom * dens_val * invc2t + xx = xx + dd_xi * xi + w_over_denom * dens_val * inv1t + xy = xy + dd_xi * yi + xz = xz + dd_xi * zi + yy = yy + dd_yi * yi + w_over_denom * dens_val * invb2t + yz = yz + dd_yi * zi + zz = zz + dd_zi * zi + w_over_denom * dens_val * invc2t return xx, xy, xz, yy, yz, zz diff --git a/galpy/potential/TriaxialGaussianPotential.py b/galpy/potential/TriaxialGaussianPotential.py index e31e0d0e8..e41fa1e37 100644 --- a/galpy/potential/TriaxialGaussianPotential.py +++ b/galpy/potential/TriaxialGaussianPotential.py @@ -10,6 +10,7 @@ import numpy from scipy import special +from ..backend import get_namespace from ..util import conversion from .EllipsoidalPotential import EllipsoidalPotential @@ -97,15 +98,18 @@ def __init__( def _psi(self, m): """\\psi(m) = -\\int_m^\\infty d m^2 \rho(m^2)""" - return -self._twosigma2 * numpy.exp(-(m**2.0) / self._twosigma2) + xp = get_namespace(m) + return -self._twosigma2 * xp.exp(-(m**2.0) / self._twosigma2) def _mdens(self, m): """Density as a function of m""" - return numpy.exp(-(m**2) / self._twosigma2) + xp = get_namespace(m) + return xp.exp(-(m**2) / self._twosigma2) def _mdens_deriv(self, m): """Derivative of the density as a function of m""" - return -2.0 * m * numpy.exp(-(m**2) / self._twosigma2) / self._twosigma2 + xp = get_namespace(m) + return -2.0 * m * xp.exp(-(m**2) / self._twosigma2) / self._twosigma2 def _mass(self, R, z=None, t=0.0): if not z is None: diff --git a/galpy/potential/TwoPowerTriaxialPotential.py b/galpy/potential/TwoPowerTriaxialPotential.py index 5646be327..3e2882fa4 100644 --- a/galpy/potential/TwoPowerTriaxialPotential.py +++ b/galpy/potential/TwoPowerTriaxialPotential.py @@ -13,6 +13,7 @@ import numpy from scipy import special +from ..backend import get_namespace from ..util import conversion from .EllipsoidalPotential import EllipsoidalPotential @@ -395,10 +396,11 @@ def __init__( def _psi(self, m): """\\psi(m) = -\\int_m^\\infty d m^2 \rho(m^2)""" + xp = get_namespace(m) return ( 2.0 * self.a2 - * (1.0 / (1.0 + m / self.a) + numpy.log(1.0 / (1.0 + self.a / m))) + * (1.0 / (1.0 + m / self.a) + xp.log(1.0 / (1.0 + self.a / m))) ) def _mdens(self, m): diff --git a/tests/test_backend_ellipsoidal.py b/tests/test_backend_ellipsoidal.py new file mode 100644 index 000000000..fff86a502 --- /dev/null +++ b/tests/test_backend_ellipsoidal.py @@ -0,0 +1,257 @@ +############################################################################### +# test_backend_ellipsoidal.py: per-family backend tests for the ellipsoidal / +# triaxial potentials (EllipsoidalPotential base + subclasses). +# +# Proves, for the migrated compute methods of each potential: +# 1. numpy / jax / torch produce identical values (rtol=1e-12, atol=1e-14), +# 2. autodiff (jax.grad / torch.autograd) on a migrated scalar potential +# (_evaluate or, where _evaluate is deferred, _Rforce) matches central +# finite differences (rtol=1e-5), +# 3. the per-instance numpy quadrature cache is never touched by the traced +# (jax/torch) path (so autodiff is pure and reentrant). +# +# Scope notes (see the module docstrings / the PR's deferred list): +# * The Gauss-Legendre quadrature path (glorder set, the default) is migrated; +# the scipy.integrate fallback (glorder=None) is deferred (Pspecial PR). +# * TwoPowerTriaxialPotential._evaluate uses scipy.special.hyp2f1 in _psi and +# is therefore NOT migrated (its forces/2nd-derivs/dens, which only use the +# pure-arithmetic _mdens, ARE migrated). _mass everywhere is out of scope. +# +# Backends that are not installed self-skip, so this is green on numpy alone. +############################################################################### +import numpy +import pytest + +from galpy.potential import ( + PerfectEllipsoidPotential, + PowerTriaxialPotential, + TriaxialGaussianPotential, + TriaxialHernquistPotential, + TriaxialJaffePotential, + TriaxialNFWPotential, + TwoPowerTriaxialPotential, +) + +# This module manages backends explicitly (parametrizes over them), so it is +# exempt from the global --backend force fixture. +pytestmark = pytest.mark.backend_managed + +# Discover available backends +BACKENDS = ["numpy"] +try: + import jax + + jax.config.update("jax_enable_x64", True) + import jax.numpy as jnp + + BACKENDS.append("jax") +except ImportError: # pragma: no cover + jax = None +try: + import torch + + BACKENDS.append("torch") +except ImportError: # pragma: no cover + torch = None + +AD_BACKENDS = [b for b in BACKENDS if b != "numpy"] + +# Compute methods migrated for every potential in this family (forces, 2nd +# derivatives, density). All use only the pure-arithmetic _mdens/_mdens_deriv. +COMMON_METHODS = [ + "_Rforce", + "_zforce", + "_phitorque", + "_R2deriv", + "_z2deriv", + "_Rzderiv", + "_phi2deriv", + "_Rphideriv", + "_phizderiv", + "_dens", +] +# _evaluate is migrated for every subclass whose _psi is namespace-clean; it is +# deferred for TwoPowerTriaxialPotential (psi uses scipy.special.hyp2f1). +EVAL = ["_evaluate"] + +# (name, instance, methods); aligned (default) instances exercise the migrated +# Gauss-Legendre quadrature path. +_CASES = [ + ("Perfect", PerfectEllipsoidPotential(amp=1.3, a=1.5, b=0.9, c=0.7), EVAL), + ("Gauss", TriaxialGaussianPotential(amp=1.3, sigma=1.5, b=0.9, c=0.7), EVAL), + ("Power", PowerTriaxialPotential(amp=1.3, alpha=1.2, b=0.9, c=0.7), EVAL), + ("Hernquist", TriaxialHernquistPotential(amp=1.3, a=1.5, b=0.9, c=0.7), EVAL), + ("Jaffe", TriaxialJaffePotential(amp=1.3, a=1.5, b=0.9, c=0.7), EVAL), + ("NFW", TriaxialNFWPotential(amp=1.3, a=1.5, b=0.9, c=0.7), EVAL), + # TwoPower: _evaluate deferred (hyp2f1), but forces/2nd-derivs/dens migrated. + ( + "TwoPower", + TwoPowerTriaxialPotential(amp=1.3, a=1.5, alpha=1.5, beta=3.5, b=0.9, c=0.7), + [], + ), +] + +# Flatten to (case_id, pot, method) for value-parity parametrization. +_VALUE_PARAMS = [] +for _name, _pot, _eval in _CASES: + for _m in _eval + COMMON_METHODS: + _VALUE_PARAMS.append(pytest.param(_pot, _m, id=f"{_name}-{_m}")) + +# Potentials whose _evaluate is migrated (used for the autodiff check). +_EVAL_POTS = [pytest.param(pot, id=name) for (name, pot, ev) in _CASES if ev == EVAL] +# Every potential supports a migrated _Rforce, used for the autodiff check on +# potentials whose _evaluate is deferred. +_ALL_POTS = [pytest.param(pot, id=name) for (name, pot, _ev) in _CASES] + +_RS = [0.5, 1.0, 2.0] +_ZS = [0.1, 0.2, 0.3] +_PHIS = [0.3, 0.6, 0.9] + + +def _asarray(backend_name, x, requires_grad=False): + if backend_name == "numpy": + return numpy.asarray(x, dtype=float) + if backend_name == "jax": + return jnp.asarray(x, dtype=jnp.float64) + if backend_name == "torch": + return torch.tensor(x, dtype=torch.float64, requires_grad=requires_grad) + + +def _tonumpy(x): + if torch is not None and isinstance(x, torch.Tensor): + return x.detach().numpy() + return numpy.asarray(x) + + +@pytest.mark.parametrize("pot,method", _VALUE_PARAMS) +@pytest.mark.parametrize("backend_name", BACKENDS) +def test_value_parity(backend_name, pot, method): + # Reference is always numpy. + ref = numpy.asarray( + getattr(pot, method)( + numpy.asarray(_RS), numpy.asarray(_ZS), numpy.asarray(_PHIS) + ) + ) + got = _tonumpy( + getattr(pot, method)( + _asarray(backend_name, _RS), + _asarray(backend_name, _ZS), + _asarray(backend_name, _PHIS), + ) + ) + numpy.testing.assert_allclose(got, ref, rtol=1e-12, atol=1e-14) + + +@pytest.mark.parametrize("pot", _EVAL_POTS) +@pytest.mark.parametrize("backend_name", AD_BACKENDS) +def test_grad_evaluate_vs_finite_difference(backend_name, pot): + R0, z0, phi0 = 1.3, 0.4, 0.5 + eps = 1e-6 + + def phi_np(R): + return float( + pot._evaluate(numpy.asarray(R), numpy.asarray(z0), numpy.asarray(phi0)) + ) + + fd = (phi_np(R0 + eps) - phi_np(R0 - eps)) / (2 * eps) + if backend_name == "jax": + ad = float( + jax.grad(lambda R: pot._evaluate(R, jnp.asarray(z0), jnp.asarray(phi0)))( + jnp.asarray(R0) + ) + ) + else: + R = torch.tensor(R0, dtype=torch.float64, requires_grad=True) + y = pot._evaluate( + R, + torch.tensor(z0, dtype=torch.float64), + torch.tensor(phi0, dtype=torch.float64), + ) + y.backward() + ad = float(R.grad) + numpy.testing.assert_allclose(ad, fd, rtol=1e-5) + + +@pytest.mark.parametrize("pot", _ALL_POTS) +@pytest.mark.parametrize("backend_name", AD_BACKENDS) +def test_grad_rforce_vs_finite_difference(backend_name, pot): + # _Rforce is migrated for every potential (depends only on _mdens); check its + # gradient wrt R against central finite differences. This also covers the + # autodiff path for potentials whose _evaluate is deferred (TwoPower). + R0, z0, phi0 = 1.3, 0.4, 0.5 + eps = 1e-6 + + def f_np(R): + return float( + pot._Rforce(numpy.asarray(R), numpy.asarray(z0), numpy.asarray(phi0)) + ) + + fd = (f_np(R0 + eps) - f_np(R0 - eps)) / (2 * eps) + if backend_name == "jax": + ad = float( + jax.grad(lambda R: pot._Rforce(R, jnp.asarray(z0), jnp.asarray(phi0)))( + jnp.asarray(R0) + ) + ) + else: + R = torch.tensor(R0, dtype=torch.float64, requires_grad=True) + y = pot._Rforce( + R, + torch.tensor(z0, dtype=torch.float64), + torch.tensor(phi0, dtype=torch.float64), + ) + y.backward() + ad = float(R.grad) + numpy.testing.assert_allclose(ad, fd, rtol=1e-5) + + +@pytest.mark.parametrize("backend_name", AD_BACKENDS) +def test_traced_path_does_not_touch_cache(backend_name): + # The refactored quadrature cache (_force_hash / _cached_F*) must be written + # ONLY by the numpy path; the traced path must leave self-state untouched so + # autodiff is pure and reentrant. + pot = PerfectEllipsoidPotential(amp=1.3, a=1.5, b=0.9, c=0.7) + assert pot._force_hash is None + R = _asarray(backend_name, _RS) + z = _asarray(backend_name, _ZS) + phi = _asarray(backend_name, _PHIS) + pot._Rforce(R, z, phi) + pot._R2deriv(R, z, phi) + assert pot._force_hash is None + assert pot._2ndderiv_hash is None + + +@pytest.mark.parametrize("backend_name", AD_BACKENDS) +def test_numpy_cache_unaffected_by_traced_call(backend_name): + # A numpy evaluation fills the cache; a subsequent traced call at the same + # point must not corrupt it, and a later numpy call must reuse the (correct) + # cached force. + pot = PerfectEllipsoidPotential(amp=1.3, a=1.5, b=0.9, c=0.7) + Rn, zn, phin = numpy.asarray(_RS), numpy.asarray(_ZS), numpy.asarray(_PHIS) + ref_z = numpy.asarray(pot._zforce(Rn, zn, phin)) # fills cache via zforce + h = pot._force_hash + # Traced call at the same point. + pot._Rforce( + _asarray(backend_name, _RS), + _asarray(backend_name, _ZS), + _asarray(backend_name, _PHIS), + ) + assert pot._force_hash == h # numpy cache untouched by traced call + got_z = numpy.asarray(pot._zforce(Rn, zn, phin)) # reuse cache + numpy.testing.assert_allclose(got_z, ref_z, rtol=1e-14, atol=0.0) + + +def test_evaluate_xyz_namespace_fallback(): + # _evaluate_xyz infers the backend from its (x,y,z) arguments when called + # without an explicit ``xp`` (the public _evaluate always passes one, so this + # exercises the defensive get_namespace fallback). It must match _evaluate. + pot = PerfectEllipsoidPotential(amp=1.3, a=1.5, b=0.9, c=0.7) + R, z, phi = 0.7, 0.3, 0.0 # aligned, axisymmetric instance: x=R, y=0, z=z + x = numpy.asarray(R) + y = numpy.asarray(0.0) + zz = numpy.asarray(z) + got = numpy.asarray(pot._evaluate_xyz(x, y, zz)) # no xp -> get_namespace + ref = numpy.asarray( + pot._evaluate(numpy.asarray(R), numpy.asarray(z), numpy.asarray(phi)) + ) + numpy.testing.assert_allclose(got, ref, rtol=1e-14, atol=0.0) From e49241b542ce0a38d992ec25413f9ccc7c722922 Mon Sep 17 00:00:00 2001 From: Jo Bovy Date: Sat, 6 Jun 2026 11:22:10 -0400 Subject: [PATCH 2/3] =?UTF-8?q?P2.3=20ellipsoidal:=20finish=20file=20migra?= =?UTF-8?q?tion=20=E2=80=94=20closed-form=20=5Fmass=20+=20rotated=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Complete the backend-agnostic migration of the ellipsoidal/triaxial family beyond the main compute methods done in the prior commit: - Migrate the CLOSED-FORM _mass methods to the backend namespace: * PerfectEllipsoidPotential._mass: numpy.arctan -> xp.atan, numpy.pi -> math.pi * TriaxialNFWPotential._mass: numpy.log -> xp.log, numpy.pi -> math.pi * TriaxialHernquistPotential / TriaxialJaffePotential._mass: numpy.pi -> math.pi These now run and differentiate under numpy / jax / torch (values bit-identical to before: math.pi == numpy.pi, numpy.atan/log are the same ufuncs). - Leave (and explicitly annotate as Pspecial-blocked) the _mass methods with no backend-agnostic replacement: * TwoPowerTriaxialPotential._mass / _psi / __init__: scipy.special.hyp2f1, gamma * TriaxialGaussianPotential._mass: scipy.special.erf * EllipsoidalPotential._mass (generic, incl. PowerTriaxial): scipy.integrate.quad The Gauss-Legendre compute path was already migrated; the glorder=None scipy.integrate fallback stays deferred. The init-time amp/rotation setup and the numpy-gated quadrature hash cache remain numpy by design. - Tests (tests/test_backend_ellipsoidal.py): * numpy-vs-jax-vs-torch parity + grad-vs-FD for the four migrated closed-form _mass methods. * ROTATED (zvec/pa) value-parity coverage for the rotated compute path (forces / phitorque / dens / migrated _evaluate) across all subclasses, plus rotated _evaluate grad-vs-FD — the rotated path previously had no coverage. All 393 backend-ellipsoidal tests pass under numpy+jax+torch; numpy potential mass/triaxial sanity tests unchanged. Co-authored-by: Claude Opus 4.8 --- galpy/potential/EllipsoidalPotential.py | 3 + galpy/potential/PerfectEllipsoidPotential.py | 8 +- galpy/potential/TriaxialGaussianPotential.py | 2 + galpy/potential/TwoPowerTriaxialPotential.py | 15 +- tests/test_backend_ellipsoidal.py | 140 ++++++++++++++++++- 5 files changed, 159 insertions(+), 9 deletions(-) diff --git a/galpy/potential/EllipsoidalPotential.py b/galpy/potential/EllipsoidalPotential.py index 76f24005e..deff20730 100644 --- a/galpy/potential/EllipsoidalPotential.py +++ b/galpy/potential/EllipsoidalPotential.py @@ -396,6 +396,9 @@ def _dens(self, R, z, phi=0.0, t=0.0): def _mass(self, R, z=None, t=0.0): if not z is None: raise AttributeError # Hack to fall back to general + # Pspecial-blocked: the generic ellipsoidal mass uses an adaptive + # scipy.integrate.quad over the density, which has no backend-agnostic + # (jax/torch) replacement -> numpy only. return ( 4.0 * numpy.pi diff --git a/galpy/potential/PerfectEllipsoidPotential.py b/galpy/potential/PerfectEllipsoidPotential.py index dd4ac59f8..045b393f4 100644 --- a/galpy/potential/PerfectEllipsoidPotential.py +++ b/galpy/potential/PerfectEllipsoidPotential.py @@ -7,8 +7,11 @@ # with m^2 = x^2+y^2/b^2+z^2/c^2 # ############################################################################### +import math + import numpy +from ..backend import get_namespace from ..util import conversion from .EllipsoidalPotential import EllipsoidalPotential @@ -107,11 +110,12 @@ def _mdens_deriv(self, m): def _mass(self, R, z=None, t=0.0): if not z is None: raise AttributeError # Hack to fall back to general + xp = get_namespace(R) return ( 2.0 - * numpy.pi + * math.pi * self._b * self._c / self.a - * (numpy.arctan(R / self.a) - R * self.a / (1.0 + R**2.0)) + * (xp.atan(R / self.a) - R * self.a / (1.0 + R**2.0)) ) diff --git a/galpy/potential/TriaxialGaussianPotential.py b/galpy/potential/TriaxialGaussianPotential.py index e41fa1e37..998775f2a 100644 --- a/galpy/potential/TriaxialGaussianPotential.py +++ b/galpy/potential/TriaxialGaussianPotential.py @@ -114,6 +114,8 @@ def _mdens_deriv(self, m): def _mass(self, R, z=None, t=0.0): if not z is None: raise AttributeError # Hack to fall back to general + # Pspecial-blocked: closed-form mass requires scipy.special.erf, which has + # no backend-agnostic (jax/torch array-API) replacement -> numpy only. return ( numpy.pi * self._b diff --git a/galpy/potential/TwoPowerTriaxialPotential.py b/galpy/potential/TwoPowerTriaxialPotential.py index 3e2882fa4..5ca519323 100644 --- a/galpy/potential/TwoPowerTriaxialPotential.py +++ b/galpy/potential/TwoPowerTriaxialPotential.py @@ -10,6 +10,8 @@ # # m^2 = x^2 + y^2/b^2 + z^2/c^2 ############################################################################### +import math + import numpy from scipy import special @@ -174,6 +176,8 @@ def _mdens_deriv(self, m): def _mass(self, R, z=None, t=0.0): if not z is None: raise AttributeError # Hack to fall back to general + # Pspecial-blocked: closed-form mass requires scipy.special.hyp2f1, which + # has no backend-agnostic (jax/torch array-API) replacement -> numpy only. return ( 4.0 * numpy.pi @@ -295,7 +299,7 @@ def _mass(self, R, z=None, t=0.0): raise AttributeError # Hack to fall back to general return ( 4.0 - * numpy.pi + * math.pi * self.a4 / self.a / (1.0 + self.a / R) ** 2.0 @@ -414,9 +418,7 @@ def _mdens_deriv(self, m): def _mass(self, R, z=None, t=0.0): if not z is None: raise AttributeError # Hack to fall back to general - return ( - 4.0 * numpy.pi * self.a * self.a2 / (1.0 + self.a / R) * self._b * self._c - ) + return 4.0 * math.pi * self.a * self.a2 / (1.0 + self.a / R) * self._b * self._c class TriaxialNFWPotential(EllipsoidalPotential): @@ -559,11 +561,12 @@ def _mdens_deriv(self, m): def _mass(self, R, z=None, t=0.0): if not z is None: raise AttributeError # Hack to fall back to general + xp = get_namespace(R) return ( 4.0 - * numpy.pi + * math.pi * self.a3 * self._b * self._c - * (numpy.log(1 + R / self.a) - R / self.a / (1.0 + R / self.a)) + * (xp.log(1 + R / self.a) - R / self.a / (1.0 + R / self.a)) ) diff --git a/tests/test_backend_ellipsoidal.py b/tests/test_backend_ellipsoidal.py index fff86a502..8e799c4a6 100644 --- a/tests/test_backend_ellipsoidal.py +++ b/tests/test_backend_ellipsoidal.py @@ -15,7 +15,15 @@ # the scipy.integrate fallback (glorder=None) is deferred (Pspecial PR). # * TwoPowerTriaxialPotential._evaluate uses scipy.special.hyp2f1 in _psi and # is therefore NOT migrated (its forces/2nd-derivs/dens, which only use the -# pure-arithmetic _mdens, ARE migrated). _mass everywhere is out of scope. +# pure-arithmetic _mdens, ARE migrated). +# * The rotated (zvec/pa) compute path -- _rotate_to_aligned / +# _rotate_force_back applied to forces, density, and the potential -- is +# backend-agnostic and is exercised here with explicit rotated instances. +# * _mass is migrated for the CLOSED-FORM subclasses (PerfectEllipsoid, +# TriaxialHernquist, TriaxialJaffe, TriaxialNFW); it remains Pspecial-blocked +# for TwoPowerTriaxial (scipy.special.hyp2f1), TriaxialGaussian +# (scipy.special.erf), and the generic EllipsoidalPotential / PowerTriaxial +# base (scipy.integrate.quad), so those are not parametrized below. # # Backends that are not installed self-skip, so this is green on numpy alone. ############################################################################### @@ -91,11 +99,79 @@ ), ] +# Rotated (zvec + pa) instances. The rotated compute path (_rotate_to_aligned / +# _rotate_force_back) is backend-agnostic; a prior review flagged it had no +# coverage. Only the forces, density, and (where migrated) potential are defined +# for rotated frames -- the 2nd derivatives raise NotImplementedError -- so the +# rotated cases use a reduced method list. TwoPower's _evaluate stays deferred. +_ROT_KW = dict(zvec=[0.0, 1.0, 1.0], pa=0.3) +_ROT_METHODS = ["_Rforce", "_zforce", "_phitorque", "_dens"] +_ROT_CASES = [ + ( + "Perfect-rot", + PerfectEllipsoidPotential(amp=1.3, a=1.5, b=0.9, c=0.7, **_ROT_KW), + EVAL, + ), + ( + "Gauss-rot", + TriaxialGaussianPotential(amp=1.3, sigma=1.5, b=0.9, c=0.7, **_ROT_KW), + EVAL, + ), + ( + "Power-rot", + PowerTriaxialPotential(amp=1.3, alpha=1.2, b=0.9, c=0.7, **_ROT_KW), + EVAL, + ), + ( + "Hernquist-rot", + TriaxialHernquistPotential(amp=1.3, a=1.5, b=0.9, c=0.7, **_ROT_KW), + EVAL, + ), + ( + "Jaffe-rot", + TriaxialJaffePotential(amp=1.3, a=1.5, b=0.9, c=0.7, **_ROT_KW), + EVAL, + ), + ( + "NFW-rot", + TriaxialNFWPotential(amp=1.3, a=1.5, b=0.9, c=0.7, **_ROT_KW), + EVAL, + ), + ( + "TwoPower-rot", + TwoPowerTriaxialPotential( + amp=1.3, a=1.5, alpha=1.5, beta=3.5, b=0.9, c=0.7, **_ROT_KW + ), + [], + ), +] + +# Potentials whose closed-form _mass is migrated to the backend namespace +# (PerfectEllipsoid: atan; TriaxialNFW: log; Hernquist/Jaffe: pure arithmetic). +# The others keep a scipy.special / scipy.integrate _mass (Pspecial-blocked). +_MASS_POTS = [ + pytest.param(PerfectEllipsoidPotential(amp=1.3, a=1.5, b=0.9, c=0.7), id="Perfect"), + pytest.param( + TriaxialHernquistPotential(amp=1.3, a=1.5, b=0.9, c=0.7), id="Hernquist" + ), + pytest.param(TriaxialJaffePotential(amp=1.3, a=1.5, b=0.9, c=0.7), id="Jaffe"), + pytest.param(TriaxialNFWPotential(amp=1.3, a=1.5, b=0.9, c=0.7), id="NFW"), +] + # Flatten to (case_id, pot, method) for value-parity parametrization. _VALUE_PARAMS = [] for _name, _pot, _eval in _CASES: for _m in _eval + COMMON_METHODS: _VALUE_PARAMS.append(pytest.param(_pot, _m, id=f"{_name}-{_m}")) +# Rotated value-parity params (forces / dens / migrated potential only). +for _name, _pot, _eval in _ROT_CASES: + for _m in _eval + _ROT_METHODS: + _VALUE_PARAMS.append(pytest.param(_pot, _m, id=f"{_name}-{_m}")) + +# Rotated potentials whose _evaluate is migrated (rotated autodiff check). +_ROT_EVAL_POTS = [ + pytest.param(pot, id=name) for (name, pot, ev) in _ROT_CASES if ev == EVAL +] # Potentials whose _evaluate is migrated (used for the autodiff check). _EVAL_POTS = [pytest.param(pot, id=name) for (name, pot, ev) in _CASES if ev == EVAL] @@ -205,6 +281,68 @@ def f_np(R): numpy.testing.assert_allclose(ad, fd, rtol=1e-5) +@pytest.mark.parametrize("pot", _MASS_POTS) +@pytest.mark.parametrize("backend_name", BACKENDS) +def test_mass_value_parity(backend_name, pot): + # The closed-form _mass (via the public mass()) is migrated for these + # subclasses; numpy / jax / torch must agree. + Rs = numpy.asarray([0.5, 1.0, 2.0]) + ref = numpy.asarray(pot.mass(Rs)) + got = _tonumpy(pot.mass(_asarray(backend_name, [0.5, 1.0, 2.0]))) + numpy.testing.assert_allclose(got, ref, rtol=1e-12, atol=1e-14) + + +@pytest.mark.parametrize("pot", _MASS_POTS) +@pytest.mark.parametrize("backend_name", AD_BACKENDS) +def test_mass_grad_vs_finite_difference(backend_name, pot): + # The migrated closed-form _mass is differentiable; check d mass / dR. + R0, eps = 1.7, 1e-6 + + def m_np(R): + return float(pot._mass(numpy.asarray(R))) + + fd = (m_np(R0 + eps) - m_np(R0 - eps)) / (2 * eps) + if backend_name == "jax": + ad = float(jax.grad(lambda R: pot._mass(R))(jnp.asarray(R0))) + else: + R = torch.tensor(R0, dtype=torch.float64, requires_grad=True) + y = pot._mass(R) + y.backward() + ad = float(R.grad) + numpy.testing.assert_allclose(ad, fd, rtol=1e-5) + + +@pytest.mark.parametrize("pot", _ROT_EVAL_POTS) +@pytest.mark.parametrize("backend_name", AD_BACKENDS) +def test_grad_evaluate_rotated_vs_finite_difference(backend_name, pot): + # Autodiff through the rotated (zvec/pa) potential path vs central FD. + R0, z0, phi0 = 1.3, 0.4, 0.5 + eps = 1e-6 + + def phi_np(R): + return float( + pot._evaluate(numpy.asarray(R), numpy.asarray(z0), numpy.asarray(phi0)) + ) + + fd = (phi_np(R0 + eps) - phi_np(R0 - eps)) / (2 * eps) + if backend_name == "jax": + ad = float( + jax.grad(lambda R: pot._evaluate(R, jnp.asarray(z0), jnp.asarray(phi0)))( + jnp.asarray(R0) + ) + ) + else: + R = torch.tensor(R0, dtype=torch.float64, requires_grad=True) + y = pot._evaluate( + R, + torch.tensor(z0, dtype=torch.float64), + torch.tensor(phi0, dtype=torch.float64), + ) + y.backward() + ad = float(R.grad) + numpy.testing.assert_allclose(ad, fd, rtol=1e-5) + + @pytest.mark.parametrize("backend_name", AD_BACKENDS) def test_traced_path_does_not_touch_cache(backend_name): # The refactored quadrature cache (_force_hash / _cached_F*) must be written From ee06712dd48abcccb5ff3517ecd9faf65ebd7223 Mon Sep 17 00:00:00 2001 From: Jo Bovy Date: Sat, 6 Jun 2026 21:45:17 -0400 Subject: [PATCH 3/3] Strengthen ellipsoidal backend gradient tests with exact analytic identities Replace the finite-difference autodiff checks (test_grad_evaluate / test_grad_rforce vs FD, rtol 1e-5) with exact analytic-identity checks at rtol=1e-9. Under galpy's sign conventions autodiff of a lower-order quantity equals the negative of the corresponding analytic higher-order quantity: AD(_evaluate wrt R/z/phi) == -_Rforce / -_zforce / -_phitorque AD(_Rforce wrt R/z/phi) == -_R2deriv / -_Rzderiv / -_Rphideriv AD(_zforce wrt z/phi) == -_z2deriv / -_phizderiv AD(_phitorque wrt phi) == -_phi2deriv These hold to round-off (lower and higher share the same Gauss-Legendre quadrature nodes), so they cross-validate the hand-coded triaxial forces and the phi-dependent Hessian, far more stringently than finite differences. The new parametrized test (potentials x AD_BACKENDS x identity pairs) gates each pair on both methods being namespace-migrated: the six force/2nd-derivative pairs run for every potential (they depend only on the pure-arithmetic _mdens), while the three _evaluate pairs are skipped for TwoPowerTriaxialPotential (its _psi uses scipy.special.hyp2f1 and is Pspecial-deferred). 60 identity tests per AD backend (54 for the six eval-migrated potentials + 6 force/Hessian pairs for TwoPower). The mass-gradient and rotated-frame FD checks are kept (no analytic counterpart). Green on numpy + jax + torch. Co-authored-by: Claude Opus 4.8 --- tests/test_backend_ellipsoidal.py | 155 ++++++++++++++++++------------ 1 file changed, 96 insertions(+), 59 deletions(-) diff --git a/tests/test_backend_ellipsoidal.py b/tests/test_backend_ellipsoidal.py index 8e799c4a6..5f074ebde 100644 --- a/tests/test_backend_ellipsoidal.py +++ b/tests/test_backend_ellipsoidal.py @@ -173,12 +173,6 @@ pytest.param(pot, id=name) for (name, pot, ev) in _ROT_CASES if ev == EVAL ] -# Potentials whose _evaluate is migrated (used for the autodiff check). -_EVAL_POTS = [pytest.param(pot, id=name) for (name, pot, ev) in _CASES if ev == EVAL] -# Every potential supports a migrated _Rforce, used for the autodiff check on -# potentials whose _evaluate is deferred. -_ALL_POTS = [pytest.param(pot, id=name) for (name, pot, _ev) in _CASES] - _RS = [0.5, 1.0, 2.0] _ZS = [0.1, 0.2, 0.3] _PHIS = [0.3, 0.6, 0.9] @@ -218,67 +212,110 @@ def test_value_parity(backend_name, pot, method): numpy.testing.assert_allclose(got, ref, rtol=1e-12, atol=1e-14) -@pytest.mark.parametrize("pot", _EVAL_POTS) -@pytest.mark.parametrize("backend_name", AD_BACKENDS) -def test_grad_evaluate_vs_finite_difference(backend_name, pot): - R0, z0, phi0 = 1.3, 0.4, 0.5 - eps = 1e-6 +# --- exact analytic-identity gradient checks ---------------------------------- +# galpy sign conventions: Rforce=-dPhi/dR, zforce=-dPhi/dz, phitorque=-dPhi/dphi; +# R2deriv=d2Phi/dR2, etc. Under autodiff this gives, for each (lower, var, higher) +# triple below, AD(lower wrt var) == -higher EXACTLY (not just to FD precision): +# +# AD(_evaluate wrt R) == -_Rforce AD(_Rforce wrt R) == -_R2deriv +# AD(_evaluate wrt z) == -_zforce AD(_Rforce wrt z) == -_Rzderiv +# AD(_evaluate wrt phi) == -_phitorque AD(_Rforce wrt phi) == -_Rphideriv +# AD(_zforce wrt z) == -_z2deriv AD(_zforce wrt phi) == -_phizderiv +# AD(_phitorque wrt phi) == -_phi2deriv +# +# This cross-validates the hand-coded analytic forces and the (phi-dependent) +# triaxial Hessian against autodiff, far more stringently than finite differences +# (these identities replace the now-subsumed FD _evaluate / _Rforce checks). +# Variable name -> positional index into (R, z, phi). +_VAR_IDX = {"R": 0, "z": 1, "phi": 2} + +# (lower_method, var, higher_method): AD(lower wrt var) == -higher. +_IDENTITY_PAIRS = [ + ("_evaluate", "R", "_Rforce"), + ("_evaluate", "z", "_zforce"), + ("_evaluate", "phi", "_phitorque"), + ("_Rforce", "R", "_R2deriv"), + ("_Rforce", "z", "_Rzderiv"), + ("_Rforce", "phi", "_Rphideriv"), + ("_zforce", "z", "_z2deriv"), + ("_zforce", "phi", "_phizderiv"), + ("_phitorque", "phi", "_phi2deriv"), +] - def phi_np(R): - return float( - pot._evaluate(numpy.asarray(R), numpy.asarray(z0), numpy.asarray(phi0)) - ) +# Off-axis, off-plane, non-zero-phi smooth point so every derivative (including +# the phi-direction ones, which vanish on axis) is exercised and nonzero. +_R0, _Z0, _PHI0 = 1.3, 0.4, 0.5 - fd = (phi_np(R0 + eps) - phi_np(R0 - eps)) / (2 * eps) + +def _ad_grad(backend_name, method, var, point): + """AD gradient of ``method`` (a scalar-returning bound potential method) with + respect to one of (R, z, phi) at ``point=(R0, z0, phi0)``, as a python float. + + Mirrors the canonical jax.grad / torch.autograd pattern: a fresh leaf tensor + per backward, scalar output.""" + idx = _VAR_IDX[var] if backend_name == "jax": - ad = float( - jax.grad(lambda R: pot._evaluate(R, jnp.asarray(z0), jnp.asarray(phi0)))( - jnp.asarray(R0) + + def f(v): + args = [jnp.asarray(point[0]), jnp.asarray(point[1]), jnp.asarray(point[2])] + args[idx] = v + return method(*args) + + return float(jax.grad(f)(jnp.asarray(point[idx]))) + # torch: a fresh leaf tensor that requires grad for the chosen variable. + args = [ + torch.tensor(point[0], dtype=torch.float64), + torch.tensor(point[1], dtype=torch.float64), + torch.tensor(point[2], dtype=torch.float64), + ] + leaf = torch.tensor(point[idx], dtype=torch.float64, requires_grad=True) + args[idx] = leaf + method(*args).backward() + return float(leaf.grad) + + +def _method_migrated(name, eval_migrated): + """Whether ```` is namespace-migrated (callable on a traced backend). + + Forces, 2nd derivatives, and density depend only on the pure-arithmetic + _mdens/_mdens_deriv and are migrated for every potential here; _evaluate is + deferred for potentials whose _psi uses scipy.special (TwoPower).""" + if name == "_evaluate": + return eval_migrated + return name in COMMON_METHODS + + +# Build (pot, lower, var, higher) params for every identity pair whose BOTH +# methods are migrated for that (aligned) potential. TwoPower keeps the six +# force/2nd-deriv pairs but drops the three _evaluate pairs (psi uses hyp2f1). +_IDENTITY_PARAMS = [] +for _name, _pot, _eval in _CASES: + _eval_migrated = _eval == EVAL + for _lower, _var, _higher in _IDENTITY_PAIRS: + if _method_migrated(_lower, _eval_migrated) and _method_migrated( + _higher, _eval_migrated + ): + _IDENTITY_PARAMS.append( + pytest.param( + _pot, _lower, _var, _higher, id=f"{_name}-{_lower}-d{_var}" + ) ) - ) - else: - R = torch.tensor(R0, dtype=torch.float64, requires_grad=True) - y = pot._evaluate( - R, - torch.tensor(z0, dtype=torch.float64), - torch.tensor(phi0, dtype=torch.float64), - ) - y.backward() - ad = float(R.grad) - numpy.testing.assert_allclose(ad, fd, rtol=1e-5) -@pytest.mark.parametrize("pot", _ALL_POTS) +@pytest.mark.parametrize("pot,lower,var,higher", _IDENTITY_PARAMS) @pytest.mark.parametrize("backend_name", AD_BACKENDS) -def test_grad_rforce_vs_finite_difference(backend_name, pot): - # _Rforce is migrated for every potential (depends only on _mdens); check its - # gradient wrt R against central finite differences. This also covers the - # autodiff path for potentials whose _evaluate is deferred (TwoPower). - R0, z0, phi0 = 1.3, 0.4, 0.5 - eps = 1e-6 - - def f_np(R): - return float( - pot._Rforce(numpy.asarray(R), numpy.asarray(z0), numpy.asarray(phi0)) +def test_force_and_hessian_identities(backend_name, pot, lower, var, higher): + # AD(lower wrt var) == -higher, exactly (rtol=1e-9). Both methods share the + # same Gauss-Legendre quadrature nodes, so autodiff of the lower method + # reproduces the analytic higher method to round-off, not just FD precision. + point = (_R0, _Z0, _PHI0) + ad = _ad_grad(backend_name, getattr(pot, lower), var, point) + ref = -float( + getattr(pot, higher)( + numpy.asarray(_R0), numpy.asarray(_Z0), numpy.asarray(_PHI0) ) - - fd = (f_np(R0 + eps) - f_np(R0 - eps)) / (2 * eps) - if backend_name == "jax": - ad = float( - jax.grad(lambda R: pot._Rforce(R, jnp.asarray(z0), jnp.asarray(phi0)))( - jnp.asarray(R0) - ) - ) - else: - R = torch.tensor(R0, dtype=torch.float64, requires_grad=True) - y = pot._Rforce( - R, - torch.tensor(z0, dtype=torch.float64), - torch.tensor(phi0, dtype=torch.float64), - ) - y.backward() - ad = float(R.grad) - numpy.testing.assert_allclose(ad, fd, rtol=1e-5) + ) + numpy.testing.assert_allclose(ad, ref, rtol=1e-9) @pytest.mark.parametrize("pot", _MASS_POTS)