| | #include "data_loader.hpp" |
| | #include <fstream> |
| | #include <stdexcept> |
| | #include <iostream> |
| | #include "optical_model.hpp" |
| |
|
| | FashionMNISTSet load_fashion_mnist_data(const std::string& data_dir, bool is_train) { |
| | FashionMNISTSet set; |
| | const std::string prefix = is_train ? "train" : "test"; |
| | const std::string images_path = data_dir + "/" + prefix + "-images.bin"; |
| | const std::string labels_path = data_dir + "/" + prefix + "-labels.bin"; |
| |
|
| | |
| | std::ifstream f_images(images_path, std::ios::binary); |
| | if (!f_images) throw std::runtime_error("Cannot open: " + images_path); |
| | f_images.seekg(0, std::ios::end); |
| | size_t num_bytes = f_images.tellg(); |
| | f_images.seekg(0, std::ios::beg); |
| | set.N = num_bytes / (IMG_SIZE * sizeof(float)); |
| | set.images.resize(set.N * IMG_SIZE); |
| | f_images.read(reinterpret_cast<char*>(set.images.data()), num_bytes); |
| |
|
| | |
| | std::ifstream f_labels(labels_path, std::ios::binary); |
| | if (!f_labels) throw std::runtime_error("Cannot open: " + labels_path); |
| | f_labels.seekg(0, std::ios::end); |
| | num_bytes = f_labels.tellg(); |
| | f_labels.seekg(0, std::ios::beg); |
| | if (set.N != num_bytes) throw std::runtime_error("Image and label count mismatch!"); |
| | set.labels.resize(set.N); |
| | f_labels.read(reinterpret_cast<char*>(set.labels.data()), num_bytes); |
| |
|
| | std::cout << "[INFO] Loaded " << set.N << " " << prefix << " samples.\n"; |
| | return set; |
| | } |