// Part 2 in the "Simple introduction to ggml" series.
At the end of Part 1, we learnt how to keep the model weights separate from temporary computation-only tensor variables. This allowed the model weights to stay in memory across multiple predictions (which is the usual behavior of machine learning programs during inference).
Now let's modify that to build a simple Neural Network model using ggml. If you're new to ggml, I recommend reading Part 1 first.
Model and Training
Our model will behave like a logic gate (AND, OR, XOR - depending on its training). We'll use a simple model which has 2 fully-connected layers (2 inputs, 16 hidden nodes, 1 output). This model's design (and its training code) is based on Omkar Prabu's excellent intro to ggml.
We'll train the model by running python train_logic_gate.py --print-weights
, which will train an XOR
gate and print the trained weights (and also write them to a model.sft
file). You can also ask the program to train an AND
or OR
gate instead, by passing in a --gate-type
argument.
The trained weights printed by the program will look something like this:
fc1_weight = { 0.22488207, -0.39456311, ..., 0.07894109, -0.41966945 }
fc1_bias = { -0.35652003, -0.67564911, ..., 1.17234588, 0.77097332 }
fc2_weight = { 0.13858399, -0.20547047, ..., -1.64424217, -0.63815284 }
fc2_bias = { -0.55232018 }
Inference in ggml
Now let's implement this model using ggml, in order to run inference on it.
Define the model
First, we'll define the model as a struct
(for convenience). This model will contain 4 tensors, i.e. a pair of weights and biases for the two fully-connected layers.
struct logic_gate_model {
ggml_tensor* fc1_weight;
ggml_tensor* fc1_bias;
ggml_tensor* fc2_weight;
ggml_tensor* fc2_bias;
ggml_context* params_ctx;
struct model_config {
int32_t n_input = 2;
int32_t n_hidden = 16;
int32_t n_output = 1;
} config;
};
Define the tensor variables required for model weights
Then we'll modify the load_weights()
function to create tensors for the model weights.
model.fc1_weight = ggml_new_tensor_2d(model.params_ctx, GGML_TYPE_F32, model.config.n_input, model.config.n_hidden);
model.fc1_bias = ggml_new_tensor_1d(model.params_ctx, GGML_TYPE_F32, model.config.n_hidden);
model.fc2_weight = ggml_new_tensor_2d(model.params_ctx, GGML_TYPE_F32, model.config.n_hidden, model.config.n_output);
model.fc2_bias = ggml_new_tensor_1d(model.params_ctx, GGML_TYPE_F32, model.config.n_output);
Allocate memory for the model weight tensors, and assign the model data
Next, we'll use the weights printed by the training code, and load that into the model.
std::vector<float> fc1_weight = { 0.22488207, -0.39456311, ..., 0.07894109, -0.41966945 };
std::vector<float> fc1_bias = { -0.35652003, -0.67564911, ..., 1.17234588, 0.77097332 };
std::vector<float> fc2_weight = { 0.13858399, -0.20547047, ..., -1.64424217, -0.63815284 };
std::vector<float> fc2_bias = { -0.55232018 };
ggml_backend_tensor_set(model.fc1_weight, fc1_weight.data(), 0, ggml_nbytes(model.fc1_weight));
ggml_backend_tensor_set(model.fc1_bias, fc1_bias.data(), 0, ggml_nbytes(model.fc1_bias));
ggml_backend_tensor_set(model.fc2_weight, fc2_weight.data(), 0, ggml_nbytes(model.fc2_weight));
ggml_backend_tensor_set(model.fc2_bias, fc2_bias.data(), 0, ggml_nbytes(model.fc2_bias));
Update the computation graph
We'll modify the predict()
function to define an input tensor, and write the series of math operations (mirroring the forward()
function in the corresponding PyTorch model).
struct ggml_tensor* x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.config.n_input);
struct ggml_tensor* fc1 = ggml_add(ctx, ggml_mul_mat(ctx, model.fc1_weight, x), model.fc1_bias); // multiply the weights, and add the bias
struct ggml_tensor* fc1_relu = ggml_relu(ctx, fc1);
struct ggml_tensor* fc2 = ggml_add(ctx, ggml_mul_mat(ctx, model.fc2_weight, fc1_relu), model.fc2_bias);
struct ggml_tensor* result = ggml_hardsigmoid(ctx, fc2);
Load the input data for prediction
This will create a truth table for the inputs: (0, 0), (0, 1), (1, 0), (1, 1)
.
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
std::vector<float> input = {float(i), float(j)};
predict(model, input);
}
}
A complete working example for this is at logic_gate.cpp. It also tells you how to compile (at the top).
Minor refactoring
The code in logic_gate.cpp is getting pretty messy, for a fairly simple model. This will make it challenging to write larger models in the future.
So let's clean up the implementation slightly. We'll move the code related to model weights and model computation into the model struct. This separates the model's logic from the code required for actually running the model.
The model struct now looks like this:
struct logic_gate_model {
ggml_tensor* fc1_weight;
ggml_tensor* fc1_bias;
ggml_tensor* fc2_weight;
ggml_tensor* fc2_bias;
ggml_context* params_ctx;
struct model_config {
int32_t n_input = 2;
int32_t n_hidden = 16;
int32_t n_output = 1;
} config;
logic_gate_model() {
// create a context (for weights)
int num_weight_tensors = 4; // since we store four tensors in the model
params_ctx = ggml_init({
/*.mem_size =*/ ggml_tensor_overhead() * num_weight_tensors,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
});
// Define the tensor variables required for model weights
fc1_weight = ggml_new_tensor_2d(params_ctx, GGML_TYPE_F32, config.n_input, config.n_hidden);
fc1_bias = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, config.n_hidden);
fc2_weight = ggml_new_tensor_2d(params_ctx, GGML_TYPE_F32, config.n_hidden, config.n_output);
fc2_bias = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, config.n_output);
ggml_backend_alloc_ctx_tensors(params_ctx, backend);
}
~logic_gate_model() {
ggml_free(params_ctx);
}
void load_weights() {
std::vector<float> fc1_weight_data = { 0.22488207, -0.39456311, 0.32581645, -0.56285965, 2.41329503, -2.41322660, -0.37499088, 0.08395171, 0.21755114, 0.80772698, 0.25437704, 1.57216692, -0.43496752, 0.22240390, 0.46247596, -0.02229351, 0.32341745, 0.25361675, -0.20483392, 0.26918083, -0.91469419, 1.23764634, 0.15310341, -0.67303509, 1.77088165, 1.77059495, -0.11867817, -0.37374884, 0.79170924, -1.17232382, 0.07894109, -0.41966945 };
std::vector<float> fc1_bias_data = { -0.35652003, -0.67564911, 0.00009615, -0.62946773, 0.27859268, 0.01491952, 0.52390707, -0.47604990, -0.25365347, 0.21269353, 0.00003640, -0.44338676, -1.77084744, 0.82772928, 1.17234588, 0.77097332 };
std::vector<float> fc2_weight_data = { 0.13858399, -0.20547047, 3.41583562, 0.15011564, 0.56532770, 1.40391135, 0.00871399, 0.24152395, -0.39389160, 0.16984159, 1.34791148, -0.12602532, -3.02119160, -0.68023020, -1.64424217, -0.63815284 };
std::vector<float> fc2_bias_data = { -0.55232018 };
ggml_backend_tensor_set(fc1_weight, fc1_weight_data.data(), 0, ggml_nbytes(fc1_weight));
ggml_backend_tensor_set(fc1_bias, fc1_bias_data.data(), 0, ggml_nbytes(fc1_bias));
ggml_backend_tensor_set(fc2_weight, fc2_weight_data.data(), 0, ggml_nbytes(fc2_weight));
ggml_backend_tensor_set(fc2_bias, fc2_bias_data.data(), 0, ggml_nbytes(fc2_bias));
}
ggml_tensor* forward(ggml_context *ctx, ggml_tensor *x) {
ggml_tensor* fc1 = ggml_add(ctx, ggml_mul_mat(ctx, fc1_weight, x), fc1_bias); // multiply the weights, and add the bias
ggml_tensor* fc1_relu = ggml_relu(ctx, fc1);
ggml_tensor* fc2 = ggml_add(ctx, ggml_mul_mat(ctx, fc2_weight, fc1_relu), fc2_bias);
return ggml_hardsigmoid(ctx, fc2);
}
};
A complete working example for this is at logic_gate_refactored.cpp. It also tells you how to compile (at the top).
A note about model weights
As you've noticed, we hardcoded the trained weights in the inference code. This isn't ideal. So we need to write a utility function that loads the weights from the model.sft
file (safetensors format).
I've implemented a very basic safetensors loader at safetensors.hpp. This implementation isn't very efficient for very large models, but it's sufficient for our purposes right now, and is easy to understand.
Let's modify the load_weights()
function. First we'll remove the hardcoded weights. Next, we'll call safetensors::load_from_file()
and assign the tensor data to the corresponding ggml_tensor
in the callback function.
std::unordered_map<std::string, struct ggml_tensor*> tensor_map;
...
// names of the parameters as written by the training code
tensor_map["fc1.weight"] = fc1_weight;
tensor_map["fc1.bias"] = fc1_bias;
tensor_map["fc2.weight"] = fc2_weight;
tensor_map["fc2.bias"] = fc2_bias;
...
auto tensors = tensor_map;
safetensors::load_from_file("model.sft", [&tensors](const std::string& key, const std::string& dtype, const std::vector<uint64_t>& shape, const std::vector<uint8_t>& tensor_data) {
std::cout<<"Read tensor: "<<key<<", size: "<<tensor_data.size()<<" bytes"<<std::endl;
auto it = tensors.find(key);
if (it != tensors.end()) {
ggml_tensor* tensor = it->second;
ggml_backend_tensor_set(tensor, tensor_data.data(), 0, ggml_nbytes(tensor));
} else {
std::cout<<"Unknown key: "<<key<<std::endl;
}
});
A complete working example for this is at logic_gate_with_weights_file.cpp. It also tells you how to compile (at the top).