-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtensor.h
56 lines (40 loc) · 1.09 KB
/
tensor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#ifndef TENSOR_H_
#define TENSOR_H_
#define TENSOR_STORE_TYPE_NULL 0
#define TENSOR_STORE_TYPE_DENSE 1
#define TENSOR_STORE_TYPE_DENSE_DISTRIBUTED 2
#define TENSOR_STORE_TYPE_SPARSE 3
#define TENSOR_STORE_TYPE_LOW_RANK 4
#include "matlab_wrapper.h"
class Tensor {
public:
double* A;
int dim;
int store_type;
double* values;
int* idx[3];
int nnz_count;
double rate;
double* Lambda;
double* U;
int rank;
Tensor();
Tensor(int dim, int store_type);
Tensor(int dim, int rank, int store_type);
~Tensor();
int symmetric_check();
void load(char*, int store_type);
void save(char*);
void load_view(int, Matlab_wrapper*);
Tensor* whiten(Matlab_wrapper*, double* W); // return: rank x rank x rank tensor, W: real dim x rank
double Tuuu(double* u, bool symeval = false);
void TIuu(double* u, double* ret, bool symeval = false);
void TIuv(double* u, double* v, double* ret);
double sqr_fnorm(bool symeval = false);
void to_sparse_format();
void sparsify(double rate);
void add_rank_one_update(double lambda, double* u);
private:
void clear();
};
#endif