Revision | 9dbc6c80a640c7f2c27cde80898f2942940db6e6 (tree) |
---|---|
Time | 2014-12-15 04:41:39 |
Author | Lorenzo Isella <lorenzo.isella@gmai...> |
Commiter | Lorenzo Isella |
A code to use caret + randomForest. In particular, it allows to deal with unbalanced classes.
@@ -0,0 +1,249 @@ | ||
1 | +rm(list=ls()) | |
2 | + | |
3 | +#library(ada) | |
4 | +## library(ggplot2) | |
5 | +#library(glmnet) | |
6 | +#library(reshape2) | |
7 | +library(randomForest) | |
8 | +## require(gridExtra) | |
9 | +## library(corrplot) | |
10 | +## library(scales) | |
11 | +## library(digest) | |
12 | +library(caret) | |
13 | +## library(gbm) | |
14 | +#library(doMC) | |
15 | +## library(nnet) | |
16 | +## library(gbm) | |
17 | + | |
18 | + | |
19 | + | |
20 | +confusion.glm <- function(data, model) { | |
21 | +prediction <- ifelse(predict(model, data, type='response') > 0.5, TRUE, FALSE) | |
22 | +confusion <- table(prediction, as.logical(model$y)) | |
23 | +confusion <- cbind(confusion, c(1 - confusion[1,1]/(confusion[1,1]+confusion[2,1]), 1 - confusion[2,2]/(confusion[2,2]+confusion[1,2]))) | |
24 | +confusion <- as.data.frame(confusion) | |
25 | +names(confusion) <- c('FALSE', 'TRUE', 'class.error') | |
26 | +confusion | |
27 | +} | |
28 | + | |
29 | + | |
30 | + | |
31 | + | |
32 | + | |
33 | + | |
34 | + | |
35 | +########################################################################## | |
36 | +########################################################################### | |
37 | + | |
38 | + | |
39 | +analyze_bin <- 1 | |
40 | + | |
41 | +analyze_rev <- 1 | |
42 | + | |
43 | +if (analyze_bin==1){ | |
44 | + | |
45 | +set.seed(1234) | |
46 | + | |
47 | + | |
48 | + | |
49 | +mydata <- readRDS("train_mf_bin.RDS") | |
50 | + | |
51 | + | |
52 | +## see http://bit.ly/1uEEPze | |
53 | + | |
54 | + | |
55 | +nmin <- sum(mydata$Sale_MF == 1) | |
56 | + | |
57 | + ctrl <- trainControl(method = "cv", | |
58 | + classProbs = TRUE,summaryFunction = twoClassSummary) | |
59 | + | |
60 | + | |
61 | +set.seed(1234) | |
62 | + | |
63 | + | |
64 | +print("RF for mf") | |
65 | + | |
66 | +rf_mf <- train(Sale_MF ~ ., data =mydata, | |
67 | + method = "rf", | |
68 | + ntree = 1500, | |
69 | + tuneLength = 10, | |
70 | + metric = "ROC", | |
71 | + trControl = ctrl, | |
72 | + ## Tell randomForest to sample by strata. Here, | |
73 | + ## that means within each class | |
74 | + strata = mydata$Sale_MF, | |
75 | + do.trace=500, | |
76 | + ## Now specify that the number of samples selected | |
77 | + ## within each class should be the same | |
78 | + sampsize = rep(nmin, 2)) | |
79 | + | |
80 | +saveRDS(rf_mf, "rf_mf_bin.RDS") | |
81 | + | |
82 | +################################################ | |
83 | + | |
84 | + | |
85 | +mydata <- readRDS("train_cc_bin.RDS") | |
86 | + | |
87 | +nmin <- sum(mydata$Sale_CC == 1) | |
88 | + | |
89 | +print("RF for cc") | |
90 | + | |
91 | + | |
92 | +rf_cc <- train(Sale_CC ~ ., data =mydata, | |
93 | + method = "rf", | |
94 | + ntree = 1500, | |
95 | + tuneLength = 10, | |
96 | + metric = "ROC", | |
97 | + trControl = ctrl, | |
98 | + ## Tell randomForest to sample by strata. Here, | |
99 | + ## that means within each class | |
100 | + strata = mydata$Sale_CC, | |
101 | + do.trace=500, | |
102 | + ## Now specify that the number of samples selected | |
103 | + ## within each class should be the same | |
104 | + sampsize = rep(nmin, 2)) | |
105 | + | |
106 | +saveRDS(rf_mf, "rf_cc_bin.RDS") | |
107 | + | |
108 | + | |
109 | +################################################ | |
110 | + | |
111 | + | |
112 | +mydata <- readRDS("train_cl_bin.RDS") | |
113 | + | |
114 | +nmin <- sum(mydata$Sale_CL == 1) | |
115 | + | |
116 | +print("RF for cl") | |
117 | + | |
118 | + | |
119 | +rf_cl <- train(Sale_CL ~ ., data =mydata, | |
120 | + method = "rf", | |
121 | + ntree = 1500, | |
122 | + tuneLength = 10, | |
123 | + metric = "ROC", | |
124 | + trControl = ctrl, | |
125 | + ## Tell randomForest to sample by strata. Here, | |
126 | + ## that means within each class | |
127 | + strata = mydata$Sale_CL, | |
128 | + do.trace=500, | |
129 | + ## Now specify that the number of samples selected | |
130 | + ## within each class should be the same | |
131 | + sampsize = rep(nmin, 2)) | |
132 | + | |
133 | +saveRDS(rf_cl, "rf_cl_bin.RDS") | |
134 | + | |
135 | + | |
136 | + | |
137 | + | |
138 | + | |
139 | +####################################################################### | |
140 | +####################################################################### | |
141 | +####################################################################### | |
142 | +####################################################################### | |
143 | + | |
144 | + | |
145 | + | |
146 | +} | |
147 | + | |
148 | +if (analyze_rev==1){ | |
149 | + | |
150 | +set.seed(1234) | |
151 | + | |
152 | + | |
153 | + ctrl <- trainControl(method = "cv") | |
154 | + | |
155 | + | |
156 | +mydata <- readRDS("train_mf_amount.RDS") | |
157 | + | |
158 | + | |
159 | + | |
160 | +sel <- which(mydata$Revenue_MF > 0) | |
161 | + | |
162 | +mydata <- mydata[sel, ] | |
163 | + | |
164 | + | |
165 | + | |
166 | + | |
167 | + | |
168 | + | |
169 | +print("RF for mf") | |
170 | + | |
171 | +rf_mf <- train(Revenue_MF ~ ., data =mydata, | |
172 | + method = "rf", | |
173 | + ntree = 1500, | |
174 | + tuneLength = 10, | |
175 | + ## metric = "ROC", | |
176 | + trControl = ctrl, | |
177 | + do.trace=500) | |
178 | + | |
179 | + | |
180 | +saveRDS(rf_mf, "rf_mf_amount.RDS") | |
181 | + | |
182 | + | |
183 | + | |
184 | +##################################################################### | |
185 | + | |
186 | +mydata <- readRDS("train_cc_amount.RDS") | |
187 | + | |
188 | + | |
189 | + | |
190 | +sel <- which(mydata$Revenue_CC > 0) | |
191 | + | |
192 | +mydata <- mydata[sel, ] | |
193 | + | |
194 | + | |
195 | + | |
196 | + | |
197 | +print("RF for cc") | |
198 | + | |
199 | +rf_cc <- train(Revenue_CC ~ ., data =mydata, | |
200 | + method = "rf", | |
201 | + ntree = 1500, | |
202 | + tuneLength = 10, | |
203 | + ## metric = "ROC", | |
204 | + trControl = ctrl, | |
205 | + do.trace=500) | |
206 | + | |
207 | + | |
208 | +saveRDS(rf_cc, "rf_cc_amount.RDS") | |
209 | + | |
210 | + | |
211 | + | |
212 | +################################################################## | |
213 | + | |
214 | + | |
215 | +mydata <- readRDS("train_cl_amount.RDS") | |
216 | + | |
217 | + | |
218 | + | |
219 | +sel <- which(mydata$Revenue_CL > 0) | |
220 | + | |
221 | +mydata <- mydata[sel, ] | |
222 | + | |
223 | + | |
224 | + | |
225 | + | |
226 | +print("RF for cl") | |
227 | + | |
228 | +rf_cl <- train(Revenue_CL ~ ., data =mydata, | |
229 | + method = "rf", | |
230 | + ntree = 1500, | |
231 | + tuneLength = 10, | |
232 | + ## metric = "ROC", | |
233 | + trControl = ctrl, | |
234 | + do.trace=500) | |
235 | + | |
236 | + | |
237 | +saveRDS(rf_cl, "rf_cl_amount.RDS") | |
238 | + | |
239 | + | |
240 | + | |
241 | + | |
242 | + | |
243 | + | |
244 | + | |
245 | + | |
246 | + | |
247 | +} | |
248 | + | |
249 | +print("So far so good") |