Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added chat template support to llama-run #11215

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 157 additions & 33 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#if defined(_WIN32)
# include <windows.h>
# include <io.h>
# include <windows.h>
#else
# include <sys/file.h>
# include <sys/ioctl.h>
Expand All @@ -12,12 +12,14 @@
#endif

#include <signal.h>
#include <sys/stat.h>

#include <climits>
#include <cstdarg>
#include <cstdio>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
Expand All @@ -35,13 +37,14 @@
#endif

GGML_ATTRIBUTE_FORMAT(1, 2)

static std::string fmt(const char * fmt, ...) {
va_list ap;
va_list ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
const int size = vsnprintf(NULL, 0, fmt, ap);
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
std::string buf;
buf.resize(size);
const int size2 = vsnprintf(const_cast<char *>(buf.data()), buf.size() + 1, fmt, ap2);
Expand All @@ -53,6 +56,7 @@ static std::string fmt(const char * fmt, ...) {
}

GGML_ATTRIBUTE_FORMAT(1, 2)

static int printe(const char * fmt, ...) {
va_list args;
va_start(args, fmt);
Expand Down Expand Up @@ -101,7 +105,8 @@ class Opt {

llama_context_params ctx_params;
llama_model_params model_params;
std::string model_;
std::string model_;
std::string chat_template_;
std::string user;
int context_size = -1, ngl = -1;
float temperature = -1;
Expand Down Expand Up @@ -137,7 +142,7 @@ class Opt {
}

int parse(int argc, const char ** argv) {
bool options_parsing = true;
bool options_parsing = true;
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
Expand Down Expand Up @@ -166,6 +171,11 @@ class Opt {

++positional_args_i;
model_ = argv[i];
} else if (options_parsing && strcmp(argv[i], "--chat-template") == 0) {
if (i + 1 >= argc) {
return 1;
}
chat_template_ = argv[++i];
} else if (positional_args_i == 1) {
++positional_args_i;
user = argv[i];
Expand Down Expand Up @@ -475,7 +485,9 @@ class HttpClient {
return (now_downloaded_plus_file_size * 100) / total_to_download;
}

static std::string generate_progress_prefix(curl_off_t percentage) { return fmt("%3ld%% |", static_cast<long int>(percentage)); }
static std::string generate_progress_prefix(curl_off_t percentage) {
return fmt("%3ld%% |", static_cast<long int>(percentage));
}

static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
const auto now = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -515,6 +527,7 @@ class HttpClient {
printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(),
progress_suffix.c_str());
}

// Function to write data to a file
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
FILE * out = static_cast<FILE *>(stream);
Expand All @@ -538,19 +551,23 @@ class LlamaData {
std::vector<llama_chat_message> messages;
std::vector<std::string> msg_strs;
std::vector<char> fmtted;
std::string chat_template;

int init(Opt & opt) {
model = initialize_model(opt);
if (!model) {
return 1;
}

chat_template = initialize_chat_template(model, opt);

context = initialize_context(model, opt);
if (!context) {
return 1;
}

sampler = initialize_sampler(opt);

return 0;
}

Expand All @@ -573,21 +590,74 @@ class LlamaData {
}
#endif

int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
int huggingface_dl_tmpl(const std::string & hfr, const std::vector<std::string> headers, const std::string & tn) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In most of the cases, chat template is already stored inside gguf file and can be accessed using common_get_builtin_chat_template

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hint!
I added this to the initialize_chat_template function where the gguf file is inspected for the chat template as well as the separate chat template file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah so what we are trying to do here, only really effects the Ollama case where the template Ollama layer downloaded seems to take precedence over the one inside the gguf file.

Copy link
Collaborator

@ericcurtin ericcurtin Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But maybe it's just my ignorance, maybe huggingface has similar functionality like this to Ollama, you discovered it and I was just unaware :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this huggingface_dl_tmpl function is redundant because convert_hf_to_gguf.py always copy the jinja template from tokenizer_config.json into GGUF.

And btw not sure if it's related, but for HF, the application/vnd.ollama.image.template inside manifest is also a Go template, not a Jinja. We do have a compatibility to convert Jinja --> Go so that it can work on ollama. This tool allow you to debug what's inside: https://huggingface.co/spaces/ngxson/debug_ollama_manifest

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think any model that works with Ollama should be able to work with llama-run . There's gaps in functionality, lets fill those gaps. Maybe the model isn't exactly well-formed, etc. But, Ollama can still run this model at the end of the day, there's no reason llama-run can't do the same.

Copy link
Collaborator

@ngxson ngxson Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think any model that works with Ollama should be able to work with llama-run

I'm doubt on this, as ollama has their own fork (and not clone) version of llama.cpp, so they can implement changes that are not compatible with upstream llama.cpp (example can be phi-4, as explained above)

There's gaps in functionality, lets fill those gaps. Maybe the model isn't exactly well-formed, etc.

To be completely honest, you can archive that goal faster and easier by moving llama-run into a full prod-ready product instead of letting it stay as an example. To properly support jinja or go, you have third party library that can handle that. The reason why we don't have it here in llama.cpp is because they are too big.

But, Ollama can still run this model at the end of the day, there's no reason llama-run can't do the same.

True, but don't forget that a big part of ollama is built on Go (i.e. template part), not C++

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope that eventually things like llama-server and llama-run do become full prod-ready products eventually.

Copy link
Collaborator

@slaren slaren Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Maybe we should separate the examples that are meant to show how to use llama.cpp, from programs like llama-cli, llama-server, llama-run, and possibly others like llama-perplexity and llama-bench, that are intended to be used by everybody, and not particularly useful as examples. We should only distribute these programs in the binary distributions, rather than the entire set of examples and tests. They could be moved to a tools directory.

Copy link
Collaborator

@ericcurtin ericcurtin Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of me wishes we had a llama-client that interacted with llama-server (with a CLI interface like llama-run). Although rather than C++, python possibly makes more sense as one can use the openai client library.

if (std::filesystem::exists(tn)) {
return 0;
}

const std::string config_url = "https://huggingface.co/" + hfr + "/resolve/main/tokenizer_config.json";
std::string tokenizer_config_str;
download(config_url, headers, "", true, &tokenizer_config_str);
if (tokenizer_config_str.empty()) {
// still return success since tokenizer_config is optional
return 0;
}

nlohmann::json config = nlohmann::json::parse(tokenizer_config_str);
std::string tmpl = config["chat_template"];

FILE * tmpl_file = fopen(tn.c_str(), "w");
if (tmpl_file == NULL) {
return 1;
}
fprintf(tmpl_file, "%s", tmpl.c_str());
fclose(tmpl_file);

return 0;
}

int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn,
const std::string & tn) {
bool model_exists = std::filesystem::exists(bn);
bool chat_tmpl_exists = std::filesystem::exists(tn);
if (model_exists && chat_tmpl_exists) {
return 0;
}

// Find the second occurrence of '/' after protocol string
size_t pos = model.find('/');
pos = model.find('/', pos + 1);
if (pos == std::string::npos) {
return 1;
}

const std::string hfr = model.substr(0, pos);
const std::string hff = model.substr(pos + 1);
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
return download(url, headers, bn, true);

if (!chat_tmpl_exists) {
const int ret = huggingface_dl_tmpl(hfr, headers, tn);
if (ret) {
return ret;
}
}

if (!model_exists) {
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
const int ret = download(url, headers, bn, true);
if (ret) {
return ret;
}
}
return 0;
}

int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn,
const std::string & tn) {
bool model_exists = std::filesystem::exists(bn);
bool chat_tmpl_exists = std::filesystem::exists(tn);
if (model_exists && chat_tmpl_exists) {
return 0;
}

if (model.find('/') == std::string::npos) {
model = "library/" + model;
}
Expand All @@ -607,16 +677,34 @@ class LlamaData {
}

nlohmann::json manifest = nlohmann::json::parse(manifest_str);
std::string layer;
std::string sha_model;
std::string sha_template;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is a bad approach because ollama uses Go template, not Jinja.

If you want to add this, be aware that most templates won't be detected.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't know about this. Considering this I agree that this approach doesn't seem good.
What do you think? @ericcurtin

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not an expert on this, but whatever works to make "vnd.ollama.image.template" compatible with llama-run

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You know the magic model to test this stuff granite-code :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No shame in reading ollama code either to figure out how it gets passed to llama.cpp

for (const auto & l : manifest["layers"]) {
if (l["mediaType"] == "application/vnd.ollama.image.model") {
layer = l["digest"];
break;
sha_model = l["digest"];
}
if (l["mediaType"] == "application/vnd.ollama.image.template") {
sha_template = l["digest"];
}
}

if (!chat_tmpl_exists && !sha_template.empty()) {
std::string tmpl_blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + sha_template;
const int tmpl_ret = download(tmpl_blob_url, headers, tn, true);
if (tmpl_ret) {
return tmpl_ret;
}
}

if (!model_exists) {
std::string model_blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + sha_model;
const int model_ret = download(model_blob_url, headers, bn, true);
if (model_ret) {
return model_ret;
}
}

std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
return download(blob_url, headers, bn, true);
return 0;
}

std::string basename(const std::string & path) {
Expand All @@ -628,6 +716,15 @@ class LlamaData {
return path.substr(pos + 1);
}

std::string get_proto(const std::string & model_) {
const std::string::size_type pos = model_.find("://");
if (pos == std::string::npos) {
return "";
}

return model_.substr(0, pos + 3); // Include "://"
}

int remove_proto(std::string & model_) {
const std::string::size_type pos = model_.find("://");
if (pos == std::string::npos) {
Expand All @@ -638,38 +735,40 @@ class LlamaData {
return 0;
}

int resolve_model(std::string & model_) {
int ret = 0;
if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
int resolve_model(std::string & model_, std::string & chat_template_) {
int ret = 0;
if (string_starts_with(model_, "file://")) {
Copy link
Collaborator

@ericcurtin ericcurtin Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we re-add "|| std::filesystem::exists(model_)" ?

The logic is supposed to be, if the file exists, use it, otherwise if we don't have a protocol specified, assume we need to pull from ollama

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this into the download functions to check for the model and template file individually. If the model or template file is there, we'll skip.

remove_proto(model_);

return ret;
}

std::string proto = get_proto(model_);
remove_proto(model_);

const std::string bn = basename(model_);
const std::string tn = chat_template_.empty() ? bn + ".template" : chat_template_;
const std::vector<std::string> headers = { "--header",
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
remove_proto(model_);
ret = huggingface_dl(model_, headers, bn);
} else if (string_starts_with(model_, "ollama://")) {
remove_proto(model_);
ret = ollama_dl(model_, headers, bn);
} else if (string_starts_with(model_, "https://")) {
if (string_starts_with(proto, "hf://") || string_starts_with(proto, "huggingface://")) {
ret = huggingface_dl(model_, headers, bn, tn);
} else if (string_starts_with(proto, "ollama://")) {
ret = ollama_dl(model_, headers, bn, tn);
} else if (string_starts_with(proto, "https://")) {
download(model_, headers, bn, true);
} else {
ret = ollama_dl(model_, headers, bn);
ret = ollama_dl(model_, headers, bn, tn);
}

model_ = bn;
model_ = bn;
chat_template_ = tn;

return ret;
}

// Initializes the model and returns a unique pointer to it
llama_model_ptr initialize_model(Opt & opt) {
ggml_backend_load_all();
resolve_model(opt.model_);
resolve_model(opt.model_, opt.chat_template_);
printe(
"\r%*s"
"\rLoading model",
Expand Down Expand Up @@ -702,6 +801,31 @@ class LlamaData {

return sampler;
}

std::string initialize_chat_template(const llama_model_ptr & model, const Opt & opt) {
if (!std::filesystem::exists(opt.chat_template_)) {
return common_get_builtin_chat_template(model.get());
}

FILE * tmpl_file = ggml_fopen(opt.chat_template_.c_str(), "r");
if (!tmpl_file) {
std::cerr << "Error opening file '" << opt.chat_template_ << "': " << strerror(errno) << "\n";
return "";
}

fseek(tmpl_file, 0, SEEK_END);
size_t size = ftell(tmpl_file);
fseek(tmpl_file, 0, SEEK_SET);

std::vector<unsigned char> data(size);
size_t read_size = fread(data.data(), 1, size, tmpl_file);
fclose(tmpl_file);
if (read_size != size) {
std::cerr << "Error reading file '" << opt.chat_template_ << "': " << strerror(errno) << "\n";
return "";
}
return std::string(data.begin(), data.end());
}
};

// Add a message to `messages` and store its content in `msg_strs`
Expand All @@ -713,11 +837,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(LlamaData & llama_data, const bool append) {
int result = llama_chat_apply_template(
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
llama_data.chat_template.c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
result = llama_chat_apply_template(llama_data.chat_template.c_str(), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
}
Expand All @@ -730,8 +854,8 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
std::vector<llama_token> & prompt_tokens) {
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
true) < 0) {
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) <
0) {
printe("failed to tokenize the prompt\n");
return -1;
}
Expand Down
Loading