diff --git a/README.md b/README.md
index 09c3347..78df497 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,9 @@ v3.0
**Note: Library needs HTML5 support!**
### Test image
-![image](https://github.com/user-attachments/assets/f11abf40-9163-49c1-8d0e-ece5c9d4e5ae)
+![image](https://user-images.githubusercontent.com/4648756/87104009-84671b80-c20b-11ea-995b-72bc47d43766.png)
+
+PNG - 635 KB
### Test data
```
@@ -24,20 +26,43 @@ BONJOUR LE MONDE!
ПРИВЕТ МИР!
```
+### Use model optimized for visual similarity
+
+![image](https://github.com/user-attachments/assets/d76de191-6580-438f-b6aa-658776d35368)
+
+JPG compressed to 120 KB (18.9%). This is the maximum compression that test data can be decoded correctly for this model.
+
+### Use model optimized for robustness
+
+![image](https://github.com/user-attachments/assets/2847eeb1-dec8-4a84-bb3d-716ab8ed3c79)
+
+JPG compressed to 28 KB (4.4%). This is the maximum compression that test data can be decoded correctly for this model.
+
+
+You can try to decode data yourself in the demo webpage. Use empty password.
+
## How to use
+
Download [cryptostego.zip](https://github.com/zeruniverse/CryptoStego/releases/latest/download/cryptostego.zip)
+
+Replace models in `dependencies/models` with the model version you want (different models can be downloaded from GitHub release). Please note, encoder and decoder must be used in pairs.
+
**Note: This JS library needs HTML5 support!**
put `stego.js` at same level of your `index.html` (together with `dependencies` folder) and import it in your HTML. If your file structure is different, you might need to modify `stego.js`
```html
-
+
```
## Features
-+ new in 3.0: robust to image resize (upsampling, downsampling), translation, jittering, rotation etc.
++ new in 3.0: robust to image resize (upsampling, downsampling), color jittering, photo editing etc.
+ robust to compression (JPEG, PNG, GIF).
-+ new in 3.0: encoded image is visually equivalent to original image.
++ new in 3.0: encoded image is more visually equivalent to original image.
## Usage
@@ -68,9 +93,9 @@ Read message from canvas `canvasid`. `password` is a string. If `password` is no
## Minimum Coding Example
-Refer to `example/` folder.
+Refer to `example/` folder. Or for demo page, refer to the `gh-pages` branch.
## Copyright
Jeffery Zhao
-License: GNU **A**GPL v3.0 or later
+License: GNU **A**GPL v3.0 or later for both code and model. For commercial use requiring another license, contact me.
diff --git a/dist/dependencies/codecs.wasm b/dist/dependencies/codecs.wasm
index 4dc9f1c..098f761 100755
Binary files a/dist/dependencies/codecs.wasm and b/dist/dependencies/codecs.wasm differ
diff --git a/dist/dependencies/models/decoder.onnx b/dist/dependencies/models/decoder.onnx
index 78ea086..93e33ba 100644
Binary files a/dist/dependencies/models/decoder.onnx and b/dist/dependencies/models/decoder.onnx differ
diff --git a/dist/dependencies/models/encoder.onnx b/dist/dependencies/models/encoder.onnx
index 6489d58..81ff61f 100644
Binary files a/dist/dependencies/models/encoder.onnx and b/dist/dependencies/models/encoder.onnx differ
diff --git a/src/cpp/codecs.cpp b/src/cpp/codecs.cpp
index ce691ab..e0d9515 100644
--- a/src/cpp/codecs.cpp
+++ b/src/cpp/codecs.cpp
@@ -17,6 +17,127 @@
#include
// Helper function to generate a deterministic permutation based on password using mt19937
+// LinkedList class for permute
+class LinkedList {
+private:
+ // Node structure
+ struct Node {
+ uint16_t value;
+ Node* next;
+ Node* prev;
+ Node(uint16_t val) : value(val), next(nullptr), prev(nullptr) {}
+ };
+
+ Node* head;
+ Node* tail;
+
+public:
+ // Constructor: initializes empty list
+ LinkedList() : head(nullptr), tail(nullptr) {}
+
+ // Constructor: builds list from vector
+ LinkedList(const std::vector& vec) : head(nullptr), tail(nullptr) {
+ for (uint16_t val : vec) {
+ append(val);
+ }
+ }
+
+ // Destructor: frees all nodes
+ ~LinkedList() {
+ Node* current = head;
+ while (current) {
+ Node* tmp = current;
+ current = current->next;
+ delete tmp;
+ }
+ }
+
+ // Append a value to the end of the list
+ void append(uint16_t val) {
+ Node* new_node = new Node(val);
+ if (!head) { // Empty list
+ head = tail = new_node;
+ } else { // Non-empty list
+ tail->next = new_node;
+ new_node->prev = tail;
+ tail = new_node;
+ }
+ }
+
+ // Iterator class
+ class Iterator {
+ private:
+ LinkedList* list;
+ Node* current;
+ public:
+ // Constructor
+ Iterator(LinkedList* lst, Node* node) : list(lst), current(node) {}
+
+ // Dereference operator
+ uint16_t& operator*() const {
+ return current->value;
+ }
+
+ // Pre-increment
+ Iterator& operator++() {
+ if (current) current = current->next;
+ return *this;
+ }
+
+ // Post-increment
+ Iterator operator++(int) {
+ Iterator tmp = *this;
+ ++(*this);
+ return tmp;
+ }
+
+ // Equality comparison
+ bool operator==(const Iterator& other) const {
+ return current == other.current;
+ }
+
+ // Inequality comparison
+ bool operator!=(const Iterator& other) const {
+ return current != other.current;
+ }
+
+ // Delete current node and move iterator to next node
+ Iterator& delete_current() {
+ if (!current) return *this; // Nothing to delete
+
+ Node* node_to_delete = current;
+ // Move current to next before deletion
+ current = current->next;
+
+ // Update links
+ if (node_to_delete->prev) {
+ node_to_delete->prev->next = node_to_delete->next;
+ } else { // Deleting head
+ list->head = node_to_delete->next;
+ }
+
+ if (node_to_delete->next) {
+ node_to_delete->next->prev = node_to_delete->prev;
+ } else { // Deleting tail
+ list->tail = node_to_delete->prev;
+ }
+
+ delete node_to_delete;
+ return *this;
+ }
+ };
+
+ // Begin iterator
+ Iterator begin() {
+ return Iterator(this, head);
+ }
+
+ // End iterator
+ Iterator end() {
+ return Iterator(this, nullptr);
+ }
+};
+
std::vector resize_linear(const std::vector& src, int src_width, int src_height,
int dst_width, int dst_height) {
std::vector dst(dst_width * dst_height, 0.0f);
@@ -111,8 +232,8 @@ double calculate_score(const std::vector& x, int width = 256, int height
}
// Helper function to generate a deterministic permutation based on password using mt19937
-std::vector generate_permutation(const std::string& password, const uint8_t shift, const size_t reduced_range) {
- std::vector O(65536 - reduced_range);
+std::vector generate_permutation(const std::string& password, const uint8_t shift) {
+ std::vector O(65536);
std::iota(O.begin(), O.end(), 0);
// Create a seed from the password using a hash function
@@ -158,10 +279,13 @@ std::vector encode(const std::vector& raw_data, const std::str
}
// Step 2: Generate permutation O based on password using mt19937
- std::vector O = generate_permutation(password, 0, 351);
+ std::vector O = generate_permutation(password, 0);
- // Step 3: Encode the length L as 10-bit unsigned integer
+ // calculate the repeats
uint16_t L = static_cast(raw_data.size());
+ uint16_t repeat = 65536 / (L * 8 + 90 + 27); // 9 bit repr for each repeat of count and shift bits.
+
+ // Step 3: Encode the length L as 10-bit unsigned integer
std::bitset<10> L_bits_set(L);
std::vector L_bit_array;
L_bit_array.reserve(10);
@@ -171,15 +295,14 @@ std::vector encode(const std::vector& raw_data, const std::str
// Step 4: Convert raw_data to bit array D
std::vector D = bytes_to_bits(raw_data);
- std::unordered_set Oindex;
+ std::unordered_set Oindex;
- // Step 5: Compose raw_bit_message as L + L + L
+ // Step 5: Compose raw_bit_message as L + L + L ... (repeat times)
std::vector raw_bit_message;
- raw_bit_message.reserve(30); // 3*10 + D.size()
- raw_bit_message.insert(raw_bit_message.end(), L_bit_array.begin(), L_bit_array.end());
- raw_bit_message.insert(raw_bit_message.end(), L_bit_array.begin(), L_bit_array.end());
- raw_bit_message.insert(raw_bit_message.end(), L_bit_array.begin(), L_bit_array.end());
- // raw_bit_message.insert(raw_bit_message.end(), D.begin(), D.end());
+ raw_bit_message.reserve(10 * repeat);
+ for (uint16_t i=0; i< repeat; i++){
+ raw_bit_message.insert(raw_bit_message.end(), L_bit_array.begin(), L_bit_array.end());
+ }
// Step 6: Initialize P as a vector of 65536 zeros
std::vector P(65536, 0);
@@ -197,7 +320,6 @@ std::vector encode(const std::vector& raw_data, const std::str
}
// calculate repeat times
- size_t repeat = (65536 - 351) / (L * 8); // 27*13 = 351
double max_score = 0;
std::vector best_P;
@@ -210,32 +332,40 @@ std::vector encode(const std::vector& raw_data, const std::str
shift_bit_array.push_back(shift[j]);
}
std::vector shift_bit_msg;
- shift_bit_msg.reserve(9);
- shift_bit_msg.insert(shift_bit_msg.end(), shift_bit_array.begin(), shift_bit_array.end());
- shift_bit_msg.insert(shift_bit_msg.end(), shift_bit_array.begin(), shift_bit_array.end());
- shift_bit_msg.insert(shift_bit_msg.end(), shift_bit_array.begin(), shift_bit_array.end());
+ shift_bit_msg.reserve(3 * repeat);
+ for (uint16_t i=0; i< repeat; i++){
+ shift_bit_msg.insert(shift_bit_msg.end(), shift_bit_array.begin(), shift_bit_array.end());
+ }
+
for (size_t i = 0; i < shift_bit_msg.size(); ++i) {
uint8_t bit = shift_bit_msg[i];
for (int k = 0; k < 9; ++k) {
- size_t pos = i * 9 + k + 270;
+ size_t pos = i * 9 + k + 90 * repeat;
if (pos >= 65536) break; // Prevent out-of-bounds
uint16_t index = O[pos];
Oindex.insert(index);
PRep[index] = bit;
}
}
+
uint32_t current_end_loc = 65535;
- std::vector O1 = generate_permutation(password, shift_idx + 1, 351);
+ std::vector O1 = generate_permutation(password, shift_idx + 1);
+ //generate orders
+ LinkedList O1_LL(O1);
+ for (auto it = O1_LL.begin(); it != O1_LL.end(); ) {
+ if (Oindex.find(*it) != Oindex.end()) {
+ // ocupied loc
+ it.delete_current();
+ } else {
+ ++it;
+ }
+ }
+ auto itO1_LL = O1_LL.begin();
for (size_t i = 0; i < D.size(); ++i) {
for (int k = 0; k < repeat; ++k) {
size_t pos = i * repeat + k;
- if (pos >= 65536) break; // Prevent out-of-bounds
- uint32_t index = O1[pos];
- if(Oindex.find(index) != Oindex.end()){
- index = current_end_loc;
- current_end_loc -= 1;
- }
- PRep[index] = D[i];
+ if (itO1_LL == O1_LL.end()) break; // Prevent out-of-bounds
+ PRep[*(itO1_LL++)] = D[i];
}
}
@@ -259,10 +389,12 @@ std::vector decode(const std::vector& coded_data, const std::str
}
// Step 2: Generate permutation O based on password using mt19937
- std::vector O = generate_permutation(password, 0, 351);
+ LinkedList OLL(generate_permutation(password, 0));
+ LinkedList* LL_ptr = &OLL;
+ LinkedList::Iterator OLLit = LL_ptr->begin();
+ LinkedList::Iterator* LLit_ptr = &OLLit;
- size_t current_pos = 0;
- std::unordered_set Oindex;
+ std::unordered_set Oindex;
uint32_t current_end_loc = 65535;
// Step 3: Define decode_bit as a lambda function
@@ -274,14 +406,11 @@ std::vector decode(const std::vector& coded_data, const std::str
// Read 9 bits
for (int i = 0; i < repeat; ++i) {
- if (current_pos >= 65536) break; // Prevent out-of-bounds
+ if (*LLit_ptr == LL_ptr->end()) break; // Prevent out-of-bounds
- uint32_t index = O[current_pos++];
+ uint16_t index = *((*LLit_ptr)++);
- if(Oindex.find(index) != Oindex.end()){
- index = current_end_loc;
- current_end_loc -= 1;
- } else if(set_oindex) Oindex.insert(index);
+ if(set_oindex) Oindex.insert(index);
float prob = coded_data[index];
uint8_t bit = (prob > 0.0f) ? 1 : 0;
@@ -318,54 +447,13 @@ std::vector decode(const std::vector& coded_data, const std::str
float mean_sigmoid = sum_sigmoid / bit_1_01.size();
return (mean_sigmoid > 0.5f) ? 1 : 0;
};
- // auto decode_bit = [&](std::vector& bits_out) -> uint8_t {
- // std::vector bit_1_01;
- // bit_1_01.reserve(9);
- // std::vector fc_prob_1_01;
- // fc_prob_1_01.reserve(9);
-
- // // Read 9 bits
- // for (int i = 0; i < 9; ++i) {
- // if (current_pos >= 65536) break; // Prevent out-of-bounds
- // float prob = coded_data[O[current_pos++]];
- // uint8_t bit = (prob > 0.0f) ? 1 : 0;
- // bit_1_01.push_back(bit);
- // fc_prob_1_01.push_back(prob);
- // }
-
- // // Compute sigmoid for each bit and calculate the mean
- // float sum_sigmoid = 0.0f;
- // for (auto bit : fc_prob_1_01) {
- // sum_sigmoid += 1.0f / (1.0f + std::exp(-bit));
- // }
- // float mean_sigmoid = sum_sigmoid / bit_1_01.size();
-
- // if (mean_sigmoid > 0.7f) {
- // return 1;
- // } else if (mean_sigmoid < 0.3f) {
- // return 0;
- // } else {
- // // Determine majority bit
- // std::unordered_map bit_count;
- // for (auto bit : bit_1_01) {
- // bit_count[bit]++;
- // }
-
- // uint8_t majority_bit = 0;
- // int max_count = 0;
- // for (const auto& pair : bit_count) {
- // if (pair.second > max_count) {
- // max_count = pair.second;
- // majority_bit = pair.first;
- // }
- // }
-
- // return majority_bit;
- // }
- // };
- // Step 4: Decode the length L by reading it three times
- std::vector decoded_len;
- for (int i = 0; i < 3; ++i) {
+
+ // Step 4: Decode the length L by reading it 7 times, then up to repeat times
+ uint16_t repeat = 7;
+ uint16_t dlen = 0;
+ std::unordered_map decoded_len_count;
+
+ for (int i = 0; i < repeat; ++i) {
std::vector int_bits;
int_bits.reserve(10);
for (int p = 0; p < 10; ++p) {
@@ -373,33 +461,24 @@ std::vector decode(const std::vector& coded_data, const std::str
int_bits.push_back(bit);
}
uint16_t uint_len = make_unsigned_int_10bit(int_bits);
- decoded_len.push_back(uint_len);
- }
-
- // Step 5: Determine the most common L
- std::unordered_map decoded_len_count;
- for (auto len : decoded_len) {
- decoded_len_count[len]++;
- }
-
- uint16_t dlen = 0;
- int max_len_count = 0;
- for (const auto& pair : decoded_len_count) {
- if (pair.second > max_len_count) {
- max_len_count = pair.second;
- dlen = pair.first;
+ decoded_len_count[uint_len]++;
+ if(i > 5){
+ // calculate updated repeat
+ uint16_t max_len_count = 0;
+ for (const auto& pair : decoded_len_count) {
+ if (pair.second > max_len_count) {
+ max_len_count = pair.second;
+ dlen = pair.first;
+ }
+ }
+ if(max_len_count == 1) repeat++; else repeat = 65536 / (dlen * 8 + 90 + 27);
+ if(repeat > 560 || repeat <= i) return std::vector(); // 65536 / (90+27) =560.1
}
}
- if (max_len_count < 2) { // No common value
- return std::vector(); // Decode failed
- }
-
- // std::cout <<" decoded len " << dlen << std::endl;
-
// Step 4/5 rep, decode shift
std::vector decoded_shift;
- for (int i = 0; i < 3; ++i) {
+ for (int i = 0; i < repeat; ++i) {
std::vector int_bits;
int_bits.reserve(3);
for (int p = 0; p < 3; ++p) {
@@ -430,14 +509,23 @@ std::vector decode(const std::vector& coded_data, const std::str
return std::vector(); // Decode failed
}
// std::cout << "sft " << dshift << std::endl;
+ LinkedList O1LL(generate_permutation(password, dshift + 1));
+ LL_ptr = &O1LL;
+ for (auto it = LL_ptr->begin(); it != LL_ptr->end(); ) {
+ if (Oindex.find(*it) != Oindex.end()) {
+ // ocupied loc
+ it.delete_current();
+ } else {
+ ++it;
+ }
+ }
- O = generate_permutation(password, dshift + 1, 351); // replace O
- current_pos = 0; // Reset position
+ LinkedList::Iterator O1LLit = LL_ptr->begin();
+ LLit_ptr = &O1LLit;
// Step 6: Decode the message based on dlen
std::vector msg;
msg.reserve(dlen);
- size_t repeat = (65536 - 351) / (dlen * 8);
for (size_t i = 0; i < dlen; ++i) {
std::vector byte_bits;
byte_bits.reserve(8);