-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluation.rmd
151 lines (132 loc) · 6.03 KB
/
evaluation.rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
---
title: "Evaluation"
geometry: a4paper, margin=2cm
output:
pdf_document: default
html_document:
df_print: paged
params:
input_file: /home/user/evaluation/hps/deap/hps-deap.csv
output_dir: /home/user/evaluation/hps/deap/
prefix: hps_deap_
fig_height: 8
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = FALSE)
library(ggplot2)
library(ggthemes)
# library(viridis)
library(tibble)
my_theme_fivethirtyeight <- function (base_size = 12, base_family = "sans")
{
colors <- deframe(ggthemes::ggthemes_data[["fivethirtyeight"]])
(theme_foundation(base_size = base_size, base_family = base_family) +
theme(line = element_line(colour = "black"), rect = element_rect(fill = colors["White"],
linetype = 0, colour = NA), text = element_text(colour = colors["Dark Gray"]),
axis.title = element_text(), axis.text = element_text(),
axis.ticks = element_blank(), axis.line = element_blank(),
legend.background = element_rect(), legend.position = "bottom",
legend.direction = "horizontal", legend.box = "vertical",
legend.title = element_blank(),
plot.background = element_rect(colour = "white"),
panel.background = element_rect(colour = "white"),
legend.spacing.y = unit(0.5, "mm"),
panel.grid = element_line(colour = NULL), panel.grid.major = element_line(colour = colors["Medium Gray"]),
panel.grid.minor = element_blank(), plot.title = element_text(hjust = 0,
size = rel(1.5), face = "bold"), plot.margin = unit(c(1,
1, 1, 1), "lines"), strip.background = element_rect()))
}
def.chunk.hook <- knitr::knit_hooks$get("chunk")
knitr::knit_hooks$set(chunk = function(x, options) {
x <- def.chunk.hook(x, options)
ifelse(options$size != "normalsize", paste0("\n \\", options$size,"\n\n", x, "\n\n \\normalsize"), x)
})
theme_set(my_theme_fivethirtyeight(base_size=14))
#theme_update(axis.title = element_text())
metrics_opportunity = c("f1-weighted", "f1-micro", "f1-macro")
```
```{r, results="asis", echo=FALSE, message=FALSE}
do_plot_opp <- function(metric) {
data <- read.csv(params$input_file)
data <- data[data$Metric == metric,]
ml_data <- data[data$Label.channel == "ML_Both", ]
p <- ggplot(data = ml_data, aes(x = Epoch, y = Metric.value, col=Model.name, shape=Model.name)) +
geom_point() +
stat_smooth(method="loess", level=0, size=0.7) +
# theme_fivethirtyeight() +
labs(y=metric, col="Model", shape="Model") # + #, title = "Testing results for ML_Both_Arms") +
# theme(legend.position = "bottom", legend.direction = "vertical")
print(p)
p2 <- ggplot(data=data, aes(x=Epoch, y=Metric.value, col=Model.name, shape=Model.name)) +
facet_wrap(data$Label.channel) +
geom_point() +
# theme_fivethirtyeight() +
stat_smooth(method="loess", level = 0, size=0.7) +
labs(y=metric, col="Model", shape="Model") # + #, title = "Testing results for all labels") +
# theme(legend.position = "bottom", legend.direction = "vertical")
print(p2)
last_data <- ml_data[ml_data$Epoch >= 40, ]
p3 <- ggplot(last_data, aes(x=Epoch, y=Metric.value, col=Model.name, shape=Model.name)) +
geom_point(size=1.7) +
stat_smooth(method="lm", size=0.7, level=0) +
scale_x_continuous(breaks=seq(40,50)) +
labs(y=metric, col="Model", shape="Model", fill="Model") #, title = "Last 10 epochs") +
# theme(legend.position = "bottom", legend.direction = "vertical")
print(p3)
ggsave(filename = paste(params$prefix, "ml_", metric, ".png", sep=""), p, path=params$output_dir, dpi="print")
ggsave(filename = paste(params$prefix, "all_", metric, ".png", sep=""), p2, path=params$output_dir, dpi="print")
ggsave(filename = paste(params$prefix, "last_", metric, ".png", sep=""), p3, path=params$output_dir, dpi="print")
}
t <- ""
do_plot_deap <- function(rev=FALSE, max=200, count=25) {
data <- read.table(params$input_file, header = TRUE, sep=";")
v_data <- data[data$Label.channel == "Valence", ]
a_data <- data[data$Label.channel == "Arousal", ]
plot_pv <- function(d, t) {
p <- ggplot(data = d, aes(x = Epoch, y = Metric.value, col=Model.name, shape=Model.name)) +
geom_point() +
stat_smooth(method="loess", level=0, size=0.7) +
labs(y=paste(t, " Accuracy"), col="Model", shape="Model")
return(p)
}
pv <- plot_pv(v_data, "Valence")
pa <- plot_pv(a_data, "Arousal")
pboth <- ggplot(data = data, aes(x = Epoch, y = Metric.value, col=Model.name, shape=Label.channel, linetype=Label.channel)) +
geom_point() +
scale_shape_manual(values=c(1, 4)) +
scale_linetype_manual(values=c(1,3)) +
stat_smooth(method="loess", level=0, size=0.7) +
labs(y=paste(t, " Accuracy"), col="Model", shape="Label channel", linetype="Label channel")
print(pv)
print("\n\n")
print(pa)
print("\n\n")
print(pboth)
ggsave(filename = paste(params$prefix, "v_acc", ".png", sep=""), pv, path=params$output_dir, dpi="print")
ggsave(filename = paste(params$prefix, "a_acc", ".png", sep=""), pa, path=params$output_dir, dpi="print")
ggsave(filename = paste(params$prefix, "both_acc", ".png", sep=""), pboth, path=params$output_dir, dpi="print")
if(rev==FALSE){
first_ep <- max - count
last_ep <- max
last_data <- data[data$Epoch >= first_ep, ]
breaks=seq(first_ep, last_ep, 5)
}
else {
last_data <- data[data$Rev.Epoch <= 25,]
breaks = seq(1, 25, 5)
}
p3 <- ggplot(last_data, aes(x=Epoch, y=Metric.value, col=Model.name, shape=Label.channel, linetype=Label.channel)) +
geom_point(size=1.7) +
scale_shape_manual(values=c(1, 4)) +
scale_linetype_manual(values=c(1,3)) +
stat_smooth(method="lm", size=0.7, level=0) +
scale_x_continuous(breaks=breaks) +
labs(y="Accuracy", col="Model", shape="Label channel", linetype="Label channel")
print(p3)
ggsave(filename = paste(params$prefix, "last_all_va", ".png", sep=""), p3, path=params$output_dir, dpi="print")
}
```
```{r}
#n <- lapply(metrics_opportunity, do_plot_opp)
do_plot_deap(rev=FALSE, count=25)
```