批量读取Cifar10数据集

解码

我正在尝试读取CIFAR10数据集,该数据集是从https://www.cs.toronto.edu/~kriz/cifar.html >分批给出的我正在尝试使用泡菜将其放在数据框中,并读取其中的“数据”部分。但是我遇到了这个错误。

KeyError                                  Traceback (most recent call last)
<ipython-input-24-8758b7a31925> in <module>()
----> 1 unpickle('datasets/cifar-10-batches-py/test_batch')

<ipython-input-23-04002b89d842> in unpickle(file)
      3     fo = open(file, 'rb')
      4     dict = pickle.load(fo, encoding ='bytes')
----> 5     X = dict['data']
      6     fo.close()
      7     return dict

KeyError:“数据”。

我正在使用ipython,这是我的代码:

def unpickle(file):

 fo = open(file, 'rb')
 dict = pickle.load(fo, encoding ='bytes')
 X = dict['data']
 fo.close()
 return dict

unpickle('datasets/cifar-10-batches-py/test_batch')
索海卜·安瓦尔(Sohaib Anwaar)

您可以通过下面给出的代码读取cifar 10数据集,但请确保已给批处理所在的写入目录

import tensorflow as tf
import pandas as pd
import numpy as np
import math
import timeit
import matplotlib.pyplot as plt
from six.moves import cPickle as pickle
import os
import platform
from subprocess import check_output
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

%matplotlib inline


img_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)
def load_pickle(f):
    version = platform.python_version_tuple()
    if version[0] == '2':
        return  pickle.load(f)
    elif version[0] == '3':
        return  pickle.load(f, encoding='latin1')
    raise ValueError("invalid python version: {}".format(version))

def load_CIFAR_batch(filename):
    """ load single batch of cifar """
    with open(filename, 'rb') as f:
        datadict = load_pickle(f)
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000,3072)
        Y = np.array(Y)
        return X, Y

def load_CIFAR10(ROOT):
    """ load all of cifar """
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte
def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):
    # Load the raw CIFAR-10 data
    cifar10_dir = '../input/cifar-10-batches-py/'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

    # Subsample the data
    mask = range(num_training, num_training + num_validation)
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = range(num_training)
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = range(num_test)
    X_test = X_test[mask]
    y_test = y_test[mask]

    x_train = X_train.astype('float32')
    x_test = X_test.astype('float32')

    x_train /= 255
    x_test /= 255

    return x_train, y_train, X_val, y_val, x_test, y_test


# Invoke the above function to get our data.
x_train, y_train, x_val, y_val, x_test, y_test = get_CIFAR10_data()


print('Train data shape: ', x_train.shape)
print('Train labels shape: ', y_train.shape)
print('Validation data shape: ', x_val.shape)
print('Validation labels shape: ', y_val.shape)
print('Test data shape: ', x_test.shape)
print('Test labels shape: ', y_test.shape)

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

Cifar10数据集:从类中读取一定数量的图像

来自分类Dev

在PyTorch中使用预训练的ResNet50解决CIFAR10数据集的问题

来自分类Dev

尝试在计算机上使用“ torchvision.datasets”下载CIFAR10数据集

来自分类Dev

使用PyTorch将Traininng数据集分为CIFAR10的训练和验证集后,如何增加数据?

来自分类Dev

我使用CNN模型创建了CIFAR10数据集学习模型。为什么会出现错误?

来自分类Dev

加载CIFAR10训练数据时出现内存错误

来自分类Dev

CIFAR10数据加载器采样器拆分

来自分类Dev

导入自己的数据,例如MNIST或CIFAR10 load_data()

来自分类Dev

尝试使用我自己的数据而不是 cifar10 生成 Keras 模型

来自分类Dev

如何创建类似于cifar-10的数据集

来自分类Dev

在F#中导入CIFAR-10数据集

来自分类Dev

tensorflow cifar10从检查点文件恢复训练

来自分类Dev

在Cifar10上更改TensorFlow中的线程数

来自分类Dev

如何测试Tensorflow Cifar10 CNN教程模型

来自分类Dev

使用Tensrorflow和cifar10进行深度学习

来自分类Dev

目标的Tensorflow形状不匹配(cifar10)

来自分类Dev

Tensorflow CIFAR10代码分析

来自分类Dev

如何提高cifar-100数据集的准确性?我目前的准确度是10%

来自分类Dev

使用官方网站上的CIFAR-10数据集进行梯度爆炸

来自分类Dev

张量流中的多GPU CIFAR10示例:汇总损失

来自分类Dev

Cifar10中的粗标签和细标签是什么?

来自分类Dev

从具有大量数据的 SQL Server Sproc 批量读取结果集?

来自分类Dev

如何使用pytorch在cifar10或stl10中加载一种类型的图像

来自分类Dev

如何用我自己的图像喂入Cifar10训练过的模型并获得标签作为输出?

来自分类Dev

PyTorch-将CIFAR数据集转换为`TensorDataset`

来自分类Dev

如何重塑输入以在 CIFAR 数据集上进行训练?

来自分类Dev

oracle批量收集和读取数据

来自分类Dev

如何使用jtable读取批量数据?

来自分类Dev

如果CNN模型同时在keras的cifar10 / 100上训练,如何在一个图中绘制精度/损耗图?

Related 相关文章

  1. 1

    Cifar10数据集:从类中读取一定数量的图像

  2. 2

    在PyTorch中使用预训练的ResNet50解决CIFAR10数据集的问题

  3. 3

    尝试在计算机上使用“ torchvision.datasets”下载CIFAR10数据集

  4. 4

    使用PyTorch将Traininng数据集分为CIFAR10的训练和验证集后,如何增加数据?

  5. 5

    我使用CNN模型创建了CIFAR10数据集学习模型。为什么会出现错误?

  6. 6

    加载CIFAR10训练数据时出现内存错误

  7. 7

    CIFAR10数据加载器采样器拆分

  8. 8

    导入自己的数据,例如MNIST或CIFAR10 load_data()

  9. 9

    尝试使用我自己的数据而不是 cifar10 生成 Keras 模型

  10. 10

    如何创建类似于cifar-10的数据集

  11. 11

    在F#中导入CIFAR-10数据集

  12. 12

    tensorflow cifar10从检查点文件恢复训练

  13. 13

    在Cifar10上更改TensorFlow中的线程数

  14. 14

    如何测试Tensorflow Cifar10 CNN教程模型

  15. 15

    使用Tensrorflow和cifar10进行深度学习

  16. 16

    目标的Tensorflow形状不匹配(cifar10)

  17. 17

    Tensorflow CIFAR10代码分析

  18. 18

    如何提高cifar-100数据集的准确性?我目前的准确度是10%

  19. 19

    使用官方网站上的CIFAR-10数据集进行梯度爆炸

  20. 20

    张量流中的多GPU CIFAR10示例:汇总损失

  21. 21

    Cifar10中的粗标签和细标签是什么?

  22. 22

    从具有大量数据的 SQL Server Sproc 批量读取结果集?

  23. 23

    如何使用pytorch在cifar10或stl10中加载一种类型的图像

  24. 24

    如何用我自己的图像喂入Cifar10训练过的模型并获得标签作为输出?

  25. 25

    PyTorch-将CIFAR数据集转换为`TensorDataset`

  26. 26

    如何重塑输入以在 CIFAR 数据集上进行训练?

  27. 27

    oracle批量收集和读取数据

  28. 28

    如何使用jtable读取批量数据?

  29. 29

    如果CNN模型同时在keras的cifar10 / 100上训练,如何在一个图中绘制精度/损耗图?

热门标签

归档