preloader
  • Home
  • Some simple Rules to save Computation Time in Python

A simple toy example demonstrating exemplary performance rules. In short: avoid expensive function/library calls and expensive cache-able computations in heavily looped operations!

blog-thumb
Image by Typhaine Therry from Pixabay

Introduction

Let’s discuss the following python code. It is a simple piece of code one could come up to solve a minor daily task. This script has been written to conduct a simple grid search over a 4D parameter space, it compute the values and stores parameter combinations of interest. It is not wrong to write it that way and it works, a classic throw away code. The story changes when you want to scan big parameter spaces. You will realize that your program will need a increasing time to run. What to do now? A simple solution would be to parallelize it. But! The first step before parallelization should always be a optimization of your program!


The lines 20-26 define the ranges and some arbitrary parameters for our computation. They also define a empty data structure to store some of our computation values. The following quadruple loop (starting from line 28) scans the parameter space (the integers from b=−200b = -200b=−200 to t=200t = 200t=200 resulting in (t−b)4=4096⋅108\left(t - b\right)^{4} = 4096 \cdot 10^{8}(t−b)4=4096⋅108 scanned values). Embedded into it is a computation function and a sub-sequential test to conditionally save the results of interest into the aforementioned data structure (lines 33-37). The computation routine in the lines 4 to 18 does some computations to bring the grid parameters into relation to each other. Then checks the computed values and conditionally returns or does further computation and returns those results later on.

 

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
import pandas as pd

def mismatch(m1, m2, n1, n2):
    x, y = (m1n1 + m2n2 + np.cos(np.pi/3)(m1n2 + m2n1)), \
                        (np.sin(np.pi/3)(m1n2 - m2n1))
    Na = n12 + n22 + n1n2
    Nb = m12 + m22 + m1m2
    A = np.array([[x, -y], [y, x]])
    detA = np.linalg.det(A)
    if detA <= 0.0 :
        return -1.0, -1.0, Na, Nb;
    rho = Nb/np.sqrt(detA)
    temp = (1/np.sqrt(detA))*A[0][0]
    if np.abs(temp) > 1.0 :
        return -1.0, -1.0, Na, Nb;
    theta = np.arccos(temp)
    return rho, theta, Na, Nb

la = 2.479
lb = 2.971
b = -200
t = 200

cn = ['theta', 'n1', 'n2', 'm1', 'm2', 'delta', 'NNi', 'NNiOOH']
df = pd.DataFrame(columns=cn)

for m1 in range(b, t):
    for m2 in range(b, t):
        for n1 in range(b, t):
            for n2 in range(b, t):
            rho, theta, Na, Nb = mismatch(m1, m2, n1, n2)
            if np.abs(rho - la/lb) < 1e-7:
                data = [np.rad2deg(theta), n1, n2, m1, m2, 
                        (np.abs(rho - la/lb)), Na, Nb]
                entry = pd.DataFrame([dict(zip(cn, data))], columns=cn)
                df = pd.concat([df, entry], ignore_index=True)
df

 

The simple Rules

We will adhere to a couple of simple rules to optimize the shown script.

  • Avoid re-computation of values if their computations is expensive.

    Usual expensive computational operations are trigonometric (e.g. sin(x), cos(x), acos(x), etc.), exponential exp(x), square root sqrt(x) and power x**a, divisions x/y and more.

  • Avoid expensive operations if you can use simpler counter parts.

    A example here is the power operation x**2 which sometimes results in the use of a expensive algorithm instead of just x*x.

  • Avoid the call too functions or libraries if the operation itself is very basic.

    The call of functions in scripting languages has a considerable cost which can easily outweigh the simple operations done by them (prominent example for simple functions is npm’s isEven2 and isOdd3).

Those rules hold value for this example and for Python or scripting languages. There are exceptions to those rules and differences to other languages.

 

Our Optimizations

So lets have a look at the example. Lines 5 and 6 contain trigonometric computations with static values.

4
5
6
def mismatch(m1, m2, n1, n2):
    x, y = (m1n1 + m2n2 + np.cos(np.pi/3)(m1n2 + m2n1)), 
                        (np.sin(np.pi/3)(m1n2 - m2n1))

We compute them by hand or pre-compute them if they have a unhandy representation.

4
5
6
7
sqrt3by2 = np.sqrt(3)*0.5

def mismatch(m1, m2, n1, n2):
    x, y = (m1n1 + m2n2 + 0.5*(m1n2 + m2n1)), (sqrt3by2*(m1n2 - m2n1))

The two subsequent lines contain the power of two operation

8
9
    Na = n12 + n22 + n1n2
    Nb = m12 + m22 + m1m2

which is easily replaceable with its cheaper multiplication counter part.

8
9
    Na = n1n1 + n2n2 + n1n2
    Nb = m1m1 + m2m2 + m1m2

At line 14 and 15 we divide by the square root of our determinant which are two avoidable operations (/ and sqrt())

14
15
    rho = Nb/np.sqrt(detA)
    temp = (1/np.sqrt(detA))*A[0][0]

replaceable by ‘remembering’ of that value (the same can be done with la/lb further down).

14
15
16
    bysqrtdetA = 1/np.sqrt(detA)
    rho = NbbysqrtdetA
    temp = bysqrtdetAA[0][0]

The last change will be the biggest one. Our function missmatch computes a determinant of a 2×2 matrix and calls a function for that. This is not necessary due to the simple nature of that computation.

 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def mismatch(m1, m2, n1, n2):
    x, y = (m1n1 + m2n2 + 0.5*(m1n2 + m2n1)), (sqrt3by2*(m1n2 - m2n1))
    Na = n1n1 + n2n2 + n1n2
    Nb = m1m1 + m2m2 + m1m2
    A = np.array([[x, -y], [y, x]])
    detA = np.linalg.det(A)
    if detA <= 0.0 :
        return -1.0, -1.0, Na, Nb;
    bysqrtdetA = 1/np.sqrt(detA)
    rho = NbbysqrtdetA
    temp = bysqrtdetAA[0][0]
    if np.abs(temp) > 1.0 :
        return -1.0, -1.0, Na, Nb;
    theta = np.arccos(temp)
    return rho, theta, Na, Nb

We replace the call to the det() function and the created helper-data-structure with the corresponding definition. This enables us to reuse a variable (the re-using of variables can bring possible dangers with it and should be generally avoided for code clarity).

 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def mismatch(m1, m2, n1, n2):
    x, y = (m1n1 + m2n2 + 0.5*(m1n2 + m2n1)), (sqrt3by2*(m1n2 - m2n1))
    Na = n1n1 + n2n2 + n1n2
    Nb = m1m1 + m2m2 + m1m2
    A = xx + yy
    if A <= 0.0 :
        return -1.0, -1.0, Na, Nb;
    A = 1/np.sqrt(A)
    rho = Nb*A
    A *= x
    if np.abs(A) > 1.0 :
        return -1.0, -1.0, Na, Nb;
    theta = np.arccos(A)
    return rho, theta, Na, Nb

 

Did it help?

Every optimization should be accompanied by time measurements to verify improvements or to undo bad changes. The measurements to the proposed changes are displayed in the following diagram. The displayed bars represent the average time needed for one data point in our 4D grid.

All of our changes had a positive effect on the runtime. The most significant jump in runtime was accomplished by circumventing of the matrix construction and the calculation of its determinant. Check the estimated run times to compute all 4004400^{4}4004 data points (hover over the respective bars). You will see that those small changes can have an effect of hours! The main conclusions you should take with you are:

  • It is fine to just code something up.

  • But you should clean it up as soon as you want to scale it (or that there is the possibility that someone wants to scale it). As those small numbers can hurt you and the planet in the long run.

  • And lastly, avoid library functions in scripting languages for very simple calculations. The overhead could be much costlier than the computation they do.

I included a Julia solution to show what is possible with other languages (and because of my poor Python skills). Not to unfairly compare but to demonstrate that there is much room for improvement. Tow Possible changes would be to avoid the pandas data structure or to use a different intermediate data structure like a linked list. Another change would be to compile the essential pieces to avoid the overhead of the arbitrary variable types and the scripting language in itself.


Here are the six respective sources to the measured versions. The Julia version uses a linked list for the data accumulation and some helper functions (possible call to run it: julia --threads 1 --eval 'include("tests/parameter_search.jl"); test_csv(100);).

 

original

 

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import pandas as pd

def mismatch(m1, m2, n1, n2):
    x, y = (m1*n1 + m2*n2 + np.cos(np.pi/3)*(m1*n2 + m2*n1)), \
                           (np.sin(np.pi/3)*(m1*n2 - m2*n1))
    Na = n1**2 + n2**2 + n1*n2
    Nb = m1**2 + m2**2 + m1*m2
    A = np.array([[x, -y], [y, x]])
    detA = np.linalg.det(A)
    if detA <= 0.0 :
        return -1.0, -1.0, Na, Nb;
    rho = Nb/np.sqrt(detA)
    temp = (1/np.sqrt(detA))*A[0][0]
    if np.abs(temp) > 1.0 :
        return -1.0, -1.0, Na, Nb;
    theta = np.arccos(temp)
    return rho, theta, Na, Nb

la = 2.479
lb = 2.971
b = -20
t = 20

cn = ['theta', 'n1', 'n2', 'm1', 'm2', 'delta', 'NNi', 'NNiOOH']
df = pd.DataFrame(columns=cn)

for m1 in range(b, t):
    for m2 in range(b, t):
        for n1 in range(b, t):
            for n2 in range(b, t):
                rho, theta, Na, Nb = mismatch(m1, m2, n1, n2)
                if np.abs(rho - la/lb) < 1e-7:
                    data = [np.rad2deg(theta), n1, n2, m1, m2, \
                            (np.abs(rho - la/lb)), Na, Nb]
                    entry = pd.DataFrame([dict(zip(cn, data))], columns=cn)
                    df = pd.concat([df, entry], ignore_index=True)

df 

 

remove trig. fun.

 

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import pandas as pd

sqrt3by2 = np.sqrt(3)*0.5

def mismatch(m1, m2, n1, n2):
    x, y = (m1*n1 + m2*n2 + 0.5*(m1*n2 + m2*n1)), (sqrt3by2*(m1*n2 - m2*n1))
    Na = n1**2 + n2**2 + n1*n2
    Nb = m1**2 + m2**2 + m1*m2
    A = np.array([[x, -y], [y, x]])
    detA = np.linalg.det(A)
    if detA <= 0.0 :
        return -1.0, -1.0, Na, Nb;
    rho = Nb/np.sqrt(detA)
    temp = (1/np.sqrt(detA))*A[0][0]
    if np.abs(temp) > 1.0 :
        return -1.0, -1.0, Na, Nb;
    theta = np.arccos(temp)
    return rho, theta, Na, Nb

la = 2.479
lb = 2.971
b = -20
t = 20

cn = ['theta', 'n1', 'n2', 'm1', 'm2', 'delta', 'NNi', 'NNiOOH']
df = pd.DataFrame(columns=cn)

for m1 in range(b, t):
    for m2 in range(b, t):
        for n1 in range(b, t):
            for n2 in range(b, t):
                rho, theta, Na, Nb = mismatch(m1, m2, n1, n2)
                if np.abs(rho - la/lb) < 1e-7:
                    data = [np.rad2deg(theta), n1, n2, m1, m2, \
                            (np.abs(rho - la/lb)), Na, Nb]
                    entry = pd.DataFrame([dict(zip(cn, data))], columns=cn)
                    df = pd.concat([df, entry], ignore_index=True)

df 

 

replace **2

 

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import pandas as pd

sqrt3by2 = np.sqrt(3)*0.5

def mismatch(m1, m2, n1, n2):
    x, y = (m1*n1 + m2*n2 + 0.5*(m1*n2 + m2*n1)), (sqrt3by2*(m1*n2 - m2*n1))
    Na = n1**2 + n2**2 + n1*n2
    Nb = m1**2 + m2**2 + m1*m2
    A = np.array([[x, -y], [y, x]])
    detA = np.linalg.det(A)
    if detA <= 0.0 :
        return -1.0, -1.0, Na, Nb;
    rho = Nb/np.sqrt(detA)
    temp = (1/np.sqrt(detA))*A[0][0]
    if np.abs(temp) > 1.0 :
        return -1.0, -1.0, Na, Nb;
    theta = np.arccos(temp)
    return rho, theta, Na, Nb

la = 2.479
lb = 2.971
b = -20
t = 20

cn = ['theta', 'n1', 'n2', 'm1', 'm2', 'delta', 'NNi', 'NNiOOH']
df = pd.DataFrame(columns=cn)

for m1 in range(b, t):
    for m2 in range(b, t):
        for n1 in range(b, t):
            for n2 in range(b, t):
                rho, theta, Na, Nb = mismatch(m1, m2, n1, n2)
                if np.abs(rho - la/lb) < 1e-7:
                    data = [np.rad2deg(theta), n1, n2, m1, m2, \
                            (np.abs(rho - la/lb)), Na, Nb]
                    entry = pd.DataFrame([dict(zip(cn, data))], columns=cn)
                    df = pd.concat([df, entry], ignore_index=True)

df 

 

reuse / and sqrt()

 

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import pandas as pd

sqrt3by2 = np.sqrt(3)*0.5

def mismatch(m1, m2, n1, n2):
    x, y = (m1*n1 + m2*n2 + 0.5*(m1*n2 + m2*n1)), (sqrt3by2*(m1*n2 - m2*n1))
    Na = n1**2 + n2**2 + n1*n2
    Nb = m1**2 + m2**2 + m1*m2
    A = np.array([[x, -y], [y, x]])
    detA = np.linalg.det(A)
    if detA <= 0.0 :
        return -1.0, -1.0, Na, Nb;
    rho = Nb/np.sqrt(detA)
    temp = (1/np.sqrt(detA))*A[0][0]
    if np.abs(temp) > 1.0 :
        return -1.0, -1.0, Na, Nb;
    theta = np.arccos(temp)
    return rho, theta, Na, Nb

la = 2.479
lb = 2.971
b = -20
t = 20

cn = ['theta', 'n1', 'n2', 'm1', 'm2', 'delta', 'NNi', 'NNiOOH']
df = pd.DataFrame(columns=cn)

for m1 in range(b, t):
    for m2 in range(b, t):
        for n1 in range(b, t):
            for n2 in range(b, t):
                rho, theta, Na, Nb = mismatch(m1, m2, n1, n2)
                if np.abs(rho - la/lb) < 1e-7:
                    data = [np.rad2deg(theta), n1, n2, m1, m2, \
                            (np.abs(rho - la/lb)), Na, Nb]
                    entry = pd.DataFrame([dict(zip(cn, data))], columns=cn)
                    df = pd.concat([df, entry], ignore_index=True)

df 

 

avoid A and det()

 

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import numpy as np
import pandas as pd

sqrt3by2 = np.sqrt(3)*0.5

def mismatch(m1, m2, n1, n2):
    x, y = (m1*n1 + m2*n2 + 0.5*(m1*n2 + m2*n1)), (sqrt3by2*(m1*n2 - m2*n1))
    Na = n1*n1 + n2*n2 + n1*n2
    Nb = m1*m1 + m2*m2 + m1*m2
    A = x*x + y*y
    if A <= 0.0 :
        return -1.0, -1.0, Na, Nb;
    A = 1/np.sqrt(A)
    rho = Nb*A
    A *= x
    if np.abs(A) > 1.0 :
        return -1.0, -1.0, Na, Nb;
    theta = np.arccos(A)
    return rho, theta, Na, Nb

la = 2.479
lb = 2.971
b = -20
t = 20

cn = ['theta', 'n1', 'n2', 'm1', 'm2', 'delta', 'NNi', 'NNiOOH']
df = pd.DataFrame(columns=cn)

labylb = la/lb

for m1 in range(b, t):
    for m2 in range(b, t):
        for n1 in range(b, t):
            for n2 in range(b, t):
                rho, theta, Na, Nb = mismatch(m1, m2, n1, n2)
                delta = np.abs(rho - labylb)
                if delta < 1e-7:
                    data = [np.rad2deg(theta), n1, n2, m1, m2, delta, Na, Nb]
                    entry = pd.DataFrame([dict(zip(cn, data))], columns=cn)
                    df = pd.concat([df, entry], ignore_index=True)

df 

 

Julia sequential

 

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
using Base.Threads;
using Printf;

abstract type AbstractList{T} end
abstract type AbstractNode{T} end

mutable struct LLNode{T} <: AbstractNode{T}
  val::T;
  next::Union{LLNode{T},Nothing};
end

mutable struct LList{T} <: AbstractList{T}
  head::Union{LLNode{T}, Nothing};
  N::Integer;
  function LList{T}() where T
    return new{T}(nothing, 0);
  end
end

function Base.push!(L::LList, value)
  L.head = LLNode(value, L.head);
  L.N += 1;
  return nothing;
end

function Base.length(L::LList)::Integer
  return L.N;
end

function toVector(L::LList{T}) where T
  n = length(L);
  vector = Vector{T}(undef, length(L));
  current_node = L.head;
  while n > 0
    vector[n] = current_node.val;
    n -= 1;
    current_node = current_node.next;
  end
  return vector;
end

struct Data{F<:AbstractFloat,I<:Integer}
theta::F;
n1::I; n2::I; m1::I; m2::I;
delta::F; NNi::I; NNiOOH::I;
end

function Base.show(io::IO, d::Data)::Nothing
  print(io, "θ=$(d.theta), " *
            "n₁=$(d.n1), n₂=$(d.n2), m₁=$(d.m1), m₂=$(d.m2), " *
            "Δ=$(d.delta), NNi=$(d.NNi), NNiOOH=$(d.NNiOOH)");
end

function mismatch(
    m1::I, m2::I, n1::I, n2::I
  ) where {I<:Integer}
  x = m1 * n1 + m2 * n2 + cos(π/3) * (m1 * n2 + m2 * n1);
  y = sin(π/3) * (m1 * n2 - m2 * n1);
  Na = n1 * n1 + n2 * n2 + n1 * n2;
  Nb = m1 * m1 + m2 * m2 + m1 * m2;
  A = x * x + y * y;
  if A <= 0.0
    return NaN, NaN, Na, Nb;
  end
  A = 1.0/sqrt(A);
  rho = Nb*A;
  A *= x;
  if abs(A) > 1
    return rho, NaN, Na, Nb;
  end
  return rho, acos(A), Na, Nb;
end

function check_append!(
    list::LList{Data{F,I}},
    m1::I, m2::I, n1::I, n2::I,
    labylb::F
  ) where {F<:AbstractFloat,I<:Integer}
  rho, theta, Na, Nb = mismatch(m1, m2, n1, n2);
  if !isnan(rho) && !isnan(theta)
    delta = abs(rho - labylb);
    if delta < 1e-7
      push!(list, Data(rad2deg(theta), n1, n2, m1, m2, delta, Na, Nb));
      return nothing;
    end
  end
  return nothing;
end

function test_sub(
    m1::I, r::UnitRange{I},
    labylb::F
  ) where {I<:Integer,F<:AbstractFloat}
  list = LList{Data{typeof(labylb),typeof(m1)}}();
  for m2 in r, n1 in r, n2 in r
    check_append!(list, m1, m2, n1, n2, labylb);
  end
  return toVector(list);
end

function test(nn::I) where {I<:Integer}
  print_status = false
  @debug print_status = true
  la = 2.479;
  lb = 2.971;
  labylb = la/lb;
  r = -nn:nn
  n = length(r);
  if print_status
    N = 1.0/n^4 * 100;
    nnn = n^3;
    s = Atomic{typeof(nn)}(0);
    d = Atomic{typeof(nn)}(0);
  end
  list = Vector{Vector{Data{typeof(labylb),typeof(nn)}}}(undef, n);
  @threads for i in 1:n
    list[i] = test_sub(r[i], r, labylb);
    if print_status
      atomic_add!(s, nnn);
      atomic_add!(d, length(list[i]));
      @printf("[%3i] m1=%6i searched (%6.2f%%), %8i found\n",
              threadid(), r[i], s[]*N, d[]);
    end
  end
  return reduce(vcat, list);
end

function test_csv(nn::I; filename::String="output.csv") where {I<:Integer}
  t = @elapsed list = test(nn);
  @info "Calculation done in $t seconds with $(nthreads()) threads.\n" *
        "combinations scanned: $(nn^4)\n" *
        "  combinations found: $(length(list))\n" *
        "output to: $filename";
  open(filename, "w") do f
    println(f, "theta,n1,n2,m1,m2,delta,NNi,NNiOOH");
    for d in list
      println(f, "$(d.theta)," *
                 "$(d.n1),$(d.n2),$(d.m1),$(d.m2)," *
                 "$(d.delta),$(d.NNi),$(d.NNiOOH)");
    end
  end
end

Additional Information