Skip to content

Commit

Permalink
add test program
Browse files Browse the repository at this point in the history
  • Loading branch information
bbawj committed Jan 22, 2024
1 parent a4c82eb commit b3bd80c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 7 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@ server:
client:
west build --build-dir="./zflclient/out" -b qemu_x86 -o="--jobs=8" "./zflclient" -- -G"Unix Makefiles" -DEXTRA_CONF_FILE=overlay-e1000.conf

test: test.c
$(CC) $(CFLAGS) "test.c" -o "test" -lm -g
12 changes: 5 additions & 7 deletions common.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ int deserialize_training_data(char *data, size_t size, Payload *p) {
} else if (strncmp(key->text, weights_key, strlen(weights_key)) == 0) {
p->weights = key->child->child;
} else {
printf("Invalid key %s\n", key->text);
return -1;
// printf("Invalid key %s\n", key->text);
// return -1;
}
key = key->next;
}
Expand All @@ -131,11 +131,9 @@ int parse_weights_json(Token *weights, float **initial_weights,

if (!accum) {
for (int i = 0; i < ARCH_COUNT - 1; ++i) {
initial_weights[i] =
calloc(1, sizeof(float) * ARCH[i] * ARCH[i + 1]);
assert(initial_weights[i] != NULL);
initial_bias[i] = calloc(1, sizeof(float) * ARCH[i + 1]);
assert(initial_bias[i] != NULL);
memset(initial_weights[i], 0,
sizeof(float) * ARCH[i] * ARCH[i + 1]);
memset(initial_bias[i], 0, sizeof(float) * ARCH[i + 1]);
}
}

Expand Down
16 changes: 16 additions & 0 deletions sb.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define SB_H

#include <assert.h>
#include <errno.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
Expand All @@ -22,6 +23,7 @@ int sb_append(StringBuilder *sb, const char *data, size_t len);
int sb_appendf(StringBuilder *sb, const char *format, ...);
char *sb_string(StringBuilder *sb);
void sb_free(StringBuilder *sb);
void sb_open_file(StringBuilder *sb, char *file_path);

#ifdef SB_IMPLEMENTATION

Expand Down Expand Up @@ -89,6 +91,20 @@ void sb_free(StringBuilder *sb) {
sb->cap = 0;
}

void sb_open_file(StringBuilder *sb, char *file_path) {
FILE *f = fopen(file_path, "r");
assert(f);
char buf[1024];
while (!feof(f)) {
int ret = fread(buf, sizeof(*buf), sizeof(buf), f);
sb_append(sb, buf, ret);
if (ferror(f) != 0) {
printf("ERROR: failed to read file %s because: %s\n", file_path,
strerror(errno));
}
}
}

#endif // SB_IMPLEMENTATION

#endif // ! SB_H
55 changes: 55 additions & 0 deletions test.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "stdlib.h"
#include <time.h>
#define IS_SERVER
#define NN_BACKPROP_TRADITIONAL
#define NN_IMPLEMENTATION
#define CSON_IMPLEMENTATION
#define SB_IMPLEMENTATION
#define COMMON_IMPLEMENTATION
#define TRAIN_IMPLEMENTATION
#include "./zflclient/src/train.h"
#include "errno.h"

Trainer TRAINER = {0};
int NUM_EPOCHS = 100;
int BATCH_SIZE = 100;

int train_labels_size = 60000;
int train_images_size = 60000;

int main(void) {
srand(time(0));
TRAINER.n_images = 60000;
StringBuilder images = {0};
sb_init(&images, 1024);
printf("INFO: opening test image!\n");
sb_open_file(&images, "./data/test-images");
StringBuilder labels = {0};
sb_init(&labels, 1024);

printf("INFO: opening test label!\n");

sb_open_file(&labels, "./data/test-labels");

printf("INFO: image: %zu, labels: %zu\n", images.size, labels.size);
assert(images.size / IMG_SIZE == labels.size);

printf("INFO: opening train image!\n");
StringBuilder train_images = {0};
sb_init(&train_images, 1024);
sb_open_file(&train_images, "./data/train-images-main");
StringBuilder train_labels = {0};
sb_init(&train_labels, 1024);
printf("INFO: opening train label!\n");
sb_open_file(&train_labels, "./data/train-labels-main");

Mat test_set = init_train_set(images.data, labels.data, 10000);
TRAINER.samples =
init_train_set(train_images.data, train_labels.data, 60000);

init_nn(&TRAINER.model, NULL, NULL);
train(&TRAINER);

float acc = accuracy(TRAINER.model, test_set);
printf("INFO: final model accuracy against test set is %f\n", acc);
}

0 comments on commit b3bd80c

Please sign in to comment.