Skip to content

Commit 8c76721

Browse files
asraacopybara-github
authored andcommitted
add openfhe ckks e2e convolution test
PiperOrigin-RevId: 827608103
1 parent 4a7fcd3 commit 8c76721

File tree

13 files changed

+282
-26
lines changed

13 files changed

+282
-26
lines changed

lib/Pipelines/ArithmeticPipelineRegistration.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "lib/Transforms/ElementwiseToAffine/ElementwiseToAffine.h"
3636
#include "lib/Transforms/FoldConstantTensors/FoldConstantTensors.h"
3737
#include "lib/Transforms/FoldPlaintextMasks/FoldPlaintextMasks.h"
38+
#include "lib/Transforms/ForwardInsertSliceToExtractSlice/ForwardInsertSliceToExtractSlice.h"
39+
#include "lib/Transforms/ForwardInsertToExtract/ForwardInsertToExtract.h"
3840
#include "lib/Transforms/FullLoopUnroll/FullLoopUnroll.h"
3941
#include "lib/Transforms/GenerateParam/GenerateParam.h"
4042
#include "lib/Transforms/InlineActivations/InlineActivations.h"
@@ -113,6 +115,7 @@ void lowerAssignLayout(OpPassManager& pm, bool unroll = false) {
113115
pm.addNestedPass<func::FuncOp>(affine::createAffineExpandIndexOpsPass());
114116
pm.addNestedPass<func::FuncOp>(affine::createSimplifyAffineStructuresPass());
115117
pm.addNestedPass<func::FuncOp>(affine::createAffineLoopNormalizePass(true));
118+
pm.addNestedPass<func::FuncOp>(createForwardInsertSliceToExtractSlice());
116119

117120
// The lowered assign_layout ops involve plaintext operations that are still
118121
// inside secret.generic, and are not handled well by downstream noise models
@@ -169,6 +172,7 @@ void mlirToSecretArithmeticPipelineBuilder(
169172
pm.addPass(
170173
createConvertToCiphertextSemantics(convertToCiphertextSemanticsOptions));
171174

175+
pm.addPass(createApplyFolders());
172176
pm.addPass(createCanonicalizerPass());
173177
pm.addPass(tensor_ext::createImplementRotateAndReduce());
174178

@@ -188,6 +192,7 @@ void mlirToSecretArithmeticPipelineBuilder(
188192
pm.addPass(createAddClientInterface(addClientInterfaceOptions));
189193

190194
// Clean up after lowering assign_layout and various related packing code
195+
pm.addPass(createApplyFolders());
191196
pm.addPass(createFoldConstantTensors());
192197
pm.addPass(createCanonicalizerPass());
193198
pm.addPass(createCSEPass());
@@ -385,6 +390,10 @@ void mlirToRLWEPipeline(OpPassManager& pm,
385390
exit(EXIT_FAILURE);
386391
}
387392

393+
pm.addPass(createForwardInsertToExtract());
394+
pm.addPass(createCanonicalizerPass());
395+
pm.addPass(createCSEPass());
396+
388397
ElementwiseToAffineOptions elementwiseOptions;
389398
elementwiseOptions.convertDialects = {"ckks", "bgv", "lwe"};
390399
pm.addPass(createElementwiseToAffine(elementwiseOptions));

lib/Pipelines/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ cc_library(
134134
"@heir//lib/Transforms/ElementwiseToAffine",
135135
"@heir//lib/Transforms/FoldConstantTensors",
136136
"@heir//lib/Transforms/FoldPlaintextMasks",
137+
"@heir//lib/Transforms/ForwardInsertSliceToExtractSlice",
138+
"@heir//lib/Transforms/ForwardInsertToExtract",
137139
"@heir//lib/Transforms/FullLoopUnroll",
138140
"@heir//lib/Transforms/GenerateParam",
139141
"@heir//lib/Transforms/InlineActivations",

lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ LogicalResult OpenFhePkeEmitter::translate(Operation& op) {
237237
// Arith ops
238238
.Case<arith::ConstantOp, arith::ExtSIOp, arith::ExtUIOp,
239239
arith::IndexCastOp, arith::ExtFOp, arith::RemSIOp,
240-
arith::AddIOp, arith::AndIOp, arith::SubIOp, arith::MulIOp,
241-
arith::DivSIOp, arith::CmpIOp, arith::SelectOp>(
240+
arith::AddIOp, arith::AddFOp, arith::AndIOp, arith::SubIOp,
241+
arith::MulIOp, arith::DivSIOp, arith::CmpIOp, arith::SelectOp>(
242242
[&](auto op) { return printOperation(op); })
243243
// SCF ops
244244
.Case<scf::IfOp, scf::ForOp, scf::YieldOp>(
@@ -248,7 +248,7 @@ LogicalResult OpenFhePkeEmitter::translate(Operation& op) {
248248
tensor::InsertSliceOp, tensor::ExtractOp,
249249
tensor::ExtractSliceOp, tensor::SplatOp,
250250
tensor::CollapseShapeOp, tensor::ExpandShapeOp,
251-
tensor::FromElementsOp>(
251+
tensor::FromElementsOp, tensor::ConcatOp>(
252252
[&](auto op) { return printOperation(op); })
253253
// LWE ops
254254
.Case<lwe::RLWEDecodeOp, lwe::ReinterpretApplicationDataOp>(
@@ -940,6 +940,10 @@ LogicalResult OpenFhePkeEmitter::printOperation(arith::AddIOp op) {
940940
return printBinaryOp(op, op.getLhs(), op.getRhs(), "+");
941941
}
942942

943+
LogicalResult OpenFhePkeEmitter::printOperation(arith::AddFOp op) {
944+
return printBinaryOp(op, op.getLhs(), op.getRhs(), "+");
945+
}
946+
943947
LogicalResult OpenFhePkeEmitter::printOperation(arith::AndIOp op) {
944948
return printBinaryOp(op, op.getLhs(), op.getRhs(), "&&");
945949
}
@@ -993,21 +997,20 @@ LogicalResult OpenFhePkeEmitter::printOperation(arith::CmpIOp op) {
993997
}
994998

995999
LogicalResult OpenFhePkeEmitter::printOperation(tensor::ConcatOp op) {
996-
// concat dim(0) %foo, %foo, ...
997-
// lower to a loop
998-
auto operandType = cast<RankedTensorType>(op.getOperands()[0].getType());
1000+
// concat dim(0) %value1, %value2, ...
9991001
auto resultType = op.getResult().getType();
10001002
std::string varName = variableNames->getNameForValue(op.getResult());
1001-
if (resultType.getRank() != 1 || operandType.getRank() != 1) {
1002-
return failure();
1003-
}
1003+
10041004
// std::vector<8192> result;
10051005
if (failed(emitType(resultType, op->getLoc()))) {
10061006
return failure();
10071007
}
10081008
os << " " << varName << ";\n";
10091009

1010+
// If all the operands are the same, we can just repeat the operand
1011+
// insertion in a loop to minimize code size.
10101012
if (llvm::all_equal(op.getOperands())) {
1013+
auto operandType = cast<RankedTensorType>(op.getOperands()[0].getType());
10111014
std::string operandName =
10121015
variableNames->getNameForValue(op.getOperands()[0]);
10131016
int64_t numRepeats =
@@ -1022,12 +1025,23 @@ LogicalResult OpenFhePkeEmitter::printOperation(tensor::ConcatOp op) {
10221025

10231026
os.unindent();
10241027
os << "}\n";
1028+
}
1029+
1030+
// If we are concatenating on dimension 0, insert the operands
1031+
// one by one into the result vector.
1032+
if (op.getDim() == 0) {
1033+
for (auto operand : op.getOperands()) {
1034+
// result.insert(result.end(), foo.begin(), foo.end());
1035+
std::string operandName = variableNames->getNameForValue(operand);
1036+
os << varName << ".insert(" << varName << ".end(), " << operandName
1037+
<< ".begin(), " << operandName << ".end());\n";
1038+
}
10251039
return success();
10261040
}
10271041

10281042
// More complicated concat ops are not supported yet. The earlier lowerings
1029-
// should just produce concat for lack of a "repeat" op. Maybe we should make
1030-
// a tensor_ext.repeat op?
1043+
// should just produce concat for lack of a "repeat" op. Maybe we should
1044+
// make a tensor_ext.repeat op?
10311045
return failure();
10321046
}
10331047

@@ -1065,8 +1079,8 @@ LogicalResult OpenFhePkeEmitter::printOperation(tensor::ExtractOp op) {
10651079

10661080
LogicalResult OpenFhePkeEmitter::printOperation(
10671081
::mlir::tensor::CollapseShapeOp op) {
1068-
// A rank-reduced type will have the same number of elements so collapsing is
1069-
// a no-op on a flattened tensor.
1082+
// A rank-reduced type will have the same number of elements so collapsing
1083+
// is a no-op on a flattened tensor.
10701084
SliceVerificationResult res =
10711085
isRankReducedType(op.getSrcType(), op.getResultType());
10721086
if (res != SliceVerificationResult::Success) {
@@ -1329,8 +1343,9 @@ LogicalResult OpenFhePkeEmitter::printOperation(
13291343
if (failed(resultCC)) return resultCC;
13301344
std::string cc = variableNames->getNameForValue(resultCC.value());
13311345

1332-
// In certain conditions, we might end up with the input being tensor<..xi64>
1333-
// which isn't a valid input type for MakeCKKSPackedPlaintext, so we convert
1346+
// In certain conditions, we might end up with the input being
1347+
// tensor<..xi64> which isn't a valid input type for
1348+
// MakeCKKSPackedPlaintext, so we convert
13341349
if (getElementTypeOrSelf(op.getValue().getType()).isInteger()) {
13351350
// This means we will have created a std::vector<int64_t>
13361351
// but we need a std::vector<double>
@@ -1393,10 +1408,10 @@ FailureOr<std::pair<unsigned, int64_t>> getNonUnitDimension(
13931408
}
13941409

13951410
LogicalResult OpenFhePkeEmitter::printOperation(lwe::RLWEDecodeOp op) {
1396-
// In OpenFHE a plaintext is already decoded by decrypt. The internal OpenFHE
1397-
// implementation is simple enough (and dependent on currently-hard-coded
1398-
// encoding choices) that we will eventually need to work at a lower level of
1399-
// the API to support this operation properly.
1411+
// In OpenFHE a plaintext is already decoded by decrypt. The internal
1412+
// OpenFHE implementation is simple enough (and dependent on
1413+
// currently-hard-coded encoding choices) that we will eventually need to
1414+
// work at a lower level of the API to support this operation properly.
14001415
bool isCKKS = llvm::isa<lwe::InverseCanonicalEncodingAttr>(op.getEncoding());
14011416
auto tensorTy = dyn_cast<RankedTensorType>(op.getResult().getType());
14021417
if (tensorTy) {

lib/Target/OpenFhePke/OpenFhePkeEmitter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class OpenFhePkeEmitter {
9696
LogicalResult printOperation(::mlir::affine::AffineForOp op);
9797
LogicalResult printOperation(::mlir::affine::AffineYieldOp op);
9898
LogicalResult printOperation(::mlir::arith::AddIOp op);
99+
LogicalResult printOperation(::mlir::arith::AddFOp op);
99100
LogicalResult printOperation(::mlir::arith::AndIOp op);
100101
LogicalResult printOperation(::mlir::arith::CmpIOp op);
101102
LogicalResult printOperation(::mlir::arith::ConstantOp op);

lib/Transforms/ApplyFolders/ApplyFolders.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,27 @@
33
#include <utility>
44

55
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
6-
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
7-
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
6+
#include "mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h" // from @llvm-project
7+
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
8+
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
89
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
910

11+
// required for generated patterns
12+
#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project
13+
1014
namespace mlir {
1115
namespace heir {
1216

1317
#define GEN_PASS_DEF_APPLYFOLDERS
1418
#include "lib/Transforms/ApplyFolders/ApplyFolders.h.inc"
1519

20+
namespace {
21+
22+
// keep in anonymous namespace
23+
#include "lib/Transforms/ApplyFolders/Patterns.cpp.inc"
24+
25+
} // namespace
26+
1627
struct ApplyFolders : impl::ApplyFoldersBase<ApplyFolders> {
1728
using ApplyFoldersBase::ApplyFoldersBase;
1829

@@ -22,6 +33,12 @@ struct ApplyFolders : impl::ApplyFoldersBase<ApplyFolders> {
2233
tensor::ControlConstantExtractSliceFusionFn controlFn =
2334
[](tensor::ExtractSliceOp op) { return true; };
2435
tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn);
36+
tensor::populateFoldTensorSubsetOpPatterns(patterns);
37+
tensor::populateDecomposeTensorConcatPatterns(patterns);
38+
tensor::populateFoldTensorEmptyPatterns(patterns);
39+
tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
40+
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
41+
populateWithGenerated(patterns);
2542
// Use the greedy pattern driver to apply folders.
2643
// TODO (#1221): Investigate whether folding (default: on) can be skipped
2744
// here.

lib/Transforms/ApplyFolders/BUILD

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms")
2+
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
23
load("@rules_cc//cc:cc_library.bzl", "cc_library")
34

45
package(
@@ -14,15 +15,29 @@ cc_library(
1415
],
1516
deps = [
1617
":pass_inc_gen",
18+
":patterns_inc_gen",
1719
"@llvm-project//mlir:IR",
1820
"@llvm-project//mlir:Pass",
21+
"@llvm-project//mlir:Support",
1922
"@llvm-project//mlir:TensorDialect",
23+
"@llvm-project//mlir:TensorTransforms",
2024
"@llvm-project//mlir:TransformUtils",
21-
"@llvm-project//mlir:Transforms",
2225
],
2326
)
2427

2528
add_heir_transforms(
2629
generated_target_name = "pass_inc_gen",
2730
pass_name = "ApplyFolders",
2831
)
32+
33+
gentbl_cc_library(
34+
name = "patterns_inc_gen",
35+
tbl_outs = {"Patterns.cpp.inc": ["-gen-rewriters"]},
36+
tblgen = "@llvm-project//mlir:mlir-tblgen",
37+
td_file = "Patterns.td",
38+
deps = [
39+
"@heir//lib/Utils/DRR",
40+
"@llvm-project//mlir:ArithOpsTdFiles",
41+
"@llvm-project//mlir:TensorOpsTdFiles",
42+
],
43+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef LIB_TRANSFORMS_APPLYFOLDERS_PATTERNS_TD_
2+
#define LIB_TRANSFORMS_APPLYFOLDERS_PATTERNS_TD_
3+
4+
include "mlir/Dialect/Arith/IR/ArithOps.td"
5+
include "mlir/Dialect/Tensor/IR/TensorOps.td"
6+
include "mlir/IR/PatternBase.td"
7+
include "lib/Utils/DRR/Utils.td"
8+
9+
def AnyZero : AttrConstraint<
10+
CPred<"::mlir::matchPattern($_self, m_AnyZeroFloat())">,
11+
"is int or float zero">;
12+
13+
// a + 0.0 = a
14+
def AddFloatingPointZero : Pat<
15+
(Arith_AddFOp $a,
16+
(ConstantLikeMatcher AnyZero:$value), $anyAttr),
17+
(replaceWithValue $a)>;
18+
19+
// add(empty, a) = a
20+
def AddEmptyTensor : Pat<
21+
(Arith_AddFOp $a,
22+
(Tensor_EmptyOp $b), $anyAttr),
23+
(replaceWithValue $a)>;
24+
25+
#endif // LIB_TRANSFORMS_APPLYFOLDERS_PATTERNS_TD_

lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ struct ConvertLinalgMatvecLayout
540540
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
541541
IRMaterializingVisitor visitor(
542542
b, input.getType(),
543-
[&](Operation* createdOp) { setMaterializedAttr(op); });
543+
[&](Operation* createdOp) { setMaterializedAttr(createdOp); });
544544
Value finalOutput = implementedKernel->visit(visitor);
545545

546546
auto layoutAttr = cast<LayoutAttr>(op->getAttr(kLayoutAttrName));
@@ -656,7 +656,7 @@ struct ConvertLinalgConv2D
656656
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
657657
IRMaterializingVisitor visitor(
658658
b, data.getType(),
659-
[&](Operation* createdOp) { setMaterializedAttr(op); });
659+
[&](Operation* createdOp) { setMaterializedAttr(createdOp); });
660660
Value finalOutput = implementedKernel->visit(visitor);
661661

662662
auto layoutAttr = cast<LayoutAttr>(op->getAttr(kLayoutAttrName));
@@ -1834,6 +1834,13 @@ struct ConvertToCiphertextSemantics
18341834
return signalPassFailure();
18351835
}
18361836

1837+
// Walk the IR to validate that there are no remaining unrealized conversion
1838+
// cast ops.
1839+
module->walk([&](UnrealizedConversionCastOp op) {
1840+
op->emitError() << "unexpected unrealized conversion cast op found";
1841+
signalPassFailure();
1842+
});
1843+
18371844
clearAttrs(module, kLayoutAttrName);
18381845
clearAttrs(module, kMaterializedAttrName);
18391846
}

tests/Emitter/Openfhe/emit_openfhe_pke.mlir

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,8 @@ module attributes {scheme.bgv} {
293293
// -----
294294

295295
module attributes {scheme.bgv} {
296-
// CHECK: test_concat
297-
func.func @test_concat() -> tensor<64xi16> {
296+
// CHECK: test_concat_same
297+
func.func @test_concat_same() -> tensor<64xi16> {
298298
// CHECK: std::vector<int16_t> [[v0:.*]] =
299299
%cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<32xi16>
300300
// CHECK: std::vector<int16_t> [[v1:.*]];
@@ -309,6 +309,41 @@ module attributes {scheme.bgv} {
309309

310310
// -----
311311

312+
module attributes {scheme.bgv} {
313+
// CHECK: test_concat
314+
func.func @test_concat() -> tensor<64xi16> {
315+
// CHECK: std::vector<int16_t> [[c0:.*]] =
316+
// CHECK: std::vector<int16_t> [[c1:.*]] =
317+
%cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<32xi16>
318+
%cst0 = arith.constant dense<[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<32xi16>
319+
// CHECK: std::vector<int16_t> [[v1:.*]];
320+
// CHECK: [[v1]].insert([[v1]].end(), [[c0]].begin(), [[c0]].end());
321+
// CHECK: [[v1]].insert([[v1]].end(), [[c1]].begin(), [[c1]].end());
322+
%v = tensor.concat dim(0) %cst, %cst0 : (tensor<32xi16>, tensor<32xi16>) -> tensor<64xi16>
323+
324+
return %v : tensor<64xi16>
325+
}
326+
}
327+
328+
// -----
329+
330+
module attributes {scheme.bgv} {
331+
// CHECK: test_concat_multidim
332+
func.func @test_concat_multidim() -> tensor<4x16xi16> {
333+
// CHECK: std::vector<int16_t> [[c0:.*]] =
334+
// CHECK: std::vector<int16_t> [[c1:.*]] =
335+
%cst = arith.constant dense<[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]> : tensor<2x16xi16>
336+
%cst0 = arith.constant dense<[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]> : tensor<2x16xi16>
337+
// CHECK: std::vector<int16_t> [[v1:.*]];
338+
// CHECK: [[v1]].insert([[v1]].end(), [[c0]].begin(), [[c0]].end());
339+
// CHECK: [[v1]].insert([[v1]].end(), [[c1]].begin(), [[c1]].end());
340+
%v = tensor.concat dim(0) %cst, %cst0 : (tensor<2x16xi16>, tensor<2x16xi16>) -> tensor<4x16xi16>
341+
return %v : tensor<4x16xi16>
342+
}
343+
}
344+
345+
// -----
346+
312347
module attributes {scheme.ckks} {
313348
// CHECK: test_insert_slice_1d
314349
// CHECK: std::vector<float> [[v4:[^(]*]](8, 0.100000001);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# See README.md for setup required to run these tests
2+
3+
load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_end_to_end_test")
4+
5+
package(default_applicable_licenses = ["@heir//:license"])
6+
7+
openfhe_end_to_end_test(
8+
name = "convolution_test",
9+
generated_lib_header = "convolution_testlib.h",
10+
heir_opt_flags = [
11+
"--annotate-module=backend=openfhe scheme=ckks",
12+
"--torch-linalg-to-ckks=ciphertext-degree=1024",
13+
"--scheme-to-openfhe",
14+
],
15+
heir_translate_flags = [],
16+
mlir_src = "convolution.mlir",
17+
tags = ["notap"],
18+
test_src = "convolution_test.cpp",
19+
)

0 commit comments

Comments
 (0)