@@ -237,8 +237,8 @@ LogicalResult OpenFhePkeEmitter::translate(Operation& op) {
237237 // Arith ops
238238 .Case <arith::ConstantOp, arith::ExtSIOp, arith::ExtUIOp,
239239 arith::IndexCastOp, arith::ExtFOp, arith::RemSIOp,
240- arith::AddIOp, arith::AndIOp , arith::SubIOp , arith::MulIOp ,
241- arith::DivSIOp, arith::CmpIOp, arith::SelectOp>(
240+ arith::AddIOp, arith::AddFOp , arith::AndIOp , arith::SubIOp ,
241+ arith::MulIOp, arith:: DivSIOp, arith::CmpIOp, arith::SelectOp>(
242242 [&](auto op) { return printOperation (op); })
243243 // SCF ops
244244 .Case <scf::IfOp, scf::ForOp, scf::YieldOp>(
@@ -248,7 +248,7 @@ LogicalResult OpenFhePkeEmitter::translate(Operation& op) {
248248 tensor::InsertSliceOp, tensor::ExtractOp,
249249 tensor::ExtractSliceOp, tensor::SplatOp,
250250 tensor::CollapseShapeOp, tensor::ExpandShapeOp,
251- tensor::FromElementsOp>(
251+ tensor::FromElementsOp, tensor::ConcatOp >(
252252 [&](auto op) { return printOperation (op); })
253253 // LWE ops
254254 .Case <lwe::RLWEDecodeOp, lwe::ReinterpretApplicationDataOp>(
@@ -940,6 +940,10 @@ LogicalResult OpenFhePkeEmitter::printOperation(arith::AddIOp op) {
940940 return printBinaryOp (op, op.getLhs (), op.getRhs (), " +" );
941941}
942942
943+ LogicalResult OpenFhePkeEmitter::printOperation (arith::AddFOp op) {
944+ return printBinaryOp (op, op.getLhs (), op.getRhs (), " +" );
945+ }
946+
943947LogicalResult OpenFhePkeEmitter::printOperation (arith::AndIOp op) {
944948 return printBinaryOp (op, op.getLhs (), op.getRhs (), " &&" );
945949}
@@ -993,21 +997,20 @@ LogicalResult OpenFhePkeEmitter::printOperation(arith::CmpIOp op) {
993997}
994998
995999LogicalResult OpenFhePkeEmitter::printOperation (tensor::ConcatOp op) {
996- // concat dim(0) %foo, %foo, ...
997- // lower to a loop
998- auto operandType = cast<RankedTensorType>(op.getOperands ()[0 ].getType ());
1000+ // concat dim(0) %value1, %value2, ...
9991001 auto resultType = op.getResult ().getType ();
10001002 std::string varName = variableNames->getNameForValue (op.getResult ());
1001- if (resultType.getRank () != 1 || operandType.getRank () != 1 ) {
1002- return failure ();
1003- }
1003+
10041004 // std::vector<8192> result;
10051005 if (failed (emitType (resultType, op->getLoc ()))) {
10061006 return failure ();
10071007 }
10081008 os << " " << varName << " ;\n " ;
10091009
1010+ // If all the operands are the same, we can just repeat the operand
1011+ // insertion in a loop to minimize code size.
10101012 if (llvm::all_equal (op.getOperands ())) {
1013+ auto operandType = cast<RankedTensorType>(op.getOperands ()[0 ].getType ());
10111014 std::string operandName =
10121015 variableNames->getNameForValue (op.getOperands ()[0 ]);
10131016 int64_t numRepeats =
@@ -1022,12 +1025,23 @@ LogicalResult OpenFhePkeEmitter::printOperation(tensor::ConcatOp op) {
10221025
10231026 os.unindent ();
10241027 os << " }\n " ;
1028+ }
1029+
1030+ // If we are concatenating on dimension 0, insert the operands
1031+ // one by one into the result vector.
1032+ if (op.getDim () == 0 ) {
1033+ for (auto operand : op.getOperands ()) {
1034+ // result.insert(result.end(), foo.begin(), foo.end());
1035+ std::string operandName = variableNames->getNameForValue (operand);
1036+ os << varName << " .insert(" << varName << " .end(), " << operandName
1037+ << " .begin(), " << operandName << " .end());\n " ;
1038+ }
10251039 return success ();
10261040 }
10271041
10281042 // More complicated concat ops are not supported yet. The earlier lowerings
1029- // should just produce concat for lack of a "repeat" op. Maybe we should make
1030- // a tensor_ext.repeat op?
1043+ // should just produce concat for lack of a "repeat" op. Maybe we should
1044+ // make a tensor_ext.repeat op?
10311045 return failure ();
10321046}
10331047
@@ -1065,8 +1079,8 @@ LogicalResult OpenFhePkeEmitter::printOperation(tensor::ExtractOp op) {
10651079
10661080LogicalResult OpenFhePkeEmitter::printOperation (
10671081 ::mlir::tensor::CollapseShapeOp op) {
1068- // A rank-reduced type will have the same number of elements so collapsing is
1069- // a no-op on a flattened tensor.
1082+ // A rank-reduced type will have the same number of elements so collapsing
1083+ // is a no-op on a flattened tensor.
10701084 SliceVerificationResult res =
10711085 isRankReducedType (op.getSrcType (), op.getResultType ());
10721086 if (res != SliceVerificationResult::Success) {
@@ -1329,8 +1343,9 @@ LogicalResult OpenFhePkeEmitter::printOperation(
13291343 if (failed (resultCC)) return resultCC;
13301344 std::string cc = variableNames->getNameForValue (resultCC.value ());
13311345
1332- // In certain conditions, we might end up with the input being tensor<..xi64>
1333- // which isn't a valid input type for MakeCKKSPackedPlaintext, so we convert
1346+ // In certain conditions, we might end up with the input being
1347+ // tensor<..xi64> which isn't a valid input type for
1348+ // MakeCKKSPackedPlaintext, so we convert
13341349 if (getElementTypeOrSelf (op.getValue ().getType ()).isInteger ()) {
13351350 // This means we will have created a std::vector<int64_t>
13361351 // but we need a std::vector<double>
@@ -1393,10 +1408,10 @@ FailureOr<std::pair<unsigned, int64_t>> getNonUnitDimension(
13931408}
13941409
13951410LogicalResult OpenFhePkeEmitter::printOperation (lwe::RLWEDecodeOp op) {
1396- // In OpenFHE a plaintext is already decoded by decrypt. The internal OpenFHE
1397- // implementation is simple enough (and dependent on currently-hard-coded
1398- // encoding choices) that we will eventually need to work at a lower level of
1399- // the API to support this operation properly.
1411+ // In OpenFHE a plaintext is already decoded by decrypt. The internal
1412+ // OpenFHE implementation is simple enough (and dependent on
1413+ // currently-hard-coded encoding choices) that we will eventually need to
1414+ // work at a lower level of the API to support this operation properly.
14001415 bool isCKKS = llvm::isa<lwe::InverseCanonicalEncodingAttr>(op.getEncoding ());
14011416 auto tensorTy = dyn_cast<RankedTensorType>(op.getResult ().getType ());
14021417 if (tensorTy) {
0 commit comments