Monday, 29 July 2019

PyTorch Binary Classification - same network structure, 'simpler' data, but worse performance?

To get to grips with PyTorch (and deep learning in general) I started by working through some basic classification examples. One such example was classifying a non-linear dataset created using sklearn (full code available as notebook here)

n_pts = 500
X, y = datasets.make_circles(n_samples=n_pts, random_state=123, noise=0.1, factor=0.2)
x_data = torch.FloatTensor(X)
y_data = torch.FloatTensor(y.reshape(500, 1))

enter image description here

This is then accurately classified using a pretty basic neural net

class Model(nn.Module):
    def __init__(self, input_size, H1, output_size):
        super().__init__()
        self.linear = nn.Linear(input_size, H1)
        self.linear2 = nn.Linear(H1, output_size)

    def forward(self, x):
        x = torch.sigmoid(self.linear(x))
        x = torch.sigmoid(self.linear2(x))
        return x

    def predict(self, x):
        pred = self.forward(x)
        if pred >= 0.5:
            return 1
        else:
            return 0

As I have an interest in health data I then decided to try and use the same network structure to classify some a basic real-world dataset. I took heart rate data for one patient from here, and altered it so all values > 91 would be labelled as anomalies (e.g. a 1 and everything <= 91 labelled a 0). This is completely arbitrary, but I just wanted to see how the classification would work. The complete notebook for this example is here.

enter image description here

What is not intuitive to me is why the first example reaches a loss of 0.0016 after 1,000 epochs, whereas the second example only reaches a loss of 0.4296 after 10,000 epochs

Training Loss for Example 1

Training Loss for Heart Rate Example

Perhaps I am being naive in thinking that the heart rate example would be much easier to classify. Any insights to help me understand why this is not what I am seeing would be great!



from PyTorch Binary Classification - same network structure, 'simpler' data, but worse performance?

No comments:

Post a Comment