-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathRAISR.h
77 lines (69 loc) · 2.84 KB
/
RAISR.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#ifndef RAISR_RAISR_H
#define RAISR_RAISR_H
#include "HashBuckets.h"
#include "opencv2/opencv.hpp"
#include "opencv2/core.hpp"
/************************************************************
* Constant variable declaration
* Rotation : a flag indicating the degree that rotation process is going to take
* Mirror : a flag indicating whether image is mirrored or not
*/
enum Rotation{
NO_ROTATION =-1,
ROTATE_90 = 0,
ROTATE_180 = 1,
ROTATE_270 = 2,
};
enum Mirror{
MIRROR = 0,
NO_MIRROR =1,
};
/************************************************************
* Class RAISR
* This class contains the implementation of RAISR which is used
* to enhance image during upscaling.
*
* Note:
* Basic idea is to find a way to map blurred image pixel to its
* corresponding High Resolution pixel. Given a HR image and
* blurred image pair, a group of filters will be trained to map
* each blurred image pixel with its certain neighbor pixels to a new
* pixel which has as less difference with true HR image pixel as
* possible. Please refer to the paper if details is needed.
*/
class RAISR {
public:
RAISR(std::vector<cv::Mat>& imageMatList, int scale, int patchLength, int gradientLength);
void train();
void test(
bool downSacle,
std::vector<cv::Mat> &imageMatList,
std::vector<cv::Mat>& downScaledImageList,
std::vector<cv::Mat>& RAISRImageList,
std::vector<cv::Mat>& cheapScaledImageList,
std::string CTBlendingType
);
void testPrivateModuleMethod();
void writeOutFilter(std::string& outPath);
void readInFilter(std::string& inPath);
private:
bool trained; // flag indicating whether model is trained or not
int patchLength; // length of a patch (patch is a size patchLength x patchLength pixel segment)
int gradientLength; // length of pixel segment that is used to determine patch's hashValue
int scale; // factor that describe the extent to which the image is scaled
std::vector<std::vector<cv::Mat>> filterBuckets; // contains trained filter
std::vector<cv::Mat>& imageMatList; // list of images that are used to train the model
};
/************************************************************
* Module private method declaration
*
*/
cv::Mat conjugateGradientSolver(cv::Mat A, cv::Mat b);
cv::Mat downGrade(cv::Mat image, int scale);
void fillBucketsMatrix(std::vector<std::vector<cv::Mat>>& ATA, std::vector<std::vector<cv::Mat>> & ATb, int hashValue, cv::Mat patch, double HRPixel, int pixelType);
void flattenPatchBoundary(cv::Mat patch, std::vector<double>& flattenPatch);
int getHashValue(HashBuckets & buckets, int r, int c, Rotation rotateFlag, Mirror mirror);
int getLeastConnectedComponents(cv::Mat patch);
Rotation& operator++( Rotation &c );
Rotation operator++( Rotation &c, int );
#endif //RAISR_RAISR_H