32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/DebugLog.h"
36#define DEBUG_TYPE "vector-to-gpu"
39#define GEN_PASS_DEF_CONVERTVECTORTOGPU
40#include "mlir/Conversion/Passes.h.inc"
51template <
typename TransferOpType>
55 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
57 unsigned offsetsIdx = 0;
58 for (
auto expr : xferOp.getPermutationMap().getResults()) {
59 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
62 dims.push_back(prevIdx);
65 rewriter, loc, d0 + offsetMap.
getResult(offsetsIdx++), dims);
75 auto infer = [&](MapList m) {
80 auto iteratorTypes =
contract.getIteratorTypes().getValue();
89 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
92 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
106 const unsigned nDim = permutationMap.
getNumDims();
107 if (0 == nDim || permutationMap.
getResults().empty())
126static std::optional<int64_t>
128 auto memrefType = dyn_cast<MemRefType>(type);
132 if (memrefType.getRank() < 2)
136 if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
143 unsigned strideIndex = strides.size();
146 if (
auto cst = dyn_cast<AffineConstantExpr>(
result)) {
148 if (0 != cst.getValue())
153 auto dim = dyn_cast<AffineDimExpr>(
result);
157 strideIndex = std::min(strideIndex, dim.getPosition());
163 if (strideIndex + 1 >= strides.size())
166 const int64_t stride = strides[strideIndex];
167 if (stride == ShapedType::kDynamic)
174 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
175 readOp.getVectorType().getRank() != 2)
183 if (readOp.getVectorType().getElementType().isInteger(8))
184 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
185 !isa<arith::ExtUIOp>(*readOp->user_begin())))
190 return llvm::is_contained(permutationMap.
getResults(), innerDim);
197 if (writeOp.getTransferRank() == 0)
200 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
201 writeOp.getVectorType().getRank() != 2)
205 std::optional<int64_t> stride =
208 if (!stride.has_value() || stride.value() == 0)
214 return permutationMap.
getResult(1) == innerDim;
220 auto vecType = dyn_cast<VectorType>(constantOp.getType());
221 if (!vecType || vecType.getRank() != 2)
223 return isa<SplatElementsAttr>(constantOp.getValue());
228 return broadcastOp.getResultVectorType().getRank() == 2;
232template <
typename ExtOpTy>
234 auto transferReadOp =
235 extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
238 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
246static std::optional<gpu::MMAElementwiseOp>
248 using MMAEwO = gpu::MMAElementwiseOp;
250 .Case([](arith::AddFOp) {
return MMAEwO::ADDF; })
251 .Case([](arith::AddIOp) {
return MMAEwO::ADDI; })
252 .Case([](arith::DivFOp) {
return MMAEwO::DIVF; })
253 .Case([](arith::DivSIOp) {
return MMAEwO::DIVS; })
254 .Case([](arith::DivUIOp) {
return MMAEwO::DIVU; })
255 .Case([](arith::ExtFOp) {
return MMAEwO::EXTF; })
256 .Case([](arith::MaximumFOp) {
return MMAEwO::MAXF; })
257 .Case([](arith::MinimumFOp) {
return MMAEwO::MINF; })
258 .Case([](arith::MulFOp) {
return MMAEwO::MULF; })
259 .Case([](arith::MulIOp) {
return MMAEwO::MULI; })
260 .Case([](arith::NegFOp) {
return MMAEwO::NEGATEF; })
261 .Case([](arith::SubFOp) {
return MMAEwO::SUBF; })
262 .Case([](arith::SubIOp) {
return MMAEwO::SUBI; })
263 .Case([](arith::TruncFOp) {
return MMAEwO::TRUNCF; })
264 .Default(std::nullopt);
277 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
279 if (failed(warpMatrixInfo))
283 if (failed(contractOp))
290 return (cast<VectorType>(op->getResult(0).getType()) ==
291 cast<VectorType>((*contractOp).getRhs().getType()));
293 return (cast<VectorType>(op->getResult(0).getType()) ==
294 cast<VectorType>((*contractOp).getAcc().getType()));
300 if (isa<scf::ForOp, scf::YieldOp>(op))
302 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op))
305 if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
308 if (
auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
311 if (
auto contract = dyn_cast<vector::ContractionOp>(op))
313 if (
auto constant = dyn_cast<arith::ConstantOp>(op))
315 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
317 if (
auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
319 if (
auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
321 if (
auto fpExtend = dyn_cast<arith::ExtFOp>(op))
323 if (
auto fpTrunc = dyn_cast<arith::TruncFOp>(op))
337 unsigned currentIndex = 0;
340 while (currentIndex != slice.size()) {
341 auto *currentOp = (slice)[currentIndex];
343 backwardSlice.clear();
346 assert(
result.succeeded() &&
"expected a backward slice");
348 slice.insert_range(backwardSlice);
351 forwardSlice.clear();
356 if (
auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
357 for (
Value forOpResult : forOp.getResults())
364 slice.insert_range(forwardSlice);
375 return llvm::any_of(op->
getResultTypes(), llvm::IsaPred<VectorType>);
378 backwardSliceOptions.
filter = hasVectorDest;
384 forwardSliceOptions.
filter = hasVectorSrc;
388 if (!isa<vector::ContractionOp>(nestedOp) &&
391 if (opToConvert.contains(nestedOp))
398 if (llvm::any_of(dependentOps, [useNvGpu](
Operation *op) {
400 LDBG() <<
"cannot convert op: " << *op;
407 opToConvert.insert_range(dependentOps);
416struct PrepareContractToGPUMMA
420 LogicalResult matchAndRewrite(vector::ContractionOp op,
421 PatternRewriter &rewriter)
const override {
422 Location loc = op.getLoc();
423 Value
lhs = op.getLhs(),
rhs = op.getRhs(), res = op.getAcc();
426 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
427 auto infer = [&](MapList m) {
432 static constexpr std::array<int64_t, 2> perm = {1, 0};
433 auto iteratorTypes = op.getIteratorTypes().getValue();
434 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
443 if (maps == infer({{m, k}, {k, n}, {m, n}}))
445 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
446 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
447 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
448 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
449 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
450 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
451 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
452 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
454 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
455 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
456 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
458 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
459 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
461 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
462 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
471 op.getIteratorTypes());
480struct CombineTransferReadOpTranspose final
484 LogicalResult matchAndRewrite(vector::TransposeOp op,
485 PatternRewriter &rewriter)
const override {
487 Value source = op.getVector();
488 Type resultType = op.getType();
495 VectorType::get(cast<VectorType>(resultType).
getShape(),
496 cast<VectorType>(source.
getType()).getElementType());
499 auto transferReadOp = source.
getDefiningOp<vector::TransferReadOp>();
504 if (transferReadOp.getTransferRank() == 0)
507 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
510 AffineMap permutationMap =
513 permutationMap.
compose(transferReadOp.getPermutationMap());
515 auto loc = op.getLoc();
516 Value
result = vector::TransferReadOp::create(
517 rewriter, loc, resultType, transferReadOp.getBase(),
518 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
519 transferReadOp.getPadding(), transferReadOp.getMask(),
520 transferReadOp.getInBoundsAttr())
525 if (isa<arith::ExtSIOp>(extOp))
526 result = arith::ExtSIOp::create(rewriter, loc, op.getType(),
result)
528 else if (isa<arith::ExtUIOp>(extOp))
529 result = arith::ExtUIOp::create(rewriter, loc, op.getType(),
result)
532 result = arith::ExtFOp::create(rewriter, loc, op.getType(),
result)
557 auto contract = dyn_cast<vector::ContractionOp>(users);
575 assert(op.getTransferRank() > 0 &&
"unexpected 0-d transfer");
577 "expected convertible operation");
580 std::optional<int64_t> stride =
582 if (!stride.has_value()) {
583 LDBG() <<
"no stride";
593 Value mappingResult = op.getResult();
594 auto elType = op.getVectorType().getElementType();
596 if (op->hasOneUse()) {
597 auto *user = *op->user_begin();
599 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
600 elType = IntegerType::get(
601 op.getContext(), cast<IntegerType>(elType).getWidth(),
602 isa<arith::ExtSIOp>(user) ? IntegerType::Signed
603 : IntegerType::Unsigned);
604 mappingResult = user->getResult(0);
609 Value load = gpu::SubgroupMmaLoadMatrixOp::create(
610 rewriter, op.getLoc(), type, op.getBase(), op.getIndices(),
612 isTranspose ? rewriter.
getUnitAttr() : UnitAttr());
613 valueMapping[mappingResult] =
load;
615 LDBG() <<
"transfer read to: " <<
load;
626 std::optional<int64_t> stride =
628 if (!stride.has_value()) {
629 LDBG() <<
"no stride";
633 auto it = valueMapping.find(op.getVector());
634 if (it == valueMapping.end()) {
635 LDBG() <<
"no mapping";
639 Value matrix = it->second;
640 auto store = gpu::SubgroupMmaStoreMatrixOp::create(
641 rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(),
645 LDBG() <<
"transfer write to: " << store;
647 LDBG() <<
"erase: " << op;
656 regInfo.elementsPerRegister};
657 Type elType = regInfo.registerLLVMType;
658 if (
auto vecType = dyn_cast<VectorType>(elType))
659 elType = vecType.getElementType();
660 return VectorType::get(
shape, elType);
670 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
672 if (failed(warpMatrixInfo)) {
673 LDBG() <<
"no warpMatrixInfo";
677 FailureOr<nvgpu::FragmentElementInfo> regInfo =
678 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
679 if (failed(regInfo)) {
680 LDBG() <<
"not mma sync reg info";
685 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
687 LDBG() <<
"not a splat";
692 rewriter, op.getLoc(), vectorType,
694 valueMapping[op.getResult()] =
result;
709 LDBG() <<
"Failed because the result of `vector.transfer_read` "
710 "is not a 2d operand";
719 auto exprM = dyn_cast<AffineDimExpr>(dM);
720 auto exprN = dyn_cast<AffineDimExpr>(dN);
722 if (!exprM || !exprN) {
723 LDBG() <<
"Failed because expressions are not affine dim "
724 "expressions, then transpose cannot be determined.";
728 return exprM.getPosition() > exprN.getPosition();
738 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
740 if (failed(warpMatrixInfo)) {
741 LDBG() <<
"no warpMatrixInfo";
745 FailureOr<nvgpu::FragmentElementInfo> regInfo =
746 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
747 if (failed(regInfo)) {
748 LDBG() <<
"not mma sync reg info";
753 if (failed(transpose)) {
754 LDBG() <<
"failed to determine the transpose";
756 op,
"Op should likely not be converted to a nvgpu.ldmatrix call.");
759 FailureOr<nvgpu::LdMatrixParams> params =
760 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
762 if (failed(params)) {
763 LDBG() <<
"failed to convert vector.transfer_read to ldmatrix. "
764 <<
"Op should likely not be converted to a nvgpu.ldmatrix call.";
766 op,
"failed to convert vector.transfer_read to ldmatrix; this op "
767 "likely should not be converted to a nvgpu.ldmatrix call.");
771 auto laneId = gpu::LaneIdOp::create(rewriter, loc,
nullptr);
772 FailureOr<AffineMap> offsets =
773 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
774 if (failed(offsets)) {
775 LDBG() <<
"no offsets";
785 nvgpu::LdMatrixOp newOp =
786 nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(),
787 indices, *transpose, params->numTiles);
788 valueMapping[op] = newOp->getResult(0);
799 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
801 if (failed(warpMatrixInfo))
803 FailureOr<nvgpu::FragmentElementInfo> regInfo =
804 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
805 if (failed(regInfo)) {
807 op,
"Failed to deduce register fragment type during "
808 "conversion to distributed non-ldmatrix compatible load");
811 Value laneId = gpu::LaneIdOp::create(rewriter, loc,
nullptr);
814 Type loadedElType = regInfo->registerLLVMType;
817 Value fill = arith::ConstantOp::create(
818 rewriter, op.getLoc(), vectorType.getElementType(),
819 rewriter.
getZeroAttr(vectorType.getElementType()));
821 vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill);
823 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
827 if (!isTransposeLoad) {
828 if (!isa<VectorType>(loadedElType)) {
829 loadedElType = VectorType::get({1}, loadedElType);
832 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
833 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
834 rewriter, op.getLoc(), *warpMatrixInfo);
838 Value logicalValueId = arith::ConstantOp::create(
840 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
843 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
845 Value el = vector::LoadOp::create(rewriter, loc, loadedElType,
846 op.getBase(), newIndices);
847 result = vector::InsertOp::create(rewriter, loc, el,
result, i);
850 if (
auto vecType = dyn_cast<VectorType>(loadedElType)) {
851 loadedElType = vecType.getElementType();
853 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
854 for (
unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
857 Value logicalValueId = arith::ConstantOp::create(
859 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
860 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
861 rewriter, op.getLoc(), *warpMatrixInfo);
867 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
868 Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType,
869 op.getBase(), newIndices);
870 result = vector::InsertOp::create(rewriter, op.getLoc(), el,
result,
876 valueMapping[op.getResult()] =
result;
883 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
884 return addressSpace &&
885 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
897 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
899 if (failed(warpMatrixInfo))
902 bool isLdMatrixCompatible =
904 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
906 VectorType vecTy = op.getVectorType();
907 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
912 if (!op.getPermutationMap().isMinorIdentity() &&
913 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
914 vecTy.getDimSize(0) * bitWidth < 128))
915 isLdMatrixCompatible =
false;
917 if (!isLdMatrixCompatible)
930 auto it = valueMapping.find(op.getVector());
931 if (it == valueMapping.end())
933 Value matrix = it->second;
935 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
937 if (failed(warpMatrixInfo))
939 FailureOr<nvgpu::FragmentElementInfo> regInfo =
940 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
945 Value laneId = gpu::LaneIdOp::create(rewriter, loc,
nullptr);
947 for (
unsigned i = 0; i < vectorType.getShape()[0]; i++) {
948 Value logicalValueId = arith::ConstantOp::create(
950 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
951 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
952 rewriter, op.getLoc(), *warpMatrixInfo);
960 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
961 vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
964 LDBG() <<
"erase: " << op;
971 for (
auto attr : arrayAttr)
972 results.push_back(cast<IntegerAttr>(attr).getInt());
977 vector::ExtractStridedSliceOp op,
984 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
986 if (failed(warpMatrixInfo))
989 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
990 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
991 if (failed(mmaSyncFragmentInfo))
995 auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
1000 if (failed(warpMatrixInfo))
1003 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
1004 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
1005 if (failed(ldFragmentInfo))
1009 (mmaSyncFragmentInfo->elementsPerRegister ==
1010 ldFragmentInfo->elementsPerRegister) &&
1011 "Number of elements per register should be same for load and mma.sync");
1014 std::array<int64_t, 2> strides = {1,
1016 std::array<int64_t, 2> sliceShape = {
1017 mmaSyncFragmentInfo->numRegistersPerFragment,
1018 mmaSyncFragmentInfo->elementsPerRegister};
1019 auto it = valueMapping.find(transferReadOp);
1020 if (it == valueMapping.end())
1022 auto sourceVector = it->second;
1035 std::array<int64_t, 2> sliceOffset = {0, 0};
1037 if (offsets[0] && offsets[1])
1038 return op->emitError() <<
"Slicing fragments in 2D is not supported. ";
1040 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1041 else if (offsets[1])
1042 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1044 Value newOp = vector::ExtractStridedSliceOp::create(
1045 rewriter, loc, sourceVector, sliceOffset, sliceShape, strides);
1047 valueMapping[op] = newOp;
1057 auto itA = valueMapping.find(op.getLhs());
1058 auto itB = valueMapping.find(op.getRhs());
1059 auto itC = valueMapping.find(op.getAcc());
1060 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1061 itC == valueMapping.end())
1063 Value opA = itA->second, opB = itB->second, opC = itC->second;
1064 Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(),
1068 valueMapping[op.getResult()] = matmul;
1078 auto itA = valueMapping.find(op.getLhs());
1079 auto itB = valueMapping.find(op.getRhs());
1080 auto itC = valueMapping.find(op.getAcc());
1081 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1082 itC == valueMapping.end())
1084 Value opA = itA->second, opB = itB->second, opC = itC->second;
1085 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1086 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1087 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1088 Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
1090 valueMapping[op.getResult()] = matmul;
1104 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1105 auto scalarConstant =
1106 arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat);
1108 auto vecType = cast<VectorType>(op.getType());
1110 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1111 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1112 type, scalarConstant);
1113 valueMapping[op.getResult()] = matrix;
1127 auto vecType = op.getResultVectorType();
1129 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1130 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1131 type, op.getSource());
1132 valueMapping[op.getResult()] = matrix;
1146 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1147 llvm::append_range(operands, newInitArgs);
1148 scf::ForOp newLoop =
1149 scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
1150 loop.getUpperBound(), loop.getStep(), operands);
1153 newLoop.getRegion().getBlocks().splice(
1154 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1155 for (
Value operand : newInitArgs)
1156 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1158 for (
auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1159 loop.getNumResults())))
1162 LDBG() <<
"newLoop now: " << newLoop;
1163 LDBG() <<
"stripped scf.for: " << loop;
1164 LDBG() <<
"erase: " << loop;
1177 for (
const auto &operand : llvm::enumerate(op.getInitArgs())) {
1178 auto it = valueMapping.find(operand.value());
1179 if (it == valueMapping.end()) {
1180 LDBG() <<
"no value mapping for: " << operand.value();
1183 argMapping.push_back(std::make_pair(
1184 operand.index(), op.getInitArgs().size() + newOperands.size()));
1185 newOperands.push_back(it->second);
1189 Block &loopBody = *newForOp.getBody();
1190 for (
auto mapping : argMapping) {
1191 valueMapping[newForOp.getResult(mapping.first)] =
1192 newForOp.getResult(mapping.second);
1194 newForOp.getNumInductionVars())] =
1195 loopBody.
getArgument(mapping.second + newForOp.getNumInductionVars());
1198 LDBG() <<
"scf.for to: " << newForOp;
1208 auto loop = cast<scf::ForOp>(op->getParentOp());
1209 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1210 for (
const auto &operand : llvm::enumerate(op.getOperands())) {
1211 auto it = valueMapping.find(operand.value());
1212 if (it == valueMapping.end())
1216 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1217 yieldOperands.push_back(it->second);
1219 scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
1221 LDBG() <<
"erase: " << op;
1229 gpu::MMAElementwiseOp opType,
1236 auto it = valueMapping.find(operand);
1237 if (it == valueMapping.end())
1239 matrixOperands.push_back(it->second);
1241 auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].
getType());
1242 if (opType == gpu::MMAElementwiseOp::EXTF ||
1243 opType == gpu::MMAElementwiseOp::TRUNCF) {
1247 vectorType.getElementType(),
1248 resultType.getOperand());
1251 Value newOp = gpu::SubgroupMmaElementwiseOp::create(
1252 rewriter, op->
getLoc(), resultType, matrixOperands, opType);
1260 patterns.
add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1265 patterns.
add<CombineTransferReadOpTranspose>(patterns.
getContext());
1273 auto globalRes = LogicalResult::success();
1275 LDBG() <<
"Process op: " << *op;
1277 auto res = LogicalResult::success();
1278 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1280 }
else if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1282 }
else if (
auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1284 }
else if (
auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1286 }
else if (
auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1288 }
else if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
1290 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1296 globalRes = failure();
1307 .Case([&](vector::TransferReadOp transferReadOp) {
1311 .Case([&](vector::TransferWriteOp transferWriteOp) {
1315 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1319 .Case([&](vector::ContractionOp contractionOp) {
1323 .Case([&](scf::ForOp forOp) {
1326 .Case([&](scf::YieldOp yieldOp) {
1329 .Case([&](arith::ConstantOp constOp) {
1333 return op->
emitError() <<
"unhandled vector to mma type: " << *op;
1337 <<
"failed to convert op during vector-to-nvgpu conversion";
1345struct ConvertVectorToGPUPass
1346 :
public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1348 explicit ConvertVectorToGPUPass(
bool useNvGpu_) {
1349 useNvGpu.setValue(useNvGpu_);
1352 void runOnOperation()
override {
1356 return signalPassFailure();
1362 return signalPassFailure();
1372 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool fpTruncSupportsMMAMatrixType(arith::TruncFOp extOp)
static const char * inferFragType(Operation *op)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info)
Returns the vector type which represents a matrix fragment.
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool isFirstResultLastMapDimension(AffineMap permutationMap)
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
IntegerAttr getIndexAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
user_iterator user_begin()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op)
Returns the number of bits in a single tile row.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction with M...
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
SliceOptions ForwardSliceOptions
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
llvm::TypeSwitch< T, ResultT > TypeSwitch
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This trait tags element-wise ops on vectors or tensors.