我想在预处理阶段使用“地图”并行旋转图像。
问题在于,每个图像都沿相同方向旋转(在生成一个随机数之后)。但是我希望每个图像都有不同的旋转度。
这是我的代码:
import tensorflow_addons as tfa
import math
import random
def rotate_tensor(image, label):
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
我试图在每次调用该函数时更改种子:
import tensorflow_addons as tfa
import math
import random
seed_num = 0
def rotate_tensor(image, label):
seed_num += 1
random.seed(seed_num)
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
但是我得到:
UnboundLocalError: local variable 'seed_num' referenced before assignment
我使用的是tf2,但我认为这没什么大不了的(除了旋转图像的代码)。
编辑:我尝试了@Mehraban建议,但似乎只有一次调用了rotation_tensor函数:
import tensorflow_addons as tfa
import math
import random
num_seed = 1
def rotate_tensor(image, label):
global num_seed
num_seed += 1
print(num_seed) #<---- print num_seed
random.seed(num_seed)
degree = random.random()*360
image = tfa.image.rotate(image, degree * math.pi / 180, interpolation='BILINEAR')
return image, label
rotated_test_set = rps_test_raw.map(rotate_tensor).batch(batch_size).prefetch(1)
但是它只打印一次“ 2”。所以我认为rotate_tensor被调用了一次。
编辑2-这是显示旋转图像的功能:
plt.figure(figsize=(12, 10))
for X_batch, y_batch in rotated_test_set.take(1):
for index in range(9):
plt.subplot(3, 3, index + 1)
plt.imshow(X_batch[index])
plt.title("Predict: {} | Actual: {}".format(class_names[y_test_proba_max_index[index]], class_names[y_batch[index]]))
plt.axis("off")
plt.show()
问题在于如何生成随机数。random
尽管tf.random
在处理tensorflow时应该使用模块,但您依赖模块。
这是当您从tf获得随机数时事物如何变化的演示:
import tensorflow as tf
import random
def gen():
for i in range(10):
yield [1.]
ds = tf.data.Dataset.from_generator(gen, (float))
def m1(d):
return d*random.random()
def m2(d):
return d*tf.random.normal([])
[d for d in ds.map(m2)]
[0.17368042,
1.5629852,
1.2372143,
1.8170034,
1.7040217,
-0.16738933,
-0.11567844,
-0.17949782,
-0.67811996,
-0.5391556]
[d for d in ds.map(m1)]
[0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798,
0.8369798]
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句