32 #include "llvm/ADT/STLExtras.h" 
   33 #include "llvm/ADT/TypeSwitch.h" 
   34 #include "llvm/Support/DebugLog.h" 
   36 #define DEBUG_TYPE "vector-to-gpu" 
   39 #define GEN_PASS_DEF_CONVERTVECTORTOGPU 
   40 #include "mlir/Conversion/Passes.h.inc" 
   51 template <
typename TransferOpType>
 
   55   indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
 
   57   unsigned offsetsIdx = 0;
 
   58   for (
auto expr : xferOp.getPermutationMap().getResults()) {
 
   59     if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
 
   60       Value prevIdx = indices[dim.getPosition()];
 
   62       dims.push_back(prevIdx);
 
   65           rewriter, loc, d0 + offsetMap.
getResult(offsetsIdx++), dims);
 
   75   auto infer = [&](MapList m) {
 
   80   auto iteratorTypes = 
contract.getIteratorTypes().getValue();
 
   89       contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
 
   92       contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
 
  115   return permutationMap == 
AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
 
  122   auto memrefType = dyn_cast<MemRefType>(type);
 
  126   if (memrefType.getRank() < 2)
 
  130   if (
failed(memrefType.getStridesAndOffset(strides, offset)) ||
 
  133   int64_t stride = strides[strides.size() - 2];
 
  134   if (stride == ShapedType::kDynamic)
 
  141   if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
 
  142       readOp.getVectorType().getRank() != 2)
 
  148   if (readOp.getVectorType().getElementType().isInteger(8))
 
  149     if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
 
  150                                  !isa<arith::ExtUIOp>(*readOp->user_begin())))
 
  157   auto broadcastInnerDim =
 
  167   if (writeOp.getTransferRank() == 0)
 
  170   if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
 
  171       writeOp.getVectorType().getRank() != 2)
 
  176   if (!writeOp.getPermutationMap().isMinorIdentity())
 
  184   auto vecType = dyn_cast<VectorType>(constantOp.getType());
 
  185   if (!vecType || vecType.getRank() != 2)
 
  187   return isa<SplatElementsAttr>(constantOp.getValue());
 
  192   return broadcastOp.getResultVectorType().getRank() == 2;
 
  196 template <
typename ExtOpTy>
 
  198   auto transferReadOp =
 
  199       extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
 
  202   return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
 
  209 static std::optional<gpu::MMAElementwiseOp>
 
  211   if (isa<arith::AddFOp>(op))
 
  212     return gpu::MMAElementwiseOp::ADDF;
 
  213   if (isa<arith::MulFOp>(op))
 
  214     return gpu::MMAElementwiseOp::MULF;
 
  215   if (isa<arith::SubFOp>(op))
 
  216     return gpu::MMAElementwiseOp::SUBF;
 
  217   if (isa<arith::MaximumFOp>(op))
 
  218     return gpu::MMAElementwiseOp::MAXF;
 
  219   if (isa<arith::MinimumFOp>(op))
 
  220     return gpu::MMAElementwiseOp::MINF;
 
  221   if (isa<arith::DivFOp>(op))
 
  222     return gpu::MMAElementwiseOp::DIVF;
 
  223   if (isa<arith::AddIOp>(op))
 
  225   if (isa<arith::MulIOp>(op))
 
  227   if (isa<arith::SubIOp>(op))
 
  229   if (isa<arith::DivSIOp>(op))
 
  230     return gpu::MMAElementwiseOp::DIVS;
 
  231   if (isa<arith::DivUIOp>(op))
 
  232     return gpu::MMAElementwiseOp::DIVU;
 
  233   if (isa<arith::NegFOp>(op))
 
  234     return gpu::MMAElementwiseOp::NEGATEF;
 
  235   if (isa<arith::ExtFOp>(op))
 
  236     return gpu::MMAElementwiseOp::EXTF;
 
  250   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
 
  252   if (
failed(warpMatrixInfo))
 
  263     return (cast<VectorType>(op->getResult(0).getType()) ==
 
  264             cast<VectorType>((*contractOp).getRhs().getType()));
 
  266     return (cast<VectorType>(op->getResult(0).getType()) ==
 
  267             cast<VectorType>((*contractOp).getAcc().getType()));
 
  273   if (isa<scf::ForOp, scf::YieldOp>(op))
 
  275   if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op))
 
  278   if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
 
  281   if (
auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
 
  284   if (
auto contract = dyn_cast<vector::ContractionOp>(op))
 
  286   if (
auto constant = dyn_cast<arith::ConstantOp>(op))
 
  288   if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
 
  290   if (
auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
 
  291     return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
 
  292   if (
auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
 
  293     return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
 
  294   if (
auto fpExtend = dyn_cast<arith::ExtFOp>(op))
 
  308   unsigned currentIndex = 0;
 
  311   while (currentIndex != slice.size()) {
 
  312     auto *currentOp = (slice)[currentIndex];
 
  314     backwardSlice.clear();
 
  315     LogicalResult result =
 
  317     assert(result.succeeded() && 
"expected a backward slice");
 
  319     slice.insert_range(backwardSlice);
 
  322     forwardSlice.clear();
 
  327     if (
auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
 
  328       for (
Value forOpResult : forOp.getResults())
 
  335     slice.insert_range(forwardSlice);
 
  346     return llvm::any_of(op->
getResultTypes(), llvm::IsaPred<VectorType>);
 
  349   backwardSliceOptions.
filter = hasVectorDest;
 
  355   forwardSliceOptions.
filter = hasVectorSrc;
 
  359     if (!isa<vector::ContractionOp>(nestedOp) &&
 
  362     if (opToConvert.contains(nestedOp))
 
  369     if (llvm::any_of(dependentOps, [useNvGpu](
Operation *op) {
 
  371             LDBG() << 
"cannot convert op: " << *op;
 
  378     opToConvert.insert_range(dependentOps);
 
  387 struct PrepareContractToGPUMMA
 
  391   LogicalResult matchAndRewrite(vector::ContractionOp op,
 
  394     Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
 
  398     auto infer = [&](MapList m) {
 
  403     static constexpr std::array<int64_t, 2> perm = {1, 0};
 
  404     auto iteratorTypes = op.getIteratorTypes().getValue();
 
  414     if (maps == infer({{m, k}, {k, n}, {m, n}}))
 
  416     if (maps == infer({{m, k}, {n, k}, {m, n}})) {
 
  417       rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
 
  418     } 
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
 
  419       lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
 
  420     } 
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
 
  421       rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
 
  422       lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
 
  423     } 
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
 
  425       rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
 
  426       lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
 
  427     } 
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
 
  429       rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
 
  430     } 
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
 
  432       lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
 
  433     } 
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
 
  442         op.getIteratorTypes());
 
  451 struct CombineTransferReadOpTranspose final
 
  455   LogicalResult matchAndRewrite(vector::TransposeOp op,
 
  458     Value source = op.getVector();
 
  459     Type resultType = op.getType();
 
  467                           cast<VectorType>(source.
getType()).getElementType());
 
  470     auto transferReadOp = source.
getDefiningOp<vector::TransferReadOp>();
 
  475     if (transferReadOp.getTransferRank() == 0)
 
  478     if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
 
  484         permutationMap.
compose(transferReadOp.getPermutationMap());
 
  486     auto loc = op.getLoc();
 
  487     Value result = vector::TransferReadOp::create(
 
  488                        rewriter, loc, resultType, transferReadOp.getBase(),
 
  490                        transferReadOp.getPadding(), transferReadOp.getMask(),
 
  491                        transferReadOp.getInBoundsAttr())
 
  496       if (isa<arith::ExtSIOp>(extOp))
 
  497         result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result)
 
  499       else if (isa<arith::ExtUIOp>(extOp))
 
  500         result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result)
 
  503         result = arith::ExtFOp::create(rewriter, loc, op.getType(), result)
 
  528     auto contract = dyn_cast<vector::ContractionOp>(users);
 
  546   assert(op.getTransferRank() > 0 && 
"unexpected 0-d transfer");
 
  548          "expected convertible operation");
 
  550   std::optional<int64_t> stride =
 
  552   if (!stride.has_value()) {
 
  553     LDBG() << 
"no stride";
 
  561   if (
auto cstExpr = dyn_cast<AffineConstantExpr>(map.
getResult(isTranspose))) {
 
  562     assert(cstExpr.getValue() == 0);
 
  566   Value mappingResult = op.getResult();
 
  567   auto elType = op.getVectorType().getElementType();
 
  569   if (op->hasOneUse()) {
 
  570     auto *user = *op->user_begin();
 
  572     if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
 
  574           op.getContext(), cast<IntegerType>(elType).getWidth(),
 
  576                                     : IntegerType::Unsigned);
 
  577       mappingResult = user->getResult(0);
 
  582   Value load = gpu::SubgroupMmaLoadMatrixOp::create(
 
  583       rewriter, op.getLoc(), type, op.getBase(), op.getIndices(),
 
  585       isTranspose ? rewriter.
getUnitAttr() : UnitAttr());
 
  586   valueMapping[mappingResult] = load;
 
  588   LDBG() << 
"transfer read to: " << load;
 
  599   std::optional<int64_t> stride =
 
  601   if (!stride.has_value()) {
 
  602     LDBG() << 
"no stride";
 
  606   auto it = valueMapping.find(op.getVector());
 
  607   if (it == valueMapping.end()) {
 
  608     LDBG() << 
"no mapping";
 
  612   Value matrix = it->second;
 
  613   auto store = gpu::SubgroupMmaStoreMatrixOp::create(
 
  614       rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(),
 
  618   LDBG() << 
"transfer write to: " << store;
 
  620   LDBG() << 
"erase: " << op;
 
  631   if (
auto vecType = dyn_cast<VectorType>(elType))
 
  632     elType = vecType.getElementType();
 
  643   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
 
  645   if (
failed(warpMatrixInfo)) {
 
  646     LDBG() << 
"no warpMatrixInfo";
 
  650   FailureOr<nvgpu::FragmentElementInfo> regInfo =
 
  653     LDBG() << 
"not mma sync reg info";
 
  658   auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
 
  660     LDBG() << 
"not a splat";
 
  664   Value result = arith::ConstantOp::create(
 
  665       rewriter, op.getLoc(), vectorType,
 
  667   valueMapping[op.getResult()] = result;
 
  682     LDBG() << 
"Failed because the result of `vector.transfer_read` " 
  683               "is not a 2d operand";
 
  692   auto exprM = dyn_cast<AffineDimExpr>(dM);
 
  693   auto exprN = dyn_cast<AffineDimExpr>(dN);
 
  695   if (!exprM || !exprN) {
 
  696     LDBG() << 
"Failed because expressions are not affine dim " 
  697               "expressions, then transpose cannot be determined.";
 
  701   return exprM.getPosition() > exprN.getPosition();
 
  711   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
 
  713   if (
failed(warpMatrixInfo)) {
 
  714     LDBG() << 
"no warpMatrixInfo";
 
  718   FailureOr<nvgpu::FragmentElementInfo> regInfo =
 
  721     LDBG() << 
"not mma sync reg info";
 
  727     LDBG() << 
"failed to determine the transpose";
 
  729         op, 
"Op should likely not be converted to a nvgpu.ldmatrix call.");
 
  732   FailureOr<nvgpu::LdMatrixParams> params =
 
  736     LDBG() << 
"failed to convert vector.transfer_read to ldmatrix. " 
  737            << 
"Op should likely not be converted to a nvgpu.ldmatrix call.";
 
  739         op, 
"failed to convert vector.transfer_read to ldmatrix; this op " 
  740             "likely should not be converted to a nvgpu.ldmatrix call.");
 
  744   auto laneId = gpu::LaneIdOp::create(rewriter, loc, 
nullptr);
 
  745   FailureOr<AffineMap> offsets =
 
  748     LDBG() << 
"no offsets";
 
  755   getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
 
  758   nvgpu::LdMatrixOp newOp =
 
  759       nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(),
 
  760                                 indices, *transpose, params->numTiles);
 
  761   valueMapping[op] = newOp->getResult(0);
 
  772   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
 
  774   if (
failed(warpMatrixInfo))
 
  776   FailureOr<nvgpu::FragmentElementInfo> regInfo =
 
  780         op, 
"Failed to deduce register fragment type during " 
  781             "conversion to distributed non-ldmatrix compatible load");
 
  784   Value laneId = gpu::LaneIdOp::create(rewriter, loc, 
nullptr);
 
  787   Type loadedElType = regInfo->registerLLVMType;
 
  790   Value fill = arith::ConstantOp::create(
 
  791       rewriter, op.getLoc(), vectorType.getElementType(),
 
  792       rewriter.
getZeroAttr(vectorType.getElementType()));
 
  794       vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill);
 
  796   bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
 
  800   if (!isTransposeLoad) {
 
  801     if (!isa<VectorType>(loadedElType)) {
 
  805     for (
int i = 0; i < vectorType.getShape()[0]; i++) {
 
  807           rewriter, op.getLoc(), *warpMatrixInfo);
 
  811       Value logicalValueId = arith::ConstantOp::create(
 
  813           rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
 
  815       getXferIndices<vector::TransferReadOp>(
 
  816           rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
 
  818       Value el = vector::LoadOp::create(rewriter, loc, loadedElType,
 
  819                                         op.getBase(), newIndices);
 
  820       result = vector::InsertOp::create(rewriter, loc, el, result, i);
 
  823     if (
auto vecType = dyn_cast<VectorType>(loadedElType)) {
 
  824       loadedElType = vecType.getElementType();
 
  826     for (
int i = 0; i < vectorType.getShape()[0]; i++) {
 
  827       for (
unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
 
  830         Value logicalValueId = arith::ConstantOp::create(
 
  832             rewriter.
getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
 
  834             rewriter, op.getLoc(), *warpMatrixInfo);
 
  839         getXferIndices<vector::TransferReadOp>(
 
  840             rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
 
  841         Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType,
 
  842                                           op.getBase(), newIndices);
 
  843         result = vector::InsertOp::create(rewriter, op.getLoc(), el, result,
 
  849   valueMapping[op.getResult()] = result;
 
  856       dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
 
  857   return addressSpace &&
 
  858          addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
 
  870   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
 
  872   if (
failed(warpMatrixInfo))
 
  875   bool isLdMatrixCompatible =
 
  879   VectorType vecTy = op.getVectorType();
 
  880   int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
 
  885   if (!op.getPermutationMap().isMinorIdentity() &&
 
  886       (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
 
  887        vecTy.getDimSize(0) * bitWidth < 128))
 
  888     isLdMatrixCompatible = 
false;
 
  890   if (!isLdMatrixCompatible)
 
  903   auto it = valueMapping.find(op.getVector());
 
  904   if (it == valueMapping.end())
 
  906   Value matrix = it->second;
 
  908   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
 
  910   if (
failed(warpMatrixInfo))
 
  912   FailureOr<nvgpu::FragmentElementInfo> regInfo =
 
  918   Value laneId = gpu::LaneIdOp::create(rewriter, loc, 
nullptr);
 
  920   for (
unsigned i = 0; i < vectorType.getShape()[0]; i++) {
 
  921     Value logicalValueId = arith::ConstantOp::create(
 
  923         rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
 
  925         rewriter, op.getLoc(), *warpMatrixInfo);
 
  932     getXferIndices<vector::TransferWriteOp>(
 
  933         rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
 
  934     vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
 
  937   LDBG() << 
"erase: " << op;
 
  944   for (
auto attr : arrayAttr)
 
  945     results.push_back(cast<IntegerAttr>(attr).getInt());
 
  950                            vector::ExtractStridedSliceOp op,
 
  957   FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
 
  959   if (
failed(warpMatrixInfo))
 
  962   FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
 
  964   if (
failed(mmaSyncFragmentInfo))
 
  968   auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
 
  973   if (
failed(warpMatrixInfo))
 
  976   FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
 
  978   if (
failed(ldFragmentInfo))
 
  982       (mmaSyncFragmentInfo->elementsPerRegister ==
 
  983        ldFragmentInfo->elementsPerRegister) &&
 
  984       "Number of elements per register should be same for load and mma.sync");
 
  987   std::array<int64_t, 2> strides = {1,
 
  989   std::array<int64_t, 2> sliceShape = {
 
  990       mmaSyncFragmentInfo->numRegistersPerFragment,
 
  991       mmaSyncFragmentInfo->elementsPerRegister};
 
  992   auto it = valueMapping.find(transferReadOp);
 
  993   if (it == valueMapping.end())
 
  995   auto sourceVector = it->second;
 
 1008   std::array<int64_t, 2> sliceOffset = {0, 0};
 
 1010   if (offsets[0] && offsets[1])
 
 1011     return op->emitError() << 
"Slicing fragments in 2D is not supported. ";
 
 1013     sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
 
 1014   else if (offsets[1])
 
 1015     sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
 
 1017   Value newOp = vector::ExtractStridedSliceOp::create(
 
 1018       rewriter, loc, sourceVector, sliceOffset, sliceShape, strides);
 
 1020   valueMapping[op] = newOp;
 
 1024 static LogicalResult
 
 1030   auto itA = valueMapping.find(op.getLhs());
 
 1031   auto itB = valueMapping.find(op.getRhs());
 
 1032   auto itC = valueMapping.find(op.getAcc());
 
 1033   if (itA == valueMapping.end() || itB == valueMapping.end() ||
 
 1034       itC == valueMapping.end())
 
 1036   Value opA = itA->second, opB = itB->second, opC = itC->second;
 
 1037   Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(),
 
 1041   valueMapping[op.getResult()] = matmul;
 
 1045 static LogicalResult
 
 1051   auto itA = valueMapping.find(op.getLhs());
 
 1052   auto itB = valueMapping.find(op.getRhs());
 
 1053   auto itC = valueMapping.find(op.getAcc());
 
 1054   if (itA == valueMapping.end() || itB == valueMapping.end() ||
 
 1055       itC == valueMapping.end())
 
 1057   Value opA = itA->second, opB = itB->second, opC = itC->second;
 
 1058   int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
 
 1059   int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
 
 1060   int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
 
 1061   Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
 
 1063   valueMapping[op.getResult()] = matmul;
 
 1068 static LogicalResult
 
 1077       cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
 
 1078   auto scalarConstant =
 
 1079       arith::ConstantOp::create(rewriter, op.getLoc(), splat.
getType(), splat);
 
 1081   auto vecType = cast<VectorType>(op.getType());
 
 1083       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
 
 1084   auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
 
 1085                                                          type, scalarConstant);
 
 1086   valueMapping[op.getResult()] = matrix;
 
 1091 static LogicalResult
 
 1100   auto vecType = op.getResultVectorType();
 
 1102       vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
 
 1103   auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
 
 1104                                                          type, op.getSource());
 
 1105   valueMapping[op.getResult()] = matrix;
 
 1119   auto operands = llvm::to_vector<4>(loop.getInitArgs());
 
 1120   llvm::append_range(operands, newInitArgs);
 
 1121   scf::ForOp newLoop =
 
 1122       scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
 
 1123                          loop.getUpperBound(), loop.getStep(), operands);
 
 1126   newLoop.getRegion().getBlocks().splice(
 
 1127       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
 
 1128   for (
Value operand : newInitArgs)
 
 1129     newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
 
 1131   for (
auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
 
 1132                                                   loop.getNumResults())))
 
 1135   LDBG() << 
"newLoop now: " << newLoop;
 
 1136   LDBG() << 
"stripped scf.for: " << loop;
 
 1137   LDBG() << 
"erase: " << loop;
 
 1151     auto it = valueMapping.find(operand.value());
 
 1152     if (it == valueMapping.end()) {
 
 1153       LDBG() << 
"no value mapping for: " << operand.value();
 
 1156     argMapping.push_back(std::make_pair(
 
 1157         operand.index(), op.getInitArgs().size() + newOperands.size()));
 
 1158     newOperands.push_back(it->second);
 
 1162   Block &loopBody = *newForOp.getBody();
 
 1163   for (
auto mapping : argMapping) {
 
 1164     valueMapping[newForOp.getResult(mapping.first)] =
 
 1165         newForOp.getResult(mapping.second);
 
 1167                                       newForOp.getNumInductionVars())] =
 
 1168         loopBody.
getArgument(mapping.second + newForOp.getNumInductionVars());
 
 1171   LDBG() << 
"scf.for to: " << newForOp;
 
 1175 static LogicalResult
 
 1181   auto loop = cast<scf::ForOp>(op->getParentOp());
 
 1182   auto yieldOperands = llvm::to_vector<4>(op.getOperands());
 
 1184     auto it = valueMapping.find(operand.value());
 
 1185     if (it == valueMapping.end())
 
 1189     yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
 
 1190     yieldOperands.push_back(it->second);
 
 1192   scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
 
 1194   LDBG() << 
"erase: " << op;
 
 1200 static LogicalResult
 
 1202                      gpu::MMAElementwiseOp opType,
 
 1209     auto it = valueMapping.find(operand);
 
 1210     if (it == valueMapping.end())
 
 1212     matrixOperands.push_back(it->second);
 
 1214   auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].
getType());
 
 1215   if (opType == gpu::MMAElementwiseOp::EXTF) {
 
 1219                                          vectorType.getElementType(),
 
 1220                                          resultType.getOperand());
 
 1223   Value newOp = gpu::SubgroupMmaElementwiseOp::create(
 
 1224       rewriter, op->
getLoc(), resultType, matrixOperands, opType);
 
 1232     patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
 
 1245   auto globalRes = LogicalResult::success();
 
 1247     LDBG() << 
"Process op: " << *op;
 
 1249     auto res = LogicalResult::success();
 
 1250     if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
 
 1252     } 
else if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
 
 1254     } 
else if (
auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
 
 1256     } 
else if (
auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
 
 1258     } 
else if (
auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
 
 1260     } 
else if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
 
 1262     } 
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
 
 1268       globalRes = failure();
 
 1279             .Case([&](vector::TransferReadOp transferReadOp) {
 
 1283             .Case([&](vector::TransferWriteOp transferWriteOp) {
 
 1287             .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
 
 1291             .Case([&](vector::ContractionOp contractionOp) {
 
 1295             .Case([&](scf::ForOp forOp) {
 
 1298             .Case([&](scf::YieldOp yieldOp) {
 
 1301             .Case([&](arith::ConstantOp constOp) {
 
 1305               return op->
emitError() << 
"unhandled vector to mma type: " << *op;
 
 1309              << 
"failed to convert op during vector-to-nvgpu conversion";
 
 1317 struct ConvertVectorToGPUPass
 
 1318     : 
public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
 
 1320   explicit ConvertVectorToGPUPass(
bool useNvGpu_) {
 
 1321     useNvGpu.setValue(useNvGpu_);
 
 1324   void runOnOperation()
 override {
 
 1328       return signalPassFailure();
 
 1334         return signalPassFailure();
 
 1344   return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
 
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
user_iterator user_begin()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This trait tags element-wise ops on vectors or tensors.
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
int64_t elementsPerRegister
int64_t numRegistersPerFragment