Magia învățării în context (ICL): când modelul dvs. vă cunoaște deja datele

URMĂREȘTE-NE
16,065FaniÎmi place
1,142CititoriConectați-vă


Te-ai uitat vreodată la un grafic de dispersie proaspăt trasat și te-ai gândit imediat „Ah, aceasta este în mod clar o curbă logaritmică cu ceva zgomot heteroschedastic,” fără a rula o singură linie de cod de modelare? Cum faci asta? Nu faci coborâre de gradient în capul tău. Îți folosești intuiţie!

În calitate de cercetător de date cu experiență, ați văzut mii de seturi de date în cariera dvs. Când se confruntă cu date noi, rețeaua ta neuronală naturală (alias creierul) pur și simplu se bazează pe această vastă bibliotecă de forme matematice din trecut și recunoaște imediat modelul. Dar dacă o rețea neuronală artificială ar putea face exact același lucru? Ce se întâmplă dacă ți-ar putea prezice datele fără a fi efectiv instruit cu privire la ele?

Bun venit în lumea uluitoare a Învățare în context (ICL) pentru date tabelare, aduse la R prin intermediul incredibilului nou TabPFN pachet (pe CRAN).

The Transformer: de la text la tabele

Pentru a înțelege ICL, trebuie să vorbim despre Modele de limbaj mari cum ar fi ChatGPT (vezi și Construirea propriului tău mini-ChatGPT cu R: De la lanțuri Markov la Transformers!). Atunci când îi dai unui chatbot o propoziție neterminată, acesta nu își reinstruiește greutățile pentru a ghici următorul cuvânt. Folosește a Transformator arhitectură echipată cu un mecanism de atenție (vezi și Atenție! Ce se află la baza ChatGPT? (De asemenea, ca videoclip!)). Citește cuvintele pe care le-ați furnizat, înțelege dependențele dintre ele (gramatica și contextul) și extrapolează instantaneu ceea ce urmează.

Geniul lui TabPFN ia această arhitectură exactă și o aplică foilor de calcul. În loc de o secvență de cuvinte, Transformerul citește o secvență de rânduri de date. Îți tratează trăsăturile (X) și ținta dvs. (Y) ca gramatica unei limbi. Comparând toate rândurile și coloanele simultan în „fereastra de context”, el descoperă dependențele din tabel la fel cum un model de limbă descoperă dependențele din text.

Modelul care apare este a model de bază pentru datele tabelaresau model de fundație tabelară pe scurt.

Acest proces este cunoscut oficial sub numele de Few-Shot Learning. Nu îi oferi modelului un creier gol de antrenat; „promează” unui creier pre-antrenat cu câteva zeci (sau câteva sute) de „împușcări” (rânduri) de date pentru a stabili tiparul!

Matricea de formare: Învățarea formei matematicii

S-ar putea să vă întrebați: Dacă nu se antrenează pe datele mele, despre ce anume a fost instruit?

Aici devine incredibil de cool. Cercetătorii care au construit TabPFN nu l-a instruit pe seturi de date din lumea reală, cum ar fi prețurile locuințelor sau dosarele medicale. În schimb, ei au scris algoritmi pentru a genera milioane de structuri de dependență matematică complet aleatorii, create artificial.

Ei au forțat rețeaua să exerseze pe seturi de date sintetice care conțineau toate ciudateniile statistice imaginabile: tendințe liniare, neliniarități severe, efecte de interacțiune bizare, mecanisme extreme de lipsă de date și zgomot absolut. Pentru că și-a petrecut întregul antrenament rezolvând miliarde de puzzle-uri matematice abstracte, modelul a învățat elementele fundamentale formă a dependenţelor matematice cauzale. Când vă vede datele din lumea reală, recunoaște doar un model pe care l-a rezolvat deja sintetic de o mie de ori înainte.

Să-l vedem în acțiune

Să folosim venerabilul iris set de date. Deoarece iris este mic și granițele matematice sunt foarte clare, este candidatul perfect pentru învățarea cu câteva lovituri. Observați cum codul arată exact ca învățarea automată tradițională, dar sub capotă, nu are loc de fapt niciun antrenament!

# Load the package
library(tabpfn)

# 1. Prepare the Data
set.seed(42)
train_indices <- sample(seq_len(nrow(iris)), size = 0.7 * nrow(iris))

iris_train <- iris(train_indices, )
iris_test  <- iris(-train_indices, )

# 2. Fit the Model
cat("Generating embeddings...n")
## Generating embeddings...
tab_fit <- tab_pfn(Species ~ ., data = iris_train)

# 3. Make Predictions
cat("Predicting...n")
## Predicting...
predictions <- predict(tab_fit, new_data = iris_test)

# 4. Check the accuracy
accuracy <- sum(predictions$.pred_class == iris_test$Species) / nrow(iris_test)
cat("nSuccess! Overall Accuracy:", round(accuracy * 100, 1), "%n")
## 
## Success! Overall Accuracy: 97.8 %

Când rulați acest lucru, veți vedea o precizie de 97,8%. Modelul a analizat cele câteva exemple din iris_traina recunoscut instantaneu formele multidimensionale care separă speciile folosind intuiția sa sintetică și a clasificat cu precizie noile date de testare fără o singură epocă de retropropagare tradițională.

Concluzie

TabPFNeste o schimbare de paradigmă. Pentru seturile de date tabulare mici până la mijlocii, nu mai trebuie să petrecem ore întregi regland hiperparametrii pentru Random Forests sau XGBoost. Putem pur și simplu să predăm datele unui transformator cu experiență, omniscient din punct de vedere matematic și să lăsăm Învățarea în context să facă treaba grea.

Încearcă-ți propriile date și povestește-ne despre experiența ta în comentariile de mai jos!

Dominic Botezariu
Dominic Botezariuhttps://www.noobz.ro/
Creator de site și redactor-șef.

Cele mai noi știri

Pe același subiect

LĂSAȚI UN MESAJ

Vă rugăm să introduceți comentariul dvs.!
Introduceți aici numele dvs.