Skip to content

Commit

Permalink
formatted code
Browse files Browse the repository at this point in the history
  • Loading branch information
lmiq committed Apr 30, 2024
1 parent ff09995 commit 7c7b008
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 64 deletions.
54 changes: 27 additions & 27 deletions src/PeriodicSystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,19 @@ function PeriodicSystem(;
autoswap::Bool=true
)
# Set xpositions if positions was set
if (isnothing(positions) && isnothing(xpositions)) || (!isnothing(positions) && !isnothing(xpositions))
if (isnothing(positions) && isnothing(xpositions)) || (!isnothing(positions) && !isnothing(xpositions))
throw(ArgumentError("Either `positions` OR `xpositions` must be defined."))
end
xpositions = isnothing(positions) ? xpositions : positions
# Check for simple input argument errors
for input_array in (xpositions, ypositions)
isnothing(input_array) && break
if input_array isa AbstractMatrix
if input_array isa AbstractMatrix
dim = size(input_array, 1)
if !(dim in (2,3))
if !(dim in (2, 3))
throw(DimensionMismatch("Matrix of coordinates must have 2 or 3 rows, one for each dimension, got size: $(size(input_array))"))
end
input_array = reinterpret(reshape, SVector{dim, eltype(input_array)}, input_array)
input_array = reinterpret(reshape, SVector{dim,eltype(input_array)}, input_array)
end
DIM = if eltype(input_array) isa SVector
length(eltype(input_array))
Expand All @@ -198,7 +198,7 @@ function PeriodicSystem(;
_output_threaded = [copy_output(output) for _ in 1:CellListMap.nbatches(_cell_list)]
output = _reset_all_output!(output, _output_threaded)
sys = PeriodicSystem1{output_name}(xpositions, output, _box, _cell_list, _output_threaded, _aux, parallel)
# Two sets of positions
# Two sets of positions
else
_box = CellListMap.Box(unitcell, cutoff, lcell=lcell)
_cell_list = CellListMap.CellList(xpositions, ypositions, _box; parallel=parallel, nbatches=nbatches, autoswap=autoswap)
Expand Down Expand Up @@ -271,13 +271,13 @@ CellListMap.unitcelltype(sys::AbstractPeriodicSystem) = unitcelltype(sys._box)

# test the construction with pathologically few particles
for x in [
SVector{3,Float64}[],
Vector{Float64}[],
Matrix{Float64}(undef, 3, 0),
[rand(SVector{3,Float64})],
[rand(3)],
rand(3,1)
]
SVector{3,Float64}[],
Vector{Float64}[],
Matrix{Float64}(undef, 3, 0),
[rand(SVector{3,Float64})],
[rand(3)],
rand(3, 1)
]
_sys = PeriodicSystem(
positions=x,
cutoff=0.1,
Expand All @@ -304,42 +304,42 @@ CellListMap.unitcelltype(sys::AbstractPeriodicSystem) = unitcelltype(sys._box)
cutoff=0.1, unitcell=[1, 1, 1], output=0.0,
)
@test_throws DimensionMismatch PeriodicSystem(
positions=rand(1,100),
positions=rand(1, 100),
cutoff=0.1, unitcell=[1, 1, 1], output=0.0,
)
@test_throws DimensionMismatch PeriodicSystem(
xpositions=rand(1,100),
xpositions=rand(1, 100),
cutoff=0.1, unitcell=[1, 1, 1], output=0.0,
)
@test_throws DimensionMismatch PeriodicSystem(
xpositions=rand(2,100),
ypositions=rand(1,100),
xpositions=rand(2, 100),
ypositions=rand(1, 100),
cutoff=0.1, unitcell=[1, 1, 1], output=0.0,
)
@test_throws DimensionMismatch PeriodicSystem(
positions=rand(2,100),
positions=rand(2, 100),
cutoff=0.1, unitcell=[1, 1, 1], output=0.0,
)
@test_throws DimensionMismatch PeriodicSystem(
xpositions=rand(2,100),
xpositions=rand(2, 100),
cutoff=0.1, unitcell=[1, 1, 1], output=0.0,
)
@test_throws DimensionMismatch PeriodicSystem(
xpositions=rand(2,100),
ypositions=rand(2,100),
xpositions=rand(2, 100),
ypositions=rand(2, 100),
cutoff=0.1, unitcell=[1, 1, 1], output=0.0,
)
@test_throws DimensionMismatch PeriodicSystem(
positions=rand(3,100),
positions=rand(3, 100),
cutoff=0.1, unitcell=[1, 1], output=0.0,
)
@test_throws DimensionMismatch PeriodicSystem(
xpositions=rand(3,100),
xpositions=rand(3, 100),
cutoff=0.1, unitcell=[1, 1], output=0.0,
)
@test_throws DimensionMismatch PeriodicSystem(
xpositions=rand(3,100),
ypositions=rand(3,100),
xpositions=rand(3, 100),
ypositions=rand(3, 100),
cutoff=0.1, unitcell=[1, 1], output=0.0,
)

Expand Down Expand Up @@ -435,9 +435,9 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::PeriodicSystem2{OutputNa
show(IOContext(io, :indent => indent + 4), mime, sys._box)
println(io)
show(io_sub, mime, sys._cell_list)
println(io,"\n Parallelization auxiliary data set for: ")
println(io, "\n Parallelization auxiliary data set for: ")
show(io_sub, mime, sys._cell_list.target.nbatches)
print(io,"\n Type of output variable ($OutputName): $(typeof(sys.output))")
print(io, "\n Type of output variable ($OutputName): $(typeof(sys.output))")
end

#
Expand Down Expand Up @@ -922,7 +922,7 @@ end
@test a == 0

# Update with matrices
x = rand(3,500)
x = rand(3, 500)
sys = PeriodicSystem(xpositions=x, unitcell=[1.0, 1.0, 1.0], cutoff=0.1, output=0.0, parallel=false)
a = @ballocated PeriodicSystems.UpdatePeriodicSystem!($sys) samples = 1 evals = 1
@test a == 0
Expand Down
74 changes: 37 additions & 37 deletions test/BasicForPeriodicSystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
using StaticArrays
using CellListMap
using CellListMap.PeriodicSystems

N = 2000
x, y, sides, cutoff = CellListMap.pathological_coordinates(N)
mass = rand(N)

# Function to be evalulated for each pair: gravitational potential
function potential(i, j, d2, u, mass)
d = sqrt(d2)
u = u - 9.8 * mass[i] * mass[j] / d
return u
end

# Some simple disjoint set properties
system = PeriodicSystem(
xpositions=x,
Expand All @@ -30,7 +30,7 @@
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, u) -> potential(i, j, d2, u, mass), system) naive

# Use matrices as input coordinates
xmat = zeros(3, N);
xmat = zeros(3, N)
ymat = zeros(3, N)
for i in 1:N
xmat[:, i] .= x[i]
Expand All @@ -48,7 +48,7 @@
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, u) -> potential(i, j, d2, u, mass), system) naive
system.parallel = true
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, u) -> potential(i, j, d2, u, mass), system) naive

# Check different lcell
system = PeriodicSystem(
xpositions=x,
Expand All @@ -62,21 +62,21 @@
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, u) -> potential(i, j, d2, u, mass), system) naive
system.parallel = true
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, u) -> potential(i, j, d2, u, mass), system) naive

# Test updating of the data on disjoint sets works fine
for arrays in [
# static vectors, length(x) > length(y)
[rand(SVector{2,Float64}, 1000), rand(SVector{2,Float64}, 100)],
[rand(SVector{2,Float64}, 1000), rand(SVector{2,Float64}, 100)],
# static vectors, length(x) < length(y)
[rand(SVector{2,Float64}, 100), rand(SVector{2,Float64}, 1000)],
[rand(SVector{2,Float64}, 100), rand(SVector{2,Float64}, 1000)],
# standard vectors, length(x) > length(y)
[[rand(2) for _ in 1:1000], [rand(2) for _ in 1:100]], # with standard vectors
# standard vectors, length(x) < length(y)
[[rand(2) for _ in 1:100], [rand(2) for _ in 1:1000]], # with standard vectors
# matrices, length(x) > length(y)
[rand(2, 1000), rand(2, 100)],
[rand(2, 1000), rand(2, 100)],
# matrices, length(x) < length(y)
[rand(2,100), rand(2,1000)], # with standard vectors
[rand(2, 100), rand(2, 1000)], # with standard vectors
]
local x = arrays[1]
local y = arrays[2]
Expand All @@ -103,7 +103,7 @@
r = CellListMap.map_pairwise!((x, y, i, j, d2, r) -> r += d2, 0.0, box, cl)
system.xpositions .= x
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system)

if y isa AbstractVector
y = rand(SVector{2,Float64}, length(y) + 100)
resize!(system.ypositions, length(y))
Expand All @@ -114,7 +114,7 @@
r = CellListMap.map_pairwise!((x, y, i, j, d2, r) -> r += d2, 0.0, box, cl)
system.ypositions .= y
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system)

end

end
Expand All @@ -134,7 +134,7 @@ end
system = PeriodicSystem(xpositions=x, cutoff=0.1, output=0.0, unitcell=[1, 1, 1])
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system)
r = CellListMap.map_pairwise!((x, y, i, j, d2, r) -> r += sqrt(d2), 0.0, box, cl)
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += sqrt(d2), system; update_lists = false)
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += sqrt(d2), system; update_lists=false)

#
# two-set systems
Expand All @@ -157,15 +157,15 @@ end
cl = CellList(x, y, box)
r = CellListMap.map_pairwise!((x, y, i, j, d2, r) -> r += d2, 0.0, box, cl)
system.xpositions .= x
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system; update_lists = false)
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system; update_lists=false)

# increase x size
x = rand(SVector{3,Float64}, 200)
cl = CellList(x, y, box)
r = CellListMap.map_pairwise!((x, y, i, j, d2, r) -> r += d2, 0.0, box, cl)
resize!(system.xpositions, length(x))
system.xpositions .= x
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system; update_lists = false)
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system; update_lists=false)

#
# x is greater
Expand All @@ -184,15 +184,15 @@ end
cl = CellList(x, y, box)
r = CellListMap.map_pairwise!((x, y, i, j, d2, r) -> r += d2, 0.0, box, cl)
system.xpositions .= x
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system; update_lists = false)
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system; update_lists=false)

# increase x size
x = rand(SVector{3,Float64}, 1100)
cl = CellList(x, y, box)
r = CellListMap.map_pairwise!((x, y, i, j, d2, r) -> r += d2, 0.0, box, cl)
resize!(system.xpositions, length(x))
system.xpositions .= x
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system; update_lists = false)
@test r PeriodicSystems.map_pairwise!((x, y, i, j, d2, r) -> r += d2, system; update_lists=false)

end

Expand All @@ -203,30 +203,30 @@ end
using StaticArrays
using CellListMap
using CellListMap.PeriodicSystems

if Threads.nthreads() == 1
println("""
WARNING: Ideally, run a multi-threaded test to check the parallel versions.
""")
end

# Function to be evalulated for each pair: sum of displacements on x
f(x, y, avg_dx) = avg_dx + abs(x[1] - y[1])

x, y, sides, cutoff = CellListMap.pathological_coordinates(2000)
box = Box(sides, cutoff)
naive = CellListMap.map_naive!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), 0.0, x, box)

# Check if changing lcell breaks something
system = PeriodicSystem(xpositions=x, unitcell=sides, cutoff=cutoff, output=0.0, lcell=1)
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) naive
system = PeriodicSystem(xpositions=x, unitcell=sides, cutoff=cutoff, output=0.0, lcell=3)
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) naive
system = PeriodicSystem(xpositions=x, unitcell=sides, cutoff=cutoff, output=0.0, lcell=5)
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) naive

# Test if changing the number of batches breaks anything
system = PeriodicSystem(xpositions=x, unitcell=sides, cutoff=cutoff, output=0.0, nbatches=(3, 5))
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) naive
Expand All @@ -249,28 +249,28 @@ end
using StaticArrays
using CellListMap
using CellListMap.PeriodicSystems

N = 2000
x, y, sides, cutoff = CellListMap.pathological_coordinates(N)
box = Box(sides, cutoff)

# Initialize auxiliary linked lists
system = PeriodicSystem(
xpositions=x,
cutoff=cutoff,
unitcell=sides,
output=0.0,
)

# Function to be evalulated for each pair: sum of displacements on x
f(x, y, avg_dx) = avg_dx + abs(x[1] - y[1])

naive = CellListMap.map_naive!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), 0.0, x, box)
system.parallel = false
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) naive
system.parallel = true
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) naive

# Orthorhombic cell
new_x = copy(x) .+ [rand(SVector{3,Float64}) for _ in 1:N]
new_sides = sides + rand(SVector{3,Float64})
Expand All @@ -285,7 +285,7 @@ end
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) new_naive
system.parallel = true
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) new_naive

# If the number of particles and box change
new_x, new_box = CellListMap.xatomic(10^5)
new_cl = CellList(new_x, new_box)
Expand All @@ -298,7 +298,7 @@ end
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) new_val
system.parallel = true
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) new_val

#
# Triclinic cell
#
Expand All @@ -316,7 +316,7 @@ end
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) new_val
system.parallel = true
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, avg_dx) -> f(x, y, avg_dx), system) new_val

# If the number of particles and box change
cutoff = cutoff + rand()
new_x, new_box = CellListMap.xatomic(10^4)
Expand All @@ -341,12 +341,12 @@ end
using StaticArrays
using CellListMap
using CellListMap.PeriodicSystems

N = 2000
x, y, sides, cutoff = CellListMap.pathological_coordinates(N)
box = Box(sides, cutoff)
cl = CellList(x, box)

# Function to be evalulated for each pair: build distance histogram
function build_histogram!(d2, hist)
d = sqrt(d2)
Expand All @@ -357,7 +357,7 @@ end
naive = CellListMap.map_naive!((x, y, i, j, d2, hist) -> build_histogram!(d2, hist), zeros(Int, 10), x, box)
system = PeriodicSystem(xpositions=x, cutoff=cutoff, unitcell=sides, output=zeros(Int, 10))
@test naive == PeriodicSystems.map_pairwise!((x, y, i, j, d2, hist) -> build_histogram!(d2, hist), system)

# Function to be evalulated for each pair: gravitational potential
function potential(i, j, d2, u, mass)
d = sqrt(d2)
Expand All @@ -371,8 +371,8 @@ end

# Check the functionality of computing a different function from the same coordinates (new_coordinates=false)
naive = CellListMap.map_pairwise!((x, y, i, j, d2, u) -> u += d2, 0.0, box, cl)
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, u) -> u += d2, system; update_lists = false) naive
@test PeriodicSystems.map_pairwise!((x, y, i, j, d2, u) -> u += d2, system; update_lists=false) naive

# Function to be evalulated for each pair: gravitational force
function calc_forces!(x, y, i, j, d2, mass, forces)
G = 9.8 * mass[i] * mass[j] / d2
Expand Down

0 comments on commit 7c7b008

Please sign in to comment.