15 #include <type_traits>
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
56 template <
typename TransferOpType>
60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
62 unsigned offsetsIdx = 0;
63 for (
auto expr : xferOp.getPermutationMap().getResults()) {
64 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
65 Value prevIdx = indices[dim.getPosition()];
67 dims.push_back(prevIdx);
70 rewriter, loc, d0 + offsetMap.
getResult(offsetsIdx++), dims);
80 auto infer = [&](MapList m) {
85 auto iteratorTypes =
contract.getIteratorTypes().getValue();
94 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
97 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
120 return permutationMap ==
AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
127 auto memrefType = dyn_cast<MemRefType>(type);
131 if (memrefType.getRank() < 2)
138 int64_t stride = strides[strides.size() - 2];
139 if (stride == ShapedType::kDynamic)
146 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147 readOp.getVectorType().getRank() != 2)
153 if (readOp.getVectorType().getElementType().isInteger(8))
154 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155 !isa<arith::ExtUIOp>(*readOp->user_begin())))
162 auto broadcastInnerDim =
172 if (writeOp.getTransferRank() == 0)
175 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176 writeOp.getVectorType().getRank() != 2)
181 if (!writeOp.getPermutationMap().isMinorIdentity())
189 auto vecType = dyn_cast<VectorType>(constantOp.getType());
190 if (!vecType || vecType.getRank() != 2)
192 return isa<SplatElementsAttr>(constantOp.getValue());
197 return broadcastOp.getResultVectorType().getRank() == 2;
201 template <
typename ExtOpTy>
203 auto transferReadOp =
204 extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
207 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
214 static std::optional<gpu::MMAElementwiseOp>
216 if (isa<arith::AddFOp>(op))
217 return gpu::MMAElementwiseOp::ADDF;
218 if (isa<arith::MulFOp>(op))
219 return gpu::MMAElementwiseOp::MULF;
220 if (isa<arith::SubFOp>(op))
221 return gpu::MMAElementwiseOp::SUBF;
222 if (isa<arith::MaximumFOp>(op))
223 return gpu::MMAElementwiseOp::MAXF;
224 if (isa<arith::MinimumFOp>(op))
225 return gpu::MMAElementwiseOp::MINF;
226 if (isa<arith::DivFOp>(op))
227 return gpu::MMAElementwiseOp::DIVF;
228 if (isa<arith::AddIOp>(op))
230 if (isa<arith::MulIOp>(op))
232 if (isa<arith::SubIOp>(op))
234 if (isa<arith::DivSIOp>(op))
235 return gpu::MMAElementwiseOp::DIVS;
236 if (isa<arith::DivUIOp>(op))
237 return gpu::MMAElementwiseOp::DIVU;
238 if (isa<arith::NegFOp>(op))
239 return gpu::MMAElementwiseOp::NEGATEF;
240 if (isa<arith::ExtFOp>(op))
241 return gpu::MMAElementwiseOp::EXTF;
255 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
257 if (failed(warpMatrixInfo))
261 if (failed(contractOp))
268 return (cast<VectorType>(op->getResult(0).getType()) ==
269 cast<VectorType>((*contractOp).getRhs().getType()));
271 return (cast<VectorType>(op->getResult(0).getType()) ==
272 cast<VectorType>((*contractOp).getAcc().getType()));
278 if (isa<scf::ForOp, scf::YieldOp>(op))
280 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op))
283 if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
286 if (
auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
289 if (
auto contract = dyn_cast<vector::ContractionOp>(op))
291 if (
auto constant = dyn_cast<arith::ConstantOp>(op))
293 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
295 if (
auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
296 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
297 if (
auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
298 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
299 if (
auto fpExtend = dyn_cast<arith::ExtFOp>(op))
313 unsigned currentIndex = 0;
316 while (currentIndex != slice.size()) {
317 auto *currentOp = (slice)[currentIndex];
319 backwardSlice.clear();
321 slice.insert(backwardSlice.begin(), backwardSlice.end());
324 forwardSlice.clear();
329 if (
auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
330 for (
Value forOpResult : forOp.getResults())
337 slice.insert(forwardSlice.begin(), forwardSlice.end());
348 return llvm::any_of(op->
getResultTypes(), llvm::IsaPred<VectorType>);
351 backwardSliceOptions.
filter = hasVectorDest;
357 forwardSliceOptions.
filter = hasVectorSrc;
361 if (opToConvert.contains(
contract.getOperation()))
368 if (llvm::any_of(dependentOps, [useNvGpu](
Operation *op) {
370 LLVM_DEBUG(
DBGS() <<
"cannot convert op: " << *op <<
"\n");
377 opToConvert.insert(dependentOps.begin(), dependentOps.end());
386 struct PrepareContractToGPUMMA
390 LogicalResult matchAndRewrite(vector::ContractionOp op,
393 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
397 auto infer = [&](MapList m) {
402 static constexpr std::array<int64_t, 2> perm = {1, 0};
403 auto iteratorTypes = op.getIteratorTypes().getValue();
413 if (maps == infer({{m, k}, {k, n}, {m, n}}))
415 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
416 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
417 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
418 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
419 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
420 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
421 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
422 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
424 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
425 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
426 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
428 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
429 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
431 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
432 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
441 op.getIteratorTypes());
450 struct CombineTransferReadOpTranspose final
454 LogicalResult matchAndRewrite(vector::TransposeOp op,
457 Value source = op.getVector();
458 Type resultType = op.getType();
466 cast<VectorType>(source.
getType()).getElementType());
469 auto transferReadOp = source.
getDefiningOp<vector::TransferReadOp>();
474 if (transferReadOp.getTransferRank() == 0)
477 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
483 permutationMap.
compose(transferReadOp.getPermutationMap());
485 auto loc = op.getLoc();
488 .
create<vector::TransferReadOp>(
489 loc, resultType, transferReadOp.getSource(),
491 transferReadOp.getPadding(), transferReadOp.getMask(),
492 transferReadOp.getInBoundsAttr())
497 if (isa<arith::ExtSIOp>(extOp))
498 result = rewriter.
create<arith::ExtSIOp>(loc, op.getType(), result)
500 else if (isa<arith::ExtUIOp>(extOp))
501 result = rewriter.
create<arith::ExtUIOp>(loc, op.getType(), result)
504 result = rewriter.
create<arith::ExtFOp>(loc, op.getType(), result)
529 auto contract = dyn_cast<vector::ContractionOp>(users);
547 assert(op.getTransferRank() > 0 &&
"unexpected 0-d transfer");
549 "expected convertible operation");
551 std::optional<int64_t> stride =
553 if (!stride.has_value()) {
554 LLVM_DEBUG(
DBGS() <<
"no stride\n");
562 if (
auto cstExpr = dyn_cast<AffineConstantExpr>(map.
getResult(isTranspose))) {
563 assert(cstExpr.getValue() == 0);
567 Value mappingResult = op.getResult();
568 auto elType = op.getVectorType().getElementType();
570 if (op->hasOneUse()) {
571 auto *user = *op->user_begin();
573 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
575 op.getContext(), cast<IntegerType>(elType).getWidth(),
577 : IntegerType::Unsigned);
578 mappingResult = user->getResult(0);
583 Value load = rewriter.
create<gpu::SubgroupMmaLoadMatrixOp>(
584 op.getLoc(), type, op.getSource(), op.getIndices(),
586 isTranspose ? rewriter.
getUnitAttr() : UnitAttr());
587 valueMapping[mappingResult] = load;
589 LLVM_DEBUG(
DBGS() <<
"transfer read to: " << load <<
"\n");
600 std::optional<int64_t> stride =
602 if (!stride.has_value()) {
603 LLVM_DEBUG(
DBGS() <<
"no stride\n");
607 auto it = valueMapping.find(op.getVector());
608 if (it == valueMapping.end()) {
609 LLVM_DEBUG(
DBGS() <<
"no mapping\n");
613 Value matrix = it->second;
614 auto store = rewriter.
create<gpu::SubgroupMmaStoreMatrixOp>(
615 op.getLoc(), matrix, op.getSource(), op.getIndices(),
619 LLVM_DEBUG(
DBGS() <<
"transfer write to: " << store <<
"\n");
621 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
632 if (
auto vecType = dyn_cast<VectorType>(elType))
633 elType = vecType.getElementType();
644 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
646 if (failed(warpMatrixInfo)) {
647 LLVM_DEBUG(
DBGS() <<
"no warpMatrixInfo\n");
651 FailureOr<nvgpu::FragmentElementInfo> regInfo =
653 if (failed(regInfo)) {
654 LLVM_DEBUG(
DBGS() <<
"not mma sync reg info\n");
659 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
661 LLVM_DEBUG(
DBGS() <<
"not a splat\n");
666 op.getLoc(), vectorType,
668 valueMapping[op.getResult()] = result;
683 LLVM_DEBUG(
DBGS() <<
"Failed because the result of `vector.transfer_read` "
684 "is not a 2d operand\n");
693 auto exprM = dyn_cast<AffineDimExpr>(dM);
694 auto exprN = dyn_cast<AffineDimExpr>(dN);
696 if (!exprM || !exprN) {
697 LLVM_DEBUG(
DBGS() <<
"Failed because expressions are not affine dim "
698 "expressions, then transpose cannot be determined.\n");
702 return exprM.getPosition() > exprN.getPosition();
712 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
714 if (failed(warpMatrixInfo)) {
715 LLVM_DEBUG(
DBGS() <<
"no warpMatrixInfo\n");
719 FailureOr<nvgpu::FragmentElementInfo> regInfo =
721 if (failed(regInfo)) {
722 LLVM_DEBUG(
DBGS() <<
"not mma sync reg info\n");
728 LLVM_DEBUG(
DBGS() <<
"failed to determine the transpose\n");
730 op,
"Op should likely not be converted to a nvgpu.ldmatrix call.");
733 FailureOr<nvgpu::LdMatrixParams> params =
736 if (failed(params)) {
739 <<
"failed to convert vector.transfer_read to ldmatrix. "
740 <<
"Op should likely not be converted to a nvgpu.ldmatrix call.\n");
742 op,
"failed to convert vector.transfer_read to ldmatrix; this op "
743 "likely should not be converted to a nvgpu.ldmatrix call.");
747 auto laneId = rewriter.
create<gpu::LaneIdOp>(loc,
nullptr);
748 FailureOr<AffineMap> offsets =
750 if (failed(offsets)) {
751 LLVM_DEBUG(
DBGS() <<
"no offsets\n");
758 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
761 nvgpu::LdMatrixOp newOp = rewriter.
create<nvgpu::LdMatrixOp>(
762 loc, vectorType, op.getSource(), indices, *
transpose, params->numTiles);
763 valueMapping[op] = newOp->getResult(0);
774 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
776 if (failed(warpMatrixInfo))
778 FailureOr<nvgpu::FragmentElementInfo> regInfo =
780 if (failed(regInfo)) {
782 op,
"Failed to deduce register fragment type during "
783 "conversion to distributed non-ldmatrix compatible load");
786 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc,
nullptr);
790 Type loadedElType = regInfo->registerLLVMType;
794 op.getLoc(), vectorType.getElementType(),
795 rewriter.
getZeroAttr(vectorType.getElementType()));
797 rewriter.
create<vector::SplatOp>(op.getLoc(), fill, vectorType);
799 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
803 if (!isTransposeLoad) {
804 if (!isa<VectorType>(loadedElType)) {
808 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
810 rewriter, op.getLoc(), *warpMatrixInfo);
814 Value logicalValueId = rewriter.
create<arith::ConstantOp>(
816 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
818 getXferIndices<vector::TransferReadOp>(
819 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
821 Value el = rewriter.
create<vector::LoadOp>(loc, loadedElType,
822 op.getSource(), newIndices);
823 result = rewriter.
create<vector::InsertOp>(loc, el, result, i);
826 if (
auto vecType = dyn_cast<VectorType>(loadedElType)) {
827 loadedElType = vecType.getElementType();
829 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
830 for (
unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
833 Value logicalValueId = rewriter.
create<arith::ConstantOp>(
835 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
837 rewriter, op.getLoc(), *warpMatrixInfo);
842 getXferIndices<vector::TransferReadOp>(
843 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
844 Value el = rewriter.
create<memref::LoadOp>(op.getLoc(), loadedElType,
845 op.getSource(), newIndices);
846 result = rewriter.
create<vector::InsertOp>(
852 valueMapping[op.getResult()] = result;
859 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
860 return addressSpace &&
861 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
873 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
875 if (failed(warpMatrixInfo))
878 bool isLdMatrixCompatible =
882 VectorType vecTy = op.getVectorType();
883 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
888 if (!op.getPermutationMap().isMinorIdentity() &&
889 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
890 vecTy.getDimSize(0) * bitWidth < 128))
891 isLdMatrixCompatible =
false;
893 if (!isLdMatrixCompatible)
906 auto it = valueMapping.find(op.getVector());
907 if (it == valueMapping.end())
909 Value matrix = it->second;
911 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
913 if (failed(warpMatrixInfo))
915 FailureOr<nvgpu::FragmentElementInfo> regInfo =
921 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc,
nullptr);
923 for (
unsigned i = 0; i < vectorType.getShape()[0]; i++) {
924 Value logicalValueId = rewriter.
create<arith::ConstantOp>(
926 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
928 rewriter, op.getLoc(), *warpMatrixInfo);
935 getXferIndices<vector::TransferWriteOp>(
936 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
937 rewriter.
create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
940 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
947 for (
auto attr : arrayAttr)
948 results.push_back(cast<IntegerAttr>(attr).getInt());
953 vector::ExtractStridedSliceOp op,
960 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
962 if (failed(warpMatrixInfo))
965 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
967 if (failed(mmaSyncFragmentInfo))
971 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
976 if (failed(warpMatrixInfo))
979 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
981 if (failed(ldFragmentInfo))
985 (mmaSyncFragmentInfo->elementsPerRegister ==
986 ldFragmentInfo->elementsPerRegister) &&
987 "Number of elements per register should be same for load and mma.sync");
990 std::array<int64_t, 2> strides = {1,
992 std::array<int64_t, 2> sliceShape = {
993 mmaSyncFragmentInfo->numRegistersPerFragment,
994 mmaSyncFragmentInfo->elementsPerRegister};
995 auto it = valueMapping.find(transferReadOp);
996 if (it == valueMapping.end())
998 auto sourceVector = it->second;
1011 std::array<int64_t, 2> sliceOffset = {0, 0};
1013 if (offsets[0] && offsets[1])
1014 return op->emitError() <<
"Slicing fragments in 2D is not supported. ";
1016 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1017 else if (offsets[1])
1018 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1020 Value newOp = rewriter.
create<vector::ExtractStridedSliceOp>(
1021 loc, sourceVector, sliceOffset, sliceShape, strides);
1023 valueMapping[op] = newOp;
1027 static LogicalResult
1033 auto itA = valueMapping.find(op.getLhs());
1034 auto itB = valueMapping.find(op.getRhs());
1035 auto itC = valueMapping.find(op.getAcc());
1036 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1037 itC == valueMapping.end())
1039 Value opA = itA->second, opB = itB->second, opC = itC->second;
1040 Value matmul = rewriter.
create<gpu::SubgroupMmaComputeOp>(
1041 op.getLoc(), opC.getType(), opA, opB, opC, UnitAttr(),
1043 valueMapping[op.getResult()] = matmul;
1047 static LogicalResult
1053 auto itA = valueMapping.find(op.getLhs());
1054 auto itB = valueMapping.find(op.getRhs());
1055 auto itC = valueMapping.find(op.getAcc());
1056 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1057 itC == valueMapping.end())
1059 Value opA = itA->second, opB = itB->second, opC = itC->second;
1060 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1061 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1062 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1065 valueMapping[op.getResult()] = matmul;
1070 static LogicalResult
1079 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1080 auto scalarConstant =
1081 rewriter.
create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1083 auto vecType = cast<VectorType>(op.getType());
1085 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1086 auto matrix = rewriter.
create<gpu::SubgroupMmaConstantMatrixOp>(
1087 op.getLoc(), type, scalarConstant);
1088 valueMapping[op.getResult()] = matrix;
1093 static LogicalResult
1102 auto vecType = op.getResultVectorType();
1104 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1105 auto matrix = rewriter.
create<gpu::SubgroupMmaConstantMatrixOp>(
1106 op.getLoc(), type, op.getSource());
1107 valueMapping[op.getResult()] = matrix;
1121 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1122 llvm::append_range(operands, newInitArgs);
1123 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
1124 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1128 newLoop.getRegion().getBlocks().splice(
1129 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1130 for (
Value operand : newInitArgs)
1131 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1133 for (
auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1134 loop.getNumResults())))
1137 LLVM_DEBUG(
DBGS() <<
"newLoop now: " << newLoop <<
"\n");
1138 LLVM_DEBUG(
DBGS() <<
"stripped scf.for: " << loop <<
"\n");
1139 LLVM_DEBUG(
DBGS() <<
"erase: " << loop);
1153 auto it = valueMapping.find(operand.value());
1154 if (it == valueMapping.end()) {
1155 LLVM_DEBUG(
DBGS() <<
"no value mapping for: " << operand.value() <<
"\n");
1158 argMapping.push_back(std::make_pair(
1159 operand.index(), op.getInitArgs().size() + newOperands.size()));
1160 newOperands.push_back(it->second);
1164 Block &loopBody = *newForOp.getBody();
1165 for (
auto mapping : argMapping) {
1166 valueMapping[newForOp.getResult(mapping.first)] =
1167 newForOp.getResult(mapping.second);
1169 newForOp.getNumInductionVars())] =
1170 loopBody.
getArgument(mapping.second + newForOp.getNumInductionVars());
1173 LLVM_DEBUG(
DBGS() <<
"scf.for to: " << newForOp <<
"\n");
1177 static LogicalResult
1183 auto loop = cast<scf::ForOp>(op->getParentOp());
1184 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1186 auto it = valueMapping.find(operand.value());
1187 if (it == valueMapping.end())
1191 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1192 yieldOperands.push_back(it->second);
1194 rewriter.
create<scf::YieldOp>(op.getLoc(), yieldOperands);
1196 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
1202 static LogicalResult
1204 gpu::MMAElementwiseOp opType,
1211 auto it = valueMapping.find(operand);
1212 if (it == valueMapping.end())
1214 matrixOperands.push_back(it->second);
1216 auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].
getType());
1217 if (opType == gpu::MMAElementwiseOp::EXTF) {
1221 vectorType.getElementType(),
1222 resultType.getOperand());
1225 Value newOp = rewriter.
create<gpu::SubgroupMmaElementwiseOp>(
1226 op->
getLoc(), resultType, matrixOperands, opType);
1234 patterns.
add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1239 patterns.
add<CombineTransferReadOpTranspose>(patterns.
getContext());
1247 auto globalRes = LogicalResult::success();
1249 LLVM_DEBUG(
DBGS() <<
"Process op: " << *op <<
"\n");
1251 auto res = LogicalResult::success();
1252 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1254 }
else if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1256 }
else if (
auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1258 }
else if (
auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1260 }
else if (
auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1262 }
else if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
1264 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1270 globalRes = failure();
1281 .Case([&](vector::TransferReadOp transferReadOp) {
1285 .Case([&](vector::TransferWriteOp transferWriteOp) {
1289 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1293 .Case([&](vector::ContractionOp contractionOp) {
1297 .Case([&](scf::ForOp forOp) {
1300 .Case([&](scf::YieldOp yieldOp) {
1303 .Case([&](arith::ConstantOp constOp) {
1307 return op->
emitError() <<
"unhandled vector to mma type: " << *op;
1311 <<
"failed to convert op during vector-to-nvgpu conversion";
1319 struct ConvertVectorToGPUPass
1320 :
public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1322 explicit ConvertVectorToGPUPass(
bool useNvGpu_) {
1323 useNvGpu.setValue(useNvGpu_);
1326 void runOnOperation()
override {
1331 return signalPassFailure();
1337 return signalPassFailure();
1347 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
user_iterator user_begin()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This trait tags element-wise ops on vectors or tensors.
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
int64_t elementsPerRegister
int64_t numRegistersPerFragment