In [1]:
%config InlineBackend.figure_format = 'retina'
In [2]:
import numpy as np
import matplotlib.pyplot as plt
import math as mt
from numba.decorators import jit
from numba import f8,u8
import scipy.fftpack as sf
from IPython.display import display, clear_output

なぜかNumbaのかけ方で計算時間が全然違う。
なんでだろう?
下のサンプルプログラムはESのPICコード。

In [3]:
@jit
def acce(vt,nop,dx,nx,dt,nt):
    np.seterr(divide='ignore', invalid='ignore')
    lx=dx*nx
    xx=dx*np.arange(nx)
    xp=np.linspace(0,lx-lx/nop,nop)
    vp=np.random.normal(0,vt,nop)#randn(nop)
    kk=2*np.pi/lx*np.r_[np.arange(nx/2),np.arange(-nx/2,0)]

    for it in range(nt):
        xp=xp+dt*vp
#        
        xp[xp>lx]=xp[xp>lx]-lx
        xp[xp<0.0 ]=xp[xp<0.0 ]+lx
        ds=np.zeros(nx)
        
        for ip in range(nop):
            ixm=mt.floor(xp[ip]/dx); ixp=ixm+1
            wxp=xp[ip]/dx-ixm; wxm=1-wxp

            if ixp>nx-1: ixp=ixp-nx#; print(ixp)

            ds[ixm]=ds[ixm]+wxm
            ds[ixp]=ds[ixp]+wxp    
        ds=ds/nop*nx
            
        exfft=1j/kk*sf.fft(ds)
        exfft[0]=0
        ex=np.real(sf.ifft(exfft))        
#        
        for ip in range(nop):
            ixm=mt.floor(xp[ip]/dx); ixp=ixm+1
            wxp=xp[ip]/dx-ixm; wxm=1-wxp

            if ixp>nx-1: ixp=ixp-nx#; print(ixp)
            vp[ip]=vp[ip]-dt*(wxm*ex[ixm]+wxp*ex[ixp])

#        
#    plt.subplot(3,1,1); plt.plot(xp,vp,'.')
#    plt.subplot(3,1,2); plt.plot(xx,ds,'-k')
#    plt.subplot(3,1,3); plt.plot(xx,ex,'-k')
#    plt.show()
In [4]:
%timeit  acce(1.0,10**5,1.0,2**8,1.0,256)
1 loop, best of 3: 1min 20s per loop

functionを分けて作って、Numbaをそれぞれにかませる。 速度が爆速になる。

In [5]:
@jit('(f8,u8,f8,u8,f8,u8)')
def acce_numba(vt,nop,dx,nx,dt,nt):
    np.seterr(divide='ignore', invalid='ignore')
    lx=dx*nx
    xx=dx*np.arange(nx)
    xp=np.linspace(0,lx-lx/nop,nop)
    vp=np.random.normal(0,vt,nop)#randn(nop)
    ds=np.zeros(nx)
    kk=2*mt.pi/lx*np.r_[np.arange(nx/2),np.arange(-nx/2,0)]
    esave=np.zeros((nt,nx))

    for it in range(nt):

        push(nop,xp,vp,lx,dt)
        dens(dx,nx,nop,xp,ds)
        
        exfft=1j/kk*sf.fft(ds/nop*nx)
        exfft[0]=0
        ex=np.real(sf.ifft(exfft))        
        
        acc(dx,nx,nop,xp,vp,ex,dt)
        esave[it,:]=ex[:]
        
@jit('f8[:](u8,f8[:],f8[:],f8,f8)')
def push(nop,xp,vp,lx,dt):
    for ip in range(nop):
        xp[ip]=xp[ip]+dt*vp[ip]
        if xp[ip]>lx: xp[ip]=xp[ip]-lx
        if xp[ip]<0:  xp[ip]=xp[ip]+lx
    return xp

@jit('f8[:](f8,u8,u8,f8[:],f8[:],f8[:],f8)')
def acc(dx,nx,nop,xp,vp,ex,dt):
    for ip in range(nop):
        ixm=mt.floor(xp[ip]/dx); ixp=ixm+1
        wxp=xp[ip]/dx-ixm; wxm=1-wxp

        if ixp>nx-1: ixp=ixp-nx#; print(ixp)
        vp[ip]=vp[ip]-dt*(wxm*ex[ixm]+wxp*ex[ixp])
        
    return vp

@jit('f8[:](f8,u8,u8,f8[:],f8[:])')
def dens(dx,nx,nop,xp,ds):
    ds=np.zeros(nx)
    for ip in range(nop):
        ixm=mt.floor(xp[ip]/dx); ixp=ixm+1
        wxp=xp[ip]/dx-ixm; wxm=1-wxp

        if ixp>nx-1: ixp=ixp-nx#; print(ixp)

        ds[ixm]=ds[ixm]+wxm
        ds[ixp]=ds[ixp]+wxp 
        
    return ds
In [6]:
%timeit acce_numba(1.0,10**5,1.0,2**8,1.0,256)
1 loop, best of 3: 401 ms per loop

計算時間のパラメータ依存性を調べる。
変えるパラメータは粒子数(nop)、空間グリッド(nx)、空間ステップ(nt)。

In [7]:
import time
npr=50
elapsed_time=np.zeros(npr)
#changing the number of particles
nop=np.linspace(10,10**6,npr).astype(int)
for iop in range(npr):
    start = time.time()
    acce_numba(1.0,nop[iop],1.0,2**8,1.0,2**8)
    elapsed_time[iop] = time.time() - start
    #print(iop,elapsed_time[iop])
In [8]:
plt.plot(nop,elapsed_time,'-o');plt.show()
In [9]:
import time
npr=50
elapsed_time=np.zeros(npr)
nt=np.linspace(10,10**4,npr).astype(int)
for ipr in range(npr):
    start = time.time()
    acce_numba(1.0,10**4,1.0,2**8,1.0,nt[ipr])
    elapsed_time[ipr] = time.time() - start
    #print(ipr,elapsed_time[ipr])
In [10]:
plt.plot(nt,elapsed_time,'-o');plt.show()
In [11]:
import time
npr=10
elapsed_time=np.zeros(npr)
nx=2**(np.arange(npr)+3)
#print(nx)
for ipr in range(npr):
    start = time.time()
    acce_numba(1.0,10**4,1.0,nx[ipr],1.0,1024)
    elapsed_time[ipr] = time.time() - start
    #print(ipr,elapsed_time[ipr])
In [12]:
plt.plot(nx,elapsed_time,'-o');plt.show()