Tuesday, 26 March 2019

Plot a bivariate gaussian using Matplotlib

The code below is a bivariate gaussian distribution. I've calculated this distribution by adjusting the COV matrix to account for specific variables. Specifically, each coordinate is applied with a radius ([_Rad]). The COV matrix is then adjusted by scaling factor ([_Scaling]) to expand the radius in x-direction and contract in y-direction. The direction of this is measured by the rotation angle ([_Rotation]).

The output is expressed as a probability function, which represents the influence of each groups coordinates over a certain space. This is expressed using the colorbar.

My question or problem I'm having is how I calling the probability function. Specifically, the coordinates should only cover a set area that is determined by the radius. This radius can be shaped by the scaling factor and rotation. But the area not covered by these coordinates should not change in probability. I tried to fix these areas as 0.5 but that isn't the issue. I'm hoping to alter the code so that the probability is only influenced by the coordinates and relating variables.

Note: The provided df only displays 10 frames. The attached animation beneath the code represents 50 frames. I only added 10 due to space constraints but can upload the rest if necessary

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as sts
from matplotlib.animation import FuncAnimation

DATA_LIMITS = [-100, 100]

def datalimits(*data):
    return DATA_LIMITS  # dmin - spad, dmax + spad

def rot(theta):
    theta = np.deg2rad(theta)
    return np.array([
        [np.cos(theta), -np.sin(theta)],
        [np.sin(theta), np.cos(theta)]
    ])

def getcov(radius=1, scale=1, theta=0):
    cov = np.array([
        [radius*(scale + 1), 0],
        [0, radius/(scale + 1)]
    ])

    r = rot(theta)
    return r @ cov @ r.T

def mvpdf(x, y, xlim, ylim, radius=1, velocity=0, scale=0, theta=0):

    X,Y = np.meshgrid(np.linspace(*xlim), np.linspace(*ylim))

    XY = np.stack([X, Y], 2)

    x,y = rot(theta) @ (velocity/2, 0) + (x, y)

    cov = getcov(radius=radius, scale=scale, theta=theta)

    PDF = sts.multivariate_normal([x, y], cov).pdf(XY)

    return X, Y, PDF

def mvpdfs(xs, ys, xlim, ylim, radius=None, velocity=None, scale=None, theta=None):
    PDFs = []
    for i,(x,y) in enumerate(zip(xs,ys)):
        kwargs = {
            'radius': radius[i] if radius is not None else 1,
            'velocity': velocity[i] if velocity is not None else 0,
            'scale': scale[i] if scale is not None else 0,
            'theta': theta[i] if theta is not None else 0,
            'xlim': xlim,
            'ylim': ylim
        }
        X, Y, PDF = mvpdf(x, y,**kwargs)
        PDFs.append(PDF)

    return X, Y, np.sum(PDFs, axis=0)

fig, ax = plt.subplots(figsize = (10,4))

ax.set_xlim(DATA_LIMITS)
ax.set_ylim(DATA_LIMITS)

line_a, = ax.plot([], [], '.', c='red', alpha = 0.5, markersize=5, animated=True)
line_b, = ax.plot([], [], '.', c='blue', alpha = 0.5, markersize=5, animated=True)
cfs = None

def plotmvs(tdf, xlim=None, ylim=None, fig=fig, ax=ax):
    global cfs  
    if cfs:
        for tp in cfs.collections:

            tp.remove()

    df = tdf[1]

    if xlim is None: xlim = datalimits(df['X'])
    if ylim is None: ylim = datalimits(df['Y'])

    PDFs = []

    for (group, gdf), group_line in zip(df.groupby('group'), (line_a, line_b)):

        # Update the scatter line data
        group_line.set_data(*gdf[['X','Y']].values.T)

        kwargs = {
            'radius': gdf['Radius'].values if 'Radius' in gdf else None,
            'velocity': gdf['Velocity'].values if 'Velocity' in gdf else None,
            'scale': gdf['Scaling'].values if 'Scaling' in gdf else None,
            'theta': gdf['Rotation'].values if 'Rotation' in gdf else None,
            'xlim': xlim,
            'ylim': ylim
        }
        X, Y, PDF = mvpdfs(gdf['X'].values, gdf['Y'].values, **kwargs)
        PDFs.append(PDF)

    PDF = PDFs[0] - PDFs[1]
    normPDF = PDF - PDF.min()
    normPDF = normPDF / normPDF.max()

    #normPDF = PDF * .5/max(PDF.max(), -PDF.min()) + .5

    cfs = ax.contourf(X, Y, normPDF, cmap='viridis', alpha = 0.8,
    levels=10)

    return cfs.collections + [line_a, line_b]

n = 10
time = range(n) 
d = ({
     'A1_X' :    [13.3,13.16,12.99,12.9,12.79,12.56,12.32,12.15,11.93,11.72],
     'A1_Y' :    [26.12,26.44,26.81,27.18,27.48,27.82,28.13,28.37,28.63,28.93],
     'A2_X' :    [6.97,6.96,7.03,6.98,6.86,6.76,6.55,6.26,6.09,5.9],
     'A2_Y' :    [10.92,10.83,10.71,10.52,10.22,10.02,9.86,9.7,9.54,9.37],
     'A3_X' :    [-31.72,-31.93,-32.18,-32.43,-32.7,-32.89,-33.15,-33.51,-33.84,-34.17],
     'A3_Y' :    [21.25,21.52,21.7,21.98,22.25,22.47,22.7,22.95,23.2,23.4],
     'A4_X' :    [37.54,37.42,37.3,37.14,36.97,36.77,36.56,36.37,36.13,35.89],
     'A4_Y' :    [7.31,7.35,7.38,7.43,7.5,7.58,7.65,7.68,7.69,7.69],
     'A5_X' :    [-5.37,-5.31,-5.28,-5.34,-5.41,-5.42,-5.68,-5.84,-6.1,-6.31],
     'A5_Y' :    [-5.42,-5.7,-6,-6.15,-6.41,-6.67,-6.88,-7.11,-7.33,-7.49],
     'A6_X' :    [-3.33,-3.15,-2.97,-2.94,-2.88,-2.79,-2.69,-2.66,-2.54,-2.67],
     'A6_Y' :    [13.69,13.86,14.09,14.34,14.73,15.01,15.38,15.83,16.15,16.73],
     'A7_X' :    [-4.4,-4.56,-4.83,-5.02,-5.18,-5.51,-5.81,-6.03,-6.31,-6.7],
     'A7_Y' :    [21.34,21.53,21.69,21.89,22.03,22.35,22.63,22.91,23.14,23.34],
     'A8_X' :    [-14.89,-15.12,-15.26,-15.52,-15.96,-16.37,-16.7,-17.08,-17.55,-17.95],
     'A8_Y' :    [3.7,3.41,3.14,2.84,2.58,2.26,2.07,1.78,1.45,1.23],
     'A9_X' :    [-51.92,-52.04,-52.15,-52.26,-52.36,-52.54,-52.76,-52.98,-53.17,-53.4],
     'A9_Y' :    [16.45,16.44,16.5,16.61,16.59,16.52,16.52,16.43,16.45,16.49],
     'A10_X' :   [-15.18,-15.18,-15.18,-15.18,-15.18,-15.18,-15.18,-15.18,-15.18,-15.18],
     'A10_Y' :   [26.02,26.02,26.02,26.02,26.02,26.02,26.02,26.02,26.02,26.02],
     'A11_X' :   [15.5,15.22,14.9,14.59,14.36,14.08,13.74,13.43,13.13,12.82],
     'A11_Y' :   [7.25,7.36,7.51,7.61,7.72,7.88,8.05,8.18,8.5,8.8],
     'A12_X' :   [-5.36,-5.35,-5.33,-5.28,-5.18,-5.12,-4.99,-4.83,-4.8,-4.71],
     'A12_Y' :   [19.02,18.77,18.56,18.41,18.22,18.03,17.9,17.72,17.69,17.58],
     'A13_X' :   [-45.76,-45.91,-46.13,-46.41,-46.62,-46.82,-47.07,-47.35,-47.61,-47.87],
     'A13_Y' :   [18.9,18.96,19.03,19.12,19.12,19.18,19.31,19.42,19.45,19.53],
     'A14_X' :   [-10.28,-10.3,-10.23,-10.36,-10.53,-10.69,-10.84,-10.95,-11.17,-11.37],
     'A14_Y' :   [18.25,18.42,18.56,18.73,18.86,18.98,19.02,19.19,19.3,19.46],
     'A15_X' :   [29.77,29.6,29.45,29.24,28.9,28.68,28.42,28.06,27.75,27.49],
     'A15_Y' :   [11.59,11.38,11.19,11.02,10.85,10.71,10.58,10.39,10.18,9.98],
     'B1_X' :    [38.35,38.1,37.78,37.55,37.36,37.02,36.78,36.46,36.21,35.79],
     'B1_Y' :    [12.55,12.58,12.58,12.55,12.5,12.47,12.43,12.48,12.44,12.44],
     'B2_X' :    [14.6,14.38,14.16,13.8,13.45,13.11,12.71,12.3,12.06,11.61],
     'B2_Y' :    [4.66,4.44,4.24,4.1,4.01,3.84,3.67,3.56,3.44,3.47],
     'B3_X' :    [-12.16,-12.35,-12.53,-12.73,-12.91,-13.01,-13.24,-13.44,-13.68,-13.93],
     'B3_Y' :    [20.07,20.26,20.34,20.5,20.62,20.69,20.72,20.73,20.63,20.58],
     'B4_X' :    [-3.27,-3.1,-2.83,-2.49,-2.34,-2.13,-1.97,-1.8,-1.67,-1.59],
     'B4_Y' :    [-6.25,-6.37,-6.52,-6.61,-6.76,-6.89,-7.01,-7.1,-7.13,-7.33],
     'B5_X' :    [-21.47,-21.63,-21.84,-22.03,-22.28,-22.53,-22.77,-22.99,-23.27,-23.52],
     'B5_Y' :    [8.94,8.87,8.79,8.68,8.61,8.56,8.48,8.35,8.22,8.12],
     'B6_X' :    [-13.81,-13.83,-13.91,-14.02,-14.15,-14.31,-14.54,-14.77,-14.96,-15.24],
     'B6_Y' :    [25.45,25.81,25.94,26.26,26.56,26.75,26.92,27.07,27.22,27.25],
     'B7_X' :    [-6.28,-6.33,-6.43,-6.44,-6.61,-6.8,-7.02,-7.22,-7.46,-7.7],
     'B7_Y' :    [13.82,13.6,13.43,13.26,13.12,13.09,13.07,13.14,13.19,13.32],
     'B8_X' :    [28.39,28.09,27.91,27.76,27.4,27.14,26.91,26.69,26.34,26.1],
     'B8_Y' :    [8.36,8.2,8.13,8.1,8.01,7.94,7.84,7.76,7.8,7.84],
     'B9_X' :    [-7.55,-7.54,-7.57,-7.65,-7.77,-7.87,-8.01,-8.06,-8.06,-8.06],
     'B9_Y' :    [17.98,17.94,17.97,18.02,18.05,18.09,18.07,18.02,17.97,17.92],
     'B10_X' :   [-32.36,-32.63,-32.92,-33.25,-33.54,-33.78,-34.13,-34.37,-34.69,-35.01],
     'B10_Y' :   [13.27,13.48,13.67,13.9,14.14,14.48,14.76,15.05,15.31,15.62],
     'B11_X' :   [-44.08,-44.19,-44.33,-44.47,-44.64,-44.78,-44.92,-45.16,-45.36,-45.56],
     'B11_Y' :   [15.9,16.09,16.22,16.38,16.49,16.63,16.7,16.79,16.85,16.94],
     'B12_X' :   [-16.47,-16.67,-16.76,-16.86,-16.99,-17.24,-17.48,-17.76,-17.98,-18.29],
     'B12_Y' :   [29.76,29.96,30.07,30.3,30.45,30.59,30.61,30.67,30.62,30.66],
     'B13_X' :   [-50.27,-50.38,-50.55,-50.74,-50.92,-51.02,-51.13,-51.3,-51.46,-51.65],
     'B13_Y' :   [16.31,16.3,16.31,16.33,16.36,16.28,16.25,16.22,16.21,16.27],
     'B14_X' :   [-15.55,-15.81,-16.05,-16.35,-16.67,-16.96,-17.35,-17.76,-18.09,-18.6],
     'B14_Y' :   [8.56,8.53,8.54,8.57,8.62,8.6,8.58,8.49,8.44,8.55],
     'B15_X' :   [9.79,9.47,9.2,8.77,8.41,8.07,7.65,7.19,6.76,6.42],
     'B15_Y' :   [27.61,27.79,27.99,28.16,28.37,28.53,28.68,28.82,28.9,29.06],
     'A1_Radius' :  [10.33,10.34,10.34,10.37,10.38,10.37,10.36,10.36,10.35,10.35],
     'A2_Radius' :  [9.05,9.06,9.07,9.08,9.09,9.09,9.08,9.06,9.05,9.04],
     'A3_Radius' :  [13.04,13.15,13.29,13.44,13.6,13.72,13.88,14.1,14.31,14.52],
     'A4_Radius' :  [25,25,25,25,25,25,25,25,25,24.81],
     'A5_Radius' :  [11.24,11.33,11.44,11.49,11.59,11.68,11.77,11.86,11.95,12.02],
     'A6_Radius' :  [8.19,8.19,8.18,8.17,8.15,8.14,8.13,8.11,8.11,8.09],
     'A7_Radius' :  [8.18,8.19,8.2,8.21,8.22,8.25,8.27,8.29,8.31,8.33],
     'A8_Radius' :  [9.71,9.79,9.85,9.94,10.05,10.17,10.26,10.38,10.53,10.65],
     'A9_Radius' :  [25,25,25,25,25,25,25,25,25,25],
     'A10_Radius' : [9.08,9.08,9.08,9.08,9.08,9.08,9.08,9.08,9.08,9.08],
     'A11_Radius' : [11.12,11.03,10.91,10.81,10.74,10.64,10.54,10.45,10.34,10.23],
     'A12_Radius' : [8.07,8.06,8.05,8.05,8.04,8.03,8.02,8.01,8.01,8.01],
     'A13_Radius' : [24.16,24.36,24.62,24.98,25,25,25,25,25,25],
     'A14_Radius' : [8.3,8.3,8.3,8.31,8.32,8.33,8.34,8.35,8.37,8.39],
     'A15_Radius' : [17.86,17.75,17.66,17.52,17.28,17.14,16.96,16.73,16.53,16.38],
     'B1_Radius' :  [25,25,25,25,25,25,24.84,24.45,24.15,23.65],
     'B2_Radius' :  [11.3,11.28,11.26,11.19,11.11,11.06,10.99,10.91,10.88,10.77],
     'B3_Radius' :  [8.46,8.48,8.5,8.52,8.54,8.55,8.57,8.59,8.61,8.63],
     'B4_Radius' :  [11.54,11.59,11.65,11.69,11.75,11.81,11.86,11.9,11.92,12],
     'B5_Radius' :  [10.08,10.13,10.18,10.24,10.3,10.37,10.43,10.5,10.59,10.66],
     'B6_Radius' :  [8.89,8.92,8.94,8.98,9.02,9.06,9.1,9.14,9.18,9.21],
     'B7_Radius' :  [8.19,8.2,8.22,8.23,8.24,8.24,8.25,8.25,8.26,8.26],
     'B8_Radius' :  [17.34,17.15,17.03,16.93,16.69,16.52,16.38,16.25,16.01,15.84],
     'B9_Radius' :  [8.13,8.13,8.14,8.14,8.15,8.15,8.16,8.16,8.16,8.16],
     'B10_Radius' : [13.38,13.5,13.64,13.8,13.94,14.05,14.23,14.35,14.53,14.7],
     'B11_Radius' : [22.24,22.35,22.5,22.65,22.84,23,23.16,23.44,23.67,23.9],
     'B12_Radius' : [9.68,9.74,9.77,9.82,9.86,9.92,9.96,10.02,10.05,10.1],
     'B13_Radius' : [25,25,25,25,25,25,25,25,25,25],
     'B14_Radius' : [9.17,9.2,9.23,9.26,9.3,9.34,9.4,9.47,9.52,9.59],
     'B15_Radius' : [9.8,9.76,9.74,9.7,9.67,9.63,9.59,9.55,9.5,9.48],
     'A1_Scaling' : [0,0.07,0.1,0.09,0.06,0.1,0.09,0.05,0.07,0.08],
     'A2_Scaling' : [0,0.01,0.01,0.02,0.06,0.03,0.04,0.07,0.03,0.04],
     'A3_Scaling' : [0,0.07,0.06,0.08,0.09,0.05,0.07,0.11,0.1,0.09],
     'A4_Scaling' : [0,0.01,0.01,0.02,0.02,0.03,0.03,0.02,0.03,0.04],
     'A5_Scaling' : [0,0.05,0.05,0.02,0.04,0.04,0.07,0.05,0.07,0.04],
     'A6_Scaling' : [0,0.04,0.05,0.04,0.09,0.05,0.09,0.12,0.07,0.21],
     'A7_Scaling' : [0,0.04,0.06,0.05,0.03,0.13,0.1,0.07,0.08,0.11],
     'A8_Scaling' : [0,0.08,0.06,0.09,0.16,0.16,0.08,0.14,0.19,0.12],
     'A9_Scaling' : [0,0.01,0.01,0.02,0.01,0.02,0.03,0.03,0.02,0.03],
     'A10_Scaling' :    [0,0,0,0,0,0,0,0,0,0],
     'A11_Scaling' :    [0,0.05,0.07,0.06,0.04,0.06,0.09,0.06,0.11,0.11],
     'A12_Scaling' :    [0,0.04,0.03,0.02,0.02,0.02,0.02,0.03,0,0.01],
     'A13_Scaling' :    [0,0.02,0.03,0.05,0.03,0.03,0.05,0.05,0.04,0.04],
     'A14_Scaling' :    [0,0.02,0.02,0.03,0.03,0.02,0.01,0.03,0.03,0.04],
     'A15_Scaling' :    [0,0.04,0.03,0.04,0.08,0.04,0.05,0.1,0.08,0.06],
     'B1_Scaling' : [0,0.04,0.06,0.03,0.02,0.07,0.04,0.06,0.04,0.11],
     'B2_Scaling' : [0,0.06,0.05,0.09,0.08,0.08,0.11,0.1,0.04,0.12],
     'B3_Scaling' : [0,0.04,0.02,0.04,0.03,0.01,0.03,0.02,0.04,0.04],
     'B4_Scaling' : [0,0.03,0.06,0.07,0.03,0.03,0.02,0.02,0.01,0.03],
     'B5_Scaling' : [0,0.02,0.03,0.03,0.04,0.04,0.04,0.04,0.06,0.04],
     'B6_Scaling' : [0,0.08,0.01,0.07,0.06,0.04,0.05,0.05,0.04,0.05],
     'B7_Scaling' : [0,0.03,0.02,0.02,0.03,0.02,0.03,0.02,0.04,0.04],
     'B8_Scaling' : [0,0.07,0.02,0.01,0.08,0.04,0.04,0.03,0.07,0.04],
     'B9_Scaling' : [0,0,0,0.01,0.01,0.01,0.01,0,0,0],
     'B10_Scaling' :    [0,0.07,0.07,0.09,0.08,0.11,0.12,0.09,0.1,0.11],
     'B11_Scaling' :    [0,0.03,0.02,0.03,0.03,0.02,0.02,0.04,0.03,0.03],
     'B12_Scaling' :    [0,0.05,0.01,0.04,0.02,0.05,0.04,0.05,0.03,0.05],
     'B13_Scaling' :    [0,0.01,0.02,0.02,0.02,0.01,0.01,0.02,0.01,0.02],
     'B14_Scaling' :    [0,0.04,0.04,0.05,0.06,0.05,0.09,0.1,0.07,0.16],
     'B15_Scaling' :    [0,0.08,0.06,0.13,0.1,0.08,0.11,0.14,0.12,0.08],           
     'A1_Rotation' :    [0,112.81,114.01,110.56,110.6,113.37,116.02,116.99,118.62,119.38],
     'A2_Rotation' :    [0,-94.27,-73.02,-87.73,-98.57,-102.94,-111.2,-119.95,-122.41,-124.66],
     'A3_Rotation' :    [0,128.47,135.84,134.06,134.75,133.94,134.69,136.59,137.5,138.85],
     'A4_Rotation' :    [0,164.64,165.45,163.48,161.42,160.63,161.31,162.54,164.99,167.17],
     'A5_Rotation' :    [0,-78.78,-81.19,-87.73,-92.23,-92.33,-101.96,-105.48,-111.09,-114.38],
     'A6_Rotation' :    [0,43.84,48.03,59.34,66.72,67.83,69.48,72.62,72.2,77.81],
     'A7_Rotation' :    [0,131.01,141.7,138.59,138.53,137.8,137.64,136.12,136.75,139.03],
     'A8_Rotation' :    [0,-127.32,-123.16,-125.78,-133.66,-135.66,-137.83,-138.76,-139.64,-141.01],
     'A9_Rotation' :    [0,-173.6,166.75,154.39,162.16,173.26,175.63,-178.69,-179.76,178.52],
     'A10_Rotation' :   [0,0,0,0,0,0,0,0,0,0],
     'A11_Rotation' :   [0,159.85,156.67,158.48,157.77,156.11,155.54,155.77,152.28,149.97],
    'A12_Rotation' :    [0,-87.73,-86.34,-82.59,-77.52,-76.38,-71.91,-67.87,-66.98,-65.81],
     'A13_Rotation' :   [0,157.86,160.03,161.3,165.63,165.1,162.84,162.02,163.42,163.4],
     'A14_Rotation' :   [0,96.06,80.08,98.93,112.07,119.3,125.62,125.42,130.18,131.8],
     'A15_Rotation' :   [0,-128.97,-128.41,-133.2,-139.65,-141,-143.05,-144.79,-145.08,-144.84],
     'B1_Rotation' :    [0,172.04,176.94,179.74,-177.22,-176.59,-175.81,-178,-177.26,-177.53],
     'B2_Rotation' :    [0,-135.52,-136.05,-144.97,-150.41,-151.25,-152.35,-154.52,-154.34,-158.27],
     'B3_Rotation' :    [0,136.51,144.7,143.29,144.27,144.21,149.19,152.95,159.91,164.19],
     'B4_Rotation' :    [0,-34.26,-30.8,-24.57,-28.69,-29.15,-30.36,-29.83,-28.55,-32.58],
     'B5_Rotation' :    [0,-157.4,-157.25,-155.34,-157.7,-160.03,-160.33,-158.82,-158.14,-158.21],
     'B6_Rotation' :    [0,92.27,101.22,104.5,106.63,110.85,116.25,120.35,122.84,128.35],
     'B7_Rotation' :    [0,-105.22,-111.46,-106.51,-115.8,-125.96,-134.82,-144.07,-152.07,-160.76],
     'B8_Rotation' :    [0,-151.19,-154.33,-157.13,-160.12,-161.43,-160.74,-160.39,-164.53,-167.14],
     'B9_Rotation' :    [0,-73.61,-155.92,158.51,161.67,160.64,169.2,175.38,-178.71,-172.8],                         
     'B10_Rotation' :   [0,142.69,144.41,144.93,143.57,139.61,139.91,138.54,138.86,138.47],
     'B11_Rotation' :   [0,119.54,127.13,128.87,133.1,133.59,136.35,140.45,143.34,144.72],
     'B12_Rotation' :   [0,134.03,132.68,125.29,126.67,132.59,139.67,144.84,150.2,153.44],
     'B13_Rotation' :   [0,-177.72,178.64,176.87,175.25,-177.73,-175.96,-175.29,-175.61,-178.46],
     'B14_Rotation' :   [0,-173.78,-177.74,179.11,176.83,178.31,179.19,-178.19,-177.33,-179.89],
     'B15_Rotation' :   [0,152.39,147.79,151.7,151.38,151.96,153.65,155.22,157.01,156.86],
     })

tuples = [((t, k.split('_')[0][0], int(k.split('_')[0][1:]), k.split('_')[1]), v[i]) 
      for k,v in d.items() for i,t in enumerate(time)]

df = pd.Series(dict(tuples)).unstack(-1)
df.index.names = ['time', 'group', 'id']

interval_ms = 200
delay_ms = 1000
ani = FuncAnimation(fig, plotmvs, frames=df.groupby('time'),
                blit=True, interval=interval_ms, repeat_delay=delay_ms)

plt.show()

Animated Description



from Plot a bivariate gaussian using Matplotlib

No comments:

Post a Comment