我只有一个3D数值数据文件,该文件是按块读取的(因为按块读取比单个索引快)。例如,说“文件”中有一个MxNx30数组,我将创建一个RDD,如下所示:
def read(ind):
f = customFileOpener(file)
return f['data'][:,:,ind[0]:ind[-1]+1]
indices = [[0,9],[10,19],[20,29]]
rdd = sc.parallelize(indices,3).map(lambda v:read(v))
rdd.count()
因此,这3个分区中的每个分区都有一个大小为MxNx10的numpy.ndarray元素。
现在,我想拆分每个元素,以便在每个分区中有10个元素,每个元素都是一个MxN数组。我为此目的尝试使用flatMap(),但收到“ NoneType对象不可迭代”的错误:
def splitArr(arr):
Nmid = arr.shape[-1]
out = []
for i in range(0,Nmid):
out.append(arr[...,i])
return out
rdd2 = rdd.flatMap(lambda v: splitArr(v))
rdd2.count()
正确的方法是什么?关键点是(a)我需要从文件中分块读取数据,并且(b)拆分数据,使元素的大小为MxN(最好保持分区结构)。
据我了解您的描述,这样的事情应该可以解决:
rdd.flatMap(lambda arr: (x for x in np.rollaxis(arr, 2)))
或者,如果您希望使用单独的功能:
def splitArr(arr):
for x in np.rollaxis(arr, 2):
yield x
rdd.flatMap(splitArr)
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句