I am trying to train a seq-to-seq model on a simple sin wave. The target is to get Nin points of data and predict Nout next data points. Task seems simple and the model predicts well for large frequency freq (y = sin(freq * x)). For example, for freq=4, the loss is very low and prediction is very close to target. However, for low frequencies the prediction is bad. Any thoughts on why the model fails?
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, RepeatVector, TimeDistributed, Dense
freq = 0.25
Nin, Nout = 14, 14
# Helper function to convert 1d data to (input, target) samples
def windowed_dataset(y, input_window = 5, output_window = 1, stride = 1, num_features = 1):
L = y.shape[0]
num_samples = (L - input_window - output_window) // stride + 1
X = np.zeros([input_window, num_samples, num_features])
Y = np.zeros([output_window, num_samples, num_features])
for ff in np.arange(num_features):
for ii in np.arange(num_samples):
start_x = stride * ii
end_x = start_x + input_window
X[:, ii, ff] = y[start_x:end_x, ff]
start_y = stride * ii + input_window
end_y = start_y + output_window
Y[:, ii, ff] = y[start_y:end_y, ff]
return X, Y
# The input shape is your sequence length and your token embedding size
inputs = Input(shape=(Nin, 1))
# Build a RNN encoder
encoder = LSTM(128, return_sequences=False)(inputs)
# Repeat the encoding for every input to the decoder
encoding_repeat = RepeatVector(Nout)(encoder)
# Pass your (5, 128) encoding to the decoder
decoder = LSTM(128, return_sequences=True)(encoding_repeat)
# Output each timestep into a fully connected layer
sequence_prediction = TimeDistributed(Dense(1, activation='linear'))(decoder)
model = Model(inputs, sequence_prediction)
model.compile('adam', 'mse') # Or categorical_crossentropy
y = np.sin(freq * np.linspace(0, 10, 1000))[:, None]
Ntr = int(0.8 * y.shape[0])
y_train, y_test = y[:Ntr], y[Ntr:]
from generate_dataset import *
stride = 1
N_features = 1
Xtrain, Ytrain = windowed_dataset(y_train, input_window=Nin, output_window=Nout, stride=stride,
num_features=N_features)
print(model.summary())
Xtrain, Ytrain = Xtrain.transpose(1, 0, 2), Ytrain.transpose(1, 0, 2)
print("Xtrain", Xtrain.shape)
model.fit(Xtrain, Ytrain, epochs=30)
plt.figure(); plt.plot(y, 'ro')
for Ns in arr([10, 50, 200, 400, 800, 1500, 3000]) // 10:
ypred = model.predict(Xtrain[[Ns]])
print("ypred", ypred.shape)
ypred = ypred[-1]
plt.figure()
plt.plot(ypred, 'ro')
plt.plot(Xtrain[Ns], 'm--')
plt.plot(Ytrain[Ns], 'k.')
plt.show()
exit()
from Poor performance of seq-to-seq LSTM on simple sin wave with low frequency
No comments:
Post a Comment