caret
- 1 o que é o
caret
- 2 avaliação dos dados
- 3 pré-processamento dos dados
- 4 treinamento e avaliação
- 5 regressão: rf - random forest
- 6 regressão: knn - kth nearest neighbor
- 7 classificação
- 8 classificação: SVM (Support Vector Machines)
- 9 classificação: árvore de decisão
Alguns pacotes do R úteis para ML:
Classification And Regression Training: um dos pacotes mais populares. Tem um conjunto de funções pré-definidas que agilizam a implementação de modelos de ML, como pré-processamento dos dados, seleção de atributos, ajuste de hiperparâmetros, etc, para um número muito grande de modelos.
- mlr3 - https://mlr3.mlr-org.com/
também provê os blocos essenciais para implementação de projetos de ML
classificação e regressão com códigos de árvore
classificação e regressão com randomForest
implementa vários algoritmos, como support vector machines
deep learning com TensorFlow
- e muitos outros!
caret
caret
: Classification And REgression Training
Este pacote do R contém várias funções que facilitam muito várias tarefas em aprendizado de máquina (ML).
Vamos ilustrar o
caret
com um problema de regressão: estimativa de redshifts fotométricos para galáxias do S-PLUS (Mendes de Oliveira et al. 2019; arXiv:1907.01567).
# tempo do início do processamento
t00 = Sys.time()
# vamos carregar vários pacotes
suppressMessages(library(caret))
suppressMessages(library(tidyverse))
suppressMessages(library(readxl))
suppressMessages(library(ggplot2))
suppressMessages(library(corrplot))
suppressMessages(library(ranger))
suppressMessages(library(dplyr))
suppressMessages(library(e1071))
suppressMessages(library(skimr))
suppressMessages(library(pROC))
suppressMessages(library("rpart"))
suppressMessages(library("rpart.plot"))
suppressMessages(library(KernSmooth))
# reprodutibilidade
set.seed(123)
a) leitura dos dados: vamos considerar uma amostra de galáxias do S-PLUS com fotometria nas 12 bandas
o arquivo splus-mag-z.dat contém, para cada objeto da amostra, as 12 magnitudes e o redshift espectroscópico (do SDSS)
vamos selecionar apenas galáxias com magnitudes no intervalo 15 \(<\) r_petro \(<\) 20.
dados = as.data.frame(read.table('splus-mag-z.dat', sep = "", header=TRUE))
# dimensão dos dados:
dim(dados)
## [1] 55803 13
# seleção em magnitudes
sel = rep(0, nrow(dados))
sel[dados$r_petro > 15 & dados$r_petro < 20] = 1
sum(sel)
## [1] 54313
dados = dados[sel == 1,]
dim(dados)
## [1] 54313 13
# topo do arquivo
head(dados)
## uJAVA_petro F378_petro F395_petro F410_petro F430_petro g_petro F515_petro
## 1 18.71 18.69 18.50 18.12 17.78 17.43 17.09
## 2 18.26 18.12 18.28 17.32 16.95 16.67 16.24
## 3 19.38 19.45 18.76 18.60 18.50 18.07 17.69
## 4 20.30 19.73 20.01 19.18 18.91 18.36 18.00
## 5 19.49 18.98 18.64 18.64 18.02 17.14 16.61
## 6 19.86 19.38 18.93 18.75 18.15 17.44 16.94
## r_petro F660_petro i_petro F861_petro z_petro z_SDSS
## 1 16.89 16.83 16.56 16.54 16.45 0.111
## 2 15.97 15.90 15.61 15.52 15.41 0.082
## 3 17.55 17.47 17.25 17.31 17.18 0.086
## 4 17.75 17.68 17.51 17.46 17.34 0.111
## 5 16.06 15.93 15.65 15.54 15.38 0.162
## 6 16.58 16.48 16.24 16.11 16.00 0.082
é sempre bom ver se há dados faltantantes ou outros problemas com os dados:
# a função skim() provê um sumário da estatística descritiva de cada variável
sumario <- skim(dados)
sumario
Name | dados |
Number of rows | 54313 |
Number of columns | 13 |
_______________________ | |
Column type frequency: | |
numeric | 13 |
________________________ | |
Group variables | None |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
uJAVA_petro | 0 | 1 | 20.10 | 1.21 | 14.95 | 19.36 | 20.05 | 20.76 | 31.02 | ▁▇▁▁▁ |
F378_petro | 0 | 1 | 19.85 | 1.17 | 12.94 | 19.15 | 19.81 | 20.50 | 28.69 | ▁▃▇▁▁ |
F395_petro | 0 | 1 | 19.69 | 1.22 | 12.72 | 18.96 | 19.66 | 20.37 | 29.16 | ▁▅▇▁▁ |
F410_petro | 0 | 1 | 19.50 | 1.20 | 11.89 | 18.79 | 19.48 | 20.19 | 28.37 | ▁▂▇▁▁ |
F430_petro | 0 | 1 | 19.25 | 1.19 | 11.66 | 18.54 | 19.25 | 19.97 | 29.13 | ▁▃▇▁▁ |
g_petro | 0 | 1 | 18.72 | 1.06 | 15.10 | 18.05 | 18.73 | 19.45 | 24.35 | ▁▇▇▁▁ |
F515_petro | 0 | 1 | 18.42 | 1.08 | 10.78 | 17.74 | 18.43 | 19.16 | 25.06 | ▁▁▇▂▁ |
r_petro | 0 | 1 | 17.97 | 1.02 | 15.01 | 17.30 | 18.01 | 18.76 | 19.99 | ▁▃▇▇▅ |
F660_petro | 0 | 1 | 17.84 | 1.03 | 10.70 | 17.16 | 17.89 | 18.65 | 21.11 | ▁▁▂▇▂ |
i_petro | 0 | 1 | 17.62 | 1.03 | 10.15 | 16.93 | 17.68 | 18.43 | 20.52 | ▁▁▂▇▃ |
F861_petro | 0 | 1 | 17.46 | 1.05 | 10.39 | 16.75 | 17.52 | 18.29 | 21.00 | ▁▁▃▇▁ |
z_petro | 0 | 1 | 17.42 | 1.06 | 9.83 | 16.70 | 17.48 | 18.26 | 20.96 | ▁▁▂▇▁ |
z_SDSS | 0 | 1 | 0.15 | 0.09 | 0.00 | 0.08 | 0.13 | 0.19 | 1.12 | ▇▂▁▁▁ |
Podemos ver que não há dados faltantes; se tivesse precisaríamos ou removê-los ou atribuir um valor para eles (imputação) pois os algoritmos que vamos usar precisam disso!
Como estamos interessados apenas em explorar o pacote e algumas de suas principais funções, vamos restringir o número de objetos:
nsample = 10000
# magnitudes
mags = dados[1:nsample,1:12]
# redshift
zspec = dados[1:nsample,13]
b) vamos dar uma olhada nos dados:
par(mfrow = c(1,2))
hist(mags$r_petro,xlab='r_petro',main='',col='red')
hist(zspec,xlab='z_SDSS',main='',col='blue')
c) vamos visualizar as correlações entre as várias magnitudes
as variáveis usadas na estimativa (as magnitudes) são denominadas atributos ou “features”
dados.cor = cor(mags)
corrplot(dados.cor, type = "upper", order = "original", tl.col = "black", tl.srt = 45)
redshift versus cada banda fotométrica:
visualização usando
featurePlot
:
featurePlot(x = mags, y = zspec, plot = "scatter")
Em geral o desempenho dos algoritmos melhora se os dados tiverem um intervalo de valores “razoável”. Assim, é conveniente pré-processar os dados, isto é, transformá-los para colocá-los num intervalo para tornar a análise mais eficiente.
Há muitos tipos de pré-processamento; os mais comuns são
- reescalonar as variáveis entre 0 e 1
- subtrair a média de cada variável e dividir por seu desvio padrão
vamos colocar as 12 magnitudes entre 0 e 1:
maxs <- apply(mags, 2, max)
mins <- apply(mags, 2, min)
magnorm <- as.data.frame(scale(mags, center = mins, scale = maxs - mins))
# exemplo de dados normalizados
summary(magnorm[,1])
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.0000 0.2420 0.2869 0.2889 0.3317 1.0000
#se quiséssemos subtrair a média e dividir pelo desvio padrão:
#magnorm <- as.data.frame(scale(mags, center = TRUE, scale = TRUE))
# não é necessário normalizar z nesta amostra pois ele está num intervalo "razoável":
summary(zspec)
## Min. 1st Qu. Median Mean 3rd Qu. Max.
## 0.0040 0.0770 0.1150 0.1435 0.1900 0.9420
Em ML o algoritmo aprende a partir dos dados.
Para avaliar quanto ele aprendeu, comparamos suas predições com resultados conhecidos (aprendizado supervisionado)
Para isso divide-se os dados em conjunto de treinamento, para determinar os parâmetros do modelo, e conjunto de teste, para se determinar a qualidade do resultado.
- criação de conjuntos de treinamento e teste:
Vamos considerar 75% dos objetos para treinamento e 25% para teste
# número das linhas dos objetos selecionados para o conjunto de treinamento
numlinhastreinamento = createDataPartition(zspec, p=0.75, list=FALSE)
# conjunto de treinamento:
xtrain=magnorm[numlinhastreinamento,]
#ytrain=as.vector(zspec[numlinhastreinamento])
ytrain=zspec[numlinhastreinamento]
# conjunto de teste
xtest=magnorm[-numlinhastreinamento,]
#ytest=as.vector(zspec[-numlinhastreinamento])
ytest=zspec[-numlinhastreinamento]
- seleção do algoritmo:
O
caret
contém um número enorme de algoritmos de ML. Para ver quais são:
paste(names(getModelInfo()), collapse=', ')
## [1] "ada, AdaBag, AdaBoost.M1, adaboost, amdai, ANFIS, avNNet, awnb, awtan, bag, bagEarth, bagEarthGCV, bagFDA, bagFDAGCV, bam, bartMachine, bayesglm, binda, blackboost, blasso, blassoAveraged, bridge, brnn, BstLm, bstSm, bstTree, C5.0, C5.0Cost, C5.0Rules, C5.0Tree, cforest, chaid, CSimca, ctree, ctree2, cubist, dda, deepboost, DENFIS, dnn, dwdLinear, dwdPoly, dwdRadial, earth, elm, enet, evtree, extraTrees, fda, FH.GBML, FIR.DM, foba, FRBCS.CHI, FRBCS.W, FS.HGD, gam, gamboost, gamLoess, gamSpline, gaussprLinear, gaussprPoly, gaussprRadial, gbm_h2o, gbm, gcvEarth, GFS.FR.MOGUL, GFS.LT.RS, GFS.THRIFT, glm.nb, glm, glmboost, glmnet_h2o, glmnet, glmStepAIC, gpls, hda, hdda, hdrda, HYFIS, icr, J48, JRip, kernelpls, kknn, knn, krlsPoly, krlsRadial, lars, lars2, lasso, lda, lda2, leapBackward, leapForward, leapSeq, Linda, lm, lmStepAIC, LMT, loclda, logicBag, LogitBoost, logreg, lssvmLinear, lssvmPoly, lssvmRadial, lvq, M5, M5Rules, manb, mda, Mlda, mlp, mlpKerasDecay, mlpKerasDecayCost, mlpKerasDropout, mlpKerasDropoutCost, mlpML, mlpSGD, mlpWeightDecay, mlpWeightDecayML, monmlp, msaenet, multinom, mxnet, mxnetAdam, naive_bayes, nb, nbDiscrete, nbSearch, neuralnet, nnet, nnls, nodeHarvest, null, OneR, ordinalNet, ordinalRF, ORFlog, ORFpls, ORFridge, ORFsvm, ownn, pam, parRF, PART, partDSA, pcaNNet, pcr, pda, pda2, penalized, PenalizedLDA, plr, pls, plsRglm, polr, ppr, PRIM, protoclass, qda, QdaCov, qrf, qrnn, randomGLM, ranger, rbf, rbfDDA, Rborist, rda, regLogistic, relaxo, rf, rFerns, RFlda, rfRules, ridge, rlda, rlm, rmda, rocc, rotationForest, rotationForestCp, rpart, rpart1SE, rpart2, rpartCost, rpartScore, rqlasso, rqnc, RRF, RRFglobal, rrlda, RSimca, rvmLinear, rvmPoly, rvmRadial, SBC, sda, sdwd, simpls, SLAVE, slda, smda, snn, sparseLDA, spikeslab, spls, stepLDA, stepQDA, superpc, svmBoundrangeString, svmExpoString, svmLinear, svmLinear2, svmLinear3, svmLinearWeights, svmLinearWeights2, svmPoly, svmRadial, svmRadialCost, svmRadialSigma, svmRadialWeights, svmSpectrumString, tan, tanSearch, treebag, vbmpRadial, vglmAdjCat, vglmContRatio, vglmCumulative, widekernelpls, WM, wsrf, xgbDART, xgbLinear, xgbTree, xyf"
# você pode saber um pouco mais sobre um modelo com o comando `modelLookup:
modelLookup('lars')
## model parameter label forReg forClass probModel
## 1 lars fraction Fraction TRUE FALSE FALSE
# lars é uma variante da regressão linear
# tem um parâmetro, fraction, pode ser usado para regressão mas não para classificação, e não dá os resultados em probabilidades
random forest é um algoritimo poderoso, usado tanto em classificação como regressão
rf consiste de um grande número de árvores de decisão que funciona como um ensamble
árvores de decisão funcionam por partições sucessivas de um conjunto de dados
numa rf, cada árvore de decisão usa apenas um subconjunto dos parâmetros de entrada (as 12 magnitudes, neste caso)
isso gera um número muito grande de modelos que, juntos, têm um valor preditivo maior que os de cada modelo individual
o resultado é uma combinação dos resultados dos modelos individuais
# modelo:
modelLookup('rf')
## model parameter label forReg forClass probModel
## 1 rf mtry #Randomly Selected Predictors TRUE TRUE TRUE
# treinando o modelo- vejam a sintaxe do comando
t0 = Sys.time()
model_rf<-train(xtrain,ytrain,method='rf')
Sys.time() - t0
## Time difference of 1.127818 hours
# sumário do modelo
print(model_rf)
## Random Forest
##
## 7502 samples
## 12 predictor
##
## No pre-processing
## Resampling: Bootstrapped (25 reps)
## Summary of sample sizes: 7502, 7502, 7502, 7502, 7502, 7502, ...
## Resampling results across tuning parameters:
##
## mtry RMSE Rsquared MAE
## 2 0.05244647 0.6846652 0.03675291
## 7 0.05158889 0.6923947 0.03563120
## 12 0.05216484 0.6852680 0.03582994
##
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was mtry = 7.
plot(model_rf)
# o melhor subconjunto tem 7 magnitudes
# a predição será feita com um conjunto de árvores com 7 magnitudes escolhidas aleatoriamente entre as 12
# predição com o conjunto de teste
pred_rf = predict(model_rf, xtest)
# estatística da performance do algoritmo
# sigma equivalente da gaussiana
relz = (ytest - pred_rf)
sig_G = 0.7413*(quantile(relz,0.75,names = FALSE) - quantile(relz,0.25,names = FALSE))
sig_G
## [1] 0.03548804
#visualização do resultado
my_data = as.data.frame(cbind(predicted = pred_rf,
observed = ytest))
ggplot(my_data,aes(predicted, observed)) +
geom_point(color = "darkred", alpha = 0.5) +
geom_smooth(method=lm)+
ggtitle("RF: redshift predito x observado") +
xlab("photo-z ") +
ylab("z spec") +
theme(plot.title = element_text(color="darkgreen",size=18,hjust = 0.5),
axis.text.y = element_text(size=12),
axis.text.x = element_text(size=12,hjust=.5),
axis.title.x = element_text(size=14),
axis.title.y = element_text(size=14))
knn, ou o k-ésimo vizinho mais próximo, faz uma inferência do redshift de uma galáxia no conjunto de teste a partir do redshift das galáxias com magnitudes mais “próximas” a ela no conjunto de treinamento
o parâmetro livre importante é o k, o número de vizinhos a considerar
o valor predito pode ser, por exemplo, o redshift médio desses k vizinhos mais próximos
no exemplo a seguir, k é determinado por validação cruzada- preste atenção na estrutura da função train():
# modelo:
modelLookup('knn')
## model parameter label forReg forClass probModel
## 1 knn k #Neighbors TRUE TRUE TRUE
t0 = Sys.time()
model_knn <- train(xtrain,ytrain,method='knn',
trControl = trainControl(method = "cv", number = 5),
tuneGrid = expand.grid(k = seq(1, 41, by = 2)))
Sys.time()-t0
## Time difference of 12.85312 secs
plot(model_knn)
# predição com o conjunto de teste
pred_knn = predict(model_knn, xtest)
relz = (ytest - pred_knn)
sig_G = 0.7413*(quantile(relz,0.75,names = FALSE) - quantile(relz,0.25,names = FALSE))
sig_G
## [1] 0.04037614
#
# o desempenho foi pior que o usando rf
#
#visualização do resultado
my_data = as.data.frame(cbind(predicted = pred_knn,
observed = ytest))
ggplot(my_data,aes(predicted, observed)) +
geom_point(color = "darkred", alpha = 0.5) +
geom_smooth(method=lm)+
ggtitle("knn: z_spec x z_phot") +
xlab("z_phot ") +
ylab("z_spec") +
theme(plot.title = element_text(color="darkgreen",size=18,hjust = 0.5),
axis.text.y = element_text(size=12),
axis.text.x = element_text(size=12,hjust=.5),
axis.title.x = element_text(size=14),
axis.title.y = element_text(size=14))
# comparação dos dois resultados:
my_data = as.data.frame(cbind(pred1 = pred_knn,
pred2 = pred_rf))
ggplot(my_data,aes(pred1,pred2)) +
geom_point(color = "darkred", alpha = 0.5) +
geom_smooth(method=lm)+
ggtitle("knn: predições knn & rf ") +
xlab("z_knn ") +
ylab("z_rf") +
theme(plot.title = element_text(color="darkgreen",size=18,hjust = 0.5),
axis.text.y = element_text(size=12),
axis.text.x = element_text(size=12,hjust=.5),
axis.title.x = element_text(size=14),
axis.title.y = element_text(size=14))
Vamos agora considerar problemas de classificação, considerando a classificação de um conjunto de dados fotométricos em estrelas ou galáxias.
Para ilustração, vamos usar 1000 objetos classificados como estrelas (classe 0) ou galáxias (classe 1):
tabela = as.data.frame(read.table(file="class_estr_gal.dat", header=TRUE))
# alguns detalhes dos dados
dim(tabela)
## [1] 1000 13
head(tabela)
## uJAVA_petro F378_petro F395_petro F410_petro F430_petro g_petro F515_petro
## 1 23.23 24.09 21.06 22.11 21.36 21.76 20.69
## 2 21.92 20.95 21.30 20.00 21.24 20.73 19.86
## 3 20.44 20.78 20.70 20.92 20.05 20.15 19.55
## 4 21.55 22.96 20.22 19.63 19.83 19.87 19.39
## 5 22.32 21.41 26.56 20.28 20.52 20.28 20.02
## 6 21.26 20.43 20.60 19.44 19.16 18.52 18.34
## r_petro F660_petro i_petro F861_petro z_petro classe
## 1 20.87 20.53 20.14 19.83 19.71 0
## 2 20.01 20.18 20.36 20.53 20.00 0
## 3 19.49 19.36 19.08 19.00 19.07 0
## 4 19.02 18.78 18.55 18.02 18.17 0
## 5 19.60 19.53 19.34 19.47 18.99 1
## 6 17.65 17.56 17.34 17.30 17.19 0
# número de objetos em cada classe:
length(tabela$classe[tabela$classe == 0])
## [1] 493
length(tabela$classe[tabela$classe == 1])
## [1] 507
Vamos examinar os dados em um diagrama cor-magnitude:
par(mfrow = c(1,3))
ca = tabela$g_petro-tabela$r_petro
cb = tabela$r_petro-tabela$i_petro
smoothScatter(ca,cb,nrpoints=0,add=FALSE,xlab="g-r",ylab="r-i",xlim=c(-1,3),ylim=c(-1,3),main='todos')
ca = tabela$g_petro[tabela$classe == 0]-tabela$r_petro[tabela$classe == 0]
cb = tabela$r_petro[tabela$classe == 0]-tabela$i_petro[tabela$classe == 0]
smoothScatter(ca,cb,nrpoints=0,add=FALSE,xlab="g-r",ylab="r-i",xlim=c(-1,3),ylim=c(-1,3),main='classe = 0')
ca = tabela$g_petro[tabela$classe == 1]-tabela$r_petro[tabela$classe == 1]
cb = tabela$r_petro[tabela$classe == 1]-tabela$i_petro[tabela$classe == 1]
smoothScatter(ca,cb,nrpoints=0,add=FALSE,xlab="g-r",ylab="r-i",xlim=c(-1,3),ylim=c(-1,3),main='classe = 1')
Vamos definir os ‘features’ (as magnitudes) e pré-processá-los:
# definindo os features
mag = tabela[,1:12]
# pré-processamento
maxs <- apply(mag, 2, max)
mins <- apply(mag, 2, min)
normalizacao <- as.data.frame(scale(mag, center = mins, scale = maxs - mins))
Vamos definir os conjuntos de treinamento e teste:
# conjuntos de treino e teste:
ntreino = round(0.75*nrow(tabela))
nteste = nrow(tabela)-ntreino
c(ntreino,nteste)
## [1] 750 250
# número das linhas dos objetos selecionados para o conjunto de treinamento
set.seed(123)
indice = createDataPartition(tabela$classe, p=0.75, list=FALSE)
xtrain = normalizacao[indice,]
# ATENÇÃO: para classificação a variável dependente deve ser tipo 'factor'
ytrain = as.factor(tabela[indice,13])
xtest = normalizacao[-indice,]
ytest = as.factor(tabela[-indice,13])
SVM, Support Vector Machines, é um algoritmo bastante poderoso e bastante usado em problemas de classificação.
Para ver como o algoritmo funciona, considere a figura abaixo: temos duas classes (vermelho e azul) que queremos separar. SVM ajusta um hiper-plano (uma reta nesse caso) que maximiza a margem entre as classes:
https://www.datacamp.com/community/tutorials/support-vector-machines-r
Assim, objetos de um lado da linha do melhor hiper-plano são classificados em uma classe e os do outro lado da linha na outra classe.
Neste exemplo, as classes são linearmente separáveis, isto é, uma linha (ou hiper-plano), provê uma separação adequada entre as classes.
Considere agora o espaço de dados ilustrado no lado esquerdo da figura abaixo: ele não é linearmente separável- não há como traçar uma reta neste espaço que separe os pontos azuis dos vermelhos.
Mas os pontos azuis e vermelhos estão claramente segregados! Um jeito de torná-los linearmente separáveis é introduzir uma terceira dimensão: \[z = x^2 + y^2.\]
No caso do nosso exemplo isso torna o espaço de dados linearmente separável, como ilustrado no lado direito da figura:
https://www.datacamp.com/community/tutorials/support-vector-machines-r
mapeando este hiper-plano de volta no espaço de dados original obtemos um círculo:
https://www.datacamp.com/community/tutorials/support-vector-machines-r
Mas como SVM faz isso? Pode-se mostrar que a equação que precisa ser minimizada (num espaço de n dimensões) não depende da posição dos pontos mas apenas de seu produto interno (\(\mathbf{x}_i . \mathbf{x}_j\)), que é a distância entre dois pontos no espaço de dados.
Assim, se quisermos transformar os dados para um espaço de maior dimensão, não precisamos calcular a transformação exata dos dados e só precisamos do produto interno dos dados nesse espaço.
kernels permitem isso: por exemplo, um kernel gaussiano mapeia os dados num espaço de dimensão infinita!
Isso é chamado de kernel trick: aumenta-se o espaço de dados para permitir uma separação das classes.
Aqui vamos usar um kernel chamado Radial Basis Function (RBF) \[ K(\mathbf{x}_i,\mathbf{x}_j) = \exp \Bigg[ -\gamma \sum_k^{dim} (x_{ik}-x_{jk})^2 \Bigg]\]
No caret temos o algoritmo svmRadial:
modelLookup('svmRadial')
## model parameter label forReg forClass probModel
## 1 svmRadial sigma Sigma TRUE TRUE TRUE
## 2 svmRadial C Cost TRUE TRUE TRUE
Este modelo tem dois parâmetros ajustáveis: sigma e C
- sigma controla a “suavidade” do limite de decisão do modelo
- C penaliza o modelo por classificações erradas: quanto maior C, menor a probabilidade de um erro.
O pacote otimiza esses parâmetros maximizando a acurácia, no caso por validação cruzada:
t0 = Sys.time()
model_svm <- train(xtrain,ytrain,method='svmRadial',
tuneLength = 10,
trControl = trainControl(method = "cv"))
Sys.time()-t0
## Time difference of 4.896056 secs
# vamos examinar o modelo ajustado:
print(model_svm)
## Support Vector Machines with Radial Basis Function Kernel
##
## 750 samples
## 12 predictor
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Cross-Validated (10 fold)
## Summary of sample sizes: 674, 674, 675, 674, 675, 675, ...
## Resampling results across tuning parameters:
##
## C Accuracy Kappa
## 0.25 0.7265306 0.4518432
## 0.50 0.7371811 0.4730668
## 1.00 0.7532717 0.5049994
## 2.00 0.7518853 0.5020645
## 4.00 0.7665723 0.5319152
## 8.00 0.7719597 0.5429883
## 16.00 0.7626610 0.5244985
## 32.00 0.7586064 0.5167716
## 64.00 0.7440977 0.4879787
## 128.00 0.7294452 0.4591230
##
## Tuning parameter 'sigma' was held constant at a value of 0.2268265
## Accuracy was used to select the optimal model using the largest value.
## The final values used for the model were sigma = 0.2268265 and C = 8.
predição:
# predição com o conjunto de teste
pred = predict(model_svm, xtest)
head(pred)
## [1] 0 1 0 1 1 0
## Levels: 0 1
postResample(pred = pred ,obs= ytest)
## Accuracy Kappa
## 0.7080000 0.4164482
Em problemas de classificação duas estatísticas são muito usadas:
- acurácia: fração das classificações corretas
- kappa: similar à acurácia mas normalizada na classificação aleatória dos dados; útil quando há “desbalanço” (imbalance) entre as classes (ex.: se a razão entre as classes 0 e 1 é 70:30, tem-se 70% de acurácia prevendo qualquer objeto como classe 0)
Em classificação uma coisa importante é a “matriz de confusão”, onde se compara, no conjunto de teste, as classificações preditas com as “verdadeiras”:
# matriz de confusão
# TPR (True Positive Rate) = TP/(TP+FN)
# FPR (False Positive Rate) = FP/(TN+FP)
print(confusionMatrix(data = as.factor(pred),reference = as.factor(ytest),positive = '1'))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 83 30
## 1 43 94
##
## Accuracy : 0.708
## 95% CI : (0.6474, 0.7636)
## No Information Rate : 0.504
## P-Value [Acc > NIR] : 4.419e-11
##
## Kappa : 0.4164
##
## Mcnemar's Test P-Value : 0.1602
##
## Sensitivity : 0.7581
## Specificity : 0.6587
## Pos Pred Value : 0.6861
## Neg Pred Value : 0.7345
## Prevalence : 0.4960
## Detection Rate : 0.3760
## Detection Prevalence : 0.5480
## Balanced Accuracy : 0.7084
##
## 'Positive' Class : 1
##
tab = table(pred, ytest)
# essa função calcula a acurácia
accuracy <- function(x){sum(diag(x)/(sum(rowSums(x)))) * 100}
accuracy(tab)
## [1] 70.8
Vamos usar o algoritmo ‘tree’ do pacote rpart.
Nesse caso, é preciso passar uma fórmula.
Vamos definir ‘f’ como a fórmula a ser passada:
n <- c(names(mag))
f = as.formula(paste("ytrain ~", paste(n[!n %in% "medv"], collapse = " + ")))
f
## ytrain ~ uJAVA_petro + F378_petro + F395_petro + F410_petro +
## F430_petro + g_petro + F515_petro + r_petro + F660_petro +
## i_petro + F861_petro + z_petro
# isto é, vamos fazer a classificação usando todas as magnitudes
t0 = Sys.time()
tree = rpart(f, data = xtrain, method = "class")
Sys.time()-t0
## Time difference of 0.02491951 secs
visualização da árvore:
rpart.plot(tree)
Nessa figura fica claro como a classificação é feita; note que apenas 6 bandas foram usadas!
predição e acurácia:
pr = predict(tree, xtest)
prl = rep(0,length(ytest))
prl[pr[,2] > pr[,1]] = 1
pr = prl
tab = table(pr, ytest)
length(pr[pr == 1 & ytest == 0])/length(pr)
## [1] 0.148
length(pr[pr == 0 & ytest == 1])/length(pr)
## [1] 0.156
print(confusionMatrix(data = as.factor(pr),reference = as.factor(ytest),positive = '1'))
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 89 39
## 1 37 85
##
## Accuracy : 0.696
## 95% CI : (0.6349, 0.7524)
## No Information Rate : 0.504
## P-Value [Acc > NIR] : 5.673e-10
##
## Kappa : 0.3919
##
## Mcnemar's Test P-Value : 0.9087
##
## Sensitivity : 0.6855
## Specificity : 0.7063
## Pos Pred Value : 0.6967
## Neg Pred Value : 0.6953
## Prevalence : 0.4960
## Detection Rate : 0.3400
## Detection Prevalence : 0.4880
## Balanced Accuracy : 0.6959
##
## 'Positive' Class : 1
##
accuracy(tab)
## [1] 69.6
tempo de processamento do script:
Sys.time() - t00
## Time difference of 1.136173 hours
Verifique o tempo de processamento dos vários algoritmos para ter uma ideia do custo computacional de cada um.
https://www.machinelearningplus.com/machine-learning/caret-package/
Mendes de Oliveira et al. (2019), arXiv:1907.01567