Tuesday, 15 December 2020

LSTM categorical crossentropy validation accuracy remains constant

I've been doing some LSTM network lately, and I'd like to predict a one-hot encoded output (2 classes for now, I already tried with binary cross entropy and my problem stays the same) given 5 features and 14 lookback steps (batch size is 1). I am using TimeseriesGenerator:

generator = TimeseriesGenerator(scaled_train, train_targets, length=LOOKBACK, batch_size=BATCH_SIZE)
validation_generator = TimeseriesGenerator(scaled_test, test_targets, length=LOOKBACK, batch_size=BATCH_SIZE)

Train targets are as following (below, the distribution):

array([[1., 0.],
       [0., 1.],
       [1., 0.],
       ...,
       [1., 0.],
       [1., 0.],
       [1., 0.]], dtype=float32)

0    1  
0.0  1.0    1619
1.0  0.0    1545
dtype: int64

and features/scaled_train like this:

array([[0.22189629, 0.21121072, 0.21790398, 0.19933957, 0.41803716],
       [0.2106806 , 0.21783771, 0.2197984 , 0.2237905 , 0.18050205],
       [0.21885786, 0.21532436, 0.21933581, 0.19678948, 0.16397564],
       ...,
       [0.2104257 , 0.20155003, 0.22193512, 0.20173967, 0.12319585],
       [0.2070304 , 0.19911522, 0.21276043, 0.19141927, 0.11876491],
       [0.19873079, 0.18909128, 0.20785918, 0.19083925, 0.1407046 ]])

Test data are formatted the same way.

Model is also super simple:

model = Sequential()

model.add(LSTM(100, input_shape=(LOOKBACK, scaled_train.shape[1])))
model.add(Dense(2, activation='softmax'))

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

Here is the fit output:

Fit output

Basically, when I'm doing predictions, it returns something > 0.5 half of the time, < 0.5 the other half. Fit metrics are just super constant. I also tried to complexify my network with multiple LSTM layers, without success. I also tried to load another to dataset just to try but it didn't change anything. I may be doing something wrong with my generators.

What am I doing wrong?

Thanks!



from LSTM categorical crossentropy validation accuracy remains constant

No comments:

Post a Comment