diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 0ad8bb15b27fb..465ee3ce020e3 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -1,6 +1,6 @@ #if defined(_WIN32) -# include # include +# include #else # include # include @@ -12,12 +12,14 @@ #endif #include +#include #include #include #include #include #include +#include #include #include #include @@ -35,13 +37,14 @@ #endif GGML_ATTRIBUTE_FORMAT(1, 2) + static std::string fmt(const char * fmt, ...) { va_list ap; va_list ap2; va_start(ap, fmt); va_copy(ap2, ap); const int size = vsnprintf(NULL, 0, fmt, ap); - GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT std::string buf; buf.resize(size); const int size2 = vsnprintf(const_cast(buf.data()), buf.size() + 1, fmt, ap2); @@ -53,6 +56,7 @@ static std::string fmt(const char * fmt, ...) { } GGML_ATTRIBUTE_FORMAT(1, 2) + static int printe(const char * fmt, ...) { va_list args; va_start(args, fmt); @@ -101,7 +105,8 @@ class Opt { llama_context_params ctx_params; llama_model_params model_params; - std::string model_; + std::string model_; + std::string chat_template_; std::string user; int context_size = -1, ngl = -1; float temperature = -1; @@ -137,7 +142,7 @@ class Opt { } int parse(int argc, const char ** argv) { - bool options_parsing = true; + bool options_parsing = true; for (int i = 1, positional_args_i = 0; i < argc; ++i) { if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) { if (handle_option_with_value(argc, argv, i, context_size) == 1) { @@ -166,6 +171,11 @@ class Opt { ++positional_args_i; model_ = argv[i]; + } else if (options_parsing && strcmp(argv[i], "--chat-template") == 0) { + if (i + 1 >= argc) { + return 1; + } + chat_template_ = argv[++i]; } else if (positional_args_i == 1) { ++positional_args_i; user = argv[i]; @@ -475,7 +485,9 @@ class HttpClient { return (now_downloaded_plus_file_size * 100) / total_to_download; } - static std::string generate_progress_prefix(curl_off_t percentage) { return fmt("%3ld%% |", static_cast(percentage)); } + static std::string generate_progress_prefix(curl_off_t percentage) { + return fmt("%3ld%% |", static_cast(percentage)); + } static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) { const auto now = std::chrono::steady_clock::now(); @@ -515,6 +527,7 @@ class HttpClient { printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(), progress_suffix.c_str()); } + // Function to write data to a file static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) { FILE * out = static_cast(stream); @@ -538,6 +551,7 @@ class LlamaData { std::vector messages; std::vector msg_strs; std::vector fmtted; + std::string chat_template; int init(Opt & opt) { model = initialize_model(opt); @@ -545,12 +559,15 @@ class LlamaData { return 1; } + chat_template = initialize_chat_template(model, opt); + context = initialize_context(model, opt); if (!context) { return 1; } sampler = initialize_sampler(opt); + return 0; } @@ -573,21 +590,76 @@ class LlamaData { } #endif - int huggingface_dl(const std::string & model, const std::vector headers, const std::string & bn) { + int huggingface_dl_tmpl(const std::string & hfr, const std::vector headers, const std::string & tn) { + // if template already exists, don't download it + struct stat info; + if (stat(tn.c_str(), &info) == 0) { + return 0; + } + + const std::string config_url = "https://huggingface.co/" + hfr + "/resolve/main/tokenizer_config.json"; + std::string tokenizer_config_str; + download(config_url, headers, "", true, &tokenizer_config_str); + if (tokenizer_config_str.empty()) { + // still return success since tokenizer_config is optional + return 0; + } + + nlohmann::json config = nlohmann::json::parse(tokenizer_config_str); + std::string tmpl = config["chat_template"]; + + FILE * tmpl_file = fopen(tn.c_str(), "w"); + if (tmpl_file == NULL) { + return 1; + } + fprintf(tmpl_file, "%s", tmpl.c_str()); + fclose(tmpl_file); + + return 0; + } + + int huggingface_dl(const std::string & model, const std::vector headers, const std::string & bn, + const std::string & tn) { + bool model_exists = std::filesystem::exists(bn); + bool chat_tmpl_exists = std::filesystem::exists(tn); + if (model_exists && chat_tmpl_exists) { + return 0; + } + // Find the second occurrence of '/' after protocol string size_t pos = model.find('/'); pos = model.find('/', pos + 1); if (pos == std::string::npos) { return 1; } - const std::string hfr = model.substr(0, pos); const std::string hff = model.substr(pos + 1); - const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; - return download(url, headers, bn, true); + + if (!chat_tmpl_exists) { + const int ret = huggingface_dl_tmpl(hfr, headers, tn); + if (ret) { + return ret; + } + } + + if (!model_exists) { + const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; + const int ret = download(url, headers, bn, true); + if (ret) { + return ret; + } + } + return 0; } - int ollama_dl(std::string & model, const std::vector headers, const std::string & bn) { + int ollama_dl(std::string & model, const std::vector headers, const std::string & bn, + const std::string & tn) { + bool model_exists = std::filesystem::exists(bn); + bool chat_tmpl_exists = std::filesystem::exists(tn); + if (model_exists && chat_tmpl_exists) { + return 0; + } + if (model.find('/') == std::string::npos) { model = "library/" + model; } @@ -607,16 +679,34 @@ class LlamaData { } nlohmann::json manifest = nlohmann::json::parse(manifest_str); - std::string layer; + std::string sha_model; + std::string sha_template; for (const auto & l : manifest["layers"]) { if (l["mediaType"] == "application/vnd.ollama.image.model") { - layer = l["digest"]; - break; + sha_model = l["digest"]; + } + if (l["mediaType"] == "application/vnd.ollama.image.template") { + sha_template = l["digest"]; + } + } + + if (!chat_tmpl_exists && !sha_template.empty()) { + std::string tmpl_blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + sha_template; + const int tmpl_ret = download(tmpl_blob_url, headers, tn, true); + if (tmpl_ret) { + return tmpl_ret; + } + } + + if (!model_exists) { + std::string model_blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + sha_model; + const int model_ret = download(model_blob_url, headers, bn, true); + if (model_ret) { + return model_ret; } } - std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer; - return download(blob_url, headers, bn, true); + return 0; } std::string basename(const std::string & path) { @@ -628,6 +718,15 @@ class LlamaData { return path.substr(pos + 1); } + std::string get_proto(const std::string & model_) { + const std::string::size_type pos = model_.find("://"); + if (pos == std::string::npos) { + return ""; + } + + return model_.substr(0, pos + 3); // Include "://" + } + int remove_proto(std::string & model_) { const std::string::size_type pos = model_.find("://"); if (pos == std::string::npos) { @@ -638,30 +737,32 @@ class LlamaData { return 0; } - int resolve_model(std::string & model_) { - int ret = 0; - if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) { + int resolve_model(std::string & model_, std::string & chat_template_) { + int ret = 0; + if (string_starts_with(model_, "file://")) { remove_proto(model_); - return ret; } + std::string proto = get_proto(model_); + remove_proto(model_); + const std::string bn = basename(model_); + const std::string tn = chat_template_.empty() ? bn + ".template" : chat_template_; const std::vector headers = { "--header", "Accept: application/vnd.docker.distribution.manifest.v2+json" }; - if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) { - remove_proto(model_); - ret = huggingface_dl(model_, headers, bn); - } else if (string_starts_with(model_, "ollama://")) { - remove_proto(model_); - ret = ollama_dl(model_, headers, bn); - } else if (string_starts_with(model_, "https://")) { + if (string_starts_with(proto, "hf://") || string_starts_with(proto, "huggingface://")) { + ret = huggingface_dl(model_, headers, bn, tn); + } else if (string_starts_with(proto, "ollama://")) { + ret = ollama_dl(model_, headers, bn, tn); + } else if (string_starts_with(proto, "https://")) { download(model_, headers, bn, true); } else { - ret = ollama_dl(model_, headers, bn); + ret = ollama_dl(model_, headers, bn, tn); } - model_ = bn; + model_ = bn; + chat_template_ = tn; return ret; } @@ -669,7 +770,7 @@ class LlamaData { // Initializes the model and returns a unique pointer to it llama_model_ptr initialize_model(Opt & opt) { ggml_backend_load_all(); - resolve_model(opt.model_); + resolve_model(opt.model_, opt.chat_template_); printe( "\r%*s" "\rLoading model", @@ -702,6 +803,27 @@ class LlamaData { return sampler; } + + std::string initialize_chat_template(const llama_model_ptr & model, const Opt & opt) { + // if no template file doesn't exists, just return an empty string + struct stat info; + if (stat(opt.chat_template_.c_str(), &info) != 0) { + return common_get_builtin_chat_template(model.get()); + } + + std::ifstream tmpl_file; + tmpl_file.open(opt.chat_template_); + if (tmpl_file.fail()) { + printe("failed to open chat template: '%s'\n", opt.chat_template_.c_str()); + return ""; + } + + std::stringstream stream; + stream << tmpl_file.rdbuf(); + tmpl_file.close(); + + return stream.str(); + } }; // Add a message to `messages` and store its content in `msg_strs` @@ -713,11 +835,11 @@ static void add_message(const char * role, const std::string & text, LlamaData & // Function to apply the chat template and resize `formatted` if needed static int apply_chat_template(LlamaData & llama_data, const bool append) { int result = llama_chat_apply_template( - llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append, + llama_data.chat_template.c_str(), llama_data.messages.data(), llama_data.messages.size(), append, append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); if (append && result > static_cast(llama_data.fmtted.size())) { llama_data.fmtted.resize(result); - result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), + result = llama_chat_apply_template(llama_data.chat_template.c_str(), llama_data.messages.data(), llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.fmtted.size()); } @@ -730,8 +852,8 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt std::vector & prompt_tokens) { const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); prompt_tokens.resize(n_prompt_tokens); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, - true) < 0) { + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < + 0) { printe("failed to tokenize the prompt\n"); return -1; }