diff --git a/src/main/cpp/src/parse_uri.cu b/src/main/cpp/src/parse_uri.cu index 54e79ab022..d0629cb71f 100644 --- a/src/main/cpp/src/parse_uri.cu +++ b/src/main/cpp/src/parse_uri.cu @@ -18,11 +18,13 @@ #include #include +#include #include #include #include #include #include +#include #include #include @@ -34,251 +36,728 @@ namespace spark_rapids_jni { using namespace cudf; namespace detail { + +struct uri_parts { + string_view scheme; + string_view host; + string_view authority; + string_view path; + string_view fragment; + string_view query; + string_view userinfo; + string_view port; + string_view opaque; + bool valid{false}; +}; + +enum class URI_chunks : int8_t { PROTOCOL, HOST, AUTHORITY, PATH, QUERY, USERINFO }; + +enum class chunk_validity : int8_t { VALID, INVALID, FATAL }; + namespace { -// utility to validate a character is valid in a URI -constexpr bool is_valid_character(char ch, bool alphanum_only) +// Some parsing errors are fatal and some parsing errors simply mean this +// thing doesn't exist or is invalid. For example, just because 280.0.1.16 is +// not a valid IPv4 address simply means if asking for the host the host is null +// but the authority is still 280.0.1.16 and the uri is not considered invalid. +// By contrast, the URI https://[15:6:g:invalid] will not return https for the +// scheme and is considered completely invalid. + +constexpr bool is_alpha(char c) { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); } + +constexpr bool is_numeric(char c) { return c >= '0' && c <= '9'; } + +constexpr bool is_alphanum(char c) { return is_alpha(c) || is_numeric(c); } + +constexpr bool is_hex(char c) { - if (alphanum_only) { - if (ch >= '-' && ch <= '9' && ch != '/') return true; // 0-9 and .- - if (ch >= 'A' && ch <= 'Z') return true; // A-Z - if (ch >= 'a' && ch <= 'z') return true; // a-z - } else { - if (ch >= '!' && ch <= ';' && ch != '"') return true; // 0-9 and !#%&'()*+,-./ - if (ch >= '=' && ch <= 'Z' && ch != '>') return true; // A-Z and =?@ - if (ch >= '_' && ch <= 'z' && ch != '`') return true; // a-z and _ + return is_numeric(c) || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); +} + +__device__ thrust::pair skip_and_validate_special( + string_view::const_iterator iter, + string_view::const_iterator end, + bool allow_invalid_escapes = false) +{ + while (iter != end) { + auto const c = *iter; + auto const num_bytes = cudf::strings::detail::bytes_in_char_utf8(*iter); + if (*iter == '%' && !allow_invalid_escapes) { + // verify following two characters are hexadecimal + for (int i = 0; i < 2; ++i) { + ++iter; + if (iter == end) { return {false, iter}; } + + if (!is_hex(*iter)) { return {false, iter}; } + } + } else if (num_bytes > 1) { + // UTF8 validation means it isn't whitespace and not a control character + // the normal validation will handle anything single byte, this checks for multiple byte + // whitespace + auto const c = *iter; + // There are multi-byte looking things like extended ASCII characters that are not valid UTF8. + // Check that here. + if ((c & 0xC0) != 0x80) { return {false, iter}; } + if (num_bytes > 2 && ((c & 0xC000) != 0x8000)) { return {false, iter}; } + if (num_bytes > 3 && ((c & 0xC00000) != 0x800000)) { return {false, iter}; } + + // Validate it isn't a whitespace or control unicode character. + if ((c >= 0xc280 && c <= 0xc2a0) || c == 0xe19a80 || (c >= 0xe28080 && c <= 0xe2808a) || + c == 0xe280af || c == 0xe280a8 || c == 0xe2819f || c == 0xe38080) { + return {false, iter}; + } + } else { + break; + } + ++iter; } - return false; + + return {true, iter}; } -/** - * @brief Count the number of characters of each string after parsing the protocol. - * - * @tparam num_warps_per_threadblock Number of warps in a threadblock. This template argument must - * match the launch configuration, i.e. the kernel must be launched with - * `num_warps_per_threadblock * cudf::detail::warp_size` threads per threadblock. - * @tparam char_block_size Number of characters which will be loaded into the shared memory at a - * time. - * - * @param in_strings Input string column - * @param out_counts Number of characters in each decode URL - * @param out_validity Bitmask of validity data, updated in funcion - */ -template -__global__ void parse_uri_protocol_char_counter(column_device_view const in_strings, - size_type* const out_counts, - bitmask_type* out_validity) +template +__device__ bool validate_chunk(string_view s, Predicate fn, bool allow_invalid_escapes = false) { - __shared__ char temporary_buffer[num_warps_per_threadblock][char_block_size]; - __shared__ typename cub::WarpScan::TempStorage cub_storage[num_warps_per_threadblock]; - __shared__ bool found_token[num_warps_per_threadblock]; - - auto const global_thread_id = cudf::detail::grid_1d::global_thread_id(); - auto const global_warp_id = static_cast(global_thread_id / cudf::detail::warp_size); - auto const local_warp_id = static_cast(threadIdx.x / cudf::detail::warp_size); - auto const warp_lane = static_cast(threadIdx.x % cudf::detail::warp_size); - auto const nwarps = static_cast(gridDim.x * blockDim.x / cudf::detail::warp_size); - char* in_chars_shared = temporary_buffer[local_warp_id]; - - // Loop through strings, and assign each string to a warp. - for (thread_index_type tidx = global_warp_id; tidx < in_strings.size(); tidx += nwarps) { - auto const row_idx = static_cast(tidx); - if (in_strings.is_null(row_idx)) { - if (warp_lane == 0) out_counts[row_idx] = 0; + auto iter = s.begin(); + { + auto [valid, iter_] = skip_and_validate_special(iter, s.end(), allow_invalid_escapes); + iter = std::move(iter_); + if (!valid) { return false; } + } + while (iter != s.end()) { + if (!fn(iter)) { return false; } + + iter++; + auto [valid, iter_] = skip_and_validate_special(iter, s.end(), allow_invalid_escapes); + iter = std::move(iter_); + if (!valid) { return false; } + } + return true; +} + +bool __device__ validate_scheme(string_view scheme) +{ + // A scheme simply needs to be an alpha character followed by alphanumeric + auto iter = scheme.begin(); + if (!is_alpha(*iter)) { return false; } + while (++iter != scheme.end()) { + auto const c = *iter; + if (!is_alphanum(c) && c != '+' && c != '-' && c != '.') { return false; } + } + return true; +} + +bool __device__ validate_ipv6(string_view s) +{ + constexpr auto max_colons{8}; + + if (s.size_bytes() < 2) { return false; } + + bool found_double_colon{false}; + int open_bracket_count{0}; + int close_bracket_count{0}; + int period_count{0}; + int colon_count{0}; + int percent_count{0}; + char previous_char{0}; + int address{0}; + int address_char_count{0}; + bool address_has_hex{false}; + + auto const leading_double_colon = [&]() { + auto iter = s.begin(); + if (*iter == '[') iter++; + return *iter++ == ':' && *iter == ':'; + }(); + + for (auto iter = s.begin(); iter < s.end(); ++iter) { + auto const c = *iter; + + switch (c) { + case '[': + open_bracket_count++; + if (open_bracket_count > 1) { return false; } + break; + case ']': + close_bracket_count++; + if (close_bracket_count > 1) { return false; } + if ((period_count > 0) && (address_has_hex || address > 255)) { return false; } + break; + case ':': + colon_count++; + if (previous_char == ':') { + if (found_double_colon) { return false; } + found_double_colon = true; + } + address = 0; + address_has_hex = false; + address_char_count = 0; + if (colon_count > max_colons || (colon_count == max_colons && !found_double_colon)) { + return false; + } + // Periods before a colon don't work, periods can be an IPv4 address after this IPv6 address + // like [1:2:3:4:5:6:d.d.d.d] + if (period_count > 0 || percent_count > 0) { return false; } + break; + case '.': + period_count++; + if (percent_count > 0) { return false; } + if (period_count > 3) { return false; } + if (address_has_hex) { return false; } + if (address > 255) { return false; } + if (colon_count != 6 && !found_double_colon) { return false; } + // Special case of ::1:2:3:4:5:d.d.d.d has 7 colons - but spark says this is invalid + // if (colon_count == max_colons && !leading_double_colon) { return false; } + if (colon_count >= max_colons) { return false; } + address = 0; + address_has_hex = false; + address_char_count = 0; + break; + case '%': + // IPv6 can define a device to use for the routing. This is expressed as '%eth0' at the end + // of the address. + percent_count++; + if (percent_count > 1) { return false; } + if ((period_count > 0) && (address_has_hex || address > 255)) { return false; } + address = 0; + address_has_hex = false; + address_char_count = 0; + break; + default: + // after % all bets are off as the device name can be nearly anything + if (percent_count == 0) { + if (address_char_count > 3) { return false; } + address_char_count++; + address *= 10; + if (c >= 'a' && c <= 'f') { + address += 10; + address += c - 'a'; + address_has_hex = true; + } else if (c >= 'A' && c <= 'Z') { + address += 10; + address += c - 'A'; + address_has_hex = true; + } else if (c >= '0' && c <= '9') { + address += c - '0'; + } else { + return false; + } + } + break; + } + previous_char = c; + } + + return true; +} + +bool __device__ validate_ipv4(string_view s) +{ + // dotted quad (0-255).(0-255).(0-255).(0-255) + int address = 0; + int address_char_count = 0; + int dot_count = 0; + for (auto iter = s.begin(); iter < s.end(); ++iter) { + auto const c = *iter; + + // can't lead with a . + if ((c < '0' || c > '9') && (iter == s.begin() || c != '.')) { return false; } + + if (c == '.') { + // verify we saw at least one character and reset values + if (address_char_count == 0) { return false; } + address = 0; + address_char_count = 0; + dot_count++; continue; } - auto const in_string = in_strings.element(row_idx); - auto const in_chars = in_string.data(); - auto const string_length = in_string.size_bytes(); - auto const nblocks = cudf::util::div_rounding_up_unsafe(string_length, char_block_size); - size_type output_string_size = 0; + address_char_count++; + address *= 10; + address += c - '0'; + + if (address > 255) { return false; } + } + + // can't end with a . + if (address_char_count == 0) { return false; } + + // must be 4 portions seperated by 3 dots. + if (dot_count != 3) { return false; } + + return true; +} + +bool __device__ validate_domain_name(string_view name) +{ + // domain name can be alphanum or -. + // slash can not be the first of last character of the domain name or around a . + bool last_was_slash = false; + bool last_was_period = false; + bool numeric_start = false; + for (auto iter = name.begin(); iter < name.end(); ++iter) { + auto const c = *iter; + if (!is_alphanum(c) && c != '-' && c != '.') { return false; } + + // the final section can't start with a digit + if (last_was_period && c >= '0' && c <= '9') { + numeric_start = true; + } else { + numeric_start = false; + } + + if (c == '-') { + if (last_was_period || iter == name.begin() || iter == --name.end()) { return false; } + last_was_slash = true; + last_was_period = false; + } else if (c == '.') { + if (last_was_slash) { return false; } + last_was_period = true; + last_was_slash = false; + } else { + last_was_period = false; + last_was_slash = false; + } + } + + // numeric start to last part of domain isn't allowed. + if (numeric_start) { return false; } + + return true; +} + +chunk_validity __device__ validate_host(string_view host) +{ + // This can be IPv4, IPv6, or a domain name. + if (*host.begin() == '[') { + // If last character is a ], this is IPv6 or invalid. + if (*(host.end() - 1) != ']') { + // invalid + return chunk_validity::FATAL; + } + if (!validate_ipv6(host)) { return chunk_validity::FATAL; } + + return chunk_validity::VALID; + } - // valid until proven otherwise - bool valid{true}; + // If there are more [ or ] characters this is invalid. + // Also need to find the last . + int last_open_bracket = -1; + int last_close_bracket = -1; + int last_period = -1; + + // The original plan on this loop was to get fancy and use a reverse iterator and exit when + // everything was found, but the expectation is there are no brackets in this string, so we have + // to traverse the entire thing anyway to verify that. The math is easier with a forward iterator, + // so we're back here. + for (auto iter = host.begin(); iter < host.end(); ++iter) { + auto const c = *iter; + if (c == '[') { + last_open_bracket = iter.position(); + } else if (c == ']') { + last_close_bracket = iter.position(); + } else if (c == '.') { + last_period = iter.position(); + } + } - // Use the last thread of the warp to initialize `found_token` to false. - if (warp_lane == cudf::detail::warp_size - 1) { found_token[local_warp_id] = false; } + if (last_open_bracket >= 0 || last_close_bracket >= 0) { return chunk_validity::FATAL; } - for (size_type block_idx = 0; block_idx < nblocks && valid; block_idx++) { - auto const string_length_block = - std::min(char_block_size, string_length - char_block_size * block_idx); + // If we didn't find a period or if the last character is a period or the character after the last + // period is non numeric + if (last_period < 0 || last_period == host.length() - 1 || host[last_period + 1] < '0' || + host[last_period + 1] > '9') { + // must be domain name or it is invalid + if (validate_domain_name(host)) { return chunk_validity::VALID; } - // Each warp collectively loads input characters of the current block to the shared memory. - for (auto char_idx = warp_lane; char_idx < string_length_block; - char_idx += cudf::detail::warp_size) { - auto const in_idx = block_idx * char_block_size + char_idx; - in_chars_shared[char_idx] = in_idx < string_length ? in_chars[in_idx] : 0; + // the only other option is that this is a IPv4 address + } else if (validate_ipv4(host)) { + return chunk_validity::VALID; + } + + return chunk_validity::INVALID; +} + +bool __device__ validate_query(string_view query) +{ + // query can be alphanum and _-!.~'()*,;:$&+=?/[]@" + return validate_chunk(query, [] __device__(string_view::const_iterator iter) { + auto const c = *iter; + if (c != '!' && c != '"' && c != '$' && !(c >= '&' && c <= ';') && c != '=' && + !(c >= '?' && c <= ']' && c != '\\') && !(c >= 'a' && c <= 'z') && c != '_' && c != '~') { + return false; + } + return true; + }); +} + +bool __device__ validate_authority(string_view authority, bool allow_invalid_escapes) +{ + // authority needs to be alphanum and @[]_-!.'()*,;:$&+= + return validate_chunk( + authority, + [allow_invalid_escapes] __device__(string_view::const_iterator iter) { + auto const c = *iter; + if (c != '!' && c != '$' && !(c >= '&' && c <= ';' && c != '/') && c != '=' && + !(c >= '@' && c <= '_' && c != '^' && c != '\\') && !(c >= 'a' && c <= 'z') && c != '~' && + (!allow_invalid_escapes || c != '%')) { + return false; } + return true; + }, + allow_invalid_escapes); +} - __syncwarp(); - - // `char_idx_start` represents the start character index of the current warp. - for (size_type char_idx_start = 0; char_idx_start < string_length_block; - char_idx_start += cudf::detail::warp_size) { - auto const char_idx = char_idx_start + warp_lane; - char const* const ch_ptr = in_chars_shared + char_idx; - - // need to know if the character we are validating is before or after the token - // as valid characters changes. Default to 1 to handle the case where we have - // alreayd found the token and do not search for it again. - int8_t out_tokens{1}; - if (!found_token[local_warp_id]) { - // Warp-wise prefix sum to establish tokens of string. - // All threads in the warp participate in the prefix sum, even if `char_idx` is beyond - // `string_length_block`. - int8_t const is_token = (char_idx < string_length_block && *ch_ptr == ':') ? 1 : 0; - cub::WarpScan(cub_storage[local_warp_id]).InclusiveSum(is_token, out_tokens); - } +bool __device__ validate_userinfo(string_view userinfo) +{ + // can't be ] or [ in here + return validate_chunk(userinfo, [] __device__(string_view::const_iterator iter) { + auto const c = *iter; + if (c == '[' || c == ']') { return false; } + return true; + }); +} - auto const before_token = out_tokens == 0; - valid = valid && __ballot_sync(0xffffffff, - (char_idx >= string_length_block || - is_valid_character(*ch_ptr, before_token)) - ? 0 - : 1) == 0; - if (!valid) { - // last thread in warp sets validity - if (warp_lane == cudf::detail::warp_size - 1) { - clear_bit(out_validity, row_idx); - out_counts[row_idx] = 0; - } +bool __device__ validate_port(string_view port) +{ + // port is positive numeric >=0 according to spark...shrug + return validate_chunk(port, [] __device__(string_view::const_iterator iter) { + auto const c = *iter; + if (c < '0' && c > '9') { return false; } + return true; + }); +} + +bool __device__ validate_path(string_view path) +{ + // path can be alphanum and @[]_-!.~'()*?/&,;:$+= + return validate_chunk(path, [] __device__(string_view::const_iterator iter) { + auto const c = *iter; + if (c != '!' && c != '$' && !(c >= '&' && c <= ';') && c != '=' && !(c >= '@' && c <= 'Z') && + c != '_' && !(c >= 'a' && c <= 'z') && c != '~') { + return false; + } + return true; + }); +} + +bool __device__ validate_opaque(string_view opaque) +{ + // opaque can be alphanum and @[]_-!.~'()*?/,;:$@+= + return validate_chunk(opaque, [] __device__(string_view::const_iterator iter) { + auto const c = *iter; + if (c != '!' && c != '$' && !(c >= '&' && c <= ';') && c != '=' && + !(c >= '?' && c <= ']' && c != '\\') && c != '_' && c != '~' && !(c >= 'a' && c <= 'z')) { + return false; + } + return true; + }); +} + +bool __device__ validate_fragment(string_view fragment) +{ + // fragment can be alphanum and @[]_-!.~'()*?/,;:$&+= + return validate_chunk(fragment, [] __device__(string_view::const_iterator iter) { + auto const c = *iter; + if (c != '!' && c != '$' && !(c >= '&' && c <= ';') && c != '=' && + !(c >= '?' && c <= ']' && c != '\\') && c != '_' && c != '~' && !(c >= 'a' && c <= 'z')) { + return false; + } + return true; + }); +} + +uri_parts __device__ validate_uri(const char* str, int len) +{ + uri_parts ret; + + // look for :/# characters. + int col = -1; + int slash = -1; + int hash = -1; + int question = -1; + for (const char* c = str; + c - str < len && (col == -1 || slash == -1 || hash == -1 || question == -1); + ++c) { + switch (*c) { + case ':': + if (col == -1) col = c - str; + break; + case '/': + if (slash == -1) slash = c - str; + break; + case '#': + if (hash == -1) hash = c - str; + break; + case '?': + if (question == -1) question = c - str; + break; + default: break; + } + } + + // anything after the hash is part of the fragment and ignored for this part + if (hash >= 0) { + ret.fragment = {str + hash + 1, len - hash - 1}; + if (!validate_fragment(ret.fragment)) { + ret.valid = false; + return ret; + } + + len = hash; + + if (col > hash) col = -1; + if (slash > hash) slash = -1; + if (question > hash) question = -1; + } + + // if the first ':' is after the other tokens, this doesn't have a scheme or it is invalid + if (col != -1 && (slash == -1 || col < slash) && (hash == -1 || col < hash)) { + // we have a scheme up to the : + ret.scheme = {str, col}; + if (!validate_scheme(ret.scheme)) { + ret.valid = false; + return ret; + } + + // skip over scheme + auto const skip = col + 1; + str += skip; + len -= skip; + question -= skip; + hash -= skip; + slash -= skip; + } + + // no more string to parse is an error + if (len <= 0) { + ret.valid = false; + return ret; + } + + // If we have a '/' as the next character, we have a heirarchical uri. If not it is opaque. + bool const heirarchical = str[0] == '/'; + if (heirarchical) { + // a '?' will break this into query and path/authority + if (question >= 0) { + ret.query = {str + question + 1, len - question - 1}; + if (!validate_query(ret.query)) { + ret.valid = false; + return ret; + } + } + auto const path_len = question >= 0 ? question : len; + + if (str[0] == '/' && str[1] == '/') { + // If we have a '/', we have //authority/path, otherwise we have //authority with no path. + int next_slash = -1; + for (int i = 2; i < path_len; ++i) { + if (str[i] == '/') { + next_slash = i; break; } + } + ret.authority = {&str[2], + next_slash == -1 ? question < 0 ? len - 2 : question - 2 : next_slash - 2}; + if (next_slash > 0) { ret.path = {str + next_slash, path_len - next_slash}; } + + if (next_slash == -1 && ret.authority.size_bytes() == 0 && ret.query.size_bytes() == 0 && + ret.fragment.size_bytes() == 0) { + // invalid! - but spark like to return things as long as you don't have illegal characters + // ret.valid = false; + ret.valid = true; + return ret; + } + + if (ret.authority.size_bytes() > 0) { + auto ipv6_address = ret.authority.size_bytes() > 2 && *ret.authority.begin() == '['; + if (!validate_authority(ret.authority, ipv6_address)) { + ret.valid = false; + return ret; + } - // if we have already found our token, no more string copy we only need to validate - // characters - if (!found_token[local_warp_id]) { - // If the current character is before the token we will output the character. - int8_t const out_size = (char_idx >= string_length_block || out_tokens > 0) ? 0 : 1; - - // Warp-wise prefix sum to establish output location of the current thread. - // All threads in the warp participate in the prefix sum, even if `char_idx` is beyond - // `string_length_block`. - int8_t out_offset; - cub::WarpScan(cub_storage[local_warp_id]).InclusiveSum(out_size, out_offset); - - // last thread of the warp updates offsets and token since it has the last offset and - // token value - if (warp_lane == cudf::detail::warp_size - 1) { - output_string_size += out_offset; - found_token[local_warp_id] = out_tokens > 0; + // Inspect the authority for userinfo, host, and port + const char* auth = ret.authority.data(); + auto auth_size = ret.authority.size_bytes(); + int amp = -1; + int closingbracket = -1; + int last_colon = -1; + for (int i = 0; i < auth_size; ++i) { + switch (auth[i]) { + case '@': + if (amp == -1) { + amp = i; + if (last_colon > 0) { last_colon = -1; } + if (closingbracket > 0) { closingbracket = -1; } + } + break; + case ':': last_colon = amp > 0 ? i - amp - 1 : i; break; + case ']': + if (closingbracket == -1) closingbracket = amp > 0 ? i - amp : i; + break; } } - __syncwarp(); - } - } + if (amp > 0) { + ret.userinfo = {auth, amp}; + if (!validate_userinfo(ret.userinfo)) { + ret.valid = false; + return ret; + } + // skip over the @ + amp++; - // last thread of the warp sets output size - if (warp_lane == cudf::detail::warp_size - 1) { - if (!found_token[local_warp_id]) { - clear_bit(out_validity, row_idx); - out_counts[row_idx] = 0; - } else if (valid) { - out_counts[row_idx] = output_string_size; + auth += amp; + auth_size -= amp; + } + if (last_colon > 0 && last_colon > closingbracket) { + // Found a port, attempt to parse it + ret.port = {auth + last_colon + 1, auth_size - last_colon - 1}; + if (!validate_port(ret.port)) { + ret.valid = false; + return ret; + } + ret.host = {auth, last_colon}; + } else { + ret.host = {auth, auth_size}; + } + auto host_ret = validate_host(ret.host); + switch (host_ret) { + case chunk_validity::FATAL: ret.valid = false; return ret; + case chunk_validity::INVALID: ret.host = {}; break; + } } + } else { + // path with no authority + ret.path = {str, len}; + } + if (!validate_path(ret.path)) { + ret.valid = false; + return ret; + } + } else { + ret.opaque = {str, len}; + if (!validate_opaque(ret.opaque)) { + ret.valid = false; + return ret; } } + + ret.valid = true; + return ret; } +// A URI is broken into parts or chunks. There are optional chunks and required chunks. A simple URI +// such as `https://www.nvidia.com` is easy to reason about, but it could also be written as +// `www.nvidia.com`, which is still valid. On top of that, there are characters which are allowed in +// certain chunks that are not allowed in others. There have been a multitude of methods attempted +// to get this correct, but at the end of the day, we have to validate the URI completely. This +// means even the simplest task of pulling off every character before the : still requires +// understanding how to validate an ipv6 address. This kernel was originally conceived as a two-pass +// kernel that ran the same code and either filled in offsets or filled in actual data. The problem +// is that to know what characters you need to copy, you need to have parsed the entire string as a +// 2 meg string could have `:/a` at the very end and everything up to that point is protocol or it +// could end in `.com` and now it is a hostname. To prevent the code from parsing it completely for +// length and then parsing it completely to copy the data, we will store off the offset of the +// string of question. The length is already stored in the offset column, so we then have a pointer +// and a number of bytes to copy and the second pass boils down to a series of memcpy calls. + /** - * @brief Parse protocol and copy from the input string column to the output char buffer. - * - * @tparam num_warps_per_threadblock Number of warps in a threadblock. This template argument must - * match the launch configuration, i.e. the kernel must be launched with - * `num_warps_per_threadblock * cudf::detail::warp_size` threads per threadblock. - * @tparam char_block_size Number of characters which will be loaded into the shared memory at a - * time. + * @brief Count the number of characters of each string after parsing the protocol. * * @param in_strings Input string column - * @param in_validity Validity vector of output column - * @param out_chars Character buffer for the output string column - * @param out_offsets Offset value of each string associated with `out_chars` + * @param chunk Chunk of URI to return + * @param out_lengths Number of characters in each decode URL + * @param out_offsets Offsets to the start of the chunks + * @param out_validity Bitmask of validity data, updated in function */ -template -__global__ void parse_uri_to_protocol(column_device_view const in_strings, - bitmask_type* in_validity, - char* const out_chars, - size_type const* const out_offsets) +__global__ void parse_uri_char_counter(column_device_view const in_strings, + URI_chunks chunk, + size_type* const out_lengths, + size_type* const out_offsets, + bitmask_type* out_validity) { - __shared__ char temporary_buffer[num_warps_per_threadblock][char_block_size]; - __shared__ typename cub::WarpScan::TempStorage cub_storage[num_warps_per_threadblock]; - __shared__ size_type out_idx[num_warps_per_threadblock]; - __shared__ bool found_token[num_warps_per_threadblock]; - - auto const global_thread_id = cudf::detail::grid_1d::global_thread_id(); - auto const global_warp_id = static_cast(global_thread_id / cudf::detail::warp_size); - auto const local_warp_id = static_cast(threadIdx.x / cudf::detail::warp_size); - auto const warp_lane = static_cast(threadIdx.x % cudf::detail::warp_size); - auto const nwarps = static_cast(gridDim.x * blockDim.x / cudf::detail::warp_size); - char* in_chars_shared = temporary_buffer[local_warp_id]; - - // Loop through strings, and assign each string to a warp - for (thread_index_type tidx = global_warp_id; tidx < in_strings.size(); tidx += nwarps) { + // thread per row + auto const tid = cudf::detail::grid_1d::global_thread_id(); + auto const base_ptr = in_strings.child(strings_column_view::chars_column_index).data(); + + for (thread_index_type tidx = tid; tidx < in_strings.size(); + tidx += cudf::detail::grid_1d::grid_stride()) { auto const row_idx = static_cast(tidx); - if (!bit_is_set(in_validity, row_idx)) { continue; } + if (in_strings.is_null(row_idx)) { + out_lengths[row_idx] = 0; + continue; + } auto const in_string = in_strings.element(row_idx); auto const in_chars = in_string.data(); auto const string_length = in_string.size_bytes(); - auto out_chars_string = out_chars + out_offsets[row_idx]; - auto const nblocks = cudf::util::div_rounding_up_unsafe(string_length, char_block_size); - - // Use the last thread of the warp to initialize `out_idx` to 0 and `found_token` to false. - if (warp_lane == cudf::detail::warp_size - 1) { - out_idx[local_warp_id] = 0; - found_token[local_warp_id] = false; - } - __syncwarp(); - - for (size_type block_idx = 0; block_idx < nblocks && !found_token[local_warp_id]; block_idx++) { - auto const string_length_block = - std::min(char_block_size, string_length - char_block_size * block_idx); + auto const uri = validate_uri(in_chars, string_length); + if (!uri.valid) { + out_lengths[row_idx] = 0; + clear_bit(out_validity, row_idx); + } else { + // stash output offsets and lengths for next kernel to do the copy + switch (chunk) { + case URI_chunks::PROTOCOL: + out_lengths[row_idx] = uri.scheme.size_bytes(); + out_offsets[row_idx] = uri.scheme.data() - base_ptr; + break; + case URI_chunks::HOST: + out_lengths[row_idx] = uri.host.size_bytes(); + out_offsets[row_idx] = uri.host.data() - base_ptr; + break; + case URI_chunks::AUTHORITY: + out_lengths[row_idx] = uri.authority.size_bytes(); + out_offsets[row_idx] = uri.authority.data() - base_ptr; + break; + case URI_chunks::PATH: + out_lengths[row_idx] = uri.path.size_bytes(); + out_offsets[row_idx] = uri.path.data() - base_ptr; + break; + case URI_chunks::QUERY: + out_lengths[row_idx] = uri.query.size_bytes(); + out_offsets[row_idx] = uri.query.data() - base_ptr; + break; + case URI_chunks::USERINFO: + out_lengths[row_idx] = uri.userinfo.size_bytes(); + out_offsets[row_idx] = uri.userinfo.data() - base_ptr; + break; + } - // Each warp collectively loads input characters of the current block to shared memory. - for (auto char_idx = warp_lane; char_idx < string_length_block; - char_idx += cudf::detail::warp_size) { - auto const in_idx = block_idx * char_block_size + char_idx; - in_chars_shared[char_idx] = in_idx >= 0 && in_idx < string_length ? in_chars[in_idx] : 0; + if (out_lengths[row_idx] == 0) { + // A URI can be valid, but still have no data for a specific chunk + clear_bit(out_validity, row_idx); } + } + } +} - __syncwarp(); - - // `char_idx_start` represents the start character index of the current warp. - for (size_type char_idx_start = 0; - char_idx_start < string_length_block && !found_token[local_warp_id]; - char_idx_start += cudf::detail::warp_size) { - auto const char_idx = char_idx_start + warp_lane; - char const* const ch_ptr = in_chars_shared + char_idx; - - // Warp-wise prefix sum to establish tokens of string. - // All threads in the warp participate in the prefix sum, even if `char_idx` is beyond - // `string_length_block`. - int8_t const is_token = (char_idx < string_length_block && *ch_ptr == ':') ? 1 : 0; - int8_t out_tokens; - cub::WarpScan(cub_storage[local_warp_id]).InclusiveSum(is_token, out_tokens); - - // If the current character is before the token we will output the character. - int8_t const out_size = (char_idx >= string_length_block || out_tokens > 0) ? 0 : 1; - - // Warp-wise prefix sum to establish output location of the current thread. - // All threads in the warp participate in the prefix sum, even if `char_idx` is beyond - // `string_length_block`. - int8_t out_offset; - cub::WarpScan(cub_storage[local_warp_id]).ExclusiveSum(out_size, out_offset); - - // out_size of 1 means this thread writes a byte - if (out_size == 1) { out_chars_string[out_idx[local_warp_id] + out_offset] = *ch_ptr; } - - // last thread of the warp updates the offset and the token - if (warp_lane == cudf::detail::warp_size - 1) { - out_idx[local_warp_id] += (out_offset + out_size); - found_token[local_warp_id] = out_tokens > 0; - } +/** + * @brief Parse protocol and copy from the input string column to the output char buffer. + * + * @param in_strings Input string column + * @param src_offsets Offset value of source strings in in_strings + * @param offsets Offset value of each string associated with `out_chars` + * @param out_chars Character buffer for the output string column + */ +__global__ void parse_uri(column_device_view const in_strings, + size_type const* const src_offsets, + size_type const* const offsets, + char* const out_chars) +{ + auto const tid = cudf::detail::grid_1d::global_thread_id(); + auto const base_ptr = in_strings.child(strings_column_view::chars_column_index).data(); + + for (thread_index_type tidx = tid; tidx < in_strings.size(); + tidx += cudf::detail::grid_1d::grid_stride()) { + auto const row_idx = static_cast(tidx); + auto const len = offsets[row_idx + 1] - offsets[row_idx]; - __syncwarp(); + if (len > 0) { + for (int i = 0; i < len; i++) { + out_chars[offsets[row_idx] + i] = base_ptr[src_offsets[row_idx] + i]; } } } @@ -286,16 +765,16 @@ __global__ void parse_uri_to_protocol(column_device_view const in_strings, } // namespace -std::unique_ptr parse_uri_to_protocol(strings_column_view const& input, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) +std::unique_ptr parse_uri(strings_column_view const& input, + URI_chunks chunk, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) { size_type strings_count = input.size(); - if (strings_count == 0) return make_empty_column(type_id::STRING); + if (strings_count == 0) { return make_empty_column(type_id::STRING); } constexpr size_type num_warps_per_threadblock = 4; constexpr size_type threadblock_size = num_warps_per_threadblock * cudf::detail::warp_size; - constexpr size_type char_block_size = 256; auto const num_threadblocks = std::min(65536, cudf::util::div_rounding_up_unsafe(strings_count, num_warps_per_threadblock)); @@ -306,6 +785,9 @@ std::unique_ptr parse_uri_to_protocol(strings_column_view const& input, auto offsets_column = make_numeric_column( data_type{type_to_id()}, offset_count, mask_state::UNALLOCATED, stream, mr); + // build src offsets buffer + auto src_offsets = rmm::device_uvector(strings_count, stream); + // copy null mask rmm::device_buffer null_mask = input.parent().nullable() @@ -315,11 +797,12 @@ std::unique_ptr parse_uri_to_protocol(strings_column_view const& input, // count number of bytes in each string after parsing and store it in offsets_column auto offsets_view = offsets_column->view(); auto offsets_mutable_view = offsets_column->mutable_view(); - parse_uri_protocol_char_counter - <<>>( - *d_strings, - offsets_mutable_view.begin(), - reinterpret_cast(null_mask.data())); + parse_uri_char_counter<<>>( + *d_strings, + chunk, + offsets_mutable_view.begin(), + reinterpret_cast(src_offsets.data()), + reinterpret_cast(null_mask.data())); // use scan to transform number of bytes into offsets thrust::exclusive_scan(rmm::exec_policy(stream), @@ -335,13 +818,12 @@ std::unique_ptr parse_uri_to_protocol(strings_column_view const& input, auto chars_column = cudf::strings::detail::create_chars_child_column(out_chars_bytes, stream, mr); auto d_out_chars = chars_column->mutable_view().data(); - // parse and copy the characters from the input column to the output column - parse_uri_to_protocol - <<>>( - *d_strings, - reinterpret_cast(null_mask.data()), - d_out_chars, - offsets_column->view().begin()); + // copy the characters from the input column to the output column + parse_uri<<>>( + *d_strings, + reinterpret_cast(src_offsets.data()), + offsets_column->view().begin(), + d_out_chars); auto null_count = cudf::null_count(reinterpret_cast(null_mask.data()), 0, strings_count); @@ -362,7 +844,7 @@ std::unique_ptr parse_uri_to_protocol(strings_column_view const& input, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::parse_uri_to_protocol(input, stream, mr); + return detail::parse_uri(input, detail::URI_chunks::PROTOCOL, stream, mr); } } // namespace spark_rapids_jni \ No newline at end of file diff --git a/src/main/cpp/tests/parse_uri.cpp b/src/main/cpp/tests/parse_uri.cpp index 3ff14a6075..6f522829b6 100644 --- a/src/main/cpp/tests/parse_uri.cpp +++ b/src/main/cpp/tests/parse_uri.cpp @@ -71,7 +71,33 @@ TEST_F(ParseURIProtocolTests, SparkEdges) "/absolute/path", "http://%77%77%77.%4EV%49%44%49%41.com", "https:://broken.url", - "https://www.nvidia.com/q/This%20is%20a%20query"}); + "https://www.nvidia.com/q/This%20is%20a%20query", + "https://www.nvidia.com/\x93path/path/to/file", + "http://?", + "http://??", + "http://\?\?/", + "http://#", + "http://user:pass@host/file;param?query;p2", + "http://[1:2:3:4:5:6:7::]", + "http://[::2:3:4:5:6:7:8]", + "http://[fe80::7:8%eth0]", + "http://[fe80::7:8%1]", + "http://foo.bar/abc/\\\\\\http://foo.bar/abc.gif\\\\\\", + "www.nvidia.com:8100/servlet/" + "impc.DisplayCredits?primekey_in=2000041100:05:14115240636", + "https://nvidia.com/2Ru15Ss ", + "http://www.nvidia.com/plugins//##", + "www.nvidia.com:81/Free.fr/L7D9qw9X4S-aC0&D4X0/Panels&solutionId=0X54a/" + "cCdyncharset=UTF-8&t=01wx58Tab&ps=solution/" + "ccmd=_help&locale0X1&countrycode=MA/", + "http://www.nvidia.com/tags.php?%2F88\323\351\300ึณวน\331\315\370%2F", + "http://www.nvidia.com//wp-admin/includes/index.html#9389#123", + "http://www.nvidia.com/" + "object.php?object=ะก-\320%9Fะฑ-ะฟ-ะก\321%82\321%80ะตะป\321%8Cะฝะฐ-\321%83ะป-\320%" + "97ะฐะฒะพะด\321%81ะบะฐ\321%8F.html&sid=5", + "http://www.nvidia.com/picshow.asp?id=106&mnid=5080&classname=\271\253ืฐฦช", + "http://-.~_!$&'()*+,;=:%40:80%2f::::::@example.com:443", + "http://userid:password@example.com:8080/"}); auto result = spark_rapids_jni::parse_uri_to_protocol(cudf::strings_column_view{col}); @@ -88,8 +114,87 @@ TEST_F(ParseURIProtocolTests, SparkEdges) "", "http", "https", - "https"}, - {1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1}); + "https", + "", + "http", + "http", + "http", + "http", + "http", + "http", + "http", + "http", + "http", + "", + "www.nvidia.com", + "", + "", + "www.nvidia.com", + "", + "", + "", + "", + "http", + "http"}, + {1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); -} \ No newline at end of file +} + +TEST_F(ParseURIProtocolTests, IP6) +{ + cudf::test::strings_column_wrapper col({ + "https://[fe80::]", + "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", + "https://[2001:0DB8:85A3:0000:0000:8A2E:0370:7334]", + "https://[2001:db8::1:0]", + "http://[2001:db8::2:1]", + "https://[::1]", + "https://[2001:db8:85a3:8d3:1319:8a2e:370:7348]:443", + "https://[2001:db8:3333:4444:5555:6666:1.2.3.4]/path/to/file", + "https://[2001:db8:3333:4444:5555:6666:7777:8888:1.2.3.4]/path/to/file", + "https://[::db8:3333:4444:5555:6666:1.2.3.4]/path/to/file]", // this is valid, but spark + // doesn't think so + }); + auto result = spark_rapids_jni::parse_uri_to_protocol(cudf::strings_column_view{col}); + + cudf::test::strings_column_wrapper expected( + {"https", "https", "https", "https", "http", "https", "https", "https", "", ""}, + {1, 1, 1, 1, 1, 1, 1, 1, 0, 0}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); +} + +TEST_F(ParseURIProtocolTests, IP4) +{ + cudf::test::strings_column_wrapper col({ + "https://192.168.1.100/", + "https://192.168.1.100:8443/", + "https://192.168.1.100.5/", + "https://192.168.1/", + "https://280.100.1.1/", + "https://182.168..100/path/to/file", + }); + auto result = spark_rapids_jni::parse_uri_to_protocol(cudf::strings_column_view{col}); + + cudf::test::strings_column_wrapper expected( + {"https", "https", "https", "https", "https", "https"}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); +} + +TEST_F(ParseURIProtocolTests, UTF8) +{ + cudf::test::strings_column_wrapper col({ + "https://nvidia.com/%4EV%49%44%49%41", + "http://%77%77%77.%4EV%49%44%49%41.com", + "http://✪↩d⁚f„⁈.ws/123", + "https:// /path/to/file", + }); + auto result = spark_rapids_jni::parse_uri_to_protocol(cudf::strings_column_view{col}); + + cudf::test::strings_column_wrapper expected({"https", "http", "http", ""}, {1, 1, 1, 0}); + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(result->view(), expected); +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java index 7289d110b2..5e90111f21 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ParseURITest.java @@ -25,10 +25,33 @@ import ai.rapids.cudf.ColumnVector; public class ParseURITest { + void buildExpectedAndRun(String[] testData) { + String[] expectedStrings = new String[testData.length]; + for (int i=0; i