Eșantionarea în jos pentru modelarea predictivă | R-BLOGGERS

URMĂREȘTE-NE
16,065FaniÎmi place
1,142CititoriConectați-vă

(Acest articol a fost publicat pentru prima dată pe Jason Bryerși a contribuit cu drag la R-Bloggers). (Puteți raporta problema despre conținutul de pe această pagină aici)


Doriți să vă împărtășiți conținutul pe R-Bloggers? Faceți clic aici dacă aveți un blog sau aici dacă nu.

Rețineți că aceasta este postată încrucișată cu o vinie în medley Pachet R. Pentru cea mai actualizată versiune, accesați aici: https://jbryer.github.io/medley/articles/downsampling.html Comentariile pot fi direcționate către mine pe Mastodon la @vis.Social@jbryer.

Pentru a instala versiunea de dezvoltare a medley pachet, utilizați următoarea comandă:

remotes::install_github('jbryer/medley')

Una dintre provocările modelării predictive apare atunci când variabila dependentă este dezechilibrată (adică raportul dintre o clasă la celălalt este mare, în general mai mare de 80 la 20). Au fost propuse mai multe strategii pentru a aborda dezechilibrul, inclusiv eșantionarea și eșantionarea în jos. Eșantionarea implică duplicarea datelor din clasa mai mică pentru a se potrivi mai bine cu numărul de observații din clasa mai mare. Dezavantajul eșantionării este că se creează noi date care ar putea provoca supraîncărcare. În plus, prin creșterea artificială a dimensiunii eșantionului, erorile standard vor fi, de asemenea, scăzute artificial. Eșantionarea în jos implică selectarea aleatorie din clasa mai mare pentru a obține un echilibru mai bun. Dezavantajul eșantionării în jos este că unele date, și uneori o mulțime de date, sunt excluse din model.

Această lucrare introduce o procedură care dobândește în timp ce utilizează toate datele disponibile prin instruirea mai multor modele. De exemplu, luați în considerare un set de date cu 1.000 de observații, 900 sunt din clasa A și 100 sunt din clasa B. Presupunând că dorim să avem un echilibru perfect între A și B, am atribui la întâmplare observațiile de clasa A 900 la unul dintre cele nouă modele. Putem apoi să reunim predicțiile pe cele nouă modele.

Exemplu de lucru

library(medley)
data('pisa', package="medley")
data('pisa_variables', package="medley")

Programul de evaluare internațională a studenților (PISA) este un studiu internațional realizat de Organizația pentru Cooperare și Dezvoltare Economică (OCDE) la fiecare trei ani. Evaluează elevii în vârstă de 15 ani în matematică, știință și lectură în timp ce colectează informații despre elevi și școlile lor. pisa set de date incluse în medley Pachetul provine de la administrația din 2009 și este utilizat pentru a demonstra prezicerea prezenței în școală privată sau publică. Există 5.233 de observații pe 44 de variabile cu 93,4% studenți ai școlii publice și 6,6% studenți ai școlii private.

Pentru început, vom împărți datele într -un set de instruire și validare folosind splitstackshape::stratified() Funcție pentru a se asigura că raportul dintre elevii de școală public-privat este același în ambele seturi de date.

pisa_formu <- Public ~ .
names(pisa) <- pisa_variables(names(pisa))
pisa_splits <- splitstackshape::stratified(
    pisa, group = "Public", size = 0.75, bothSets = TRUE)
pisa_train <- pisa_splits((1)) |> as.data.frame()
pisa_valid <- pisa_splits((2)) |> as.data.frame()
table(pisa$Public, useNA = 'ifany') |> print() |> prop.table()
     FALSE       TRUE 
0.06592777 0.93407223 
table(pisa_train$Public, useNA = 'ifany') |> print() |> prop.table()
     FALSE       TRUE 
0.06598726 0.93401274 
table(pisa_valid$Public, useNA = 'ifany') |> print() |> prop.table()
     FALSE       TRUE 
0.06574924 0.93425076 

Putem estima un model de regresie logistică și obținem probabilitățile prezise pentru setul de date de validare.

pisa_lr_out <- glm(pisa_formu, data = pisa_train, family = binomial(link = 'logit'))
pisa_predictions <- predict(pisa_lr_out, newdata = pisa_valid, type="response")

Figura de mai jos arată distribuția probabilităților prevăzute pentru setul de date de validare. Există o oarecare separare între elevii de școală publică și privați, dar densitățile sunt clar centrate în partea dreaptă a gamei.

ggplot(data.frame(Public = pisa_valid$Public, 
                  Prediction = pisa_predictions), 
       aes(x = Prediction, color = Public)) +
  geom_density()

Figura de mai jos oferă o curbă caracteristică a operatorului de receptor (ROC) împreună cu o diagramă a preciziei, sensibilității și specificității.

calculate_roc(predictions = pisa_predictions, 
              observed = pisa_valid$Public) |> plot()

Matricea de confuzie de mai jos, împărțind la 0,5, indică faptul că acest model nu este mai bun decât modelul nul (adică procentul de studenți ai școlii publice este de 93,4%). Desigur, am putea ajusta această valoare redusă pentru a optimiza fie specificitate sau sensibilitate.

confusion_matrix(observed = pisa_valid$Public, 
                 predicted = pisa_predictions > 0.5)
           predicted              
  observed     FALSE          TRUE
     FALSE 1 (0.08%)    85 (6.50%)
      TRUE 4 (0.31%) 1218 (93.12%)
Accuracy: 93.2%
Sensitivity: 1.16%
Specificity: 99.67%

Micșorarea valorilor montate

Se dovedește că gama de valori montate din regresia logistică se va micșora pe măsură ce cantitatea de dezechilibru în variabila dependentă crește. Am întâlnit pentru prima dată această problemă când am estimat scoruri de propensiune pentru disertația mea într -un studiu al Charter versus elevii tradiționali ai școlii publice. În studiul respectiv folosind evaluarea națională a progresului educațional (NAEP), aproximativ 3% dintre studenți au participat la o școală charter. În acel studiu, gama de scoruri de propensiune a fost grav restricționată. Pentru a explora de ce, funcția multilevel :: PSRANGE () a fost dezvoltată rezultatul acestei funcții este figura de mai jos. Începând cu partea de jos, 345 de studenți ai școlii publice au fost selectați la întâmplare, astfel încât regresia logistică ar putea fi estimată acolo unde există un echilibru perfect în variabila dependentă. Pe măsură ce urcăm în sus, creștem raportul de la 1: 1 la 1:13. Pentru fiecare raport, sunt desenate 20 de probe aleatorii, modelul de regresie logistică estimat și se înregistrează valorile minime și maxime montate (adică probabilități prezise) (sunt reprezentate de punctele negre și barele verzi). Distribuțiile pe toate modelele sunt de asemenea incluse.

Transmiterea doar a intervalelor împreună cu media valorilor montate pentru elevii de școală publici (albaștri) și privați (verzi) arată că, odată ce raportul este mai mare de 3-la-1, media valorilor montate pentru clasa zero (școlile private din acest exemplu) este mai mare de 0,5.

Eșantionare în jos

După cum s -a discutat mai sus, unul dintre dezavantajele cheie ale eșantionării în jos este că, în situațiile în care există un dezechilibru semnificativ, excludem o mulțime de date din analiză. downsample() Funcția va determina mai întâi câte modele trebuie estimate astfel încât fiecare observație din clasa mai mare este utilizată exact o dată. Pentru acest exemplu, utilizăm un raport al studenților public-privat de 2-la-1, astfel încât pentru fiecare model estimat că există 259 de observații private și 518 de studenți publici. Având în vedere că există 3925 de observații în setul nostru de antrenament, dowmsample() Funcția va estima 7 modele.

pisa_ds_out <- downsample(
  formu = pisa_formu,
  data = pisa_train,
  model_fun = glm,
  ratio = 2,
  family = binomial(link = 'logit'))
length(pisa_ds_out)

Putem folosi predict() Funcție pentru a obține un cadru de date de predicții. Fiecare coloană corespunde valorii prevăzute pentru fiecare din cele 7 modele.

pisa_predictions_ds <- predict(pisa_ds_out,
                               newdata = pisa_valid, 
                               type="response")
head(pisa_predictions_ds)
     model1    model2    model3    model4    model5    model6    model7
1 0.8511828 0.7437341 0.8921605 0.8424369 0.7347052 0.8697531 0.6928875
2 0.7393116 0.6822714 0.9466815 0.7959953 0.8118642 0.9441840 0.9580830
3 0.4944206 0.3813138 0.5575741 0.3586561 0.5023435 0.5805062 0.6281852
4 0.8525691 0.8268514 0.8293386 0.8372777 0.9464037 0.8843848 0.9016058
5 0.1823382 0.3670335 0.4556063 0.1408078 0.1899378 0.3578418 0.2657968
6 0.9216160 0.8192096 0.9040353 0.9213184 0.8080822 0.9076342 0.8768295

Putem media predicțiile pentru a obține un singur vector.

pisa_predictions_ds2 <- pisa_predictions_ds |> apply(1, mean)

Distribuțiile de densitate sunt prezentate mai jos. Aceste distribuții seamănă mai mult cu distribuțiile pe care le așteptăm atunci când avem date echilibrate, chiar dacă am folosit toate observațiile pentru a obține aceste probabilități prezise.

ggplot(data.frame(Public = pisa_valid$Public, 
                  Prediction = pisa_predictions_ds2), 
       aes(x = Prediction, color = Public)) +
  geom_density()

Deși downsample() Funcția pare să abordeze problema valorilor montate de micșorare și oprire centrată, valorile de performanță ale modelului furnizate mai jos sugerează că nu a îmbunătățit performanța generală a predicțiilor modelului.

roc <- calculate_roc(predictions = pisa_predictions_ds2, 
                     observed = pisa_valid$Public)
plot(roc)

confusion_matrix(observed = pisa_valid$Public, 
                 predicted = pisa_predictions_ds2 > 0.5)
              predicted              
  observed        FALSE          TRUE
     FALSE   45 (3.44%)    41 (3.13%)
      TRUE 204 (15.60%) 1018 (77.83%)
Accuracy: 81.27%
Sensitivity: 52.33%
Specificity: 83.31%

Apendicele: rezumate ale modelului

Mai sus am înregistrat o medie a valorilor prevăzute pe toate modelele pentru a obține o singură predicție pentru fiecare observație din setul nostru de date de validare. Cu toate acestea, este posibil să grupați modele folosind mice::pool() Funcție pentru a obține un singur set de coeficienți de regresie. Tabelul de mai jos oferă coeficienții de regresie combinați de la downsample Funcționați împreună cu coeficienții din modelul de regresie logistică folosind toate datele.

Reclamate de la fațete Date complete
(Intercepta) 5.516 * (2.376) 7.862 *** (1.657)
Sexmale -0.823 ** (0.244) -0.762 *** (0.149)
`Participă – Da, un an sau mai puțin 0,495 (0,286) 0,498 ** (0,190)
`Vârsta la ` 0,087 (0,186) 0,075 (0,104)
`Repetă – Da, o dată 0,693 (0,548) 0,645 (0,345)
`Acasă – Mama`true -0.888 (0,843) -0.921 (0.493)
`Acasă – tată ‘ -0.672 (0,414) -0.620 * (0.277)
`Acasă – Brothers`true 0,184 (0,313) 0,155 (0,146)
`Acasă – surori`true 0,563 * (0,237) 0,454 ** (0,146)
`Acasă – bunici`true -0.725 * (0.328) -0.648 ** (0,201)
`Acasă – alții ‘ -0.094 (0.332) -0.136 (0.221)
`Mamă ` 0,101 (0,631) 0,070 (0,388)
„Starea actuală a mamei actuale” -0.457 (0,609) -0.303 (0,364)
„Starea actuală a postului de lucru” cu normă întreagă -0.608 (0,537) -0.443 (0,339)
„Starea actuală a postului de lucru” -0.586 (0,650) -0.446 (0,369)
– Părinte ` 0,072 (1.219) 0,077 (0,832)
– Părinte ` -0.579 (1.121) -0.658 (0,754)
„Starea de muncă actuală a părintelui” -0.019 (0.737) -0.065 (0.394)
`Starea actuală a tatălui de muncă` Părintele de lucru cu normă întreagă 0,356 (0,604) 0,236 (0,324)
`Părintele statutului actual de muncă ‘prelucrare part-time 1,260 (0,892) 0,998 (0,529)
„Limba acasă” de testare 0,137 (0,489) -0.104 (0.263)
`Possesions Desk`true -0.583 (0.404) -0.531 * (0.265)
`Posesiuni propria cameră`true 0,521 (0,384) 0,600 * (0,238)
`Possesions Study Place`true -0.056 (0.488) -0.223 (0,303)
`Posessions Computer`true -0.077 (0.855) -0.038 (0,592)
`Software posesions`true 0,365 (0,332) 0,358 * (0,161)
`Posesiuni Internet`true -1.416 (0,917) -1.177 (0,602)
`Posesiuni literatură`true -0.619 * (0.295) -0.551 ** (0.175)
`Posesiuni Poetry`true -0.369 (0,308) -0.250 (0.176)
`Posesiuni art`true -0.402 (0,333) -0.273 (0.196)
`Posesions manuale `true -0.021 (0.356) -0.007 (0.214)
`Dicționar posesions`true 0,100 (0,583) -0.000 (0,422)
`Posesiuni Mașină de spălat vase ‘ -0.074 (0.336) -0.078 (0.234)
„Câte telefoane celulare” -0.851 (0,987) -0.906 (0,741)
„Câte telefoane celulare” -0.192 (0.985) -0.478 (0,771)
„Câte televizoare” sau mai multe 1.378 * (0,645) 0,995 *** (0,302)
„Câte televizoare” 0,816 (0,640) 0,519 (0,324)
„Câte computere” 0,589 (1.188) 0,343 (0,823)
„Câte calculatoare” sau mai multe 0,174 (1.168) -0.072 (0.838)
„Câte calculatoare” 0,079 (1.143) -0.120 (0,832)
„Câte mașini” sau mai multe 0,038 (0,458) -0.041 (0.295)
„Câte mașini” două -0.223 (0,427) -0.264 (0,291)
„Câte camere baie sau duș” sau mai multe -1.001 * (0.427) -0.703 ** (0.238)
„Câte camere baie sau duș” -0.223 (0,371) -0.107 (0.217)
„Câte cărți acasă ”101-200 cărți -0.255 (0,440) -0.367 (0,327)
„Câte cărți acasă ”11-25 cărți -0.158 (0.431) -0.158 (0,339)
`Câte cărți acasă`201-500 cărți -0.998 * (0.477) -0.985 ** (0.334)
„Câte cărți acasă ”26-100 cărți -0.498 (0,395) -0.489 (0,302)
„Câte cărți acasă” mai mult de 500 de cărți -1.082 (0,558) -1.042 ** (0,366)
`Citirea timpului de plăcere ’30 minute sau mai puțin pe zi -0.071 (0.477) -0.183 (0.251)
„Citirea plăcerii timpului” între 30 și 60 de minute 0,494 (0,457) 0,311 (0,283)
„Citirea plăcerii timpului” Nu citesc pentru plăcere -0.019 (0.466) -0.118 (0.259)
`Citirea timpului de plăcere ‘mai mult de 2 ore pe zi 0,245 (0,747) 0,208 (0,406)
` în `Adevărat -0.012 (0.671) 0,242 (0,413)
` în `Adevărat -0.236 (0,569) -0.276 (0,323)
` în `Adevărat 0,576 (0,634) 0,367 (0,410)
` în `Adevărat 0,311 (0,964) 0,041 (0,524)
` în `Adevărat -0.511 (0,685) -0.614 (0,384)
` în `Adevărat -0.197 (0.789) 0,245 (0,496)
„În afara lecțiilor școlare „Nu participați -0.279 (0,815) -0.106 (0.493)
„În afara lecțiilor școlare `Mai puțin de 2 ore pe săptămână -0.685 (0,712) -0.533 (0,465)
„În afara lecțiilor școlare `4 până la 6 ore pe săptămână 0,237 (0,902) 0,386 (0,541)
„În afara lecțiilor școlare „Nu participați -0.254 (0,798) -0.160 (0,410)
„În afara lecțiilor școlare `Mai puțin de 2 ore pe săptămână -0.117 (0.618) 0,122 (0,369)
„În afara lecțiilor școlare `4 până la 6 ore pe săptămână -1.091 (0,789) -0.837 (0,549)
„În afara lecțiilor școlare „Nu participați -0.565 (0,717) -0.456 (0,472)
„În afara lecțiilor școlare `Mai puțin de 2 ore pe săptămână -0.534 (0,751) -0.513 (0,464)
n 783 3925.000
*** P <0,001; ** P <0,01; * P <0.05.

Dominic Botezariu
Dominic Botezariuhttps://www.noobz.ro/
Creator de site și redactor-șef.

Cele mai noi știri

Pe același subiect

LĂSAȚI UN MESAJ

Vă rugăm să introduceți comentariul dvs.!
Introduceți aici numele dvs.