diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 29b2be61d4..65dccd9a1a 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -421,7 +421,7 @@ def local_kernels(self): return tsfc_interface.compile_form(self._form, "form", diagonal=self.diagonal, parameters=self._form_compiler_params) elif isinstance(self._form, slate.TensorBase): - return slac.compile_expression(self._form, compiler_parameters=self._form_compiler_params) + return slac.compile_expression(self._form, compiler_parameters=self._form_compiler_params, diagonal=self.diagonal) else: raise AssertionError @@ -872,7 +872,7 @@ def _as_global_kernel_arg_output(_, self): if rank == 0: return op2.GlobalKernelArg((1,)) elif rank == 1 or rank == 2 and self._diagonal: - V, = Vs + V = Vs[0] if V.ufl_element().family() == "Real": return op2.GlobalKernelArg((1,)) else: @@ -1068,7 +1068,7 @@ def _as_parloop_arg_output(_, self): if rank == 0: return op2.GlobalParloopArg(tensor) elif rank == 1 or rank == 2 and self._diagonal: - V, = Vs + V = Vs[0] if V.ufl_element().family() == "Real": return op2.GlobalParloopArg(tensor) else: diff --git a/firedrake/slate/slac/compiler.py b/firedrake/slate/slac/compiler.py index 5f2a8fafed..5825837e9c 100644 --- a/firedrake/slate/slac/compiler.py +++ b/firedrake/slate/slac/compiler.py @@ -83,22 +83,24 @@ class SlateKernel(TSFCKernel): @classmethod - def _cache_key(cls, expr, compiler_parameters, coffee): + def _cache_key(cls, expr, compiler_parameters, coffee, diagonal): return md5((expr.expression_hash + str(sorted(compiler_parameters.items())) - + str(coffee)).encode()).hexdigest(), expr.ufl_domains()[0].comm + + str(coffee) + + str(diagonal)).encode()).hexdigest(), expr.ufl_domains()[0].comm - def __init__(self, expr, compiler_parameters, coffee=False): + def __init__(self, expr, compiler_parameters, coffee=False, diagonal=False): if self._initialized: return if coffee: + assert not diagonal, "Slate compiler cannot handle diagonal option in coffee mode. Use loopy backend instead." self.split_kernel = generate_kernel(expr, compiler_parameters) else: - self.split_kernel = generate_loopy_kernel(expr, compiler_parameters) + self.split_kernel = generate_loopy_kernel(expr, compiler_parameters, diagonal) self._initialized = True -def compile_expression(slate_expr, compiler_parameters=None, coffee=False): +def compile_expression(slate_expr, compiler_parameters=None, coffee=False, diagonal=False): """Takes a Slate expression `slate_expr` and returns the appropriate :class:`firedrake.op2.Kernel` object representing the Slate expression. @@ -128,7 +130,7 @@ def compile_expression(slate_expr, compiler_parameters=None, coffee=False): try: return cache[key] except KeyError: - kernel = SlateKernel(slate_expr, params, coffee).split_kernel + kernel = SlateKernel(slate_expr, params, coffee, diagonal).split_kernel return cache.setdefault(key, kernel) @@ -152,13 +154,16 @@ def get_temp_info(loopy_kernel): return mem_total, num_temps, mems, shapes -def generate_loopy_kernel(slate_expr, compiler_parameters=None): +def generate_loopy_kernel(slate_expr, compiler_parameters=None, diagonal=False): cpu_time = time.time() if len(slate_expr.ufl_domains()) > 1: raise NotImplementedError("Multiple domains not implemented.") Citations().register("Gibson2018") + if diagonal: + slate_expr = slate.DiagonalTensor(slate_expr, vec=True) + orig_expr = slate_expr # Optimise slate expr, e.g. push blocks as far inward as possible if compiler_parameters["slate_compiler"]["optimise"]: diff --git a/firedrake/slate/slac/optimise.py b/firedrake/slate/slac/optimise.py index 3e3a74cae4..c80ae0a714 100644 --- a/firedrake/slate/slac/optimise.py +++ b/firedrake/slate/slac/optimise.py @@ -74,11 +74,16 @@ def _push_block_transpose(expr, self, indices): @_push_block.register(Add) @_push_block.register(Negative) -@_push_block.register(DiagonalTensor) @_push_block.register(Reciprocal) def _push_block_distributive(expr, self, indices): """Distributes Blocks for these nodes""" - return type(expr)(*map(self, expr.children, repeat(indices))) if indices else expr + return type(expr)(*map(self, expr.children, repeat(indices))) + + +@_push_block.register(DiagonalTensor) +def _push_block_diag(expr, self, indices): + """Distributes Blocks for these nodes""" + return type(expr)(*map(self, expr.children, repeat(indices)), expr.vec) @_push_block.register(Factorization) @@ -127,11 +132,12 @@ def push_diag(expression): on terminal tensors whereever possible. """ mapper = MemoizerArg(_push_diag) + mapper.vec = False return mapper(expression, False) @singledispatch -def _push_diag(expr, self, diag): +def _push_diag(expr, self, diag, vec): raise AssertionError("Cannot handle terminal type: %s" % type(expr)) @@ -151,14 +157,14 @@ def _push_diag_distributive(expr, self, diag): def _push_diag_stop(expr, self, diag): """Diagonal Tensors cannot be pushed further into this set of nodes.""" expr = type(expr)(*map(self, expr.children, repeat(False))) if not expr.terminal else expr - return DiagonalTensor(expr) if diag else expr + return DiagonalTensor(expr, self.vec) if diag else expr @_push_diag.register(Block) def _push_diag_block(expr, self, diag): """Diagonal Tensors cannot be pushed further into this set of nodes.""" expr = type(expr)(*map(self, expr.children, repeat(False)), expr._indices) if not expr.terminal else expr - return DiagonalTensor(expr) if diag else expr + return DiagonalTensor(expr, self.vec) if diag else expr @_push_diag.register(AssembledVector) @@ -174,6 +180,7 @@ def _push_diag_vectors(expr, self, diag): @_push_diag.register(DiagonalTensor) def _push_diag_diag(expr, self, diag): """DiagonalTensors are either pushed down or ignored when wrapped into another DiagonalTensor.""" + self.vec = expr.vec return self(*expr.children, not diag) @@ -247,13 +254,17 @@ def _drop_double_transpose_transpose(expr, self): @_drop_double_transpose.register(Mul) @_drop_double_transpose.register(Solve) @_drop_double_transpose.register(Inverse) -@_drop_double_transpose.register(DiagonalTensor) @_drop_double_transpose.register(Reciprocal) def _drop_double_transpose_distributive(expr, self): """Distribute into the children of the expression. """ return type(expr)(*map(self, expr.children)) +@_drop_double_transpose.register(DiagonalTensor) +def _drop_double_transpose_diag(expr, self): + return type(expr)(*map(self, expr.children), expr.vec) + + @singledispatch def _push_mul(expr, self, state): raise AssertionError("Cannot handle terminal type: %s" % type(expr)) diff --git a/firedrake/slate/slac/utils.py b/firedrake/slate/slac/utils.py index 47c54b2cce..4bf97ec0fd 100644 --- a/firedrake/slate/slac/utils.py +++ b/firedrake/slate/slac/utils.py @@ -193,7 +193,9 @@ def _slate2gem_diagonal(expr, self): A, = map(self, expr.children) assert A.shape[0] == A.shape[1] i, j = (Index(extent=s) for s in A.shape) - return ComponentTensor(Product(Indexed(A, (i, i)), Delta(i, j)), (i, j)) + idx = (i, i) if expr.vec else (i, j) + jdx = (i,) if expr.vec else (i, j) + return ComponentTensor(Product(Indexed(A, (i, i)), Delta(*idx)), jdx) else: raise NotImplementedError("Diagonals on Slate expressions are \ not implemented in a matrix-free manner yet.") diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index 3fd9b85ec9..e883b5e4c7 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -158,6 +158,8 @@ def expression_hash(self): data = (type(op).__name__, op.decomposition, ) elif isinstance(op, Tensor): data = (op.form.signature(), op.diagonal, ) + elif isinstance(op, DiagonalTensor): + data = (type(op).__name__, op.vec, ) elif isinstance(op, (UnaryOp, BinaryOp)): data = (type(op).__name__, ) else: @@ -1268,14 +1270,15 @@ class DiagonalTensor(UnaryOp): """ diagonal = True - def __init__(self, A): + def __init__(self, A, vec=False): """Constructor for the Diagonal class.""" - assert A.rank == 2, "The tensor must be rank 2." + assert A.rank == 2 or vec, "The tensor must be rank 2." assert A.shape[0] == A.shape[1], ( "The diagonal can only be computed on square tensors." ) super(DiagonalTensor, self).__init__(A) + self.vec = vec @cached_property def arg_function_spaces(self): @@ -1283,18 +1286,24 @@ def arg_function_spaces(self): is defined on. """ tensor, = self.operands - return tuple(arg.function_space() for arg in tensor.arguments()) + return (tuple(arg.function_space() for arg in [tensor.arguments()[0]]) + if self.vec else tuple(arg.function_space() for arg in tensor.arguments())) def arguments(self): """Returns a tuple of arguments associated with the tensor.""" tensor, = self.operands - return tensor.arguments() + return (tensor.arguments()[0],) if self.vec else tensor.arguments() def _output_string(self, prec=None): """Creates a string representation of the diagonal of a tensor.""" tensor, = self.operands return "(%s).diag" % tensor + @cached_property + def _key(self): + """Returns a key for hash and equality.""" + return ((type(self), *self.operands, self.vec)) + def space_equivalence(A, B): """Checks that two function spaces are equivalent. diff --git a/tests/slate/test_slate_hybridization.py b/tests/slate/test_slate_hybridization.py index 6f14c8145d..49d90d9cca 100644 --- a/tests/slate/test_slate_hybridization.py +++ b/tests/slate/test_slate_hybridization.py @@ -505,3 +505,39 @@ def test_mixed_poisson_approximated_schur_jacobi_prec(): assert sigma_err < 1e-8 assert u_err < 1e-8 + + +def test_slate_hybridization_global_matfree_jacobi(): + a, L, W = setup_poisson() + + w = Function(W) + jacobi_matfree_params = {'mat_type': 'matfree', + 'ksp_type': 'cg', + 'pc_type': 'python', + 'pc_python_type': 'firedrake.HybridizationPC', + 'hybridization': {'ksp_type': 'cg', + 'pc_type': 'jacobi', + 'mat_type': 'matfree', + 'ksp_rtol': 1e-8}} + + eq = a == L + problem = LinearVariationalProblem(eq.lhs, eq.rhs, w) + solver = LinearVariationalSolver(problem, solver_parameters=jacobi_matfree_params) + solver.solve() + sigma_h, u_h = w.split() + + w2 = Function(W) + solve(a == L, w2, solver_parameters={'ksp_type': 'preonly', + 'pc_type': 'python', + 'mat_type': 'matfree', + 'pc_python_type': 'firedrake.HybridizationPC', + 'hybridization': {'ksp_type': 'preonly', + 'pc_type': 'lu'}}) + nh_sigma, nh_u = w2.split() + + # Return the L2 error + sigma_err = errornorm(sigma_h, nh_sigma) + u_err = errornorm(u_h, nh_u) + + assert sigma_err < 1e-8 + assert u_err < 1e-8