ds_train = tf.data.experimental.make_csv_dataset(
file_pattern = "./df_profile_seq_fill_csv/*.csv",
batch_size=batch_size, column_names=use_cols, label_name='label',
select_columns= select_cols,
num_parallel_reads=30,
shuffle_buffer_size=10000)
csv에서 데이터를 읽었습니다. 여기서 label
열은 0, 1,2 ...와 같은 정수 레이블입니다.
model.fit( ds_train, validation_data=ds_test, steps_per_epoch=10000,
verbose=1,
epochs=1000000
)
label == 0
ds_train 및 ds_test에 대한 모든 샘플 을 필터링하고 싶습니다. 이것을 실현하는 방법은 무엇입니까? 감사.
이를 수행하는 한 가지 방법은 먼저 배치 1을 사용하여 csv에서 데이터 세트를 생성하는 것입니다 (일괄 처리는 필수 조치 임). 그런 다음 예제 인 "배치"를 필터링 한 다음 다시 배치합니다.
class_number_to_get_rid_of = 0
TRAIN_DATA_URL = "https://storage.googleapis.com/tf-datasets/titanic/train.csv"
train_file_path = tf.keras.utils.get_file("train.csv", TRAIN_DATA_URL)
dataset = tf.data.experimental.make_csv_dataset(train_file_path, batch_size=1)
dataset_filtered = dataset.filter(lambda p: tf.reduce_all(tf.not_equal(p['survived'], [class_number_to_get_rid_of])))
dataset_filtered = dataset_filtered.batch(5)
이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.
침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제
몇 마디 만하겠습니다