Skip to content

Commit ff8ddd3

Browse files
committed
Add ufl.imag
1 parent 16c9e2b commit ff8ddd3

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

ffcx/codegeneration/expression_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def generate_partition(self, symbol, F, mode):
365365
vexpr = L.ufl_to_lnodes(v, *vops)
366366

367367
is_cond = isinstance(v, ufl.classes.Condition)
368-
is_real = isinstance(v, ufl.classes.Real)
368+
is_real = isinstance(v, (ufl.classes.Real, ufl.classes.Imag))
369369
if is_cond:
370370
dtype = L.DataType.BOOL
371371
elif is_real:

test/test_jit_forms.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import os
88
import sys
9+
import typing
910

1011
import basix.ufl
1112
import numpy as np
@@ -2166,6 +2167,7 @@ def test_multiple_integrands_same_quadrature(compile_args, dtype):
21662167
np.testing.assert_allclose(A_mixed, A_ref)
21672168

21682169

2170+
@pytest.mark.parametrize("operator", (ufl.real, ufl.imag))
21692171
@pytest.mark.parametrize(
21702172
"dtype",
21712173
[
@@ -2189,7 +2191,14 @@ def test_multiple_integrands_same_quadrature(compile_args, dtype):
21892191
),
21902192
],
21912193
)
2192-
def test_ufl_real(compile_args: list[str], dtype: npt.DTypeLike) -> None:
2194+
def test_ufl_complex_extraction(
2195+
compile_args: list[str],
2196+
dtype: npt.DTypeLike,
2197+
operator: typing.Callable[[typing.Any], typing.Any],
2198+
) -> None:
2199+
if "float" in dtype and operator == ufl.imag:
2200+
pytest.xfail("Cannot have imag in real form")
2201+
21932202
xdtype = dtype_to_scalar_dtype(dtype)
21942203
c_el = basix.ufl.element("Lagrange", "interval", 1, shape=(1,), dtype=xdtype)
21952204
mesh = ufl.Mesh(c_el)
@@ -2198,8 +2207,8 @@ def test_ufl_real(compile_args: list[str], dtype: npt.DTypeLike) -> None:
21982207
u = ufl.Coefficient(V)
21992208

22002209
dx = ufl.Measure("dx")
2201-
b = ufl.conditional(ufl.gt(ufl.real(u), 0), u, -u) * dx
2202-
val = 5 - 5j if "complex" in dtype else 5
2210+
b = ufl.conditional(ufl.gt(operator(u), 0), u, -u) * dx
2211+
val = 5 + 6j if "complex" in dtype else 5
22032212
w = np.array([val], dtype=dtype)
22042213
forms = [b]
22052214
compiled_forms, module, _code = ffcx.codegeneration.jit.compile_forms(

0 commit comments

Comments
 (0)