Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def local_kernels(self):
return tsfc_interface.compile_form(self._form, "form", diagonal=self.diagonal,
parameters=self._form_compiler_params)
elif isinstance(self._form, slate.TensorBase):
return slac.compile_expression(self._form, compiler_parameters=self._form_compiler_params)
return slac.compile_expression(self._form, compiler_parameters=self._form_compiler_params, diagonal=self.diagonal)
else:
raise AssertionError

Expand Down Expand Up @@ -872,7 +872,7 @@ def _as_global_kernel_arg_output(_, self):
if rank == 0:
return op2.GlobalKernelArg((1,))
elif rank == 1 or rank == 2 and self._diagonal:
V, = Vs
V = Vs[0]
if V.ufl_element().family() == "Real":
return op2.GlobalKernelArg((1,))
else:
Expand Down Expand Up @@ -1068,7 +1068,7 @@ def _as_parloop_arg_output(_, self):
if rank == 0:
return op2.GlobalParloopArg(tensor)
elif rank == 1 or rank == 2 and self._diagonal:
V, = Vs
V = Vs[0]
if V.ufl_element().family() == "Real":
return op2.GlobalParloopArg(tensor)
else:
Expand Down
19 changes: 12 additions & 7 deletions firedrake/slate/slac/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,24 @@

class SlateKernel(TSFCKernel):
@classmethod
def _cache_key(cls, expr, compiler_parameters, coffee):
def _cache_key(cls, expr, compiler_parameters, coffee, diagonal):
return md5((expr.expression_hash
+ str(sorted(compiler_parameters.items()))
+ str(coffee)).encode()).hexdigest(), expr.ufl_domains()[0].comm
+ str(coffee)
+ str(diagonal)).encode()).hexdigest(), expr.ufl_domains()[0].comm

def __init__(self, expr, compiler_parameters, coffee=False):
def __init__(self, expr, compiler_parameters, coffee=False, diagonal=False):
if self._initialized:
return
if coffee:
assert not diagonal, "Slate compiler cannot handle diagonal option in coffee mode. Use loopy backend instead."
self.split_kernel = generate_kernel(expr, compiler_parameters)
else:
self.split_kernel = generate_loopy_kernel(expr, compiler_parameters)
self.split_kernel = generate_loopy_kernel(expr, compiler_parameters, diagonal)
self._initialized = True


def compile_expression(slate_expr, compiler_parameters=None, coffee=False):
def compile_expression(slate_expr, compiler_parameters=None, coffee=False, diagonal=False):
"""Takes a Slate expression `slate_expr` and returns the appropriate
:class:`firedrake.op2.Kernel` object representing the Slate expression.

Expand Down Expand Up @@ -128,7 +130,7 @@ def compile_expression(slate_expr, compiler_parameters=None, coffee=False):
try:
return cache[key]
except KeyError:
kernel = SlateKernel(slate_expr, params, coffee).split_kernel
kernel = SlateKernel(slate_expr, params, coffee, diagonal).split_kernel
return cache.setdefault(key, kernel)


Expand All @@ -152,13 +154,16 @@ def get_temp_info(loopy_kernel):
return mem_total, num_temps, mems, shapes


def generate_loopy_kernel(slate_expr, compiler_parameters=None):
def generate_loopy_kernel(slate_expr, compiler_parameters=None, diagonal=False):
cpu_time = time.time()
if len(slate_expr.ufl_domains()) > 1:
raise NotImplementedError("Multiple domains not implemented.")

Citations().register("Gibson2018")

if diagonal:
slate_expr = slate.DiagonalTensor(slate_expr, vec=True)

orig_expr = slate_expr
# Optimise slate expr, e.g. push blocks as far inward as possible
if compiler_parameters["slate_compiler"]["optimise"]:
Expand Down
23 changes: 17 additions & 6 deletions firedrake/slate/slac/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,16 @@ def _push_block_transpose(expr, self, indices):

@_push_block.register(Add)
@_push_block.register(Negative)
@_push_block.register(DiagonalTensor)
@_push_block.register(Reciprocal)
def _push_block_distributive(expr, self, indices):
"""Distributes Blocks for these nodes"""
return type(expr)(*map(self, expr.children, repeat(indices))) if indices else expr
return type(expr)(*map(self, expr.children, repeat(indices)))


@_push_block.register(DiagonalTensor)
def _push_block_diag(expr, self, indices):
"""Distributes Blocks for these nodes"""
return type(expr)(*map(self, expr.children, repeat(indices)), expr.vec)


@_push_block.register(Factorization)
Expand Down Expand Up @@ -127,11 +132,12 @@ def push_diag(expression):
on terminal tensors whereever possible.
"""
mapper = MemoizerArg(_push_diag)
mapper.vec = False
return mapper(expression, False)


@singledispatch
def _push_diag(expr, self, diag):
def _push_diag(expr, self, diag, vec):
raise AssertionError("Cannot handle terminal type: %s" % type(expr))


Expand All @@ -151,14 +157,14 @@ def _push_diag_distributive(expr, self, diag):
def _push_diag_stop(expr, self, diag):
"""Diagonal Tensors cannot be pushed further into this set of nodes."""
expr = type(expr)(*map(self, expr.children, repeat(False))) if not expr.terminal else expr
return DiagonalTensor(expr) if diag else expr
return DiagonalTensor(expr, self.vec) if diag else expr


@_push_diag.register(Block)
def _push_diag_block(expr, self, diag):
"""Diagonal Tensors cannot be pushed further into this set of nodes."""
expr = type(expr)(*map(self, expr.children, repeat(False)), expr._indices) if not expr.terminal else expr
return DiagonalTensor(expr) if diag else expr
return DiagonalTensor(expr, self.vec) if diag else expr


@_push_diag.register(AssembledVector)
Expand All @@ -174,6 +180,7 @@ def _push_diag_vectors(expr, self, diag):
@_push_diag.register(DiagonalTensor)
def _push_diag_diag(expr, self, diag):
"""DiagonalTensors are either pushed down or ignored when wrapped into another DiagonalTensor."""
self.vec = expr.vec
return self(*expr.children, not diag)


Expand Down Expand Up @@ -247,13 +254,17 @@ def _drop_double_transpose_transpose(expr, self):
@_drop_double_transpose.register(Mul)
@_drop_double_transpose.register(Solve)
@_drop_double_transpose.register(Inverse)
@_drop_double_transpose.register(DiagonalTensor)
@_drop_double_transpose.register(Reciprocal)
def _drop_double_transpose_distributive(expr, self):
"""Distribute into the children of the expression. """
return type(expr)(*map(self, expr.children))


@_drop_double_transpose.register(DiagonalTensor)
def _drop_double_transpose_diag(expr, self):
return type(expr)(*map(self, expr.children), expr.vec)


@singledispatch
def _push_mul(expr, self, state):
raise AssertionError("Cannot handle terminal type: %s" % type(expr))
Expand Down
4 changes: 3 additions & 1 deletion firedrake/slate/slac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def _slate2gem_diagonal(expr, self):
A, = map(self, expr.children)
assert A.shape[0] == A.shape[1]
i, j = (Index(extent=s) for s in A.shape)
return ComponentTensor(Product(Indexed(A, (i, i)), Delta(i, j)), (i, j))
idx = (i, i) if expr.vec else (i, j)
jdx = (i,) if expr.vec else (i, j)
return ComponentTensor(Product(Indexed(A, (i, i)), Delta(*idx)), jdx)
else:
raise NotImplementedError("Diagonals on Slate expressions are \
not implemented in a matrix-free manner yet.")
Expand Down
17 changes: 13 additions & 4 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def expression_hash(self):
data = (type(op).__name__, op.decomposition, )
elif isinstance(op, Tensor):
data = (op.form.signature(), op.diagonal, )
elif isinstance(op, DiagonalTensor):
data = (type(op).__name__, op.vec, )
elif isinstance(op, (UnaryOp, BinaryOp)):
data = (type(op).__name__, )
else:
Expand Down Expand Up @@ -1268,33 +1270,40 @@ class DiagonalTensor(UnaryOp):
"""
diagonal = True

def __init__(self, A):
def __init__(self, A, vec=False):
"""Constructor for the Diagonal class."""
assert A.rank == 2, "The tensor must be rank 2."
assert A.rank == 2 or vec, "The tensor must be rank 2."
assert A.shape[0] == A.shape[1], (
"The diagonal can only be computed on square tensors."
)

super(DiagonalTensor, self).__init__(A)
self.vec = vec

@cached_property
def arg_function_spaces(self):
"""Returns a tuple of function spaces that the tensor
is defined on.
"""
tensor, = self.operands
return tuple(arg.function_space() for arg in tensor.arguments())
return (tuple(arg.function_space() for arg in [tensor.arguments()[0]])
if self.vec else tuple(arg.function_space() for arg in tensor.arguments()))

def arguments(self):
"""Returns a tuple of arguments associated with the tensor."""
tensor, = self.operands
return tensor.arguments()
return (tensor.arguments()[0],) if self.vec else tensor.arguments()

def _output_string(self, prec=None):
"""Creates a string representation of the diagonal of a tensor."""
tensor, = self.operands
return "(%s).diag" % tensor

@cached_property
def _key(self):
"""Returns a key for hash and equality."""
return ((type(self), *self.operands, self.vec))


def space_equivalence(A, B):
"""Checks that two function spaces are equivalent.
Expand Down
36 changes: 36 additions & 0 deletions tests/slate/test_slate_hybridization.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,39 @@ def test_mixed_poisson_approximated_schur_jacobi_prec():

assert sigma_err < 1e-8
assert u_err < 1e-8


def test_slate_hybridization_global_matfree_jacobi():
a, L, W = setup_poisson()

w = Function(W)
jacobi_matfree_params = {'mat_type': 'matfree',
'ksp_type': 'cg',
'pc_type': 'python',
'pc_python_type': 'firedrake.HybridizationPC',
'hybridization': {'ksp_type': 'cg',
'pc_type': 'jacobi',
'mat_type': 'matfree',
'ksp_rtol': 1e-8}}

eq = a == L
problem = LinearVariationalProblem(eq.lhs, eq.rhs, w)
solver = LinearVariationalSolver(problem, solver_parameters=jacobi_matfree_params)
solver.solve()
sigma_h, u_h = w.split()

w2 = Function(W)
solve(a == L, w2, solver_parameters={'ksp_type': 'preonly',
'pc_type': 'python',
'mat_type': 'matfree',
'pc_python_type': 'firedrake.HybridizationPC',
'hybridization': {'ksp_type': 'preonly',
'pc_type': 'lu'}})
nh_sigma, nh_u = w2.split()

# Return the L2 error
sigma_err = errornorm(sigma_h, nh_sigma)
u_err = errornorm(u_h, nh_u)

assert sigma_err < 1e-8
assert u_err < 1e-8