#include "ggml.h" #include #include #include #include #include #include #include #include #include #include #include #include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif // default hparams (GPT-2 774M) struct gpt_hparams { int32_t n_vocab = 50257; // Vocabulary size remains the same int32_t n_embd = 1024; // Embedding dimensionality int32_t n_head = 16; // Number of attention heads int32_t n_layer = 24; // Number of transformer layers int32_t ftype = 1; // Set to 1 for FP16 precision (optional) float eps = 1e-5f; // Small constant for numerical stability int32_t seed = -1; // RNG seed int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 200; // new tokens to predict int32_t n_parallel = 1; // number of parallel streams int32_t n_batch = 32; // batch size for prompt processing int32_t n_ctx = 1024; // context size (this is the KV cache max size) int32_t n_gpu_layers = 0; // number of layers to offlload to the GPU bool ignore_eos = false; // ignore EOS token when generating text // sampling parameters int32_t top_k = 40; float top_p = 0.9f; float temp = 0.9f; int32_t repeat_last_n = 64; float repeat_penalty = 1.00f; std::string model = "ggml-model-gpt-2-774M.bin"; // model path std::string prompt = ""; std::string token_test = ""; bool interactive = false; int32_t interactive_port = -1; }; struct gpt_vocab { using id = int32_t; using token = std::string; std::map token_to_id; std::map id_to_token; std::vector special_tokens; void add_special_token(const std::string & token); }; struct gpt_layer { // normalization struct ggml_tensor * ln_1_g; struct ggml_tensor * ln_1_b; struct ggml_tensor * ln_2_g; struct ggml_tensor * ln_2_b; // attention struct ggml_tensor * c_attn_attn_w; struct ggml_tensor * c_attn_attn_b; struct ggml_tensor * c_attn_proj_w; struct ggml_tensor * c_attn_proj_b; // mlp struct ggml_tensor * c_mlp_fc_w; struct ggml_tensor * c_mlp_fc_b; struct ggml_tensor * c_mlp_proj_w; struct ggml_tensor * c_mlp_proj_b; }; struct gpt_model { gpt_hparams hparams; // normalization struct ggml_tensor * ln_f_g; struct ggml_tensor * ln_f_b; struct ggml_tensor * wte; // position embedding struct ggml_tensor * wpe; // token embedding struct ggml_tensor * lm_head; // language model head std::vector layers; // key + value memory struct ggml_tensor * memory_k; struct ggml_tensor * memory_v; // struct ggml_context * ctx_w; std::map tensors; }; // load the model's weights from a file bool gpt_model_load(const std::string & fname, gpt_model & model, gpt_vocab & vocab) { printf("%s: loading model from '%s'\n", __func__, fname.c_str()); auto fin = std::ifstream(fname, std::ios::binary); if (!fin) { fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); return false; } // verify magic { uint32_t magic; fin.read((char *) &magic, sizeof(magic)); if (magic != GGML_FILE_MAGIC) { fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str()); return false; } } // load hparams { auto & hparams = model.hparams; fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx)); fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd)); fin.read((char *) &hparams.n_head, sizeof(hparams.n_head)); fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer)); fin.read((char *) &hparams.ftype, sizeof(hparams.ftype)); const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab); printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); printf("%s: n_embd = %d\n", __func__, hparams.n_embd); printf("%s: n_head = %d\n", __func__, hparams.n_head); printf("%s: n_layer = %d\n", __func__, hparams.n_layer); printf("%s: ftype = %d\n", __func__, hparams.ftype); printf("%s: qntvr = %d\n", __func__, qntvr); hparams.ftype %= GGML_QNT_VERSION_FACTOR; } // load vocab { int32_t n_vocab = 0; fin.read((char *) &n_vocab, sizeof(n_vocab)); if (n_vocab != model.hparams.n_vocab) { fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n", __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); return false; } std::string word; std::vector buf(128); for (int i = 0; i < n_vocab; i++) { uint32_t len; fin.read((char *) &len, sizeof(len)); buf.resize(len); fin.read((char *) buf.data(), len); word.assign(buf.data(), len); vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; } } // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation ggml_type wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); if (wtype == GGML_TYPE_COUNT) { fprintf(stderr, "%s: invalid model file '%s' (bad ftype value %d)\n", __func__, fname.c_str(), model.hparams.ftype); return false; } auto & ctx = model.ctx_w; size_t ctx_size = 0; { const auto & hparams = model.hparams; const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_g ctx_size += ggml_row_size(GGML_TYPE_F32, n_embd); // ln_f_b ctx_size += ggml_row_size(wtype, n_vocab*n_embd); // wte ctx_size += ggml_row_size(GGML_TYPE_F32, n_ctx*n_embd); // wpe ctx_size += ggml_row_size(wtype, n_vocab*n_embd); // lm_head ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_g ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_1_b ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_g ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // ln_2_b ctx_size += n_layer*(ggml_row_size(wtype, 3*n_embd*n_embd)); // c_attn_attn_w ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 3*n_embd)); // c_attn_attn_b ctx_size += n_layer*(ggml_row_size(wtype, n_embd*n_embd)); // c_attn_proj_w ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, n_embd)); // c_attn_proj_b ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_fc_w ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_fc_b ctx_size += n_layer*(ggml_row_size(wtype, 4*n_embd*n_embd)); // c_mlp_proj_w ctx_size += n_layer*(ggml_row_size(GGML_TYPE_F32, 4*n_embd)); // c_mlp_proj_b ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_k ctx_size += n_ctx*n_layer*ggml_row_size(GGML_TYPE_F32, n_embd); // memory_v ctx_size += (6 + 12*n_layer)*512; // object overhead printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); } // create the ggml context { struct ggml_init_params params = { /*.mem_size =*/ ctx_size, /*.mem_buffer =*/ NULL, /*.no_alloc =*/ false, }; model.ctx_w = ggml_init(params); if (!model.ctx_w) { fprintf(stderr, "%s: ggml_init() failed\n", __func__); return false; } } // prepare memory for the weights { const auto & hparams = model.hparams; const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; model.layers.resize(n_layer); model.ln_f_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); model.ln_f_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); model.wte = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); model.wpe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ctx); model.lm_head = ggml_new_tensor_2d(ctx, wtype, n_embd, n_vocab); // map by name model.tensors["model/ln_f/g"] = model.ln_f_g; model.tensors["model/ln_f/b"] = model.ln_f_b; model.tensors["model/wte"] = model.wte; model.tensors["model/wpe"] = model.wpe; model.tensors["model/lm_head"] = model.lm_head; for (int i = 0; i < n_layer; ++i) { auto & layer = model.layers[i]; layer.ln_1_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_2_g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.ln_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.c_attn_attn_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 3*n_embd); layer.c_attn_attn_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3*n_embd); layer.c_attn_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd); layer.c_attn_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); layer.c_mlp_fc_w = ggml_new_tensor_2d(ctx, wtype, n_embd, 4*n_embd); layer.c_mlp_fc_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_embd); layer.c_mlp_proj_w = ggml_new_tensor_2d(ctx, wtype, 4*n_embd, n_embd); layer.c_mlp_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // map by name model.tensors["model/h" + std::to_string(i) + "/ln_1/g"] = layer.ln_1_g; model.tensors["model/h" + std::to_string(i) + "/ln_1/b"] = layer.ln_1_b; model.tensors["model/h" + std::to_string(i) + "/ln_2/g"] = layer.ln_2_g; model.tensors["model/h" + std::to_string(i) + "/ln_2/b"] = layer.ln_2_b; model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/w"] = layer.c_attn_attn_w; model.tensors["model/h" + std::to_string(i) + "/attn/c_attn/b"] = layer.c_attn_attn_b; model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/w"] = layer.c_attn_proj_w; model.tensors["model/h" + std::to_string(i) + "/attn/c_proj/b"] = layer.c_attn_proj_b; model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/w"] = layer.c_mlp_fc_w; model.tensors["model/h" + std::to_string(i) + "/mlp/c_fc/b"] = layer.c_mlp_fc_b; model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/w"] = layer.c_mlp_proj_w; model.tensors["model/h" + std::to_string(i) + "/mlp/c_proj/b"] = layer.c_mlp_proj_b; } } // key + value memory { const auto & hparams = model.hparams; const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; const int n_ctx = hparams.n_ctx; const int n_mem = n_layer*n_ctx; const int n_elements = n_embd*n_mem; model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements); const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem); } // load weights { size_t total_size = 0; bool has_lm_head = false; while (true) { int32_t n_dims; int32_t length; int32_t ttype; fin.read(reinterpret_cast(&n_dims), sizeof(n_dims)); fin.read(reinterpret_cast(&length), sizeof(length)); fin.read(reinterpret_cast(&ttype), sizeof(ttype)); if (fin.eof()) { break; } int32_t nelements = 1; int32_t ne[2] = { 1, 1 }; for (int i = 0; i < n_dims; ++i) { fin.read(reinterpret_cast(&ne[i]), sizeof(ne[i])); nelements *= ne[i]; } std::string name(length, 0); fin.read(&name[0], length); if (model.tensors.find(name) == model.tensors.end()) { fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.c_str()); return false; } auto tensor = model.tensors[name]; if (ggml_nelements(tensor) != nelements) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str()); return false; } if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) { fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n", __func__, name.c_str(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]); return false; } // for debugging if (0) { printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.c_str(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor)); } const size_t bpe = ggml_type_size(ggml_type(ttype)); if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.c_str(), ggml_nbytes(tensor), nelements*bpe); return false; } fin.read(reinterpret_cast(tensor->data), ggml_nbytes(tensor)); // GPT-2 models share the WTE tensor as the LM head if (name == "model/wte" && has_lm_head == false) { memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor)); } if (name == "model/lm_head") { has_lm_head = true; } total_size += ggml_nbytes(tensor); } printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0); } fin.close(); return true; } void gpt_split_words(std::string str, std::vector& words) { const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; const std::regex re(pattern); std::smatch m; while (std::regex_search(str, m, re)) { for (auto x : m) { words.push_back(x); } str = m.suffix(); } } std::vector gpt_tokenize(const gpt_vocab & vocab, const std::string & text) { std::vector words; // first split the text into words { std::string str = text; // Generate the subpattern from the special_tokens vector if it's not empty if (!vocab.special_tokens.empty()) { const std::regex escape(R"([\[\\\^\$\.\|\?\*\+\(\)\{\}])"); std::string special_tokens_subpattern; for (const auto & token : vocab.special_tokens) { if (!special_tokens_subpattern.empty()) { special_tokens_subpattern += "|"; } special_tokens_subpattern += std::regex_replace(token, escape, R"(\$&)"); } std::regex re(special_tokens_subpattern); std::smatch m; // Split the text by special tokens. while (std::regex_search(str, m, re)) { // Split the substrings in-between special tokens into words. gpt_split_words(m.prefix(), words); // Add matched special tokens as words. for (auto x : m) { words.push_back(x); } str = m.suffix(); } // Remaining text without special tokens will be handled below. } gpt_split_words(str, words); } // find the longest token that forms each word in words: std::vector tokens; for (const auto & word : words) { for (int i = 0; i < (int) word.size(); ){ for (int j = word.size() - 1; j >= i; j--){ auto cand = word.substr(i, j-i+1); auto it = vocab.token_to_id.find(cand); if (it != vocab.token_to_id.end()){ // word.substr(i, j-i+1) in vocab tokens.push_back(it->second); i = j + 1; break; } else if (j == i){ // word.substr(i, 1) has no matching fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data()); i++; } } } } return tokens; } static std::vector parse_tokens_from_string(const std::string& input, char delimiter) { std::vector output; std::stringstream ss(input); std::string token; while (std::getline(ss, token, delimiter)) { output.push_back(std::stoi(token)); } return output; } static std::map> extract_tests_from_file(const std::string & fpath_test){ if (fpath_test.empty()){ fprintf(stderr, "%s : No test file found.\n", __func__); return std::map>(); } std::map> tests; auto fin = std::ifstream(fpath_test, std::ios_base::in); const char * delimeter = " => "; const char del_tok = ','; std::string line; while (std::getline(fin, line)) { size_t delimiterPos = line.find(delimeter); if (delimiterPos != std::string::npos) { std::string text = line.substr(0, delimiterPos); std::string s_tokens = line.substr(delimiterPos + std::strlen(delimeter)); tests[text] = parse_tokens_from_string(s_tokens, del_tok); } } return tests; } void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test){ std::map> tests = extract_tests_from_file(fpath_test); size_t n_fails = 0; for (const auto & test : tests) { std::vector tokens = gpt_tokenize(vocab, test.first); if (tokens != test.second){ n_fails++; // print out failure cases fprintf(stderr, "%s : failed test: '%s'\n", __func__, test.first.c_str()); fprintf(stderr, "%s : tokens in hf: ", __func__); for (const auto & t : test.second) { fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t); } fprintf(stderr, "\n"); fprintf(stderr, "%s : tokens in ggml: ", __func__); for (const auto & t : tokens) { fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t); } fprintf(stderr, "\n"); } } fprintf(stderr, "%s : %zu tests failed out of %zu tests.\n", __func__, n_fails, tests.size()); } gpt_vocab::id gpt_sample_top_k_top_p( const gpt_vocab & vocab, const float * logits, int top_k, double top_p, double temp, std::mt19937 & rng) { int n_logits = vocab.id_to_token.size(); std::vector> logits_id; logits_id.reserve(n_logits); { const double scale = 1.0/temp; for (int i = 0; i < n_logits; ++i) { logits_id.push_back(std::make_pair(logits[i]*scale, i)); } } // find the top K tokens std::partial_sort( logits_id.begin(), logits_id.begin() + top_k, logits_id.end(), [](const std::pair & a, const std::pair & b) { return a.first > b.first; }); logits_id.resize(top_k); double maxl = -INFINITY; for (const auto & kv : logits_id) { maxl = std::max(maxl, kv.first); } // compute probs for the top K tokens std::vector probs; probs.reserve(logits_id.size()); double sum = 0.0; for (const auto & kv : logits_id) { double p = exp(kv.first - maxl); probs.push_back(p); sum += p; } // normalize the probs for (auto & p : probs) { p /= sum; } if (top_p < 1.0f) { double cumsum = 0.0f; for (int i = 0; i < top_k; i++) { cumsum += probs[i]; if (cumsum >= top_p) { top_k = i + 1; probs.resize(top_k); logits_id.resize(top_k); break; } } cumsum = 1.0/cumsum; for (int i = 0; i < (int) probs.size(); i++) { probs[i] *= cumsum; } } //printf("\n"); //for (int i = 0; i < (int) probs.size(); i++) { // printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); //} //exit(0); std::discrete_distribution<> dist(probs.begin(), probs.end()); int idx = dist(rng); return logits_id[idx].second; } // evaluate the transformer // // - model: the model // - n_threads: number of threads to use // - n_past: the context size so far // - embd_inp: the embeddings of the tokens in the context // - embd_w: the predicted logits for the next token // bool gpt_eval( const gpt_model & model, const int n_threads, const int n_past, const std::vector & embd_inp, std::vector & embd_w, size_t & mem_per_token) { const int N = embd_inp.size(); const auto & hparams = model.hparams; const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; const int n_ctx = hparams.n_ctx; const int n_head = hparams.n_head; const int n_vocab = hparams.n_vocab; static size_t buf_size = 256u*1024*1024; static void * buf = malloc(buf_size); if (mem_per_token > 0 && mem_per_token*N > buf_size) { const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); // reallocate buf_size = buf_size_new; buf = realloc(buf, buf_size); if (buf == nullptr) { fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); return false; } } struct ggml_init_params params = { /*.mem_size =*/ buf_size, /*.mem_buffer =*/ buf, /*.no_alloc =*/ false, }; struct ggml_context * ctx0 = ggml_init(params); struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); for (int i = 0; i < N; ++i) { ((int32_t *) position->data)[i] = n_past + i; } // wte + wpe struct ggml_tensor * inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.wte, embd), ggml_get_rows(ctx0, model.wpe, position)); for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * cur; // norm { // [ 768, N] cur = ggml_norm(ctx0, inpL, hparams.eps); // cur = ln_1_g*cur + ln_1_b // [ 768, N] cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_1_g, cur), cur), ggml_repeat(ctx0, model.layers[il].ln_1_b, cur)); } // attn // [2304, 768] - model.layers[il].c_attn_attn_w // [2304, 1] - model.layers[il].c_attn_attn_b // [ 768, N] - cur (in) // [2304, N] - cur (out) // // cur = attn_w*cur + attn_b // [2304, N] { cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_attn_w, cur); cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].c_attn_attn_b, cur), cur); } // self-attention { struct ggml_tensor * Qcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd); struct ggml_tensor * Kcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd); struct ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd); // store key and value to memory if (N >= 1) { struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) // [64, N, 12] struct ggml_tensor * Q = ggml_permute(ctx0, ggml_cpy(ctx0, Qcur, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)), 0, 2, 1, 3); // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) // [64, n_past + N, 12] struct ggml_tensor * K = ggml_permute(ctx0, ggml_reshape_3d(ctx0, ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), n_embd/n_head, n_head, n_past + N), 0, 2, 1, 3); // GG: flash attention //struct ggml_tensor * V = // ggml_cpy(ctx0, // ggml_permute(ctx0, // ggml_reshape_3d(ctx0, // ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), // n_embd/n_head, n_head, n_past + N), // 1, 2, 0, 3), // ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head)); //struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true); // K * Q // [n_past + N, N, 12] struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); // KQ_scaled = KQ / sqrt(n_embd/n_head) // [n_past + N, N, 12] struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f/sqrt(float(n_embd)/n_head)); // KQ_masked = mask_past(KQ_scaled) // [n_past + N, N, 12] struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); // KQ = soft_max(KQ_masked) // [n_past + N, N, 12] struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() // [n_past + N, 64, 12] struct ggml_tensor * V_trans = ggml_cpy(ctx0, ggml_permute(ctx0, ggml_reshape_3d(ctx0, ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd), n_embd/n_head, n_head, n_past + N), 1, 2, 0, 3), ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head)); // KQV = transpose(V) * KQ_soft_max // [64, N, 12] struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); // KQV_merged = KQV.permute(0, 2, 1, 3) // [64, 12, N] struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); // cur = KQV_merged.contiguous().view(n_embd, N) // [768, N] cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); } // projection // [ 768, 768] - model.layers[il].c_attn_proj_w // [ 768, 1] - model.layers[il].c_attn_proj_b // [ 768, N] - cur (in) // [ 768, N] - cur (out) // // cur = proj_w*cur + proj_b // [768, N] { cur = ggml_mul_mat(ctx0, model.layers[il].c_attn_proj_w, cur); cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].c_attn_proj_b, cur), cur); } // add the input cur = ggml_add(ctx0, cur, inpL); struct ggml_tensor * inpFF = cur; // feed-forward network { // norm { cur = ggml_norm(ctx0, inpFF, hparams.eps); // cur = ln_2_g*cur + ln_2_b // [ 768, N] cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ln_2_g, cur), cur), ggml_repeat(ctx0, model.layers[il].ln_2_b, cur)); } // fully connected // [3072, 768] - model.layers[il].c_mlp_fc_w // [3072, 1] - model.layers[il].c_mlp_fc_b // [ 768, N] - cur (in) // [3072, N] - cur (out) // // cur = fc_w*cur + fc_b // [3072, N] cur = ggml_mul_mat(ctx0, model.layers[il].c_mlp_fc_w, cur); cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].c_mlp_fc_b, cur), cur); // GELU activation // [3072, N] cur = ggml_gelu(ctx0, cur); // projection // [ 768, 3072] - model.layers[il].c_mlp_proj_w // [ 768, 1] - model.layers[il].c_mlp_proj_b // [3072, N] - cur (in) // [ 768, N] - cur (out) // // cur = proj_w*cur + proj_b // [768, N] cur = ggml_mul_mat(ctx0, model.layers[il].c_mlp_proj_w, cur); cur = ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].c_mlp_proj_b, cur), cur); } // input for next layer inpL = ggml_add(ctx0, cur, inpFF); } // norm { // [ 768, N] inpL = ggml_norm(ctx0, inpL, hparams.eps); // inpL = ln_f_g*inpL + ln_f_b // [ 768, N] inpL = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, model.ln_f_g, inpL), inpL), ggml_repeat(ctx0, model.ln_f_b, inpL)); } // inpL = WTE * inpL // [ 768, 50257] - model.lm_head // [ 768, N] - inpL inpL = ggml_mul_mat(ctx0, model.lm_head, inpL); // logits -> probs //inpL = ggml_soft_max_inplace(ctx0, inpL); // run the computation ggml_build_forward_expand(gf, inpL); ggml_graph_compute_with_ctx(ctx0, gf, n_threads); //if (n_past%100 == 0) { // ggml_graph_print (&gf); // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); //} //embd_w.resize(n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); // return result just for the last token embd_w.resize(n_vocab); memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); if (mem_per_token == 0) { mem_per_token = ggml_used_mem(ctx0)/N; } //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); ggml_free(ctx0); return true; } void gpt_print_usage(int argc, char ** argv, const gpt_hparams & params) { fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help show this help message and exit\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); fprintf(stderr, " prompt to start generation with (default: random)\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " load prompt from a file\n"); fprintf(stderr, " -tt TOKEN_TEST, --token_test TOKEN_TEST\n"); fprintf(stderr, " test tokenization\n"); fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n); fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -c N, --context N context / KV cache size (default: %d)\n", params.n_ctx); fprintf(stderr, " --ignore-eos ignore EOS token during generation\n"); fprintf(stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n", params.n_gpu_layers); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, "\n"); } // Function to check if the next argument exists static std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_hparams& params) { if (i + 1 < argc && argv[i + 1][0] != '-') { return argv[++i]; } else { fprintf(stderr, "error: %s requires one argument.\n", flag.c_str()); gpt_print_usage(argc, argv, params); exit(0); } } bool gpt_params_parse(int argc, char ** argv, gpt_hparams & params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; if (arg == "-s" || arg == "--seed") { params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-p" || arg == "--prompt") { params.prompt = get_next_arg(i, argc, argv, arg, params); } else if (arg == "-n" || arg == "--n_predict") { params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-np" || arg == "--n_parallel") { params.n_parallel = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--top_k") { params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--top_p") { params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--temp") { params.temp = std::stof(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--repeat-last-n") { params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--repeat-penalty") { params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-b" || arg == "--batch_size") { params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-c" || arg == "--context") { params.n_ctx= std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "--ignore-eos") { params.ignore_eos = true; } else if (arg == "-m" || arg == "--model") { params.model = get_next_arg(i, argc, argv, arg, params); } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; } else if (arg == "-ip" || arg == "--interactive-port") { params.interactive = true; params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params)); } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); } else if (arg == "-f" || arg == "--file") { get_next_arg(i, argc, argv, arg, params); std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); break; } std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); if (params.prompt.back() == '\n') { params.prompt.pop_back(); } } else if (arg == "-tt" || arg == "--token_test") { params.token_test = get_next_arg(i, argc, argv, arg, params); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, params); exit(0); } } return true; } std::string gpt_random_prompt(std::mt19937 & rng) { const int r = rng() % 10; switch (r) { case 0: return "So"; case 1: return "Once upon a time"; case 2: return "When"; case 3: return "The"; case 4: return "After"; case 5: return "If"; case 6: return "import"; case 7: return "He"; case 8: return "She"; case 9: return "They"; } return "The"; } int main(int argc, char ** argv) { ggml_time_init(); const int64_t t_main_start_us = ggml_time_us(); gpt_hparams params; if (gpt_params_parse(argc, argv, params) == false) { return 1; } if (params.seed < 0) { params.seed = time(NULL); } printf("%s: seed = %d\n", __func__, params.seed); std::mt19937 rng(params.seed); if (params.prompt.empty()) { params.prompt = gpt_random_prompt(rng); } int64_t t_load_us = 0; gpt_vocab vocab; gpt_model model; // load the model { const int64_t t_start_us = ggml_time_us(); if (!gpt_model_load(params.model, model, vocab)) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return 1; } t_load_us = ggml_time_us() - t_start_us; test_gpt_tokenizer(vocab, params.token_test); } while(true) { int n_past = 0; int64_t t_sample_us = 0; int64_t t_predict_us = 0; std::vector logits; // tokenize the prompt std::vector embd_inp = ::gpt_tokenize(vocab, params.prompt); params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); printf("%s: number of tokens in prompt = %zu, first 8 tokens: ", __func__, embd_inp.size()); for (int i = 0; i < std::min(8, (int) embd_inp.size()); i++) { printf("%d ", embd_inp[i]); } printf("\n\n"); // submit the input prompt token-by-token // this reduces the memory usage during inference, at the cost of a bit of speed at the beginning std::vector embd; // determine the required inference memory per token: size_t mem_per_token = 0; gpt_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); if (!gpt_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { printf("Failed to predict\n"); return 1; } t_predict_us += ggml_time_us() - t_start_us; } n_past += embd.size(); embd.clear(); if (i >= embd_inp.size()) { // sample next token const int top_k = params.top_k; const float top_p = params.top_p; const float temp = params.temp; const int n_vocab = model.hparams.n_vocab; gpt_vocab::id id = 0; { const int64_t t_start_sample_us = ggml_time_us(); id = gpt_sample_top_k_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_k, top_p, temp, rng); t_sample_us += ggml_time_us() - t_start_sample_us; } // add it to the context embd.push_back(id); } else { // if here, it means we are still processing the input prompt for (size_t k = i; k < embd_inp.size(); k++) { embd.push_back(embd_inp[k]); if (int32_t(embd.size()) >= params.n_batch) { break; } } i += embd.size() - 1; } // display text for (auto id : embd) { printf("%s", vocab.id_to_token[id].c_str()); } fflush(stdout); // end of text token if (embd.back() == 50256) { // report timing { const int64_t t_main_end_us = ggml_time_us(); printf("\n\n"); printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); } break; } } } ggml_free(model.ctx_w); return 0; }