Added an MNIST loader

This commit is contained in:
Davis King 2015-11-02 17:29:04 -05:00
parent 074d4b6e15
commit e98bade3ed
6 changed files with 195 additions and 1 deletions

View File

@ -132,7 +132,8 @@ if (NOT TARGET dlib)
md5/md5_kernel_1.cpp
tokenizer/tokenizer_kernel_1.cpp
unicode/unicode.cpp
data_io/image_dataset_metadata.cpp)
data_io/image_dataset_metadata.cpp
data_io/mnist.cpp)
if (DLIB_ISO_CPP_ONLY)
add_library(dlib STATIC ${source_files} )

View File

@ -16,6 +16,7 @@
#include "../tokenizer/tokenizer_kernel_1.cpp"
#include "../unicode/unicode.cpp"
#include "../data_io/image_dataset_metadata.cpp"
#include "../data_io/mnist.cpp"
#ifndef DLIB_ISO_CPP_ONLY
// Code that depends on OS specific APIs

View File

@ -5,6 +5,7 @@
#include "data_io/libsvm_io.h"
#include "data_io/image_dataset_metadata.h"
#include "data_io/mnist.h"
#ifndef DLIB_ISO_CPP_ONLY
#include "data_io/load_image_dataset.h"

113
dlib/data_io/mnist.cpp Normal file
View File

@ -0,0 +1,113 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MNIST_CPp_
#define DLIB_MNIST_CPp_
#include "mnist.h"
#include <fstream>
#include "../byte_orderer.h"
// ----------------------------------------------------------------------------------------
namespace dlib
{
void load_mnist_dataset (
const std::string& folder_name,
std::vector<matrix<unsigned char> >& training_images,
std::vector<int>& training_labels,
std::vector<matrix<unsigned char> >& testing_images,
std::vector<int>& testing_labels
)
{
using namespace std;
ifstream fin1((folder_name+"/train-images-idx3-ubyte").c_str(), ios::binary);
ifstream fin2((folder_name+"/train-labels-idx1-ubyte").c_str(), ios::binary);
ifstream fin3((folder_name+"/t10k-images-idx3-ubyte").c_str(), ios::binary);
ifstream fin4((folder_name+"/t10k-labels-idx1-ubyte").c_str(), ios::binary);
if (!fin1) throw error("Unable to open file train-images-idx3-ubyte");
if (!fin2) throw error("Unable to open file train-labels-idx1-ubyte");
if (!fin3) throw error("Unable to open file t10k-images-idx3-ubyte");
if (!fin4) throw error("Unable to open file t10k-labels-idx1-ubyte");
byte_orderer bo;
// make sure the files have the contents we expect.
uint32_t magic, num, nr, nc, num2, num3, num4;
fin1.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
fin1.read((char*)&num, sizeof(num)); bo.big_to_host(num);
fin1.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr);
fin1.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc);
if (magic != 2051 || num != 60000 || nr != 28 || nc != 28)
throw error("mndist dat files are corrupted.");
fin2.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
fin2.read((char*)&num2, sizeof(num2)); bo.big_to_host(num2);
if (magic != 2049 || num2 != 60000)
throw error("mndist dat files are corrupted.");
fin3.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
fin3.read((char*)&num3, sizeof(num3)); bo.big_to_host(num3);
fin3.read((char*)&nr, sizeof(nr)); bo.big_to_host(nr);
fin3.read((char*)&nc, sizeof(nc)); bo.big_to_host(nc);
if (magic != 2051 || num3 != 10000 || nr != 28 || nc != 28)
throw error("mndist dat files are corrupted.");
fin4.read((char*)&magic, sizeof(magic)); bo.big_to_host(magic);
fin4.read((char*)&num4, sizeof(num4)); bo.big_to_host(num4);
if (magic != 2049 || num4 != 10000)
throw error("mndist dat files are corrupted.");
if (!fin1) throw error("Unable to read train-images-idx3-ubyte");
if (!fin2) throw error("Unable to read train-labels-idx1-ubyte");
if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte");
if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte");
training_images.resize(60000);
training_labels.resize(60000);
testing_images.resize(10000);
testing_labels.resize(10000);
for (size_t i = 0; i < training_images.size(); ++i)
{
training_images[i].set_size(nr,nc);
fin1.read((char*)&training_images[i](0,0), nr*nc);
}
for (size_t i = 0; i < training_labels.size(); ++i)
{
char l;
fin2.read(&l, 1);
training_labels[i] = l;
}
for (size_t i = 0; i < testing_images.size(); ++i)
{
testing_images[i].set_size(nr,nc);
fin3.read((char*)&testing_images[i](0,0), nr*nc);
}
for (size_t i = 0; i < testing_labels.size(); ++i)
{
char l;
fin4.read(&l, 1);
testing_labels[i] = l;
}
if (!fin1) throw error("Unable to read train-images-idx3-ubyte");
if (!fin2) throw error("Unable to read train-labels-idx1-ubyte");
if (!fin3) throw error("Unable to read t10k-images-idx3-ubyte");
if (!fin4) throw error("Unable to read t10k-labels-idx1-ubyte");
if (fin1.get() != EOF) throw error("Unexpected bytes at end of train-images-idx3-ubyte");
if (fin2.get() != EOF) throw error("Unexpected bytes at end of train-labels-idx1-ubyte");
if (fin3.get() != EOF) throw error("Unexpected bytes at end of t10k-images-idx3-ubyte");
if (fin4.get() != EOF) throw error("Unexpected bytes at end of t10k-labels-idx1-ubyte");
}
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_MNIST_CPp_

32
dlib/data_io/mnist.h Normal file
View File

@ -0,0 +1,32 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_MNIST_Hh_
#define DLIB_MNIST_Hh_
#include "mnist_abstract.h"
#include <string>
#include <vector>
#include "../matrix.h"
// ----------------------------------------------------------------------------------------
namespace dlib
{
void load_mnist_dataset (
const std::string& folder_name,
std::vector<matrix<unsigned char> >& training_images,
std::vector<int>& training_labels,
std::vector<matrix<unsigned char> >& testing_images,
std::vector<int>& testing_labels
);
}
// ----------------------------------------------------------------------------------------
#ifdef NO_MAKEFILE
#include "mnist.cpp"
#endif
#endif // DLIB_MNIST_Hh_

View File

@ -0,0 +1,46 @@
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#undef DLIB_MNIST_ABSTRACT_Hh_
#ifdef DLIB_MNIST_ABSTRACT_Hh_
#include <string>
#include <vector>
#include "../matrix.h"
// ----------------------------------------------------------------------------------------
namespace dlib
{
void load_mnist_dataset (
const std::string& folder_name,
std::vector<matrix<unsigned char> >& training_images,
std::vector<int>& training_labels,
std::vector<matrix<unsigned char> >& testing_images,
std::vector<int>& testing_labels
);
/*!
ensures
- Attempts to load the MNIST dataset from the hard drive. This is the dataset
of handwritten digits available from http://yann.lecun.com/exdb/mnist/. In
particular, the 4 files comprising the MNIST dataset should be present in the
folder indicated by folder_name. These four files are:
- train-images-idx3-ubyte
- train-labels-idx1-ubyte
- t10k-images-idx3-ubyte
- t10k-labels-idx1-ubyte
- #training_images == The 60,000 training images from the dataset.
- #training_labels == The labels for the contents of #training_images.
I.e. #training_labels[i] is the label of #training_images[i].
- #testing_images == The 10,000 testing images from the dataset.
- #testing_labels == The labels for the contents of #testing_images.
I.e. #testing_labels[i] is the label of #testing_images[i].
throws
- dlib::error if some problem prevents us from loading the data or the files
can't be found.
!*/
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_MNIST_ABSTRACT_Hh_