【深度学习入门:基于Python的理论与实现图书电子版及各章代码】解决手写数字识别MNIST数据集无法访问问题

仓库:

deep_learning_from_scratch_python

仓库中有各章代码和电子版图书。

问题描述:

在学习《深度学习入门:基于Python的理论与实现》高清中文版时,参考了GitHub代码:https://github.com/ZhangXinNan/deep_learning_from_scratch的代码。

 但代码需要访问数据集 MNIST handwritten digit database,而这个官方的数据集无法直接访问,提示以下错误,所以运行代码也会出现问题。

解决方法:

找到了 备份数据集:mnist数据集,从中下载了四个文件:

将四个文件放在了https://github.com/ZhangXinNan/deep_learning_from_scratch的dataset/下。

现在有了离线的数据集文件,但是https://github.com/ZhangXinNan/deep_learning_from_scratch/dataset/mnist.py文件中是在线的下载方案,mnist.py的代码如下:

# coding: utf-8
try:
    import urllib.request
except ImportError:
    raise ImportError('You should use Python 3.x')
import os.path
import gzip
import pickle
import os
import numpy as np


url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {
    'train_img':'train-images-idx3-ubyte.gz',
    'train_label':'train-labels-idx1-ubyte.gz',
    'test_img':'t10k-images-idx3-ubyte.gz',
    'test_label':'t10k-labels-idx1-ubyte.gz'
}

dataset_dir = os.path.dirname(os.path.abspath(__file__))
save_file = dataset_dir + "/mnist.pkl"

train_num = 60000
test_num = 10000
img_dim = (1, 28, 28)
img_size = 784


def _download(file_name):
    file_path = dataset_dir + "/" + file_name
    
    if os.path.exists(file_path):
        return

    print("Downloading " + file_name + " ... ")
    urllib.request.urlretrieve(url_base + file_name, file_path)
    print("Done")
    
def download_mnist():
    for v in key_file.values():
       _download(v)
        
def _load_label(file_name):
    file_path = dataset_dir + "/" + file_name
    
    print("Converting " + file_name + " to NumPy Array ...")
    with gzip.open(file_path, 'rb') as f:
            labels = np.frombuffer(f.read(), np.uint8, offset=8)
    print("Done")
    
    return labels

def _load_img(file_name):
    file_path = dataset_dir + "/" + file_name
    
    print("Converting " + file_name + " to NumPy Array ...")    
    with gzip.open(file_path, 'rb') as f:
            data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape(-1, img_size)
    print("Done")
    
    return data
    
def _convert_numpy():
    dataset = {}
    dataset['train_img'] =  _load_img(key_file['train_img'])
    dataset['train_label'] = _load_label(key_file['train_label'])    
    dataset['test_img'] = _load_img(key_file['test_img'])
    dataset['test_label'] = _load_label(key_file['test_label'])
    
    return dataset

def init_mnist():
    download_mnist()
    dataset = _convert_numpy()
    print("Creating pickle file ...")
    with open(save_file, 'wb') as f:
        pickle.dump(dataset, f, -1)
    print("Done!")

def _change_one_hot_label(X):
    T = np.zeros((X.size, 10))
    for idx, row in enumerate(T):
        row[X[idx]] = 1
        
    return T
    

def load_mnist(normalize=True, flatten=True, one_hot_label=False):
    """读入MNIST数据集
    
    Parameters
    ----------
    normalize : 将图像的像素值正规化为0.0~1.0
    one_hot_label : 
        one_hot_label为True的情况下,标签作为one-hot数组返回
        one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
    flatten : 是否将图像展开为一维数组
    
    Returns
    -------
    (训练图像, 训练标签), (测试图像, 测试标签)
    """
    if not os.path.exists(save_file):
        init_mnist()
        
    with open(save_file, 'rb') as f:
        dataset = pickle.load(f)
    
    if normalize:
        for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].astype(np.float32)
            dataset[key] /= 255.0
            
    if one_hot_label:
        dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
        dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
    
    if not flatten:
         for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

    return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) 


if __name__ == '__main__':
    init_mnist()

修改mnist.py中的代码:

# coding: utf-8
import os.path
import gzip
import pickle
import os
import numpy as np

# Specify your local dataset directory
#dataset_dir = 'D:\\Workspace\\LearningDeepLearning\\deep_learning_from_scratch\\dataset'
dataset_dir = './dataset'
save_file = os.path.join(dataset_dir, "mnist.pkl")

key_file = {
    'train_img': 'train-images-idx3-ubyte.gz',
    'train_label': 'train-labels-idx1-ubyte.gz',
    'test_img': 't10k-images-idx3-ubyte.gz',
    'test_label': 't10k-labels-idx1-ubyte.gz'
}

img_size = 784  # Image size is 28x28

def _load_label(file_name):
    """Load label data from local file"""
    file_path = os.path.join(dataset_dir, file_name)
    print("Converting " + file_name + " to NumPy Array ...")
    with gzip.open(file_path, 'rb') as f:
        labels = np.frombuffer(f.read(), np.uint8, offset=8)
    print("Done")
    return labels

def _load_img(file_name):
    """Load image data from local file"""
    file_path = os.path.join(dataset_dir, file_name)
    print("Converting " + file_name + " to NumPy Array ...")
    with gzip.open(file_path, 'rb') as f:
        data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape(-1, img_size)
    print("Done")
    return data

def _convert_numpy():
    """Convert data to NumPy arrays and package them"""
    dataset = {
        'train_img': _load_img(key_file['train_img']),
        'train_label': _load_label(key_file['train_label']),
        'test_img': _load_img(key_file['test_img']),
        'test_label': _load_label(key_file['test_label'])
    }
    return dataset

def _change_one_hot_label(X):
    """Convert labels to one-hot encoding"""
    T = np.zeros((X.size, 10))
    for idx, row in enumerate(T):
        row[X[idx]] = 1
    return T

def load_mnist(normalize=True, flatten=True, one_hot_label=False):
    """Load and preprocess the MNIST dataset"""
    if not os.path.exists(save_file):
        init_mnist()
    
    with open(save_file, 'rb') as f:
        dataset = pickle.load(f)
    
    if normalize:
        for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].astype(np.float32)
            dataset[key] /= 255.0
            
    if one_hot_label:
        dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
        dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
    
    if not flatten:
        for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

    return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label'])

def init_mnist():
    """Initialize the MNIST dataset from local files"""
    dataset = _convert_numpy()
    print("Creating pickle file ...")
    with open(save_file, 'wb') as f:
        pickle.dump(dataset, f, -1)
    print("Done!")

if __name__ == '__main__':
    init_mnist()

为了测试修改后的代码能否对mnist数据集进行正确的处理,运行dataset/mnist.py

PS D:\Workspace\LearningDeepLearning\deep_learning_from_scratch\dataset> python .\mnist.py 
Converting train-images-idx3-ubyte.gz to NumPy Array ...
Done
Converting train-labels-idx1-ubyte.gz to NumPy Array ...
Done
Converting t10k-images-idx3-ubyte.gz to NumPy Array ...
Done
Converting t10k-labels-idx1-ubyte.gz to NumPy Array ...
Done
Creating pickle file ...
Done!

运行成功!

为了进一步验证,运行其他代码中需要访问dataset并调用dataset/mnist.py中函数的程序,例如,运行ch03/mnist_show.py

PS D:\Workspace\LearningDeepLearning\deep_learning_from_scratch\ch03> python .\mnist_show.py
5
(784,)
(28, 28)

运行成功!并显示了MNIST数据集第一张照片:

Reference:

books/ai/《深度学习入门:基于Python的理论与实现》高清中文版.pdf at master · chapin666/books · GitHub

GitHub – ZhangXinNan/deep_learning_from_scratch: 《深度学习入门——基于Python的理论与实现》作者:斋藤康毅 译者:陆宇杰

MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

GitHub – cvdfoundation/mnist: The MNIST database of handwritten digits is one of the most popular image recognition datasets. It contains 60k examples for training and 10k examples for testing.

作者:RL^2

物联沃分享整理
物联沃-IOTWORD物联网 » 【深度学习入门:基于Python的理论与实现图书电子版及各章代码】解决手写数字识别MNIST数据集无法访问问题

发表回复