@@ -226,21 +226,20 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
226
226
switch (VectorVT.SimpleTy ) {
227
227
default :
228
228
return std::nullopt ;
229
+
229
230
case MVT::v4i64:
230
231
case MVT::v4f64:
231
- case MVT::v8i32:
232
- // This is a "native" vector type iff the address space is global
233
- // and the target supports 256-bit loads/stores
232
+ // This is a "native" vector type iff the address space is global and the
233
+ // target supports 256-bit loads/stores
234
234
if (!CanLowerTo256Bit)
235
235
return std::nullopt ;
236
236
LLVM_FALLTHROUGH;
237
237
case MVT::v2i8:
238
- case MVT::v2i32:
239
238
case MVT::v2i64:
240
239
case MVT::v2f64:
241
- case MVT::v4i32:
242
240
// This is a "native" vector type
243
241
return std::pair (NumElts, EltVT);
242
+
244
243
case MVT::v16f16: // <8 x f16x2>
245
244
case MVT::v16bf16: // <8 x bf16x2>
246
245
case MVT::v16i16: // <8 x i16x2>
@@ -264,12 +263,18 @@ getVectorLoweringShape(EVT VectorEVT, const NVPTXSubtarget &STI,
264
263
case MVT::v16i8: // <4 x i8x4>
265
264
PackRegSize = 32 ;
266
265
break ;
266
+
267
267
case MVT::v8f32: // <4 x f32x2>
268
+ case MVT::v8i32: // <4 x i32x2>
269
+ // This is a "native" vector type iff the address space is global and the
270
+ // target supports 256-bit loads/stores
268
271
if (!CanLowerTo256Bit)
269
272
return std::nullopt ;
270
273
LLVM_FALLTHROUGH;
271
274
case MVT::v2f32: // <1 x f32x2>
272
275
case MVT::v4f32: // <2 x f32x2>
276
+ case MVT::v2i32: // <1 x i32x2>
277
+ case MVT::v4i32: // <2 x i32x2>
273
278
if (!STI.hasF32x2Instructions ())
274
279
return std::pair (NumElts, EltVT);
275
280
PackRegSize = 64 ;
@@ -590,8 +595,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
590
595
addRegisterClass (MVT::bf16 , &NVPTX::B16RegClass);
591
596
addRegisterClass (MVT::v2bf16, &NVPTX::B32RegClass);
592
597
593
- if (STI.hasF32x2Instructions ())
598
+ if (STI.hasF32x2Instructions ()) {
594
599
addRegisterClass (MVT::v2f32, &NVPTX::B64RegClass);
600
+ addRegisterClass (MVT::v2i32, &NVPTX::B64RegClass);
601
+ }
595
602
596
603
// Conversion to/from FP16/FP16x2 is always legal.
597
604
setOperationAction (ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -628,12 +635,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
628
635
setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
629
636
setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
630
637
631
- // No support for these operations with v2f32.
632
- setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::v2f32, Expand);
633
- setOperationAction (ISD::VECTOR_SHUFFLE, MVT::v2f32, Expand);
638
+ // No support for these operations with v2f32/v2i32
639
+ setOperationAction (ISD::INSERT_VECTOR_ELT, { MVT::v2f32, MVT::v2i32} , Expand);
640
+ setOperationAction (ISD::VECTOR_SHUFFLE, { MVT::v2f32, MVT::v2i32} , Expand);
634
641
// Need custom lowering in case the index is dynamic.
635
642
if (STI.hasF32x2Instructions ())
636
- setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
643
+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, {MVT::v2f32, MVT::v2i32},
644
+ Custom);
637
645
638
646
// Custom conversions to/from v2i8.
639
647
setOperationAction (ISD::BITCAST, MVT::v2i8, Custom);
@@ -661,14 +669,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
661
669
// Operations not directly supported by NVPTX.
662
670
for (MVT VT : {MVT::bf16 , MVT::f16 , MVT::v2bf16, MVT::v2f16, MVT::f32 ,
663
671
MVT::v2f32, MVT::f64 , MVT::i1, MVT::i8 , MVT::i16 , MVT::v2i16,
664
- MVT::v4i8, MVT::i32 , MVT::i64 }) {
672
+ MVT::v4i8, MVT::i32 , MVT::v2i32, MVT:: i64 }) {
665
673
setOperationAction (ISD::SELECT_CC, VT, Expand);
666
674
setOperationAction (ISD::BR_CC, VT, Expand);
667
675
}
668
676
669
- // Not directly supported. TLI would attempt to expand operations like
670
- // FMINIMUM(v2f32) using invalid SETCC and VSELECT nodes.
671
- setOperationAction (ISD::VSELECT, MVT::v2f32, Expand);
677
+ // We don't want ops like FMINIMUM or UMAX to be lowered to SETCC+VSELECT.
678
+ setOperationAction (ISD::VSELECT, {MVT::v2f32, MVT::v2i32}, Expand);
672
679
673
680
// Some SIGN_EXTEND_INREG can be done using cvt instruction.
674
681
// For others we will expand to a SHL/SRA pair.
@@ -815,7 +822,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
815
822
setOperationAction ({ISD::SDIV, ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
816
823
ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
817
824
ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::SETCC},
818
- MVT::v2i16, Expand);
825
+ {MVT::v2i16, MVT::v2i32}, Expand);
826
+
827
+ // v2i32 is not supported for any arithmetic operations
828
+ setOperationAction ({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
829
+ ISD::CTPOP, ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
830
+ ISD::SHL, ISD::SRA, ISD::SRL, ISD::OR, ISD::AND, ISD::XOR,
831
+ ISD::SREM, ISD::UREM},
832
+ MVT::v2i32, Expand);
819
833
820
834
setOperationAction (ISD::ADDC, MVT::i32 , Legal);
821
835
setOperationAction (ISD::ADDE, MVT::i32 , Legal);
@@ -829,7 +843,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
829
843
}
830
844
831
845
setOperationAction (ISD::CTTZ, MVT::i16 , Expand);
832
- setOperationAction (ISD::CTTZ, MVT::v2i16, Expand);
846
+ setOperationAction (ISD::CTTZ, { MVT::v2i16, MVT::v2i32} , Expand);
833
847
setOperationAction (ISD::CTTZ, MVT::i32 , Expand);
834
848
setOperationAction (ISD::CTTZ, MVT::i64 , Expand);
835
849
@@ -1071,7 +1085,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
1071
1085
// Custom lowering for tcgen05.st vector operands
1072
1086
setOperationAction (ISD::INTRINSIC_VOID,
1073
1087
{MVT::v2i32, MVT::v4i32, MVT::v8i32, MVT::v16i32,
1074
- MVT::v32i32, MVT::v64i32, MVT::v128i32},
1088
+ MVT::v32i32, MVT::v64i32, MVT::v128i32, MVT::Other },
1075
1089
Custom);
1076
1090
1077
1091
// Enable custom lowering for the following:
@@ -2604,7 +2618,7 @@ static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
2604
2618
return V;
2605
2619
}
2606
2620
2607
- static SDValue LowerTcgen05St (SDValue Op, SelectionDAG &DAG) {
2621
+ static SDValue lowerTcgen05St (SDValue Op, SelectionDAG &DAG) {
2608
2622
SDNode *N = Op.getNode ();
2609
2623
SDLoc DL (N);
2610
2624
SmallVector<SDValue, 32 > Ops;
@@ -2719,7 +2733,52 @@ static SDValue LowerTcgen05MMADisableOutputLane(SDValue Op, SelectionDAG &DAG) {
2719
2733
return Tcgen05MMANode;
2720
2734
}
2721
2735
2722
- static SDValue LowerIntrinsicVoid (SDValue Op, SelectionDAG &DAG) {
2736
+ // Lower vector return type of tcgen05.ld intrinsics
2737
+ static std::optional<std::pair<SDValue, SDValue>>
2738
+ lowerTcgen05Ld (SDNode *N, SelectionDAG &DAG, bool HasOffset = false ) {
2739
+ SDLoc DL (N);
2740
+ EVT ResVT = N->getValueType (0 );
2741
+ if (!ResVT.isVector ())
2742
+ return {}; // already legalized.
2743
+
2744
+ const unsigned NumElts = ResVT.getVectorNumElements ();
2745
+
2746
+ // Create the return type of the instructions
2747
+ SmallVector<EVT, 5 > ListVTs;
2748
+ for (unsigned i = 0 ; i < NumElts; ++i)
2749
+ ListVTs.push_back (MVT::i32 );
2750
+
2751
+ ListVTs.push_back (N->getValueType (1 )); // Chain
2752
+
2753
+ SDVTList ResVTs = DAG.getVTList (ListVTs);
2754
+
2755
+ SmallVector<SDValue, 8 > Ops{N->getOperand (0 ), N->getOperand (1 ),
2756
+ N->getOperand (2 )};
2757
+
2758
+ if (HasOffset) {
2759
+ Ops.push_back (N->getOperand (3 )); // offset
2760
+ Ops.push_back (N->getOperand (4 )); // Pack flag
2761
+ } else
2762
+ Ops.push_back (N->getOperand (3 )); // Pack flag
2763
+
2764
+ MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2765
+ SDValue NewNode =
2766
+ DAG.getMemIntrinsicNode (ISD::INTRINSIC_W_CHAIN, DL, ResVTs, Ops,
2767
+ MemSD->getMemoryVT (), MemSD->getMemOperand ());
2768
+
2769
+ // split the vector result
2770
+ SmallVector<SDValue, 4 > ScalarRes;
2771
+ for (unsigned i = 0 ; i < NumElts; ++i) {
2772
+ SDValue Res = NewNode.getValue (i);
2773
+ ScalarRes.push_back (Res);
2774
+ }
2775
+
2776
+ SDValue Chain = NewNode.getValue (NumElts);
2777
+ SDValue BuildVector = DAG.getNode (ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
2778
+ return {{BuildVector, Chain}};
2779
+ }
2780
+
2781
+ static SDValue lowerIntrinsicVoid (SDValue Op, SelectionDAG &DAG) {
2723
2782
SDNode *N = Op.getNode ();
2724
2783
SDValue Intrin = N->getOperand (1 );
2725
2784
@@ -2765,7 +2824,7 @@ static SDValue LowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
2765
2824
case Intrinsic::nvvm_tcgen05_st_16x64b_x64:
2766
2825
case Intrinsic::nvvm_tcgen05_st_32x32b_x64:
2767
2826
case Intrinsic::nvvm_tcgen05_st_32x32b_x128:
2768
- return LowerTcgen05St (Op, DAG);
2827
+ return lowerTcgen05St (Op, DAG);
2769
2828
case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
2770
2829
case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg2:
2771
2830
case Intrinsic::nvvm_tcgen05_mma_shared_scale_d_disable_output_lane_cg1:
@@ -2867,6 +2926,28 @@ static SDValue lowerPrmtIntrinsic(SDValue Op, SelectionDAG &DAG) {
2867
2926
SDValue Selector = (Op->op_end () - 1 )->get ();
2868
2927
return getPRMT (A, B, Selector, DL, DAG, Mode);
2869
2928
}
2929
+
2930
+ static SDValue lowerIntrinsicWChain (SDValue Op, SelectionDAG &DAG) {
2931
+ switch (Op->getConstantOperandVal (1 )) {
2932
+ default :
2933
+ return Op;
2934
+
2935
+ // These tcgen05 intrinsics return a v2i32, which is legal, so we have to
2936
+ // lower them through LowerOperation() instead of ReplaceNodeResults().
2937
+ case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
2938
+ case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
2939
+ case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
2940
+ if (auto Res = lowerTcgen05Ld (Op.getNode (), DAG))
2941
+ return DAG.getMergeValues ({Res->first , Res->second }, SDLoc (Op));
2942
+ return SDValue ();
2943
+
2944
+ case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
2945
+ if (auto Res = lowerTcgen05Ld (Op.getNode (), DAG, /* HasOffset=*/ true ))
2946
+ return DAG.getMergeValues ({Res->first , Res->second }, SDLoc (Op));
2947
+ return SDValue ();
2948
+ }
2949
+ }
2950
+
2870
2951
static SDValue lowerIntrinsicWOChain (SDValue Op, SelectionDAG &DAG) {
2871
2952
switch (Op->getConstantOperandVal (0 )) {
2872
2953
default :
@@ -3029,11 +3110,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3029
3110
case ISD::ADDRSPACECAST:
3030
3111
return LowerADDRSPACECAST (Op, DAG);
3031
3112
case ISD::INTRINSIC_W_CHAIN:
3032
- return Op ;
3113
+ return lowerIntrinsicWChain (Op, DAG) ;
3033
3114
case ISD::INTRINSIC_WO_CHAIN:
3034
3115
return lowerIntrinsicWOChain (Op, DAG);
3035
3116
case ISD::INTRINSIC_VOID:
3036
- return LowerIntrinsicVoid (Op, DAG);
3117
+ return lowerIntrinsicVoid (Op, DAG);
3037
3118
case ISD::BUILD_VECTOR:
3038
3119
return LowerBUILD_VECTOR (Op, DAG);
3039
3120
case ISD::BITCAST:
@@ -5920,7 +6001,7 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
5920
6001
IsPTXVectorType (VectorVT.getSimpleVT ()))
5921
6002
return SDValue (); // Native vector loads already combine nicely w/
5922
6003
// extract_vector_elt.
5923
- // Don't mess with singletons or packed types (v2f32 , v2*16, v4i8 and v8i8),
6004
+ // Don't mess with singletons or packed types (v2*32 , v2*16, v4i8 and v8i8),
5924
6005
// we already handle them OK.
5925
6006
if (VectorVT.getVectorNumElements () == 1 ||
5926
6007
NVPTX::isPackedVectorTy (VectorVT) || VectorVT == MVT::v8i8)
@@ -6300,53 +6381,6 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
6300
6381
DAG.getNode (ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
6301
6382
}
6302
6383
6303
- // Lower vector return type of tcgen05.ld intrinsics
6304
- static void ReplaceTcgen05Ld (SDNode *N, SelectionDAG &DAG,
6305
- SmallVectorImpl<SDValue> &Results,
6306
- bool hasOffset = false ) {
6307
- SDLoc DL (N);
6308
- EVT ResVT = N->getValueType (0 );
6309
- if (!ResVT.isVector ())
6310
- return ; // already legalized.
6311
-
6312
- const unsigned NumElts = ResVT.getVectorNumElements ();
6313
-
6314
- // Create the return type of the instructions
6315
- SmallVector<EVT, 5 > ListVTs;
6316
- for (unsigned i = 0 ; i < NumElts; ++i)
6317
- ListVTs.push_back (MVT::i32 );
6318
-
6319
- ListVTs.push_back (N->getValueType (1 )); // Chain
6320
-
6321
- SDVTList ResVTs = DAG.getVTList (ListVTs);
6322
-
6323
- SmallVector<SDValue, 8 > Ops{N->getOperand (0 ), N->getOperand (1 ),
6324
- N->getOperand (2 )};
6325
-
6326
- if (hasOffset) {
6327
- Ops.push_back (N->getOperand (3 )); // offset
6328
- Ops.push_back (N->getOperand (4 )); // Pack flag
6329
- } else
6330
- Ops.push_back (N->getOperand (3 )); // Pack flag
6331
-
6332
- MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
6333
- SDValue NewNode =
6334
- DAG.getMemIntrinsicNode (ISD::INTRINSIC_W_CHAIN, DL, ResVTs, Ops,
6335
- MemSD->getMemoryVT (), MemSD->getMemOperand ());
6336
-
6337
- // split the vector result
6338
- SmallVector<SDValue, 4 > ScalarRes;
6339
- for (unsigned i = 0 ; i < NumElts; ++i) {
6340
- SDValue Res = NewNode.getValue (i);
6341
- ScalarRes.push_back (Res);
6342
- }
6343
-
6344
- SDValue Chain = NewNode.getValue (NumElts);
6345
- SDValue BuildVector = DAG.getNode (ISD::BUILD_VECTOR, DL, ResVT, ScalarRes);
6346
- Results.push_back (BuildVector); // Build Vector
6347
- Results.push_back (Chain); // Chain
6348
- }
6349
-
6350
6384
static void ReplaceINTRINSIC_W_CHAIN (SDNode *N, SelectionDAG &DAG,
6351
6385
SmallVectorImpl<SDValue> &Results) {
6352
6386
SDValue Chain = N->getOperand (0 );
@@ -6455,21 +6489,18 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
6455
6489
return ;
6456
6490
}
6457
6491
6458
- case Intrinsic::nvvm_tcgen05_ld_16x64b_x2:
6459
6492
case Intrinsic::nvvm_tcgen05_ld_16x64b_x4:
6460
6493
case Intrinsic::nvvm_tcgen05_ld_16x64b_x8:
6461
6494
case Intrinsic::nvvm_tcgen05_ld_16x64b_x16:
6462
6495
case Intrinsic::nvvm_tcgen05_ld_16x64b_x32:
6463
6496
case Intrinsic::nvvm_tcgen05_ld_16x64b_x64:
6464
6497
case Intrinsic::nvvm_tcgen05_ld_16x64b_x128:
6465
- case Intrinsic::nvvm_tcgen05_ld_32x32b_x2:
6466
6498
case Intrinsic::nvvm_tcgen05_ld_32x32b_x4:
6467
6499
case Intrinsic::nvvm_tcgen05_ld_32x32b_x8:
6468
6500
case Intrinsic::nvvm_tcgen05_ld_32x32b_x16:
6469
6501
case Intrinsic::nvvm_tcgen05_ld_32x32b_x32:
6470
6502
case Intrinsic::nvvm_tcgen05_ld_32x32b_x64:
6471
6503
case Intrinsic::nvvm_tcgen05_ld_32x32b_x128:
6472
- case Intrinsic::nvvm_tcgen05_ld_16x128b_x1:
6473
6504
case Intrinsic::nvvm_tcgen05_ld_16x128b_x2:
6474
6505
case Intrinsic::nvvm_tcgen05_ld_16x128b_x4:
6475
6506
case Intrinsic::nvvm_tcgen05_ld_16x128b_x8:
@@ -6482,16 +6513,23 @@ static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
6482
6513
case Intrinsic::nvvm_tcgen05_ld_16x256b_x8:
6483
6514
case Intrinsic::nvvm_tcgen05_ld_16x256b_x16:
6484
6515
case Intrinsic::nvvm_tcgen05_ld_16x256b_x32:
6485
- return ReplaceTcgen05Ld (N, DAG, Results);
6516
+ if (auto Res = lowerTcgen05Ld (N, DAG)) {
6517
+ Results.push_back (Res->first );
6518
+ Results.push_back (Res->second );
6519
+ }
6520
+ return ;
6486
6521
6487
- case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x2:
6488
6522
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x4:
6489
6523
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x8:
6490
6524
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x16:
6491
6525
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x32:
6492
6526
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x64:
6493
6527
case Intrinsic::nvvm_tcgen05_ld_16x32bx2_x128:
6494
- return ReplaceTcgen05Ld (N, DAG, Results, /* Offset */ true );
6528
+ if (auto Res = lowerTcgen05Ld (N, DAG, /* HasOffset=*/ true )) {
6529
+ Results.push_back (Res->first );
6530
+ Results.push_back (Res->second );
6531
+ }
6532
+ return ;
6495
6533
}
6496
6534
}
6497
6535
0 commit comments