Monday 9 November 2020

Problem with KerasRegressor & multiple output

I have 3 inputs and 3 outputs. I am trying to use KerasRegressor and cross_val_score to get my prediction score.

my code is:

# Function to create model, required for KerasClassifier
def create_model():

    # create model
    # #Start defining the input tensor:
    input_data = layers.Input(shape=(3,))

    #create the layers and pass them the input tensor to get the output tensor:
    layer = [2,2]
    hidden1Out = Dense(units=layer[0], activation='relu')(input_data)
    finalOut = Dense(units=layer[1], activation='relu')(hidden1Out)

    u_out = Dense(1, activation='linear', name='u')(finalOut)   
    v_out = Dense(1, activation='linear', name='v')(finalOut)   
    p_out = Dense(1, activation='linear', name='p')(finalOut)   

    #define the model's start and end points
    model = Model(input_data,outputs = [u_out, v_out, p_out])    

    model.compile(loss='mean_squared_error', optimizer='adam')

    return model

#load data
...

input_var = np.vstack((AOA, x, y)).T
output_var = np.vstack((u,v,p)).T

# evaluate model
estimator = KerasRegressor(build_fn=create_model, epochs=num_epochs, batch_size=batch_size, verbose=0)
kfold = KFold(n_splits=10)

I tried:

results = cross_val_score(estimator, input_var, [output_var[:,0], output_var[:,1], output_var[:,2]], cv=kfold)

and

results = cross_val_score(estimator, input_var, [output_var[:,0:1], output_var[:,1:2], output_var[:,2:3]], cv=kfold)

and

results = cross_val_score(estimator, input_var, output_var, cv=kfold)

I got the error msg like:

Details: ValueError: Error when checking model target: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 3 array(s), but instead got the following list of 1 arrays: [array([[ 0.69945297, 0.13296847, 0.06292328],

or

ValueError: Found input variables with inconsistent numbers of samples: [72963, 3]

So how do I solve this problem?

Thanks.



from Problem with KerasRegressor & multiple output

No comments:

Post a Comment