diff --git a/README.md b/README.md index 09c3347..78df497 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,9 @@ v3.0 **Note: Library needs HTML5 support!** ### Test image -![image](https://github.com/user-attachments/assets/f11abf40-9163-49c1-8d0e-ece5c9d4e5ae) +![image](https://user-images.githubusercontent.com/4648756/87104009-84671b80-c20b-11ea-995b-72bc47d43766.png) + +PNG - 635 KB ### Test data ``` @@ -24,20 +26,43 @@ BONJOUR LE MONDE! ПРИВЕТ МИР! ``` +### Use model optimized for visual similarity + +![image](https://github.com/user-attachments/assets/d76de191-6580-438f-b6aa-658776d35368) + +JPG compressed to 120 KB (18.9%). This is the maximum compression that test data can be decoded correctly for this model. + +### Use model optimized for robustness + +![image](https://github.com/user-attachments/assets/2847eeb1-dec8-4a84-bb3d-716ab8ed3c79) + +JPG compressed to 28 KB (4.4%). This is the maximum compression that test data can be decoded correctly for this model. + + +You can try to decode data yourself in the demo webpage. Use empty password. + ## How to use + Download [cryptostego.zip](https://github.com/zeruniverse/CryptoStego/releases/latest/download/cryptostego.zip) + +Replace models in `dependencies/models` with the model version you want (different models can be downloaded from GitHub release). Please note, encoder and decoder must be used in pairs. + **Note: This JS library needs HTML5 support!** put `stego.js` at same level of your `index.html` (together with `dependencies` folder) and import it in your HTML. If your file structure is different, you might need to modify `stego.js` ```html - + ``` ## Features -+ new in 3.0: robust to image resize (upsampling, downsampling), translation, jittering, rotation etc. ++ new in 3.0: robust to image resize (upsampling, downsampling), color jittering, photo editing etc. + robust to compression (JPEG, PNG, GIF). -+ new in 3.0: encoded image is visually equivalent to original image. ++ new in 3.0: encoded image is more visually equivalent to original image. ## Usage @@ -68,9 +93,9 @@ Read message from canvas `canvasid`. `password` is a string. If `password` is no ## Minimum Coding Example -Refer to `example/` folder. +Refer to `example/` folder. Or for demo page, refer to the `gh-pages` branch. ## Copyright Jeffery Zhao -License: GNU **A**GPL v3.0 or later +License: GNU **A**GPL v3.0 or later for both code and model. For commercial use requiring another license, contact me. diff --git a/dist/dependencies/codecs.wasm b/dist/dependencies/codecs.wasm index 4dc9f1c..098f761 100755 Binary files a/dist/dependencies/codecs.wasm and b/dist/dependencies/codecs.wasm differ diff --git a/dist/dependencies/models/decoder.onnx b/dist/dependencies/models/decoder.onnx index 78ea086..93e33ba 100644 Binary files a/dist/dependencies/models/decoder.onnx and b/dist/dependencies/models/decoder.onnx differ diff --git a/dist/dependencies/models/encoder.onnx b/dist/dependencies/models/encoder.onnx index 6489d58..81ff61f 100644 Binary files a/dist/dependencies/models/encoder.onnx and b/dist/dependencies/models/encoder.onnx differ diff --git a/src/cpp/codecs.cpp b/src/cpp/codecs.cpp index ce691ab..e0d9515 100644 --- a/src/cpp/codecs.cpp +++ b/src/cpp/codecs.cpp @@ -17,6 +17,127 @@ #include // Helper function to generate a deterministic permutation based on password using mt19937 +// LinkedList class for permute +class LinkedList { +private: + // Node structure + struct Node { + uint16_t value; + Node* next; + Node* prev; + Node(uint16_t val) : value(val), next(nullptr), prev(nullptr) {} + }; + + Node* head; + Node* tail; + +public: + // Constructor: initializes empty list + LinkedList() : head(nullptr), tail(nullptr) {} + + // Constructor: builds list from vector + LinkedList(const std::vector& vec) : head(nullptr), tail(nullptr) { + for (uint16_t val : vec) { + append(val); + } + } + + // Destructor: frees all nodes + ~LinkedList() { + Node* current = head; + while (current) { + Node* tmp = current; + current = current->next; + delete tmp; + } + } + + // Append a value to the end of the list + void append(uint16_t val) { + Node* new_node = new Node(val); + if (!head) { // Empty list + head = tail = new_node; + } else { // Non-empty list + tail->next = new_node; + new_node->prev = tail; + tail = new_node; + } + } + + // Iterator class + class Iterator { + private: + LinkedList* list; + Node* current; + public: + // Constructor + Iterator(LinkedList* lst, Node* node) : list(lst), current(node) {} + + // Dereference operator + uint16_t& operator*() const { + return current->value; + } + + // Pre-increment + Iterator& operator++() { + if (current) current = current->next; + return *this; + } + + // Post-increment + Iterator operator++(int) { + Iterator tmp = *this; + ++(*this); + return tmp; + } + + // Equality comparison + bool operator==(const Iterator& other) const { + return current == other.current; + } + + // Inequality comparison + bool operator!=(const Iterator& other) const { + return current != other.current; + } + + // Delete current node and move iterator to next node + Iterator& delete_current() { + if (!current) return *this; // Nothing to delete + + Node* node_to_delete = current; + // Move current to next before deletion + current = current->next; + + // Update links + if (node_to_delete->prev) { + node_to_delete->prev->next = node_to_delete->next; + } else { // Deleting head + list->head = node_to_delete->next; + } + + if (node_to_delete->next) { + node_to_delete->next->prev = node_to_delete->prev; + } else { // Deleting tail + list->tail = node_to_delete->prev; + } + + delete node_to_delete; + return *this; + } + }; + + // Begin iterator + Iterator begin() { + return Iterator(this, head); + } + + // End iterator + Iterator end() { + return Iterator(this, nullptr); + } +}; + std::vector resize_linear(const std::vector& src, int src_width, int src_height, int dst_width, int dst_height) { std::vector dst(dst_width * dst_height, 0.0f); @@ -111,8 +232,8 @@ double calculate_score(const std::vector& x, int width = 256, int height } // Helper function to generate a deterministic permutation based on password using mt19937 -std::vector generate_permutation(const std::string& password, const uint8_t shift, const size_t reduced_range) { - std::vector O(65536 - reduced_range); +std::vector generate_permutation(const std::string& password, const uint8_t shift) { + std::vector O(65536); std::iota(O.begin(), O.end(), 0); // Create a seed from the password using a hash function @@ -158,10 +279,13 @@ std::vector encode(const std::vector& raw_data, const std::str } // Step 2: Generate permutation O based on password using mt19937 - std::vector O = generate_permutation(password, 0, 351); + std::vector O = generate_permutation(password, 0); - // Step 3: Encode the length L as 10-bit unsigned integer + // calculate the repeats uint16_t L = static_cast(raw_data.size()); + uint16_t repeat = 65536 / (L * 8 + 90 + 27); // 9 bit repr for each repeat of count and shift bits. + + // Step 3: Encode the length L as 10-bit unsigned integer std::bitset<10> L_bits_set(L); std::vector L_bit_array; L_bit_array.reserve(10); @@ -171,15 +295,14 @@ std::vector encode(const std::vector& raw_data, const std::str // Step 4: Convert raw_data to bit array D std::vector D = bytes_to_bits(raw_data); - std::unordered_set Oindex; + std::unordered_set Oindex; - // Step 5: Compose raw_bit_message as L + L + L + // Step 5: Compose raw_bit_message as L + L + L ... (repeat times) std::vector raw_bit_message; - raw_bit_message.reserve(30); // 3*10 + D.size() - raw_bit_message.insert(raw_bit_message.end(), L_bit_array.begin(), L_bit_array.end()); - raw_bit_message.insert(raw_bit_message.end(), L_bit_array.begin(), L_bit_array.end()); - raw_bit_message.insert(raw_bit_message.end(), L_bit_array.begin(), L_bit_array.end()); - // raw_bit_message.insert(raw_bit_message.end(), D.begin(), D.end()); + raw_bit_message.reserve(10 * repeat); + for (uint16_t i=0; i< repeat; i++){ + raw_bit_message.insert(raw_bit_message.end(), L_bit_array.begin(), L_bit_array.end()); + } // Step 6: Initialize P as a vector of 65536 zeros std::vector P(65536, 0); @@ -197,7 +320,6 @@ std::vector encode(const std::vector& raw_data, const std::str } // calculate repeat times - size_t repeat = (65536 - 351) / (L * 8); // 27*13 = 351 double max_score = 0; std::vector best_P; @@ -210,32 +332,40 @@ std::vector encode(const std::vector& raw_data, const std::str shift_bit_array.push_back(shift[j]); } std::vector shift_bit_msg; - shift_bit_msg.reserve(9); - shift_bit_msg.insert(shift_bit_msg.end(), shift_bit_array.begin(), shift_bit_array.end()); - shift_bit_msg.insert(shift_bit_msg.end(), shift_bit_array.begin(), shift_bit_array.end()); - shift_bit_msg.insert(shift_bit_msg.end(), shift_bit_array.begin(), shift_bit_array.end()); + shift_bit_msg.reserve(3 * repeat); + for (uint16_t i=0; i< repeat; i++){ + shift_bit_msg.insert(shift_bit_msg.end(), shift_bit_array.begin(), shift_bit_array.end()); + } + for (size_t i = 0; i < shift_bit_msg.size(); ++i) { uint8_t bit = shift_bit_msg[i]; for (int k = 0; k < 9; ++k) { - size_t pos = i * 9 + k + 270; + size_t pos = i * 9 + k + 90 * repeat; if (pos >= 65536) break; // Prevent out-of-bounds uint16_t index = O[pos]; Oindex.insert(index); PRep[index] = bit; } } + uint32_t current_end_loc = 65535; - std::vector O1 = generate_permutation(password, shift_idx + 1, 351); + std::vector O1 = generate_permutation(password, shift_idx + 1); + //generate orders + LinkedList O1_LL(O1); + for (auto it = O1_LL.begin(); it != O1_LL.end(); ) { + if (Oindex.find(*it) != Oindex.end()) { + // ocupied loc + it.delete_current(); + } else { + ++it; + } + } + auto itO1_LL = O1_LL.begin(); for (size_t i = 0; i < D.size(); ++i) { for (int k = 0; k < repeat; ++k) { size_t pos = i * repeat + k; - if (pos >= 65536) break; // Prevent out-of-bounds - uint32_t index = O1[pos]; - if(Oindex.find(index) != Oindex.end()){ - index = current_end_loc; - current_end_loc -= 1; - } - PRep[index] = D[i]; + if (itO1_LL == O1_LL.end()) break; // Prevent out-of-bounds + PRep[*(itO1_LL++)] = D[i]; } } @@ -259,10 +389,12 @@ std::vector decode(const std::vector& coded_data, const std::str } // Step 2: Generate permutation O based on password using mt19937 - std::vector O = generate_permutation(password, 0, 351); + LinkedList OLL(generate_permutation(password, 0)); + LinkedList* LL_ptr = &OLL; + LinkedList::Iterator OLLit = LL_ptr->begin(); + LinkedList::Iterator* LLit_ptr = &OLLit; - size_t current_pos = 0; - std::unordered_set Oindex; + std::unordered_set Oindex; uint32_t current_end_loc = 65535; // Step 3: Define decode_bit as a lambda function @@ -274,14 +406,11 @@ std::vector decode(const std::vector& coded_data, const std::str // Read 9 bits for (int i = 0; i < repeat; ++i) { - if (current_pos >= 65536) break; // Prevent out-of-bounds + if (*LLit_ptr == LL_ptr->end()) break; // Prevent out-of-bounds - uint32_t index = O[current_pos++]; + uint16_t index = *((*LLit_ptr)++); - if(Oindex.find(index) != Oindex.end()){ - index = current_end_loc; - current_end_loc -= 1; - } else if(set_oindex) Oindex.insert(index); + if(set_oindex) Oindex.insert(index); float prob = coded_data[index]; uint8_t bit = (prob > 0.0f) ? 1 : 0; @@ -318,54 +447,13 @@ std::vector decode(const std::vector& coded_data, const std::str float mean_sigmoid = sum_sigmoid / bit_1_01.size(); return (mean_sigmoid > 0.5f) ? 1 : 0; }; - // auto decode_bit = [&](std::vector& bits_out) -> uint8_t { - // std::vector bit_1_01; - // bit_1_01.reserve(9); - // std::vector fc_prob_1_01; - // fc_prob_1_01.reserve(9); - - // // Read 9 bits - // for (int i = 0; i < 9; ++i) { - // if (current_pos >= 65536) break; // Prevent out-of-bounds - // float prob = coded_data[O[current_pos++]]; - // uint8_t bit = (prob > 0.0f) ? 1 : 0; - // bit_1_01.push_back(bit); - // fc_prob_1_01.push_back(prob); - // } - - // // Compute sigmoid for each bit and calculate the mean - // float sum_sigmoid = 0.0f; - // for (auto bit : fc_prob_1_01) { - // sum_sigmoid += 1.0f / (1.0f + std::exp(-bit)); - // } - // float mean_sigmoid = sum_sigmoid / bit_1_01.size(); - - // if (mean_sigmoid > 0.7f) { - // return 1; - // } else if (mean_sigmoid < 0.3f) { - // return 0; - // } else { - // // Determine majority bit - // std::unordered_map bit_count; - // for (auto bit : bit_1_01) { - // bit_count[bit]++; - // } - - // uint8_t majority_bit = 0; - // int max_count = 0; - // for (const auto& pair : bit_count) { - // if (pair.second > max_count) { - // max_count = pair.second; - // majority_bit = pair.first; - // } - // } - - // return majority_bit; - // } - // }; - // Step 4: Decode the length L by reading it three times - std::vector decoded_len; - for (int i = 0; i < 3; ++i) { + + // Step 4: Decode the length L by reading it 7 times, then up to repeat times + uint16_t repeat = 7; + uint16_t dlen = 0; + std::unordered_map decoded_len_count; + + for (int i = 0; i < repeat; ++i) { std::vector int_bits; int_bits.reserve(10); for (int p = 0; p < 10; ++p) { @@ -373,33 +461,24 @@ std::vector decode(const std::vector& coded_data, const std::str int_bits.push_back(bit); } uint16_t uint_len = make_unsigned_int_10bit(int_bits); - decoded_len.push_back(uint_len); - } - - // Step 5: Determine the most common L - std::unordered_map decoded_len_count; - for (auto len : decoded_len) { - decoded_len_count[len]++; - } - - uint16_t dlen = 0; - int max_len_count = 0; - for (const auto& pair : decoded_len_count) { - if (pair.second > max_len_count) { - max_len_count = pair.second; - dlen = pair.first; + decoded_len_count[uint_len]++; + if(i > 5){ + // calculate updated repeat + uint16_t max_len_count = 0; + for (const auto& pair : decoded_len_count) { + if (pair.second > max_len_count) { + max_len_count = pair.second; + dlen = pair.first; + } + } + if(max_len_count == 1) repeat++; else repeat = 65536 / (dlen * 8 + 90 + 27); + if(repeat > 560 || repeat <= i) return std::vector(); // 65536 / (90+27) =560.1 } } - if (max_len_count < 2) { // No common value - return std::vector(); // Decode failed - } - - // std::cout <<" decoded len " << dlen << std::endl; - // Step 4/5 rep, decode shift std::vector decoded_shift; - for (int i = 0; i < 3; ++i) { + for (int i = 0; i < repeat; ++i) { std::vector int_bits; int_bits.reserve(3); for (int p = 0; p < 3; ++p) { @@ -430,14 +509,23 @@ std::vector decode(const std::vector& coded_data, const std::str return std::vector(); // Decode failed } // std::cout << "sft " << dshift << std::endl; + LinkedList O1LL(generate_permutation(password, dshift + 1)); + LL_ptr = &O1LL; + for (auto it = LL_ptr->begin(); it != LL_ptr->end(); ) { + if (Oindex.find(*it) != Oindex.end()) { + // ocupied loc + it.delete_current(); + } else { + ++it; + } + } - O = generate_permutation(password, dshift + 1, 351); // replace O - current_pos = 0; // Reset position + LinkedList::Iterator O1LLit = LL_ptr->begin(); + LLit_ptr = &O1LLit; // Step 6: Decode the message based on dlen std::vector msg; msg.reserve(dlen); - size_t repeat = (65536 - 351) / (dlen * 8); for (size_t i = 0; i < dlen; ++i) { std::vector byte_bits; byte_bits.reserve(8);