Abordarea Atlas-Learn a ipotezei multiple

URMĂREȘTE-NE
16,065FaniÎmi place
1,142CititoriConectați-vă

Arată codul
#| label: atlas-functions

# ===========================================================================
# PART 1: Quadratic feature helpers
# ---------------------------------------------------------------------------
# These functions implement the d=2 specialization of the general quadratic
# feature map. For general d, phi(xi) would have choose(d+1, 2) components.
# For d=2 it has exactly 3: (xi1^2, xi1*xi2, xi2^2).
# ===========================================================================

# Maps xi in R^2 to the three quadratic monomials used to model surface curvature.
# General d would give choose(d+1,2) monomials; for d=2 this is exactly 3.
quad_features <- function(xi) {
  c(xi(1)^2, xi(1) * xi(2), xi(2)^2)
}

# Jacobian of quad_features w.r.t. xi: a 3x2 matrix (d=2 specialization).
# Row i is the gradient of the i-th monomial:
#   d/dxi (xi1^2)     = (2*xi1,   0   )
#   d/dxi (xi1*xi2)   = (xi2,     xi1 )
#   d/dxi (xi2^2)     = (0,       2*xi2)
quad_jacobian <- function(xi) {
  matrix(
    c(2 * xi(1),   xi(2),         0,
              0,   xi(1),  2 * xi(2)),
    nrow = 3, ncol = 2, byrow = TRUE
  )
}

# ===========================================================================
# PART 2: Core chart operations
# ---------------------------------------------------------------------------
# Each chart stores:
#   mean  : R^3 — chart center (centroid of the cluster)
#   L     : R^{3x2} — orthonormal tangent basis (columns = v1, v2 from SVD)
#   M     : R^{3x1} — unit surface normal (v3 from SVD; scalar normal because D-d=1)
#   K     : R^{1x3} — quadratic curvature coefficients (1 row because D-d=1)
#   ell_A : R^{2x2} — ellipsoidal domain matrix (2x2 because d=2)
# ===========================================================================

# Evaluate the inverse chart map: xi in R^2  ->  point in R^3.
# Formula: x = mean + L*xi + M * (K * phi(xi))
#   - mean + L*xi : linear move along the tangent plane
#   - M*(K*phi)   : quadratic normal correction encoding surface curvature
chart_eval <- function(chart, xi) {
  q <- quad_features(xi)          # 3-vector: (xi1^2, xi1*xi2, xi2^2)
  as.vector(
    chart$mean +
    chart$L %*% xi +              # 3x2 * 2x1 = 3x1 tangent contribution
    chart$M * as.numeric(chart$K %*% q)   # 3x1 * scalar normal correction
  )
}

# Jacobian of the inverse chart map at xi: a 3x2 matrix.
# J(xi) = L + M * (K * dQ(xi))
#   dQ(xi) is the 3x2 Jacobian of the monomial map (quad_jacobian).
#   K * dQ  is 1x3 * 3x2 = 1x2; then M * (K*dQ) is 3x1 * 1x2 = 3x2.
# This is the key object for geodesic integration: it maps tangent vectors
# in R^2 chart coordinates to ambient R^3 velocity vectors.
chart_jacobian <- function(chart, xi) {
  dQ <- quad_jacobian(xi)                   # 3x2 monomial Jacobian
  chart$L + chart$M %*% (chart$K %*% dQ)   # 3x2 total Jacobian J(xi)
}

# Project an ambient R^3 point onto this chart's tangent-plane coordinates (R^2).
# The linear projection xi = L^T * (x - mean) is exact for points on the tangent
# plane; for points on the actual surface it is a first-order approximation.
chart_project <- function(chart, x) {
  as.vector(t(chart$L) %*% (x - chart$mean))
}

# Test whether tangent coords xi lie within the chart's ellipsoidal domain.
# The domain is {xi in R^2 : xi^T * A * xi <= 1}, a 2D Mahalanobis ball.
# This is O(d^2) = O(4) for d=2.
chart_in_domain <- function(chart, xi) {
  as.numeric(t(xi) %*% chart$ell_A %*% xi) <= 1
}

# ===========================================================================
# PART 3: atlas_learn() — the main learning function
# ---------------------------------------------------------------------------
# Input : X (N x 3 matrix of surface points), k (number of charts)
# Output: an S3 object of class "atlas" containing k atlas_chart objects
#
# The algorithm runs four steps per chart:
#   (a) k-medoids partitioning
#   (b) local PCA to find the tangent plane and normal
#   (c) quadratic regression for the curvature coefficients K
#   (d) ellipsoidal domain construction
#
# d=2, D=3 specializations that are hard-coded here:
#   - SVD takes nv=3 (the full 3x3 right-singular-vector matrix)
#   - L = V(,1:2) is 3x2; M = V(,3) is 3x1 (one normal vector, not a matrix)
#   - nu (normal offset) is a scalar per point, not a vector
#   - K is fitted as a 1x3 row vector (one row per normal direction)
#   - ell_A is 2x2 (domain is a 2D ellipse)
# ===========================================================================

atlas_learn <- function(X, k, ellipsoid_scale = 1.1) {
  # X              : N x 3 matrix of surface points
  # k              : number of charts
  # ellipsoid_scale: inflate domains by this factor so adjacent charts overlap

  message("Fitting k-medoids (k = ", k, ") ...")
  # PAM (Partitioning Around Medoids) is preferred over k-means for surfaces:
  # medoids are actual data points, making the partition robust to outliers.
  km <- cluster::pam(X, k, diss = FALSE)

  charts <- seq_len(k) |>
    purrr::map((j) {

      # --- (a) Extract the cluster and center it ----------------------------
      idx <- which(km$clustering == j)
      Xj  <- X(idx, , drop = FALSE)   # N_j x 3 matrix of cluster points
      m   <- colMeans(Xj)             # chart center (3-vector)
      Xc  <- sweep(Xj, 2, m)         # centered cluster: N_j x 3

      # --- (b) Local PCA: tangent plane and normal -------------------------
      # SVD of the centered cluster reveals local geometry:
      #   - first two right singular vectors (v1, v2) span the tangent plane
      #   - third right singular vector (v3) is the surface normal
      # d=2, D=3: we always take nv=3 because D=3 (fully determined).
      sv  <- svd(Xc, nu = 0, nv = 3)

      # L: 3x2 tangent basis (orthonormal columns). General form: R^{D x d}.
      L   <- sv$v(, 1:2, drop = FALSE)

      # M: 3x1 unit normal. d=2, D=3 specialization: D-d=1, so there is
      # exactly ONE normal direction and M is a column vector, not a matrix.
      M   <- sv$v(, 3, drop = FALSE)

      # Project each centered point into tangent / normal coordinates.
      # tau: N_j x 2 — tangent coordinates (2D because d=2)
      # nu:  N_j x 1 — normal offsets (scalar because D-d=1)
      tau <- Xc %*% L    # N_j x 2
      nu  <- Xc %*% M    # N_j x 1 (scalar per point; would be N_j x (D-d) in general)

      # --- (c) Quadratic curvature regression ------------------------------
      # Fit: nu ~ K * phi(tau)  where phi(tau) = (tau1^2, tau1*tau2, tau2^2)
      # K is 1x3 here (one output because D-d=1; would be (D-d)x3 in general).
      #
      # Design matrix Q: N_j x 3, each row is phi(tau_i)
      Q   <- t(apply(tau, 1, quad_features))   # N_j x 3

      # Ridge-regularized least squares: K = nu^T * Q * (Q^T*Q + esp*I)^{-1}
      # The ridge penalty eps=1e-10 prevents singularity when clusters are
      # nearly collinear in tangent coordinates.
      K   <- t(nu) %*% Q %*% solve(crossprod(Q) + 1e-10 * diag(3))  # 1 x 3

      # --- (d) Ellipsoidal domain ------------------------------------------
      # Domain: {xi in R^2 : xi^T A xi <= 1}, a Mahalanobis ball.
      # A_raw = Cov(tau)^{-1}: inverse covariance of the tangent coordinates.
      # This adapts the domain shape to the data spread in each direction.
      A_raw <- solve(cov(tau))

      # Scale A so the outermost point lands on the boundary, then inflate
      # by ellipsoid_scale (default 1.1) to ensure overlap with neighbors.
      qvals <- apply(tau, 1, (xi) as.numeric(t(xi) %*% A_raw %*% xi))
      ell_A <- A_raw / (ellipsoid_scale * max(qvals))   # 2x2 positive-definite

      # Pack all chart data into an S3 object
      structure(
        list(mean = m, L = L, M = M, K = K, ell_A = ell_A, idx = idx,
             n_points = length(idx)),
        class = "atlas_chart"
      )
    })

  # Return the full atlas as an S3 object
  structure(
    list(
      charts        = charts,
      k             = k,
      clustering    = km$clustering,
      n_points      = nrow(X),
      ambient_dim   = ncol(X),   # D = 3
      intrinsic_dim = 2L         # d = 2 (hard-coded for surfaces in R^3)
    ),
    class = "atlas"
  )
}

# ===========================================================================
# PART 4: S3 print / summary methods
# ===========================================================================

print.atlas_chart <- function(x, ...) {
  cat(sprintf(
    "  %d points | mean (%s) | cond(ell_A) = %.1fn",
    x$n_points,
    paste(round(x$mean, 3), collapse = ", "),
    kappa(x$ell_A)
  ))
  invisible(x)
}

# Returns a per-chart summary tibble
summary.atlas <- function(object, ...) {
  purrr::imap_dfr(object$charts, (ch, i)
    tibble::tibble(
      chart      = i,
      n_points   = ch$n_points,
      mean_norm  = round(sqrt(sum(ch$mean^2)), 4),
      cond_ell_A = round(kappa(ch$ell_A), 1)
    )
  )
}

print.atlas <- function(x, ...) {
  cat(sprintf(
    "  k = %d | ambient R^%d | intrinsic R^%d | %d pointsnn",
    x$k, x$ambient_dim, x$intrinsic_dim, x$n_points
  ))
  print(summary(x))
  invisible(x)
}

# ===========================================================================
# PART 5: Chart lookup
# ===========================================================================

# Return the index of the nearest chart whose domain contains x.
# Falls back to the globally nearest chart center if none qualify.
# Checking only the 6 nearest charts (by Euclidean center distance) is
# sufficient in practice and avoids an O(k) domain test at every step.
find_chart <- function(atlas, x) {
  dists      <- map_dbl(atlas$charts, (ch) sum((x - ch$mean)^2))
  candidates <- order(dists)(seq_len(min(6L, atlas$k)))
  in_domain  <- purrr::keep(
    candidates,
    (i) chart_in_domain(atlas$charts((i)),
                         chart_project(atlas$charts((i)), x))
  )
  if (length(in_domain) > 0L) in_domain((1L)) else which.min(dists)
}

# ===========================================================================
# PART 6: Quasi-Euclidean retraction
# ---------------------------------------------------------------------------
# Advances one step from (chart_idx, xi) in the ambient direction tau_r3 (R^3).
#
# The pseudoinverse J^+ = (J^T J)^{-1} J^T is the key d=2, D=3 quantity:
#   - J is 3x2 (D x d)
#   - J^T J is 2x2 (d x d) — the pullback metric g; invertible for full-rank J
#   - J^+ is 2x3 (d x D) — pulls ambient vectors back to chart coordinates
#
# In general D > d, J^+ is the minimum-norm left inverse of J.
# ===========================================================================

retract_step <- function(atlas, chart_idx, xi, tau_r3) {
  chart  <- atlas$charts((chart_idx))
  J      <- chart_jacobian(chart, xi)        # 3x2 (D x d for d=2, D=3)
  Jp     <- solve(crossprod(J)) %*% t(J)    # 2x3 pseudoinverse: (J^T J)^{-1} J^T
  xi_new <- xi + as.vector(Jp %*% tau_r3)   # advance in chart coordinates

  if (chart_in_domain(chart, xi_new)) {
    # Still inside the same chart: evaluate and return
    return(list(
      chart_idx = chart_idx,
      xi        = xi_new,
      x         = chart_eval(chart, xi_new)
    ))
  }

  # Step crossed a chart boundary: find a new host chart and re-project
  x_cand <- chart_eval(chart, xi_new)       # approximate ambient position
  ci_new <- find_chart(atlas, x_cand)       # nearest chart that contains x_cand
  xi2    <- chart_project(atlas$charts((ci_new)), x_cand)   # re-project
  list(
    chart_idx = ci_new,
    xi        = xi2,
    x         = chart_eval(atlas$charts((ci_new)), xi2)
  )
}

# ===========================================================================
# PART 7: Geodesic path integration
# ---------------------------------------------------------------------------
# Traces a geodesic from x_start in direction direction_r3 (ambient R^3).
# Returns a tibble with columns x, y, z, step, chart_idx.
#
# At each step the ambient velocity vector is re-projected into the current
# chart's tangent plane and pushed back to ambient.  This keeps the direction
# of motion consistent across chart boundaries — without it, the fixed tau
# vector drifts after every chart transition and the path wanders off the
# intended geodesic.
# ===========================================================================

atlas_geodesic <- function(atlas, x_start, direction_r3,
                           n_steps = 100L, step_size = 0.02) {
  ci  <- find_chart(atlas, x_start)
  ch  <- atlas$charts((ci))
  xi  <- chart_project(ch, x_start)

  # Project the initial direction onto the starting chart's tangent plane,
  # then normalize to unit speed.  This ensures tau_r3 is always in T_x M.
  J      <- chart_jacobian(ch, xi)
  Jp     <- solve(crossprod(J)) %*% t(J)          # 2x3 pseudoinverse
  tau_r3 <- as.vector(J %*% (Jp %*% direction_r3))  # project onto tangent plane
  tau_r3 <- tau_r3 / sqrt(sum(tau_r3^2)) * step_size  # unit-speed scaling

  steps     <- vector("list", n_steps + 1L)
  chart_ids <- integer(n_steps + 1L)
  steps((1L))     <- chart_eval(ch, xi)
  chart_ids((1L)) <- ci

  for (i in seq_len(n_steps)) {
    res    <- retract_step(atlas, ci, xi, tau_r3)
    ci     <- res$chart_idx
    xi     <- res$xi
    steps((i + 1L))     <- res$x
    chart_ids((i + 1L)) <- ci

    # Re-project tau_r3 into the current chart's tangent plane and normalize.
    # This is the identity-transport approximation: it holds the direction
    # constant in the current chart's frame rather than applying the
    # Christoffel correction that true parallel transport would require.
    J_new  <- chart_jacobian(atlas$charts((ci)), xi)
    Jp_new <- solve(crossprod(J_new)) %*% t(J_new)
    tau_r3 <- as.vector(J_new %*% (Jp_new %*% tau_r3))
    tau_r3 <- tau_r3 / sqrt(sum(tau_r3^2)) * step_size
    J      <- J_new
  }

  do.call(rbind, steps) |>
    as_tibble(.name_repair = "minimal") |>
    set_names(c("x", "y", "z")) |>
    mutate(step = row_number(), chart_idx = chart_ids)
}

# ===========================================================================
# PART 8: Sphere-specific helpers
# ---------------------------------------------------------------------------
# These are ground-truth functions for S^2 used only for validation;
# they play no role in the Atlas-Learn algorithm itself.
# ===========================================================================

# Great-circle geodesic distance between two unit vectors (in radians).
# d(p, q) = arccos(p . q), clamped to (-1, 1) to avoid NaN from floating point.
sphere_dist <- function(p, q) {
  acos(pmax(pmin(sum(p * q), 1.0), -1.0))
}

# Unit tangent vector at p pointing toward q along the great circle.
# Computed by projecting q onto the tangent plane at p and normalizing.
geodesic_direction_sphere <- function(p, q) {
  v   <- q - sum(p * q) * p   # remove radial component
  nrm <- sqrt(sum(v^2))
  if (nrm < 1e-12) return(NULL)
  v / nrm
}

# Run the Atlas geodesic from p toward q for the exact number of steps
# needed to cover the true great-circle distance, then measure endpoint error.
atlas_endpoint_error <- function(atlas, p, q, step_size = 0.02) {
  d   <- sphere_dist(p, q)
  dir <- geodesic_direction_sphere(p, q)

  if (d < 1e-6 || is.null(dir)) {
    return(tibble(true_dist = d, endpoint_error = 0, n_steps = 0L))
  }

  n_steps  <- max(1L, round(d / step_size))
  path     <- atlas_geodesic(atlas, p, dir,
                              n_steps = n_steps, step_size = step_size)

  endpoint <- as.numeric(path(nrow(path), c("x", "y", "z")))
  # Re-project endpoint onto S^2 before measuring angle error.
  # This separates geodesic-direction error from on-manifold drift.
  endpoint <- endpoint / sqrt(sum(endpoint^2))

  tibble(
    true_dist      = d,
    endpoint_error = sphere_dist(endpoint, q),
    n_steps        = as.integer(n_steps)
  )
}
Dominic Botezariu
Dominic Botezariuhttps://www.noobz.ro/
Creator de site și redactor-șef.

Cele mai noi știri

Pe același subiect

LĂSAȚI UN MESAJ

Vă rugăm să introduceți comentariul dvs.!
Introduceți aici numele dvs.