#!/usr/bin/env python
#coding:utf-8

import wx,numpy

class Particle:
    def __init__(self,x1,y1):
        self.damp = 0.2
        self.pos = numpy.array([x1,y1])
        self.vel = numpy.array([0,0])
        self.acc = numpy.array([0,0])
        
    def update(self):
        self.acc -= self.vel * self.damp
        self.vel += self.acc
        self.pos += self.vel
        self.acc = numpy.zeros(2)

    def mouse_push(self,x,y):
        mousePos = numpy.array([x,y])
        diff = self.pos - mousePos
        diff_length = numpy.linalg.norm(diff)
        if(diff_length < 20):
            pct = 1 - (diff_length/20)
            diff /= diff_length
            self.acc += diff * pct * 100

class Spring:
    def __init__(self,a,b):
        self.parta = a
        self.partb = b
        self.len = 100 #numpy.linalg.norm(a - b)
        self.dist = 0.0
        self.dir = numpy.zeros(2)

    def update(self):
        self.dist = numpy.linalg.norm(self.parta.pos - self.partb.pos)
        self.dir = self.partb.pos - self.parta.pos
        self.dir = self.dir / numpy.linalg.norm(self.dir)
        self.dir *= (self.dist - self.len)/2
        self.parta.acc += self.dir

 
class MyWindow(wx.Frame):
    def __init__(self, parent=None, id=-1, title=None):
        wx.Frame.__init__(self, parent, id, title)
        self.panel = wx.Panel(self, size=(800, 600))
        self.panel.SetBackgroundColour('WHITE')
        self.timer = wx.Timer(self)
        self.Bind(wx.EVT_TIMER, self.OnTimer)
        self.Fit()
        #self.counter = 0
        self.timer.Start(50)
        self.rad_node = 20 #radius of nodes

        #make instance
        panel_w, panel_h = self.panel.GetSize()
        self.nodes = []
        self.nodes.append(Particle(panel_w/2, panel_h/2))
        self.nodes.append(Particle(panel_w/2+numpy.random.random_sample(), panel_h/2+numpy.random.random_sample()))
        self.arcs = [] 
        self.arcs.append(Spring(self.nodes[1],self.nodes[0]))

        #make button to add node spring

    def OnTimer(self, event):
        """
        self.counter += 1
        if self.counter == 999:
            self.timer.Stop()
        """
        for i in self.arcs:
            i.update()
        for i in self.nodes[1:]:
            mx,my = self.panel.ScreenToClient(wx.GetMousePosition())
            i.mouse_push(mx,my)
            i.update()
            
        cdc = wx.ClientDC(self.panel)
        bmp_w, bmp_h = self.panel.GetSize()
        bmp = wx.EmptyBitmap(bmp_w, bmp_h)
        bdc = wx.BufferedDC(cdc, bmp)
        bdc.Clear()
        for i in self.arcs:
            bdc.SetPen(wx.Pen("black"))
            bdc.DrawLine(i.parta.pos[0],i.parta.pos[1],i.partb.pos[0],i.partb.pos[1])
        for i in self.nodes:
            bdc.SetPen(wx.Pen('black'))
            bdc.SetBrush(wx.Brush('white'))
            bdc.DrawCircle(i.pos[0],i.pos[1],self.rad_node)
        cdc.DrawBitmap(bmp,0,0)
 
if __name__ == '__main__':
    app = wx.PySimpleApp()
    w = MyWindow(title='APA visualizer')
    w.Center()
    w.Show()
    app.MainLoop()
