-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmatch_loss.py
43 lines (31 loc) · 1.22 KB
/
match_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Bruno Iochins Grisci
# June 28th, 2022
import os
import sys
import numpy as np
import pandas as pd
def main():
loss_file = sys.argv[1]
silh_file = sys.argv[2]
freq = int(sys.argv[3])
df_loss = pd.read_csv(loss_file, delimiter=',', header=0, index_col=None)
df_silh = pd.read_csv(silh_file, delimiter=',', header=0, index_col=0)
df_loss['epoch'] += 1
print(df_loss)
print(df_silh)
if freq > 1:
df_loss = df_loss[df_loss['epoch'] % freq == 0]
df_loss.reset_index(drop=True, inplace=True)
df_silh.reset_index(drop=True, inplace=True)
df_stack = pd.concat([df_loss, df_silh], axis=1)
print(df_stack)
df_stack.to_csv(silh_file.replace('.csv', 'Xloss.csv'))
# https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.corr.html
#met = ['loss', 'accuracy', 'KL divergence']
met = ['loss', 'KL divergence']
for col1 in met:
for col2 in ["Embedding silhouette", 'Weighted silhouette', 'KL divergence']:
corr = df_stack[col1].corr(df_stack[col2], method='pearson')
print ("Correlation between ", col1, " and ", col2, "is: ", round(corr, 4))
if __name__ == '__main__':
main()