Keras에서 피드 포워드 NN의 마지막 단계에 변수를 포함하려고합니다. 단 하나가 아닌 2 개의 열을 포함 할 때만 작동 할 수있는 것 같습니다. 내 코드 예제는 다음과 같습니다.
먼저 기본 입력 데이터 세트를 준비합니다.
import pandas as pd
from keras.models import Model
from keras.layers import Dense, Input, Concatenate
from keras.optimizers import Adam
iris = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv')
iris.loc[:, 'target'] = (iris.species == "setosa").map(int)
train_x = iris.drop(columns=['target', 'species'])
train_y = iris['target'].map(int)
그런 다음 train_x
두 개의 개별 데이터 프레임으로 분리하고 다른 위치의 네트워크에 입력합니다.
feature_x = train_x.drop(columns='petal_width')
single_feature_x = train_x[['petal_width']]
input_x = Input(shape=feature_x.shape, name='feature_input')
single_input_x = Input(shape=single_feature_x.shape, name='single_input')
x = Dense(4, activation='relu')(input_x)
concat_feat = Concatenate(axis=-1, name='concat_fc')([x, single_input_x])
outputs = Dense(1, activation='sigmoid')(concat_feat)
model = Model(inputs=[input_x, single_input_x], outputs=outputs)
model.compile(loss='binary_crossentropy',
optimizer=Adam(lr=0.001))
model.fit({'feature_input': feature_x,
'single_input': single_feature_x},
train_y,
epochs=100,
batch_size=512,
verbose=1)
이로 인해 오류가 발생합니다.
ValueError: Shape must be rank 2 but is rank 3 for '{{node model_5/concat_fc/concat}} = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32](model_5/dense_10/Relu, model_5/Cast_1, model_5/concat_fc/concat/axis)' with input shapes: [?,4], [?,1,1], [].
그러나이 한 줄을 추가하면 잘 실행됩니다.
feature_x = train_x.drop(columns='petal_width')
single_feature_x = train_x[['petal_width']]
# Add a constant column so the shape becomes (?,2)
single_feature_x.loc[:, 'constant'] = 0
두 개의 열에서는 작동하지만 하나에서는 작동하지 않는 이유는 무엇입니까?
입력 형태를 정확하게 지정하기 만하면됩니다. 2D 데이터의 경우 희미한 특성 만 전달하면됩니다. 샘플 치수는 필요하지 않습니다. 입력을 다음과 같이 수정하기 만하면됩니다.
input_x = Input(shape=feature_x.shape[1], name='feature_input')
single_input_x = Input(shape=single_feature_x.shape[1], name='single_input')
이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.
침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제
몇 마디 만하겠습니다