(Acest articol a fost publicat pentru prima dată pe R pe Statistici și Rși cu amabilitate a contribuit la R-bloggeri). (Puteți raporta problema legată de conținutul acestei pagini aici)
Doriți să vă distribuiți conținutul pe R-bloggeri? dați clic aici dacă aveți un blog, sau aici dacă nu aveți.

Această postare a fost scrisă în colaborare cu Joshua Marie.
Rețelele neuronale standard învață greutăți fixe în timpul antrenamentului și produc o estimare unică punctuală pentru fiecare intrare – fără a simți cât de sigur este modelul. Rețele neuronale bayesiene (BNN) înlocuiți acele greutăți fixe cu distribuții de probabilitate (Neal 2012). În loc să învețe o singură valoare pe greutate, rețeaua învață o medie și o varianță și eșantioane din acea distribuție la fiecare trecere înainte.
Formal, un BNN plasează un prior (p(mathbf{w})) peste greutăți și îl actualizează cu datele observate (mathcal{D}) prin teorema lui Bayes:
(p(mathbf{w} mid mathcal{D}) = frac{p(mathcal{D} mid mathbf{w}), p(mathbf{w})}{p(mathcal{D})})
Pentru rețelele mari și adânci, cea posterioară (p(mathbf{w} mid mathcal{D})) este insolubil. BNN-urile folosesc de obicei inferență variațională pentru a o aproxima cu o distribuție tratabilă (q_phi(mathbf{w})) prin minimizarea divergenţei KL între cele două (Blundell și colab. 2015).
De la estimări punctuale la distribuții
Un strat liniar standard calculează (mathbf{y} = mathbf{W}mathbf{x} + mathbf{b})unde (mathbf{W}) şi (mathbf{b}) sunt fixe. Într-un strat liniar bayesian, acestea sunt variabile aleatorii. Punem un anterior gaussian peste greutăți:
(mathbf{W} sim mathcal{N}(mu_{text{anterior}}, sigma_{text{anterior}}^2))
iar în timpul antrenamentului învățăm a posterior variaţional:
(q(mathbf{W}) = mathcal{N}(mu_W, sigma_W^2))
Prin urmare, fiecare greutate are doi scalari care pot fi învățați: o medie (mu_W) și un log-deviație standard (log sigma_W).
Trucul de reparametrizare
Prelevare direct de la (q(mathbf{W})) nu este diferențiabilă în raport cu (mu_W) şi (sigma_W). Rezolvăm asta cu truc de reparametrizare (Kingma și Welling 2013):
(mathbf{W} = mu_W + sigma_W odot varepsilon, quad varepsilon sim mathcal{N}(mathbf{0}, mathbf{I}))
Noi depozităm (log sigma_W) (nu (sigma_W) direct) pentru a asigura pozitivitatea în timpul optimizării neconstrânse:
(sigma_W = exp(texttt{greutate}_{logsigma}))
Astfel rezultă greutatea eșantionată utilizată în trecerea înainte:
(mathbf{W}_{text{sample}} = mu_W + exp(texttt{greutate}_{logsigma}) odot varepsilon_W)
și în mod similar pentru părtinire:
(mathbf{b}_{text{sample}} = mu_b + exp(texttt{bias}_{logsigma}) odot varepsilon_b)
În acest moment, {tidymodels} oferă API-uri limitate pentru arhitecturi de rețele neuronale, mai ales în jurul MLP-urilor standard. Câteva pachete fac {torch} mai ușor de utilizat la un nivel superior:
{brulee}poduri{torch}cu{tidymodels}. Este ergonomic, dar în prezent se concentrează pe MLP-uri pentru date tabelare.{cito}nu se integrează cu{tidymodels}și se înclină mai mult spre aplicații statistice.- The
{torch}echipa mentine{luz}. {tabnet}se integrează cu{tidymodels}dar este dedicat arhitecturii TabNet.
{kindling} nu înlocuiește aceste pachete. Cu toate acestea, ajută la eliminarea unora dintre aceste lacune prin sprijinirea arhitecturilor personalizate cu adâncime flexibilă. Ca rezultat, arhitecturile personalizate de rețele neuronale, inclusiv BNN-urile, pot fi utilizate în interior {tidymodels} fluxurilor de lucru. Pentru o imagine de ansamblu mai amplă a ce {kindling} activează în R, vezi această postare anterioară.
La 3 martie 2026, {kindling} v0.3.0 a fost lansat pe CRAN. Ține minte asta utils::install.packages() în R instalează întotdeauna cea mai recentă versiune.
install.packages("kindling")
pak::pak("joshuamarie/kindling")
Recomand folosirea {pak}: este rapid și foarte bun la rezolvarea dependențelor pachetelor din medii, inclusiv biblioteci de sistem și {renv}-proiecte gestionate.
Înainte de utilizare {kindling}instalați LibTorch — backend-ul C++ partajat de PyTorch și {torch} Pachetul R (Falbel și Luraschi 2023):
torch::install_torch()
Când predau R, recomand să folosești box::use() și calificarea numelor pe care le importați. Creați o bnn folderul din rădăcina proiectului, apoi clonează acest depozit pentru a utiliza BNN-uri în R (acest modul este adaptat din torchbnn (Kim 2020)): https://github.com/joshuamarie/RTorchBNN.
Puteți rula o anumită comandă:
git clone --branch main https://github.com/joshuamarie/RTorchBNN bnn
Apoi, încărcați următoarele:
box::use( bnn = . / bnn, kindling(train_nnsnip, nn_arch, act_funs), parsnip(fit, augment), yardstick(metrics) )
Apoi înregistrați modelele în {parsnip} cu:
loadNamespace("kindling")
În {kindling} v0.3.0, train_nnsnip() a fost introdus ca a {parsnip}-specificație de model compatibilă care face legătura între antrenorul de rețea neuronală generalizată cu {tidymodels} ecosistem. Spre deosebire de mlp_kindling()care este specific rețelelor feedforward, train_nnsnip() este agnostic de arhitectură: descrieți topologia stratului prin nn_arch()permițând introducerea în stilul BNN sau a altor tipuri de straturi personalizate fără a schimba logica de antrenament.
Antrenăm un model Bayesian Neural Network pentru a prezice speciile din iris set de date prin train_nnsnip()prin configurare arch cu nn_arch(). Interior nn_arch()setat nn_layer la bnn$BayesLinear în loc de stratul liniar implicit, apoi definiți argumentele pentru fiecare strat cu layer_arg_fn.
nn_model <-
train_nnsnip(
mode = "classification",
arch = nn_arch(
nn_layer = bnn$BayesLinear,
out_nn_layer = torch::nn_linear,
layer_arg_fn = ~ if (.is_output) {
list(.in, .out)
} else {
list(
in_features = .in,
out_features = .out,
prior_mu = 0,
prior_sigma = 0.1
)
}
),
hidden_neurons = c(64, 32),
activations = act_funs(relu, elu),
loss = "cross_entropy",
epochs = 50,
verbose = TRUE,
learn_rate = 1e-3,
optimizer_args = list(weight_decay = 0.01)
) |>
fit(Species ~ ., data = iris)
nn_model |>
augment(new_data = iris) |>
metrics(truth = Species, estimate = .pred_class)
## # A tibble: 2 × 3
## .metric .estimator .estimate
##
## 1 accuracy multiclass 0.667
## 2 kap multiclass 0.5
Media predicțiilor pentru mai multe treceri înainte oferă atât o estimare punctuală, cât și o măsură a răspândirii predictive – toate în cadrul unui standard {tidymodels} conductă.
Deocamdată, nu există nicio opțiune încorporată pentru a actualiza dinamic pierderile în timpul antrenamentului. Pentru a utiliza un obiectiv personalizat, cum ar fi Limită inferioară a dovezilor (ELBO), în prezent trebuie să îl definiți manual și să vă reantrenați cu acea pierdere.
Actualizarea modelului
Acestea fiind spuse, puteți aplica în continuare o pierdere personalizată setând loss parametru la a torch funcție, de exemplu torch::nnf_mse_loss(). Dacă doriți să antrenați modelul cu o pierdere în stil ELBO, faceți următoarele:
make_elbo_loss <- function(bnn_model, n_obs, kl_weight = 1.0) {
box::use(
torch(nnf_cross_entropy, torch_tensor)
)
function(input, target) {
ce <- nnf_cross_entropy(input, target)
device <- input$device
kl_val <- torch_tensor(0.0, device = device, requires_grad = FALSE)
for (m in bnn_model$modules) {
if (inherits(m, "BayesLinear") && !is.null(m$kl)) {
kl_val <- kl_val + m$kl
}
}
ce + kl_weight * kl_val / n_obs
}
}
elbo_loss <- make_elbo_loss(nn_model$fit$model, n_obs = nrow(iris), kl_weight = 1.0)
nn_model2 <-
train_nnsnip(
mode = "classification",
arch = nn_arch(
nn_layer = bnn$BayesLinear,
out_nn_layer = torch::nn_linear,
layer_arg_fn = ~ if (.is_output) {
list(.in, .out)
} else {
list(
in_features = .in,
out_features = .out,
prior_mu = 0,
prior_sigma = 0.1
)
}
),
hidden_neurons = c(64, 32),
activations = act_funs(relu, elu),
loss = elbo_loss,
epochs = 50,
verbose = TRUE,
learn_rate = 1e-3,
optimizer_args = list(weight_decay = 0.01)
) |>
fit(Species ~ ., data = iris)
nn_model2 |>
augment(new_data = iris) |>
metrics(truth = Species, estimate = .pred_class)
## # A tibble: 2 × 3
## .metric .estimator .estimate
##
## 1 accuracy multiclass 0.847
## 2 kap multiclass 0.77
Ca întotdeauna, dacă aveți întrebări legate de subiectul abordat în această postare, vă rugăm să o adăugați ca comentariu, astfel încât alți cititori să poată beneficia de discuție.
