32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/TypeSwitch.h"
35 #define DEBUG_TYPE "vector-to-gpu"
36 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37 #define DBGSNL() (llvm::dbgs() << "\n")
40 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
41 #include "mlir/Conversion/Passes.h.inc"
52 template <
typename TransferOpType>
56 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
58 unsigned offsetsIdx = 0;
59 for (
auto expr : xferOp.getPermutationMap().getResults()) {
60 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
61 Value prevIdx = indices[dim.getPosition()];
63 dims.push_back(prevIdx);
66 rewriter, loc, d0 + offsetMap.
getResult(offsetsIdx++), dims);
76 auto infer = [&](MapList m) {
81 auto iteratorTypes =
contract.getIteratorTypes().getValue();
90 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
93 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
116 return permutationMap ==
AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
123 auto memrefType = dyn_cast<MemRefType>(type);
127 if (memrefType.getRank() < 2)
131 if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
134 int64_t stride = strides[strides.size() - 2];
135 if (stride == ShapedType::kDynamic)
142 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
143 readOp.getVectorType().getRank() != 2)
149 if (readOp.getVectorType().getElementType().isInteger(8))
150 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
151 !isa<arith::ExtUIOp>(*readOp->user_begin())))
158 auto broadcastInnerDim =
168 if (writeOp.getTransferRank() == 0)
171 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
172 writeOp.getVectorType().getRank() != 2)
177 if (!writeOp.getPermutationMap().isMinorIdentity())
185 auto vecType = dyn_cast<VectorType>(constantOp.getType());
186 if (!vecType || vecType.getRank() != 2)
188 return isa<SplatElementsAttr>(constantOp.getValue());
193 return broadcastOp.getResultVectorType().getRank() == 2;
197 template <
typename ExtOpTy>
199 auto transferReadOp =
200 extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
203 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
210 static std::optional<gpu::MMAElementwiseOp>
212 if (isa<arith::AddFOp>(op))
213 return gpu::MMAElementwiseOp::ADDF;
214 if (isa<arith::MulFOp>(op))
215 return gpu::MMAElementwiseOp::MULF;
216 if (isa<arith::SubFOp>(op))
217 return gpu::MMAElementwiseOp::SUBF;
218 if (isa<arith::MaximumFOp>(op))
219 return gpu::MMAElementwiseOp::MAXF;
220 if (isa<arith::MinimumFOp>(op))
221 return gpu::MMAElementwiseOp::MINF;
222 if (isa<arith::DivFOp>(op))
223 return gpu::MMAElementwiseOp::DIVF;
224 if (isa<arith::AddIOp>(op))
226 if (isa<arith::MulIOp>(op))
228 if (isa<arith::SubIOp>(op))
230 if (isa<arith::DivSIOp>(op))
231 return gpu::MMAElementwiseOp::DIVS;
232 if (isa<arith::DivUIOp>(op))
233 return gpu::MMAElementwiseOp::DIVU;
234 if (isa<arith::NegFOp>(op))
235 return gpu::MMAElementwiseOp::NEGATEF;
236 if (isa<arith::ExtFOp>(op))
237 return gpu::MMAElementwiseOp::EXTF;
251 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
253 if (failed(warpMatrixInfo))
257 if (failed(contractOp))
264 return (cast<VectorType>(op->getResult(0).getType()) ==
265 cast<VectorType>((*contractOp).getRhs().getType()));
267 return (cast<VectorType>(op->getResult(0).getType()) ==
268 cast<VectorType>((*contractOp).getAcc().getType()));
274 if (isa<scf::ForOp, scf::YieldOp>(op))
276 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op))
279 if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
282 if (
auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
285 if (
auto contract = dyn_cast<vector::ContractionOp>(op))
287 if (
auto constant = dyn_cast<arith::ConstantOp>(op))
289 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
291 if (
auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
292 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
293 if (
auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
294 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
295 if (
auto fpExtend = dyn_cast<arith::ExtFOp>(op))
309 unsigned currentIndex = 0;
312 while (currentIndex != slice.size()) {
313 auto *currentOp = (slice)[currentIndex];
315 backwardSlice.clear();
316 LogicalResult result =
318 assert(result.succeeded() &&
"expected a backward slice");
320 slice.insert_range(backwardSlice);
323 forwardSlice.clear();
328 if (
auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
329 for (
Value forOpResult : forOp.getResults())
336 slice.insert_range(forwardSlice);
347 return llvm::any_of(op->
getResultTypes(), llvm::IsaPred<VectorType>);
350 backwardSliceOptions.
filter = hasVectorDest;
356 forwardSliceOptions.
filter = hasVectorSrc;
360 if (opToConvert.contains(
contract.getOperation()))
367 if (llvm::any_of(dependentOps, [useNvGpu](
Operation *op) {
369 LLVM_DEBUG(
DBGS() <<
"cannot convert op: " << *op <<
"\n");
376 opToConvert.insert_range(dependentOps);
385 struct PrepareContractToGPUMMA
389 LogicalResult matchAndRewrite(vector::ContractionOp op,
392 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
396 auto infer = [&](MapList m) {
401 static constexpr std::array<int64_t, 2> perm = {1, 0};
402 auto iteratorTypes = op.getIteratorTypes().getValue();
412 if (maps == infer({{m, k}, {k, n}, {m, n}}))
414 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
415 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
416 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
417 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
418 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
419 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
420 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
421 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
423 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
424 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
425 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
427 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
428 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
430 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
431 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
440 op.getIteratorTypes());
449 struct CombineTransferReadOpTranspose final
453 LogicalResult matchAndRewrite(vector::TransposeOp op,
456 Value source = op.getVector();
457 Type resultType = op.getType();
465 cast<VectorType>(source.
getType()).getElementType());
468 auto transferReadOp = source.
getDefiningOp<vector::TransferReadOp>();
473 if (transferReadOp.getTransferRank() == 0)
476 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
482 permutationMap.
compose(transferReadOp.getPermutationMap());
484 auto loc = op.getLoc();
485 Value result = vector::TransferReadOp::create(
486 rewriter, loc, resultType, transferReadOp.getBase(),
488 transferReadOp.getPadding(), transferReadOp.getMask(),
489 transferReadOp.getInBoundsAttr())
494 if (isa<arith::ExtSIOp>(extOp))
495 result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result)
497 else if (isa<arith::ExtUIOp>(extOp))
498 result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result)
501 result = arith::ExtFOp::create(rewriter, loc, op.getType(), result)
526 auto contract = dyn_cast<vector::ContractionOp>(users);
544 assert(op.getTransferRank() > 0 &&
"unexpected 0-d transfer");
546 "expected convertible operation");
548 std::optional<int64_t> stride =
550 if (!stride.has_value()) {
551 LLVM_DEBUG(
DBGS() <<
"no stride\n");
559 if (
auto cstExpr = dyn_cast<AffineConstantExpr>(map.
getResult(isTranspose))) {
560 assert(cstExpr.getValue() == 0);
564 Value mappingResult = op.getResult();
565 auto elType = op.getVectorType().getElementType();
567 if (op->hasOneUse()) {
568 auto *user = *op->user_begin();
570 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
572 op.getContext(), cast<IntegerType>(elType).getWidth(),
574 : IntegerType::Unsigned);
575 mappingResult = user->getResult(0);
580 Value load = gpu::SubgroupMmaLoadMatrixOp::create(
581 rewriter, op.getLoc(), type, op.getBase(), op.getIndices(),
583 isTranspose ? rewriter.
getUnitAttr() : UnitAttr());
584 valueMapping[mappingResult] = load;
586 LLVM_DEBUG(
DBGS() <<
"transfer read to: " << load <<
"\n");
597 std::optional<int64_t> stride =
599 if (!stride.has_value()) {
600 LLVM_DEBUG(
DBGS() <<
"no stride\n");
604 auto it = valueMapping.find(op.getVector());
605 if (it == valueMapping.end()) {
606 LLVM_DEBUG(
DBGS() <<
"no mapping\n");
610 Value matrix = it->second;
611 auto store = gpu::SubgroupMmaStoreMatrixOp::create(
612 rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(),
616 LLVM_DEBUG(
DBGS() <<
"transfer write to: " << store <<
"\n");
618 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
629 if (
auto vecType = dyn_cast<VectorType>(elType))
630 elType = vecType.getElementType();
641 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
643 if (failed(warpMatrixInfo)) {
644 LLVM_DEBUG(
DBGS() <<
"no warpMatrixInfo\n");
648 FailureOr<nvgpu::FragmentElementInfo> regInfo =
650 if (failed(regInfo)) {
651 LLVM_DEBUG(
DBGS() <<
"not mma sync reg info\n");
656 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
658 LLVM_DEBUG(
DBGS() <<
"not a splat\n");
662 Value result = arith::ConstantOp::create(
663 rewriter, op.getLoc(), vectorType,
665 valueMapping[op.getResult()] = result;
680 LLVM_DEBUG(
DBGS() <<
"Failed because the result of `vector.transfer_read` "
681 "is not a 2d operand\n");
690 auto exprM = dyn_cast<AffineDimExpr>(dM);
691 auto exprN = dyn_cast<AffineDimExpr>(dN);
693 if (!exprM || !exprN) {
694 LLVM_DEBUG(
DBGS() <<
"Failed because expressions are not affine dim "
695 "expressions, then transpose cannot be determined.\n");
699 return exprM.getPosition() > exprN.getPosition();
709 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
711 if (failed(warpMatrixInfo)) {
712 LLVM_DEBUG(
DBGS() <<
"no warpMatrixInfo\n");
716 FailureOr<nvgpu::FragmentElementInfo> regInfo =
718 if (failed(regInfo)) {
719 LLVM_DEBUG(
DBGS() <<
"not mma sync reg info\n");
724 if (failed(transpose)) {
725 LLVM_DEBUG(
DBGS() <<
"failed to determine the transpose\n");
727 op,
"Op should likely not be converted to a nvgpu.ldmatrix call.");
730 FailureOr<nvgpu::LdMatrixParams> params =
733 if (failed(params)) {
736 <<
"failed to convert vector.transfer_read to ldmatrix. "
737 <<
"Op should likely not be converted to a nvgpu.ldmatrix call.\n");
739 op,
"failed to convert vector.transfer_read to ldmatrix; this op "
740 "likely should not be converted to a nvgpu.ldmatrix call.");
744 auto laneId = gpu::LaneIdOp::create(rewriter, loc,
nullptr);
745 FailureOr<AffineMap> offsets =
747 if (failed(offsets)) {
748 LLVM_DEBUG(
DBGS() <<
"no offsets\n");
755 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
758 nvgpu::LdMatrixOp newOp =
759 nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(),
760 indices, *transpose, params->numTiles);
761 valueMapping[op] = newOp->getResult(0);
772 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
774 if (failed(warpMatrixInfo))
776 FailureOr<nvgpu::FragmentElementInfo> regInfo =
778 if (failed(regInfo)) {
780 op,
"Failed to deduce register fragment type during "
781 "conversion to distributed non-ldmatrix compatible load");
784 Value laneId = gpu::LaneIdOp::create(rewriter, loc,
nullptr);
787 Type loadedElType = regInfo->registerLLVMType;
790 Value fill = arith::ConstantOp::create(
791 rewriter, op.getLoc(), vectorType.getElementType(),
792 rewriter.
getZeroAttr(vectorType.getElementType()));
794 vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill);
796 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
800 if (!isTransposeLoad) {
801 if (!isa<VectorType>(loadedElType)) {
805 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
807 rewriter, op.getLoc(), *warpMatrixInfo);
811 Value logicalValueId = arith::ConstantOp::create(
813 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
815 getXferIndices<vector::TransferReadOp>(
816 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
818 Value el = vector::LoadOp::create(rewriter, loc, loadedElType,
819 op.getBase(), newIndices);
820 result = vector::InsertOp::create(rewriter, loc, el, result, i);
823 if (
auto vecType = dyn_cast<VectorType>(loadedElType)) {
824 loadedElType = vecType.getElementType();
826 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
827 for (
unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
830 Value logicalValueId = arith::ConstantOp::create(
832 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
834 rewriter, op.getLoc(), *warpMatrixInfo);
839 getXferIndices<vector::TransferReadOp>(
840 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
841 Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType,
842 op.getBase(), newIndices);
843 result = vector::InsertOp::create(rewriter, op.getLoc(), el, result,
849 valueMapping[op.getResult()] = result;
856 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
857 return addressSpace &&
858 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
870 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
872 if (failed(warpMatrixInfo))
875 bool isLdMatrixCompatible =
879 VectorType vecTy = op.getVectorType();
880 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
885 if (!op.getPermutationMap().isMinorIdentity() &&
886 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
887 vecTy.getDimSize(0) * bitWidth < 128))
888 isLdMatrixCompatible =
false;
890 if (!isLdMatrixCompatible)
903 auto it = valueMapping.find(op.getVector());
904 if (it == valueMapping.end())
906 Value matrix = it->second;
908 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
910 if (failed(warpMatrixInfo))
912 FailureOr<nvgpu::FragmentElementInfo> regInfo =
918 Value laneId = gpu::LaneIdOp::create(rewriter, loc,
nullptr);
920 for (
unsigned i = 0; i < vectorType.getShape()[0]; i++) {
921 Value logicalValueId = arith::ConstantOp::create(
923 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
925 rewriter, op.getLoc(), *warpMatrixInfo);
932 getXferIndices<vector::TransferWriteOp>(
933 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
934 vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
937 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
944 for (
auto attr : arrayAttr)
945 results.push_back(cast<IntegerAttr>(attr).getInt());
950 vector::ExtractStridedSliceOp op,
957 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
959 if (failed(warpMatrixInfo))
962 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
964 if (failed(mmaSyncFragmentInfo))
968 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
973 if (failed(warpMatrixInfo))
976 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
978 if (failed(ldFragmentInfo))
982 (mmaSyncFragmentInfo->elementsPerRegister ==
983 ldFragmentInfo->elementsPerRegister) &&
984 "Number of elements per register should be same for load and mma.sync");
987 std::array<int64_t, 2> strides = {1,
989 std::array<int64_t, 2> sliceShape = {
990 mmaSyncFragmentInfo->numRegistersPerFragment,
991 mmaSyncFragmentInfo->elementsPerRegister};
992 auto it = valueMapping.find(transferReadOp);
993 if (it == valueMapping.end())
995 auto sourceVector = it->second;
1008 std::array<int64_t, 2> sliceOffset = {0, 0};
1010 if (offsets[0] && offsets[1])
1011 return op->emitError() <<
"Slicing fragments in 2D is not supported. ";
1013 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1014 else if (offsets[1])
1015 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1017 Value newOp = vector::ExtractStridedSliceOp::create(
1018 rewriter, loc, sourceVector, sliceOffset, sliceShape, strides);
1020 valueMapping[op] = newOp;
1024 static LogicalResult
1030 auto itA = valueMapping.find(op.getLhs());
1031 auto itB = valueMapping.find(op.getRhs());
1032 auto itC = valueMapping.find(op.getAcc());
1033 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1034 itC == valueMapping.end())
1036 Value opA = itA->second, opB = itB->second, opC = itC->second;
1037 Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(),
1041 valueMapping[op.getResult()] = matmul;
1045 static LogicalResult
1051 auto itA = valueMapping.find(op.getLhs());
1052 auto itB = valueMapping.find(op.getRhs());
1053 auto itC = valueMapping.find(op.getAcc());
1054 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1055 itC == valueMapping.end())
1057 Value opA = itA->second, opB = itB->second, opC = itC->second;
1058 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1059 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1060 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1061 Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
1063 valueMapping[op.getResult()] = matmul;
1068 static LogicalResult
1077 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1078 auto scalarConstant =
1079 arith::ConstantOp::create(rewriter, op.getLoc(), splat.
getType(), splat);
1081 auto vecType = cast<VectorType>(op.getType());
1083 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1084 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1085 type, scalarConstant);
1086 valueMapping[op.getResult()] = matrix;
1091 static LogicalResult
1100 auto vecType = op.getResultVectorType();
1102 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1103 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1104 type, op.getSource());
1105 valueMapping[op.getResult()] = matrix;
1119 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1120 llvm::append_range(operands, newInitArgs);
1121 scf::ForOp newLoop =
1122 scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
1123 loop.getUpperBound(), loop.getStep(), operands);
1126 newLoop.getRegion().getBlocks().splice(
1127 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1128 for (
Value operand : newInitArgs)
1129 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1131 for (
auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1132 loop.getNumResults())))
1135 LLVM_DEBUG(
DBGS() <<
"newLoop now: " << newLoop <<
"\n");
1136 LLVM_DEBUG(
DBGS() <<
"stripped scf.for: " << loop <<
"\n");
1137 LLVM_DEBUG(
DBGS() <<
"erase: " << loop);
1151 auto it = valueMapping.find(operand.value());
1152 if (it == valueMapping.end()) {
1153 LLVM_DEBUG(
DBGS() <<
"no value mapping for: " << operand.value() <<
"\n");
1156 argMapping.push_back(std::make_pair(
1157 operand.index(), op.getInitArgs().size() + newOperands.size()));
1158 newOperands.push_back(it->second);
1162 Block &loopBody = *newForOp.getBody();
1163 for (
auto mapping : argMapping) {
1164 valueMapping[newForOp.getResult(mapping.first)] =
1165 newForOp.getResult(mapping.second);
1167 newForOp.getNumInductionVars())] =
1168 loopBody.
getArgument(mapping.second + newForOp.getNumInductionVars());
1171 LLVM_DEBUG(
DBGS() <<
"scf.for to: " << newForOp <<
"\n");
1175 static LogicalResult
1181 auto loop = cast<scf::ForOp>(op->getParentOp());
1182 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1184 auto it = valueMapping.find(operand.value());
1185 if (it == valueMapping.end())
1189 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1190 yieldOperands.push_back(it->second);
1192 scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
1194 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
1200 static LogicalResult
1202 gpu::MMAElementwiseOp opType,
1209 auto it = valueMapping.find(operand);
1210 if (it == valueMapping.end())
1212 matrixOperands.push_back(it->second);
1214 auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].
getType());
1215 if (opType == gpu::MMAElementwiseOp::EXTF) {
1219 vectorType.getElementType(),
1220 resultType.getOperand());
1223 Value newOp = gpu::SubgroupMmaElementwiseOp::create(
1224 rewriter, op->
getLoc(), resultType, matrixOperands, opType);
1232 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1245 auto globalRes = LogicalResult::success();
1247 LLVM_DEBUG(
DBGS() <<
"Process op: " << *op <<
"\n");
1249 auto res = LogicalResult::success();
1250 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1252 }
else if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1254 }
else if (
auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1256 }
else if (
auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1258 }
else if (
auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1260 }
else if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
1262 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1268 globalRes = failure();
1279 .Case([&](vector::TransferReadOp transferReadOp) {
1283 .Case([&](vector::TransferWriteOp transferWriteOp) {
1287 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1291 .Case([&](vector::ContractionOp contractionOp) {
1295 .Case([&](scf::ForOp forOp) {
1298 .Case([&](scf::YieldOp yieldOp) {
1301 .Case([&](arith::ConstantOp constOp) {
1305 return op->
emitError() <<
"unhandled vector to mma type: " << *op;
1309 <<
"failed to convert op during vector-to-nvgpu conversion";
1317 struct ConvertVectorToGPUPass
1318 :
public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1320 explicit ConvertVectorToGPUPass(
bool useNvGpu_) {
1321 useNvGpu.setValue(useNvGpu_);
1324 void runOnOperation()
override {
1328 return signalPassFailure();
1334 return signalPassFailure();
1344 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
user_iterator user_begin()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This trait tags element-wise ops on vectors or tensors.
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
int64_t elementsPerRegister
int64_t numRegistersPerFragment