Cadre de învățare automată în R

URMĂREȘTE-NE
16,065FaniÎmi place
1,142CititoriConectați-vă

(Acest articol a fost publicat pentru prima dată pe R’ticchokeși cu amabilitate a contribuit la R-bloggeri). (Puteți raporta problema legată de conținutul acestei pagini aici)


Doriți să vă distribuiți conținutul pe R-bloggeri? dați clic aici dacă aveți un blog, sau aici dacă nu aveți.

Ecosistemul lui R oferă o selecție bogată de cadre de învățare automată, fiecare cu filozofii de design și puncte forte distincte. Această postare este o comparație alăturată a cinci cadre ML în R care oferă interfețe unificate pe mai mulți algoritmi, cu exemple de cod rulabile pe același set de date, astfel încât să puteți compara direct API-urile. Accentul se pune pe pachetele care vă permit să schimbați algoritmi fără a vă rescrie codul.

Cadre dintr-o privire

Tuning încorporat ✅ (tune) ✅ (mlr3tuning) ✅ (AutoML) ✅ (qeFT())
Conducta de preprocesare ✅ (recipes) ✅ (preProcess) ✅ (mlr3pipelines)
Varietate de modele Peste 200 de motoare Peste 230 de metode Peste 100 de cursanți GBM, GLM, DL, DRF Peste 20 de ambalaje
Viteza relativă Moderat Moderat Moderat Rapid (distribuit) Depinde de backend
Curba de învățare Mediu Scăzut Ridicat Scăzut Foarte scăzut
Dezvoltare activă ⚠️ Modul de întreținere
Cel mai bun pentru Conducte de producție Prototipare rapidă Benchmarking AutoML și scalare Predare și explorare

Configurare și date

Toate exemplele de mai jos folosesc iris sarcina de clasificare: prezice Species din cele patru măsurători numerice. O singură împărțire tren/test este creată în față, astfel încât rezultatele să fie direct comparabile.

library(dplyr)

# Reproducible train/test split (framework-agnostic)
set.seed(42)
n <- nrow(iris)
train_idx <- sample(seq_len(n), size = floor(0.7 * n))

train_data <- iris(train_idx, )
test_data  <- iris(-train_idx, )

# Store accuracy results for final comparison
results <- data.frame(
  Framework = character(),
  Model = character(),
  Accuracy = numeric(),
  stringsAsFactors = FALSE
)

cat("Training set:", nrow(train_data), "observationsn")
Training set: 105 observations
cat("Test set:    ", nrow(test_data), "observationsn")
Test set:     45 observations

1. modele ordonate

Ecosistemul tidymodels este abordarea nativă modernă și diversă a modelării în R. Acesta oferă o gramatică consistentă pentru specificarea modelelor (parsnip), preprocesare (recipes), compunerea fluxurilor de lucru (workflows), și hiperparametrii de reglare (tune).

library(tidymodels)

# Define a recipe (preprocessing)
rec <- recipe(Species ~ ., data = train_data)

# Define a model specification
rf_spec <- rand_forest(trees = 500) %>%
  set_engine("ranger") %>%
  set_mode("classification")

# Combine into a workflow
rf_wf <- workflow() %>%
  add_recipe(rec) %>%
  add_model(rf_spec)

# Fit the workflow
rf_fit <- rf_wf %>% fit(data = train_data)

# Predict on test set
preds_tidy <- predict(rf_fit, test_data) %>%
  bind_cols(test_data %>% select(Species))

# Evaluate
acc_tidy <- accuracy(preds_tidy, truth = Species, estimate = .pred_class)
acc_tidy
# A tibble: 1 × 3
  .metric  .estimator .estimate
                
1 accuracy multiclass     0.978
results <- rbind(results, data.frame(
  Framework = "tidymodels",
  Model = "Random Forest (ranger)",
  Accuracy = acc_tidy$.estimate
))
  • Conductă componabilă: recipe + model + workflow este ușor de extins
  • Schimbați motoarele cu o singură linie (set_engine("xgboost"))
  • Validare încrucișată fără probleme și reglare hiperparametrică prin tune_grid() / tune_bayes()
  • Integrare profundă cu tidyverse

2. caret

The caret pachetul (Classification And REgression Training) a fost standardul de facto pentru ML în R timp de peste un deceniu. Include peste 230 de modele în spatele unui singur train() interfata. În timp ce acum se află în modul de întreținere (creatorul său, Max Kuhn, conduce tidymodels), rămâne încă utilizat pe scară largă.

library(caret)

# Train a random forest with 5-fold CV
ctrl <- trainControl(method = "cv", number = 5)

rf_caret <- train(
  Species ~ .,
  data = train_data,
  method = "rf",
  trControl = ctrl,
  tuneLength = 3  # Try 3 values of mtry
)

# Best tuning parameter
rf_caret$bestTune
# Predict on test set
preds_caret <- predict(rf_caret, test_data)

# Evaluate
cm_caret <- confusionMatrix(preds_caret, test_data$Species)
cm_caret$overall("Accuracy")
results <- rbind(results, data.frame(
  Framework = "caret",
  Model = "Random Forest (rf)",
  Accuracy = as.numeric(cm_caret$overall("Accuracy"))
))
  • API minimă: un singur train() apelul se ocupă de preprocesare, reglare și potrivire
  • Peste 230 de metode de model disponibile din cutie
  • Încorporat confusionMatrix() cu diagnostice extinse
  • Baza masivă de cunoștințe ale comunității și acoperire Stack Overflow

3. mlr3

mlr3 este un cadru ML modern, orientat pe obiecte, construit pe clase R6. Excelează la analiza comparativă sistematică, conducte composabile și experimente reproductibile. Curba de învățare este mai abruptă, dar rezultatul este o arhitectură puternică, extensibilă.

library(mlr3)
library(mlr3learners)

# Define the task
task <- TaskClassif$new(
  id = "iris",
  backend = train_data,
  target = "Species"
)

# Define the learner
learner <- lrn("classif.ranger", num.trees = 500)

# Train
learner$train(task)

# Predict on test data — create a test task to avoid backend storage issues
test_task <- TaskClassif$new(
  id = "iris_test",
  backend = test_data,
  target = "Species"
)
pred_mlr3 <- learner$predict(test_task)

# Evaluate
acc_mlr3 <- pred_mlr3$score(msr("classif.acc"))
acc_mlr3
results <- rbind(results, data.frame(
  Framework = "mlr3",
  Model = "Random Forest (ranger)",
  Accuracy = as.numeric(acc_mlr3)
))
  • Design orientat pe obiecte R6 — totul este un obiect cu metode
  • Evaluare comparativă de primă clasă: comparați mai mulți cursanți la mai multe sarcini cu benchmark()
  • Conducte composabile prin mlr3pipelines (stivuire, asamblare, inginerie de caracteristici)
  • Strategii de reeșantionare și măsuri de performanță încorporate

4. h2o (AutoML)

h2o este o platformă de învățare automată distribuită cu o interfață R puternică. Caracteristica sa remarcabilă este h2o.automl() selectarea automată a modelului, reglarea hiperparametrului și crearea unui ansamblu stivuit cu un singur apel de funcție. Se rulează pe un JVM local, deci Java trebuie instalat.

Această secțiune necesită instalarea Java (JDK 8+). h2o pornește un server local bazat pe JVM. Dacă nu aveți Java, treceți la compararea rezultatelor – celelalte patru cadre acoperă același teren fără această dependență.

library(h2o)

# Start a local h2o cluster (uses available cores)
h2o.init(nthreads = -1, max_mem_size = "2G")
H2O is not running yet, starting it now...

Note:  In case of errors look at the following log files:
    C:UsersRIDDHI~1AppDataLocalTempRtmp8G4u9Cfilec0caad7dcc/h2o_Riddhiman_Roy_started_from_r.out
    C:UsersRIDDHI~1AppDataLocalTempRtmp8G4u9Cfilec0c1660783f/h2o_Riddhiman_Roy_started_from_r.err


Starting H2O JVM and connecting:  Connection successful!

R is connected to the H2O cluster: 
    H2O cluster uptime:         2 seconds 276 milliseconds 
    H2O cluster timezone:       Asia/Kolkata 
    H2O data parsing timezone:  UTC 
    H2O cluster version:        3.44.0.3 
    H2O cluster version age:    2 years, 3 months and 23 days 
    H2O cluster name:           H2O_started_from_R_Riddhiman_Roy_axb153 
    H2O cluster total nodes:    1 
    H2O cluster total memory:   1.98 GB 
    H2O cluster total cores:    24 
    H2O cluster allowed cores:  24 
    H2O cluster healthy:        TRUE 
    H2O Connection ip:          localhost 
    H2O Connection port:        54321 
    H2O Connection proxy:       NA 
    H2O Internal Security:      FALSE 
    R Version:                  R version 4.5.3 (2026-03-11 ucrt) 
h2o.no_progress()  # Suppress progress bars

# Convert data to h2o frames
train_h2o <- as.h2o(train_data)
test_h2o  <- as.h2o(test_data)

# Run AutoML — automatic model selection and stacking
aml <- h2o.automl(
  x = c("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"),
  y = "Species",
  training_frame = train_h2o,
  max_models = 10,
  seed = 42
)
22:11:38.3: AutoML: XGBoost is not available; skipping it.
22:11:39.171: _min_rows param, The dataset size is too small to split for min_rows=100.0: must have at least 200.0 (weighted) rows, but have only 105.0.
# Leaderboard — best models ranked by cross-validated performance
h2o.get_leaderboard(aml) |> as.data.frame() |> head(5)
                                                 model_id mean_per_class_error
1    DeepLearning_grid_1_AutoML_1_20260412_221137_model_1           0.03988095
2                          GBM_2_AutoML_1_20260412_221137           0.05029762
3                          GLM_1_AutoML_1_20260412_221137           0.05029762
4    StackedEnsemble_AllModels_1_AutoML_1_20260412_221137           0.05982143
5 StackedEnsemble_BestOfFamily_1_AutoML_1_20260412_221137           0.05982143
     logloss      rmse        mse
1 0.10262590 0.1802887 0.03250400
2 0.13688347 0.1981121 0.03924839
3 0.09073184 0.1736624 0.03015862
4 0.12933104 0.2002635 0.04010548
5 0.11828660 0.1921656 0.03692762
# Predict with the best model
preds_h2o <- h2o.predict(aml@leader, test_h2o)
acc_h2o <- mean(as.vector(preds_h2o$predict) == as.vector(test_h2o$Species))
cat("Accuracy:", acc_h2o, "n")
results <- rbind(results, data.frame(
  Framework = "h2o",
  Model = paste0("AutoML (", aml@leader@algorithm, ")"),
  Accuracy = acc_h2o
))

# Shutdown h2o
h2o.shutdown(prompt = FALSE)
  • h2o.automl() — selecția modelului, reglarea și ansamblurile stivuite complet automate
  • Antrenează GBM, XGBoost, GLM, DRF și modele de învățare profundă într-un singur apel
  • Calcul distribuit — se extinde la seturi de date mai mari decât memoria
  • Clasament încorporat pentru compararea modelelor
  • Implementarea producției prin exportul modelului MOJO/POJO

5. qeML

qeML (Învățare automată rapidă și ușoară) adoptă o abordare diferită de minimizare a standardelor. Fiecare algoritm – pădure aleatoare, creșterea gradientului, SVM, KNN, LASSO, rețele neuronale și multe altele este înfășurat în spatele unei singure linii qe*() funcţionează cu o consecventă (data, targetName) semnătură. Fără obiecte formulă, fără conversii matrice, fără apeluri de predicție separate, doar rezultate. Este ideal pentru predare, explorare și comparații rapide de modele.

library(qeML)

# qeML convention: pass full data + target name (string)
# It handles train/test splitting internally via holdout
# But to match our split, we'll train on train_data and predict on test_data
# predict() expects new data WITHOUT the target column
test_features <- test_data(, -which(names(test_data) == "Species"))

# Random Forest (wraps randomForest)
rf_qe <- qeRF(train_data, "Species")
preds_rf_qe <- predict(rf_qe, test_features)
acc_rf_qe <- mean(preds_rf_qe$predClasses == test_data$Species)
cat("Random Forest accuracy:", acc_rf_qe, "n")
Random Forest accuracy: 0.9777778 
# Gradient Boosting (wraps gbm)
gb_qe <- qeGBoost(train_data, "Species")
preds_gb_qe <- predict(gb_qe, test_features)
acc_gb_qe <- mean(preds_gb_qe$predClasses == test_data$Species)
cat("Gradient Boosting accuracy:", acc_gb_qe, "n")
Gradient Boosting accuracy: 0.9555556 
# SVM (wraps e1071)
svm_qe <- qeSVM(train_data, "Species")
preds_svm_qe <- predict(svm_qe, test_features)
acc_svm_qe <- mean(preds_svm_qe$predClasses == test_data$Species)
cat("SVM accuracy:", acc_svm_qe, "n")
# Use the best-performing qeML model for the results table
best_acc_qe <- max(acc_rf_qe, acc_gb_qe, acc_svm_qe)
best_model_qe <- c("Random Forest", "Gradient Boosting", "SVM")(
  which.max(c(acc_rf_qe, acc_gb_qe, acc_svm_qe))
)

results <- rbind(results, data.frame(
  Framework = "qeML",
  Model = paste0(best_model_qe, " (qe wrapper)"),
  Accuracy = best_acc_qe
))
  • Potrivire model cu o linie: qeRF(data, "target") — fără formulă, fără matrice, fără rețetă
  • Peste 20 de algoritmi în spatele unei uniforme qe*() interfață (RF, GBM, SVM, KNN, LASSO, rețele neuronale și multe altele)
  • qeCompare() vă permite să comparați mai multe metode într-un singur apel
  • Evaluare de rezistență încorporată
  • Cea mai scăzută curbă de învățare dintre orice cadru prezentat aici

Compararea rezultatelor

Toate cele cinci cadre au fost instruite pe aceeași împărțire 70/30 a iris set de date. Iată cum se strâng:

library(knitr)

results <- results %>% arrange(desc(Accuracy))
kable(results, digits = 4, caption = "Test Set Accuracy by Framework")
Testați acuratețea setului în funcție de cadru
modele ordonate Pădurea aleatorie (ranger) 0,9778
h2o AutoML (învățare profundă) 0,9778
qeML Pădurea aleatorie (wrapper qe) 0,9778
semn de omisiune Pădure aleatorie (rf) 0,9556
mlr3 Pădurea aleatorie (ranger) 0,9556

Pe un set de date curat și mic, cum ar fi irisdiferențele de precizie sunt minime. Adevăratul diferențiere este API-ul și fluxul de lucru oferit de fiecare cadru. Pe seturile de date din lumea reală, alegerea cadrului contează mai mult pentru modul în care vă structurați codul decât pentru acuratețea brută.

Gânduri de închidere

Nu există un singur cadru ML „cel mai bun” în R și alegerea corectă depinde de sarcina la îndemână:

  • Începe cu tidymodels pentru o conductă modernă, componabilă, pregătită pentru producție.
  • Încerca qeML pentru cea mai rapidă cale de la date la rezultate.
  • Utilizare h2o pentru selectarea automată a modelului și stivuirea cu efort minim.
  • Luați în considerare mlr3 pentru benchmarking riguros și compoziția avansată a conductelor.
  • Stai cu caret dacă menține codul existent sau preferă simplitatea lui testată în luptă.

Dominic Botezariu
Dominic Botezariuhttps://www.noobz.ro/
Creator de site și redactor-șef.

Cele mai noi știri

Pe același subiect

LĂSAȚI UN MESAJ

Vă rugăm să introduceți comentariul dvs.!
Introduceți aici numele dvs.