Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Magic Dots #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Magic Dots #4

wants to merge 1 commit into from

Conversation

mcabbott
Copy link
Owner

@mcabbott mcabbott commented May 9, 2020

Sometimes it would be neat if the same function could allow for several dimensionalities. Perhaps this is written:

f(A, B) = @tullio C[i, j, ..] := A[i, k, ..] * B[k, j, ..]

which accepts f(Matrix, Matrix)::Matrix as usual, but also f(Array3, Array3)::Array3 with one more dimension. And ideally also f(Array4, Matrix)::Array4 with varying numbers of dimensions, obeying broadcasting rules.

This is an implementation which adds extra loops over one CartesianIndex{N}, using clever things from Base.Broadcast to work out the appropriate ranges. It adds a bit more complication than I pictured before starting!

First working version. Multi-threading is disabled, no idea whether gradients will work.

@mcabbott
Copy link
Owner Author

mcabbott commented May 13, 2020

Another approach would be, if there are any .., to write a generated function which calls the macro with appropriate numbers of entries. Replace A[i,j,..] with A[i,j,b1,b2] when ndims(A)==4, etc:

@generated function f_outer(A, B)
    DA, DB = ndims(A)-1, ndims(B)-1
    DC = max(DA, DB)
    bs = [Symbol(:b, i) for i in 1:DC]
    :(@tullio C[i,j,$(bs...)] := A[i,k,$(bs[1:DA]...)] * B[k,j,$(bs[1:DB]...)])
end
f_outer(rand(2,2), rand(2,2))
# ERROR: LoadError: eval cannot be used in a generated function

So I can't do this until I find a way to avoid eval. (Done, #5)

An even simpler approach would be to allow A[i,j,k] to work with a matrix, reshaping to have a trivial 3rd dimension & correcting afterwards. That's a step further from broadcasting as it will demand that all k indices have the same range.

@mcabbott
Copy link
Owner Author

Note to self about @generated:

struct Eval2{A,B}
    make::A
    act::B
end
(e::Eval2)(args...) = e.make(args..., e.act)

@generated function gen1(arr)
    N = ndims(arr)
    quote
        # This is a variant of what the macro expands to:
        f(x, act!) = act!(similar(x), x) # but f isn't a closure over g
        g(y, x) = y .= $N .* x
        Eval2(f,g)(arr)
    end
end

gen1(1:3) # ERROR: The function body AST defined by this @generated function is not pure ...

# It seems can't define a function in the generated body at all, as it defines a new type.

@generated function gen4(arr)
    N = ndims(arr)
    # However, you can define functions during generation:
    f(x, act!) = act!(similar(x), x) 
    g(y, x) = y .= N .* x
    quote
        Eval2($f, $g)(arr)
    end
end

gen4(1:3)
gen4(ones(2,3,1))

@mcabbott
Copy link
Owner Author

mcabbott commented Aug 24, 2020

Further note: that won't work, because f needs to @eval some expression which is constructed using N.

What might work is something more like this. Write the body of the function in one generated function, and call that outside the quote of a second:

julia> @generated function gen7(x, ::Val{N}) where N
           ind = [Symbol(:i_, i) for i in 1:N]
           quote
               @einsum y[$(ind[1:end-1]...)] := x[$(ind...)]^2 + $N
           end
       end;

julia> @generated function gen8(arr)
           g(x) = gen7(x, Val(ndims(arr)))
           quote
               $g(arr)
           end
       end;

julia> gen8(rand(2,3,1,1,10)) # 5 dims, sum over 1:10
2×3×1×1 Array{Float64,4}:
[:, :, 1, 1] =
 51.392   54.0073  53.4575
 52.5122  52.8876  52.9701

Of course right now @tullio defines quite a few functions, and they depend on each other, which I'm not sure is possible. But the simplest case, in-place operations, some @tullio_body could work like @einsum above, and produce g() which works like act!() function, and can be passed to threader.

The low-tech way, however, is just to allow trivial indices. For A[i,j,k] enforce ndims(A) <= 3 and since size(A,4)==1 the loop over k may become trivial.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant