官术网_书友最值得收藏!

  • Deep Learning with Theano
  • Christopher Bourez
  • 409字
  • 2021-07-15 17:16:59

The MNIST dataset

The Modified National Institute of Standards and Technology (MNIST) dataset is a very well-known dataset of handwritten digits {0,1,2,3,4,5,6,7,8,9} used to train and test classification models.

A classification model is a model that predicts the probabilities of observing a class, given an input.

Training is the task of learning the parameters to fit the model to the data as well as we can so that for any input image, the correct label is predicted. For this training task, the MNIST dataset contains 60,000 images with a target label (a number between 0 and 9) for each example.

To validate that the training is efficient and to decide when to stop the training, we usually split the training dataset into two datasets: 80% to 90% of the images are used for training, while the remaining 10-20% of images will not be presented to the algorithm for training but to validate that the model generalizes well on unobserved data.

There is a separate dataset that the algorithm should never see during training, named the test set, which consists of 10,000 images in the MNIST dataset.

In the MNIST dataset, the input data of each example is a 28x28 normalized monochrome image and a label, represented as a simple integer between 0 and 9 for each example. Let's display some of them:

  1. First, download a pre-packaged version of the dataset that makes it easier to load from Python:
    wget http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz -P /sharedfiles
  2. Then load the data into a Python session:
    import pickle, gzip
    with gzip.open("/sharedfiles/mnist.pkl.gz", 'rb') as f:
       train_set, valid_set, test_set = pickle.load(f)

    For Python3, we need pickle.load(f, encoding='latin1') due to the way it was serialized.

    train_set[0].shape
    (50000, 784)
    
    train_set[1].shape
    (50000,)
    
    import matplotlib
    
    import numpy 
    
    import matplotlib.pyplot as plt
    
    plt.rcParams['figure.figsize'] = (10, 10)
    
    plt.rcParams['image.cmap'] = 'gray'
    
    for i in range(9):
        plt.subplot(1,10,i+1)
        plt.imshow(train_set[0][i].reshape(28,28))
        plt.axis('off')
        plt.title(str(train_set[1][i]))
    
    plt.show()

The first nine samples from the dataset are displayed with the corresponding label (the ground truth, that is, the correct answer expected by the classification algorithm) on top of them:

In order to avoid too many transfers to the GPU, and since the complete dataset is small enough to fit in the memory of the GPU, we usually place the full training set in shared variables:

import theano
train_set_x = theano.shared(numpy.asarray(train_set[0], dtype=theano.config.floatX))
train_set_y = theano.shared(numpy.asarray(train_set[1], dtype='int32'))

Avoiding these data transfers allows us to train faster on the GPU, despite recent GPU and fast PCIe connections.

More information on the dataset is available at http://yann.lecun.com/exdb/mnist/.

主站蜘蛛池模板: 牙克石市| 印江| 平安县| 平泉县| 赤水市| 孙吴县| 冕宁县| 扎兰屯市| 得荣县| 通许县| 安西县| 石柱| 镇江市| 当雄县| 江津市| 钦州市| 肃宁县| 黑河市| 井冈山市| 余庆县| 西昌市| 甘孜县| 隆昌县| 蒙自县| 姚安县| 柯坪县| 辽阳市| 清苑县| 循化| 四川省| 阿拉善盟| 绍兴县| 广昌县| 临桂县| 英山县| 古田县| 莆田市| 云阳县| 营口市| 安吉县| 宁都县|