Implementation

The way ModuleMixins is implemented, is that we start out with something relatively simple, and build out from that. This means there will be some redudant code. Macros are hard to engineer, this takes you through the entire process.

Prelude

#| file: src/ModuleMixins.jl
module ModuleMixins

using MacroTools: @capture, postwalk, prewalk

export @compose, @for_each

<<spec>>
<<mixin>>
<<struct-data>>
<<compose>>
<<for-each>>

end

To facilitate testing, we need to be able to compare syntax. We use the clean function to remove source information from expressions.

#| file: test/runtests.jl
using Test
using ModuleMixins:
    @spec,
    @spec_mixin,
    @spec_using,
    @mixin,
    Struct,
    parse_struct,
    define_struct,
    Pass,
    @compose,
    @for_each
using MacroTools: prewalk, rmlines

clean(expr) = prewalk(rmlines, expr)

<<test-toplevel>>

@testset "ModuleMixins" begin
    <<test>>
end

@spec

The @spec macro creates a new module, and stores its own AST inside that module.

#| id: test-toplevel
@spec module MySpec
const msg = "hello"
end
#| id: test
@testset "@spec" begin
    @test clean.(MySpec.AST) == clean.([:(const msg = "hello")])
    @test MySpec.msg == "hello"
end

The @spec macro is used to specify the structs of a model component.

#| id: spec
"""
    @spec module *name*
        *body*...
    end

Create a spec. The `@spec` macro itself doesn't perform any operations other than creating a module and storing its own AST as `const *name*.AST`.

This macro is only here for teaching purposes.
"""
macro spec(mod)
    @assert @capture(mod, module name_
    body__
    end)

    esc(Expr(:toplevel, :(module $name
    $(body...)
    const AST = $body
    end)))
end

@spec_mixin

We now add the @mixin syntax. This still doesn't do anything, other than storing the names of parent modules.

#| id: test-toplevel
@spec_mixin module MyMixinSpecOne
@mixin A
end
@spec_mixin module MyMixinSpecMany
@mixin A, B, C
end
#| id: test
@testset "@spec_mixin" begin
    @test MyMixinSpecOne.PARENTS == [:A]
    @test MyMixinSpecMany.PARENTS == [:A, :B, :C]
end

Here's the @mixin macro:

#| id: mixin
macro mixin(deps)
    if @capture(deps, (multiple_deps__,))
        esc(:(const PARENTS = [$(QuoteNode.(multiple_deps)...)]))
    else
        esc(:(const PARENTS = [$(QuoteNode(deps))]))
    end
end

The QuoteNode calls prevent the symbols from being evaluated at macro expansion time. We need to make sure that the @mixin syntax is also available from within the module.

#| id: spec

macro spec_mixin(mod)
    @assert @capture(mod, module name_
    body__
    end)

    esc(Expr(:toplevel, :(module $name
    import ..@mixin

    $(body...)

    const AST = $body
    end)))
end

@spec_using

I can't think of any usecase where a @mixin A, doesn't also mean using ..A. By replacing the @mixin with a using statement, we also no longer need to import @mixin. In fact, that macro becomes redundant. Also, in @spec_using we're allowed multiple @mixin statements.

#| id: test-toplevel
@spec_using module SU_A
const X = :hello
export X
end

@spec_using module SU_B
@mixin SU_A
const Y = X
end

@spec_using module SU_C
const Z = :goodbye
end

@spec_using module SU_D
@mixin SU_A
@mixin SU_B, SU_C
end
#| id: test
@testset "@spec_using" begin
    @test SU_B.Y == SU_A.X
    @test SU_B.PARENTS == [:SU_A]
    @test SU_D.PARENTS == [:SU_A, :SU_B, :SU_C]
    @test SU_D.SU_C.Z == :goodbye
end

We now use the postwalk function (from MacroTools.jl) to transform expressions and collect information into searchable data structures. We make a little abstraction over the postwalk function, so we can compose multiple transformations in a single tree walk.

#| id: test-toplevel
struct EmptyPass <: Pass
    tag::Symbol
end
#| id: test
@testset "pass composition" begin
    a = EmptyPass(:a) + EmptyPass(:b)
    @test a.parts[1].tag == :a
    @test a.parts[2].tag == :b
end

A composite pass tries all of its parts in order, returning the value of the first pass that doesn't return nothing.

#| id: spec

abstract type Pass end

function pass(x::Pass, expr)
    error("Can't call `pass` on abstract `Pass`.")
end

struct CompositePass <: Pass
    parts::Vector{Pass}
end

Base.:+(a::CompositePass...) = CompositePass(splat(vcat)(getfield.(a, :parts)))
Base.convert(::Type{CompositePass}, a::Pass) = CompositePass([a])
Base.:+(a::Pass...) = splat(+)(convert.(CompositePass, a))

function pass(cp::CompositePass, expr)
    for p in cp.parts
        result = pass(p, expr)
        if result !== :nomatch
            return result
        end
    end
    return :nomatch
end

function walk(x::Pass, expr_list)
    function patch(expr)
        result = pass(x, expr)
        result === :nomatch ? expr : result
    end
    prewalk.(patch, expr_list)
end
#| id: spec
@kwdef struct MixinPass <: Pass
    items::Vector{Symbol}
end

function pass(m::MixinPass, expr)
    @capture(expr, @mixin deps_) || return :nomatch

    if @capture(deps, (multiple_deps__,))
        append!(m.items, multiple_deps)
        :(
            begin
                $([:(using ..$d) for d in multiple_deps]...)
            end
        )
    else
        push!(m.items, deps)
        :(using ..$deps)
    end
end

macro spec_using(mod)
    @assert @capture(mod, module name_ body__ end)

    parents = MixinPass([])
    clean_body = walk(parents, body)

    esc(Expr(:toplevel, :(module $name
        $(clean_body...)
        const AST = $body
        const PARENTS = [$(QuoteNode.(parents.items)...)]
    end)))
end

Structure of structs

We'll convert struct syntax into collectable data, then convert that back into structs again. We'll support several patterns:

#| id: test
cases = Dict(
    :(struct A x end) => Struct(false, false, :A, nothing, nothing, [:x]),
    :(mutable struct A x end) => Struct(false, true, :A, nothing, nothing, [:x]),
    :(@kwdef struct A x end) => Struct(true, false, :A, nothing, nothing, [:x]),
    :(@kwdef mutable struct A x end) => Struct(true, true, :A, nothing, nothing, [:x]),
    :(struct A{T} x::T end) => Struct(false, false, :A, [:T], nothing, [:(x::T)]),
)

for (k, v) in pairs(cases)
    @testset "Struct mangling: $(join(split(string(clean(k))), " "))" begin
        @test clean(define_struct(parse_struct(k))) == clean(k)
        @test clean(define_struct(v)) == clean(k)
    end
end

Each of these can have either just a Symbol for a name, or a A <: B expression. This is a bit cumbersome, but we'll have to deal with all of these cases.

#| id: test
@testset "Struct mangling abstracts" begin
    @test parse_struct(:(struct A <: B x end)).abstract_type == :B
    @test parse_struct(:(mutable struct A <: B x end)).abstract_type == :B
end

@testset "Mangling type arguments" begin
    using ModuleMixins: mangle_type_parameters!
    let s = parse_struct(:(struct S{T} x::T end))
        @test s.type_parameters == [:T]
        mangle_type_parameters!(s, :A)
        @test s.type_parameters == [:_T_A]
        @test clean(define_struct(s)) == clean(:(struct S{_T_A} x::_T_A end))
    end
    let s = parse_struct(:(struct S{T} x::Vector{T} = [] end))
        mangle_type_parameters!(s, :A)
        @test clean(define_struct(s)) == clean(:(struct S{_T_A} x::Vector{_T_A} = [] end))
    end
end
#| id: struct-data

mutable struct Struct
    use_kwdef::Bool
    is_mutable::Bool
    name::Symbol
    type_parameters::Union{Vector{Symbol},Nothing}
    abstract_type::Union{Symbol,Nothing}
    fields::Vector{Union{Expr,Symbol}}
end

function mangle_type_parameters!(s::Struct, suffix::Symbol)
    s.type_parameters === nothing && return

    d = IdDict{Symbol, Symbol}(
        (k => Symbol("_$(k)_$(suffix)") for k in s.type_parameters)...)

    replace_type_par(expr) =
        postwalk(x -> x isa Symbol ? get(d, x, x) : x, expr)
        
    s.fields = replace_type_par.(s.fields)
    s.type_parameters = collect(values(d))
    return s
end

function mappend(a::Union{Vector{T}, Nothing}, b::Union{Vector{T}, Nothing}) where T
    isnothing(a) && return b
    isnothing(b) && return a
    return vcat(a, b)
end

function extend_struct!(s1::Struct, s2::Struct)
    append!(s1.fields, s2.fields)
    s1.type_parameters = mappend(s1.type_parameters, s2.type_parameters)
    return s1
end

function parse_struct(expr)
    @capture(expr, (@kwdef kw_struct_expr_) | struct_expr_)
    uses_kwdef = kw_struct_expr !== nothing
    struct_expr = uses_kwdef ? kw_struct_expr : struct_expr

    @capture(struct_expr,
        (struct name_ fields__ end) |
        (mutable struct mut_name_ fields__ end)) || return

    is_mutable = mut_name !== nothing
    sname = is_mutable ? mut_name : name
    @capture(sname, (pname_ <: abst_) | pname_)
    @capture(pname, (name_{pars__}) | name_)


    return Struct(uses_kwdef, is_mutable, name, pars, abst, fields)
end

function define_struct(s::Struct)
    name = s.type_parameters !== nothing ? :($(s.name){$(s.type_parameters...)}) : s.name
    name = s.abstract_type !== nothing ? :($(name) <: $(s.abstract_type)) : name
    sdef = if s.is_mutable
        :(mutable struct $name
            $(s.fields...)
        end)
    else
        :(struct $name
            $(s.fields...)
        end)
    end
    s.use_kwdef ? :(@kwdef $sdef) : sdef
end

@compose

Unfortunately now comes a big leap. We'll merge all struct definitions inside the body of a module definition with that of its parents. We must also make sure that a struct definition still compiles, so we have to take along using and const statements.

#| id: test-toplevel
module ComposeTest1
using ModuleMixins

@compose module A
    struct S
        a::Int
    end
end

@compose module B
    struct S{T}
        b::T
    end
end

@compose module AB
    @mixin A, B
end
end
#| id: test
@testset "compose struct members" begin
    @test ComposeTest1.AB.PARENTS == [:A, :B]
    @test fieldnames(ComposeTest1.AB.S) == (:a, :b)
end
@testset "compose hierarchy" begin
    @test ComposeTest1.AB.MIXIN_TREE == IdDict(:AB => [:A, :B], :A => [], :B => [])
    @test WriterABC.MIXIN_TREE == IdDict(
        :WriterABC => [:WriterB, :WriterC],
        :WriterC => [],
        :WriterB => [:WriterA],
        :WriterA => [])
end
@testset "composed struct has type parameter" begin
    @test ComposeTest1.AB.S{Float64}(1, 2).b isa Float64
end
#| id: compose

struct CollectUsingPass <: Pass
    items::Vector{Expr}
end

function pass(p::CollectUsingPass, expr)
    @capture(expr, using x__ | using mod__: x__) || return :nomatch
    push!(p.items, expr)
    return nothing
end

struct CollectConstPass <: Pass
    items::Vector{Expr}
end

function pass(p::CollectConstPass, expr)
    @capture(expr, const x_ = y_) || return :nomatch
    push!(p.items, expr)
    return nothing
end

struct CollectStructPass <: Pass
    items::IdDict{Symbol,Struct}
    name::Symbol
end

function pass(p::CollectStructPass, expr)
    s = parse_struct(expr)
    s === nothing && return :nomatch
    mangle_type_parameters!(s, p.name)
    if s.name in keys(p.items)
        extend_struct!(p.items[s.name], s)
    else
        p.items[s.name] = s
    end
    return nothing
end

"""
    @compose module Name
        [@mixin Parents, ...]
        ...
    end

Creates a new composable module `Name`. Structs inside this module are
merged with those of the same name in `Parents`.
"""
macro compose(mod)
    @assert @capture(mod, module name_ body__ end)

    mixins = Symbol[]
    mixin_tree = IdDict{Symbol, Vector{Symbol}}()
    parents = MixinPass([])
    usings = CollectUsingPass([])
    consts = CollectConstPass([])
    struct_items = IdDict{Symbol, Struct}()

    function mixin(expr; name=name)
        structs = CollectStructPass(struct_items, name)
        parents = MixinPass([])
        pass1 = walk(parents, expr)
        mixin_tree[name] = parents.items
        for p in parents.items
            p in mixins && continue
            push!(mixins, p)
            parent_expr = Core.eval(__module__, :($(p).AST))
            mixin(parent_expr; name=p)
        end
        walk(usings + consts + structs, pass1)
    end

    fields = CollectStructPass(IdDict{Symbol,Struct}(), name)
    walk(fields, body)

    clean_body = mixin(body)

    esc(Expr(:toplevel, :(module $name
        const AST = $body
        const PARENTS = [$(QuoteNode.(mixins)...)]
        const MIXIN_TREE = $(mixin_tree)
        const FIELDS = $(IdDict((n => v.fields for (n, v) in pairs(fields.items))...))
        $(usings.items...)
        $(consts.items...)
        $(define_struct.(values(struct_items))...)
        $(clean_body...)
    end)))
end

For-each

The @for_each macro is meant for situations where you want to call a certain member function for each module that has it defined. Our use case: we have several components that need to write different bits of information to an output file. Each component defines a write(io, data) method. In our composed model, we can now call:

@for_each(P->P.write(io, data), PARENTS)
#| id: test-toplevel
module Common
    export AbstractData
    abstract type AbstractData end
end

@compose module WriterA
    using ..Common

    @kwdef struct Data <: AbstractData
        a::Int
    end

    function write(io::IO, data::AbstractData)
        println(io, data.a)
    end
end

@compose module WriterB
    using ..Common
    @mixin WriterA

    @kwdef struct Data <: AbstractData
        b::Int
    end

    function write(io::IO, data::AbstractData)
        println(io, data.b)
    end
end

@compose module WriterC
end

@compose module WriterABC
    using ModuleMixins
    @mixin WriterB, WriterC

    function write(io::IO, data::AbstractData)
        @for_each(P->P.write(io, data), PARENTS)
    end
end
#| id: test
@testset "for-each" begin
    io = IOBuffer(write=true)
    data = WriterABC.Data(a = 42, b = 23)
    WriterABC.write(io, data)
    @test String(take!(io)) == "23\n42\n"
end
#| id: for-each
"""
    substitute_top_level(var, val, mod, expr)

Takes a syntax object `expr` and substitutes every occurence of 
module `var` for `val`, only if the resulting object is actually
present in module `mod`. The `mod` module should correspond with
a lookup of `val` in the caller's namespace.
"""
function substitute_top_level(var, val, mod, expr)
    postwalk(function (x)
        @capture(x, gen_.item_) || return x
        if gen === var
            if item in names(mod, all=true)
                return Expr(:., val, QuoteNode(item))
            else
                return Returns(nothing)
            end
        end
        return x
    end, expr)
end

"""
    @for_each(M -> M.method(), lst::Vector{Symbol})

Calls `method()` for each module in `lst` that actually implements
that method. Here `lst` should be a vector of symbols that are all
in the current module's namespace.
"""
macro for_each(_fun, _lst)
    @assert @capture(_fun, var_ -> expr_)

    function replace_call_parent(p)
        mod = Core.eval(__module__, p)
        substitute_top_level(var, p, mod, expr)
    end

    lst = Core.eval(__module__, _lst)
    esc(:(begin
        $((replace_call_parent(p) for p in lst)...)
    end))
end