32#include "llvm/ADT/STLExtras.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/DebugLog.h"
36#define DEBUG_TYPE "vector-to-gpu"
39#define GEN_PASS_DEF_CONVERTVECTORTOGPU
40#include "mlir/Conversion/Passes.h.inc"
51template <
typename TransferOpType>
55 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
57 unsigned offsetsIdx = 0;
58 for (
auto expr : xferOp.getPermutationMap().getResults()) {
59 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
62 dims.push_back(prevIdx);
65 rewriter, loc, d0 + offsetMap.
getResult(offsetsIdx++), dims);
75 auto infer = [&](MapList m) {
80 auto iteratorTypes =
contract.getIteratorTypes().getValue();
89 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
92 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
106 const unsigned nDim = permutationMap.
getNumDims();
107 if (0 == nDim || permutationMap.
getResults().empty())
126static std::optional<int64_t>
128 auto memrefType = dyn_cast<MemRefType>(type);
132 if (memrefType.getRank() < 2)
136 if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
143 unsigned strideIndex = strides.size();
146 if (
auto cst = dyn_cast<AffineConstantExpr>(
result)) {
148 if (0 != cst.getValue())
153 auto dim = dyn_cast<AffineDimExpr>(
result);
157 strideIndex = std::min(strideIndex, dim.getPosition());
163 if (strideIndex + 1 >= strides.size())
166 const int64_t stride = strides[strideIndex];
167 if (stride == ShapedType::kDynamic)
174 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
175 readOp.getVectorType().getRank() != 2)
183 if (readOp.getVectorType().getElementType().isInteger(8))
184 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
185 !isa<arith::ExtUIOp>(*readOp->user_begin())))
190 return llvm::is_contained(permutationMap.
getResults(), innerDim);
197 if (writeOp.getTransferRank() == 0)
200 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
201 writeOp.getVectorType().getRank() != 2)
205 std::optional<int64_t> stride =
208 if (!stride.has_value() || stride.value() == 0)
214 return permutationMap.
getResult(1) == innerDim;
220 auto vecType = dyn_cast<VectorType>(constantOp.getType());
221 if (!vecType || vecType.getRank() != 2)
223 return isa<SplatElementsAttr>(constantOp.getValue());
228 return broadcastOp.getResultVectorType().getRank() == 2;
232template <
typename ExtOpTy>
234 auto transferReadOp =
235 extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
238 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
245static std::optional<gpu::MMAElementwiseOp>
247 if (isa<arith::AddFOp>(op))
248 return gpu::MMAElementwiseOp::ADDF;
249 if (isa<arith::MulFOp>(op))
250 return gpu::MMAElementwiseOp::MULF;
251 if (isa<arith::SubFOp>(op))
252 return gpu::MMAElementwiseOp::SUBF;
253 if (isa<arith::MaximumFOp>(op))
254 return gpu::MMAElementwiseOp::MAXF;
255 if (isa<arith::MinimumFOp>(op))
256 return gpu::MMAElementwiseOp::MINF;
257 if (isa<arith::DivFOp>(op))
258 return gpu::MMAElementwiseOp::DIVF;
259 if (isa<arith::AddIOp>(op))
260 return gpu::MMAElementwiseOp::ADDI;
261 if (isa<arith::MulIOp>(op))
262 return gpu::MMAElementwiseOp::MULI;
263 if (isa<arith::SubIOp>(op))
264 return gpu::MMAElementwiseOp::SUBI;
265 if (isa<arith::DivSIOp>(op))
266 return gpu::MMAElementwiseOp::DIVS;
267 if (isa<arith::DivUIOp>(op))
268 return gpu::MMAElementwiseOp::DIVU;
269 if (isa<arith::NegFOp>(op))
270 return gpu::MMAElementwiseOp::NEGATEF;
271 if (isa<arith::ExtFOp>(op))
272 return gpu::MMAElementwiseOp::EXTF;
286 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
288 if (failed(warpMatrixInfo))
292 if (failed(contractOp))
299 return (cast<VectorType>(op->getResult(0).getType()) ==
300 cast<VectorType>((*contractOp).getRhs().getType()));
302 return (cast<VectorType>(op->getResult(0).getType()) ==
303 cast<VectorType>((*contractOp).getAcc().getType()));
309 if (isa<scf::ForOp, scf::YieldOp>(op))
311 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op))
314 if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
317 if (
auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
320 if (
auto contract = dyn_cast<vector::ContractionOp>(op))
322 if (
auto constant = dyn_cast<arith::ConstantOp>(op))
324 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
326 if (
auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
328 if (
auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
330 if (
auto fpExtend = dyn_cast<arith::ExtFOp>(op))
344 unsigned currentIndex = 0;
347 while (currentIndex != slice.size()) {
348 auto *currentOp = (slice)[currentIndex];
350 backwardSlice.clear();
353 assert(
result.succeeded() &&
"expected a backward slice");
355 slice.insert_range(backwardSlice);
358 forwardSlice.clear();
363 if (
auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
364 for (
Value forOpResult : forOp.getResults())
371 slice.insert_range(forwardSlice);
382 return llvm::any_of(op->
getResultTypes(), llvm::IsaPred<VectorType>);
385 backwardSliceOptions.
filter = hasVectorDest;
391 forwardSliceOptions.
filter = hasVectorSrc;
395 if (!isa<vector::ContractionOp>(nestedOp) &&
398 if (opToConvert.contains(nestedOp))
405 if (llvm::any_of(dependentOps, [useNvGpu](
Operation *op) {
407 LDBG() <<
"cannot convert op: " << *op;
414 opToConvert.insert_range(dependentOps);
423struct PrepareContractToGPUMMA
427 LogicalResult matchAndRewrite(vector::ContractionOp op,
428 PatternRewriter &rewriter)
const override {
429 Location loc = op.getLoc();
430 Value
lhs = op.getLhs(),
rhs = op.getRhs(), res = op.getAcc();
433 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
434 auto infer = [&](MapList m) {
439 static constexpr std::array<int64_t, 2> perm = {1, 0};
440 auto iteratorTypes = op.getIteratorTypes().getValue();
441 SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
450 if (maps == infer({{m, k}, {k, n}, {m, n}}))
452 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
453 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
454 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
455 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
456 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
457 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
458 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
459 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
461 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
462 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
463 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
465 rhs = vector::TransposeOp::create(rewriter, loc,
rhs, perm);
466 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
468 lhs = vector::TransposeOp::create(rewriter, loc,
lhs, perm);
469 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
478 op.getIteratorTypes());
487struct CombineTransferReadOpTranspose final
491 LogicalResult matchAndRewrite(vector::TransposeOp op,
492 PatternRewriter &rewriter)
const override {
494 Value source = op.getVector();
495 Type resultType = op.getType();
502 VectorType::get(cast<VectorType>(resultType).
getShape(),
503 cast<VectorType>(source.
getType()).getElementType());
506 auto transferReadOp = source.
getDefiningOp<vector::TransferReadOp>();
511 if (transferReadOp.getTransferRank() == 0)
514 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
517 AffineMap permutationMap =
520 permutationMap.
compose(transferReadOp.getPermutationMap());
522 auto loc = op.getLoc();
523 Value
result = vector::TransferReadOp::create(
524 rewriter, loc, resultType, transferReadOp.getBase(),
525 transferReadOp.getIndices(), AffineMapAttr::get(newMap),
526 transferReadOp.getPadding(), transferReadOp.getMask(),
527 transferReadOp.getInBoundsAttr())
532 if (isa<arith::ExtSIOp>(extOp))
533 result = arith::ExtSIOp::create(rewriter, loc, op.getType(),
result)
535 else if (isa<arith::ExtUIOp>(extOp))
536 result = arith::ExtUIOp::create(rewriter, loc, op.getType(),
result)
539 result = arith::ExtFOp::create(rewriter, loc, op.getType(),
result)
564 auto contract = dyn_cast<vector::ContractionOp>(users);
582 assert(op.getTransferRank() > 0 &&
"unexpected 0-d transfer");
584 "expected convertible operation");
587 std::optional<int64_t> stride =
589 if (!stride.has_value()) {
590 LDBG() <<
"no stride";
600 Value mappingResult = op.getResult();
601 auto elType = op.getVectorType().getElementType();
603 if (op->hasOneUse()) {
604 auto *user = *op->user_begin();
606 if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
607 elType = IntegerType::get(
608 op.getContext(), cast<IntegerType>(elType).getWidth(),
609 isa<arith::ExtSIOp>(user) ? IntegerType::Signed
610 : IntegerType::Unsigned);
611 mappingResult = user->getResult(0);
616 Value load = gpu::SubgroupMmaLoadMatrixOp::create(
617 rewriter, op.getLoc(), type, op.getBase(), op.getIndices(),
619 isTranspose ? rewriter.
getUnitAttr() : UnitAttr());
620 valueMapping[mappingResult] =
load;
622 LDBG() <<
"transfer read to: " <<
load;
633 std::optional<int64_t> stride =
635 if (!stride.has_value()) {
636 LDBG() <<
"no stride";
640 auto it = valueMapping.find(op.getVector());
641 if (it == valueMapping.end()) {
642 LDBG() <<
"no mapping";
646 Value matrix = it->second;
647 auto store = gpu::SubgroupMmaStoreMatrixOp::create(
648 rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(),
652 LDBG() <<
"transfer write to: " << store;
654 LDBG() <<
"erase: " << op;
663 regInfo.elementsPerRegister};
664 Type elType = regInfo.registerLLVMType;
665 if (
auto vecType = dyn_cast<VectorType>(elType))
666 elType = vecType.getElementType();
667 return VectorType::get(
shape, elType);
677 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
679 if (failed(warpMatrixInfo)) {
680 LDBG() <<
"no warpMatrixInfo";
684 FailureOr<nvgpu::FragmentElementInfo> regInfo =
685 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
686 if (failed(regInfo)) {
687 LDBG() <<
"not mma sync reg info";
692 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
694 LDBG() <<
"not a splat";
699 rewriter, op.getLoc(), vectorType,
701 valueMapping[op.getResult()] =
result;
716 LDBG() <<
"Failed because the result of `vector.transfer_read` "
717 "is not a 2d operand";
726 auto exprM = dyn_cast<AffineDimExpr>(dM);
727 auto exprN = dyn_cast<AffineDimExpr>(dN);
729 if (!exprM || !exprN) {
730 LDBG() <<
"Failed because expressions are not affine dim "
731 "expressions, then transpose cannot be determined.";
735 return exprM.getPosition() > exprN.getPosition();
745 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
747 if (failed(warpMatrixInfo)) {
748 LDBG() <<
"no warpMatrixInfo";
752 FailureOr<nvgpu::FragmentElementInfo> regInfo =
753 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
754 if (failed(regInfo)) {
755 LDBG() <<
"not mma sync reg info";
760 if (failed(transpose)) {
761 LDBG() <<
"failed to determine the transpose";
763 op,
"Op should likely not be converted to a nvgpu.ldmatrix call.");
766 FailureOr<nvgpu::LdMatrixParams> params =
767 nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
769 if (failed(params)) {
770 LDBG() <<
"failed to convert vector.transfer_read to ldmatrix. "
771 <<
"Op should likely not be converted to a nvgpu.ldmatrix call.";
773 op,
"failed to convert vector.transfer_read to ldmatrix; this op "
774 "likely should not be converted to a nvgpu.ldmatrix call.");
778 auto laneId = gpu::LaneIdOp::create(rewriter, loc,
nullptr);
779 FailureOr<AffineMap> offsets =
780 nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
781 if (failed(offsets)) {
782 LDBG() <<
"no offsets";
792 nvgpu::LdMatrixOp newOp =
793 nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(),
794 indices, *transpose, params->numTiles);
795 valueMapping[op] = newOp->getResult(0);
806 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
808 if (failed(warpMatrixInfo))
810 FailureOr<nvgpu::FragmentElementInfo> regInfo =
811 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
812 if (failed(regInfo)) {
814 op,
"Failed to deduce register fragment type during "
815 "conversion to distributed non-ldmatrix compatible load");
818 Value laneId = gpu::LaneIdOp::create(rewriter, loc,
nullptr);
821 Type loadedElType = regInfo->registerLLVMType;
824 Value fill = arith::ConstantOp::create(
825 rewriter, op.getLoc(), vectorType.getElementType(),
826 rewriter.
getZeroAttr(vectorType.getElementType()));
828 vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill);
830 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
834 if (!isTransposeLoad) {
835 if (!isa<VectorType>(loadedElType)) {
836 loadedElType = VectorType::get({1}, loadedElType);
839 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
840 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
841 rewriter, op.getLoc(), *warpMatrixInfo);
845 Value logicalValueId = arith::ConstantOp::create(
847 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
850 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
852 Value el = vector::LoadOp::create(rewriter, loc, loadedElType,
853 op.getBase(), newIndices);
854 result = vector::InsertOp::create(rewriter, loc, el,
result, i);
857 if (
auto vecType = dyn_cast<VectorType>(loadedElType)) {
858 loadedElType = vecType.getElementType();
860 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
861 for (
unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
864 Value logicalValueId = arith::ConstantOp::create(
866 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
867 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
868 rewriter, op.getLoc(), *warpMatrixInfo);
874 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
875 Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType,
876 op.getBase(), newIndices);
877 result = vector::InsertOp::create(rewriter, op.getLoc(), el,
result,
883 valueMapping[op.getResult()] =
result;
890 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
891 return addressSpace &&
892 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
904 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
906 if (failed(warpMatrixInfo))
909 bool isLdMatrixCompatible =
911 nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
913 VectorType vecTy = op.getVectorType();
914 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
919 if (!op.getPermutationMap().isMinorIdentity() &&
920 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
921 vecTy.getDimSize(0) * bitWidth < 128))
922 isLdMatrixCompatible =
false;
924 if (!isLdMatrixCompatible)
937 auto it = valueMapping.find(op.getVector());
938 if (it == valueMapping.end())
940 Value matrix = it->second;
942 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
944 if (failed(warpMatrixInfo))
946 FailureOr<nvgpu::FragmentElementInfo> regInfo =
947 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
952 Value laneId = gpu::LaneIdOp::create(rewriter, loc,
nullptr);
954 for (
unsigned i = 0; i < vectorType.getShape()[0]; i++) {
955 Value logicalValueId = arith::ConstantOp::create(
957 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
958 FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
959 rewriter, op.getLoc(), *warpMatrixInfo);
967 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
968 vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
971 LDBG() <<
"erase: " << op;
978 for (
auto attr : arrayAttr)
979 results.push_back(cast<IntegerAttr>(attr).getInt());
984 vector::ExtractStridedSliceOp op,
991 FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
993 if (failed(warpMatrixInfo))
996 FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
997 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
998 if (failed(mmaSyncFragmentInfo))
1002 auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
1003 if (!transferReadOp)
1007 if (failed(warpMatrixInfo))
1010 FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
1011 nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
1012 if (failed(ldFragmentInfo))
1016 (mmaSyncFragmentInfo->elementsPerRegister ==
1017 ldFragmentInfo->elementsPerRegister) &&
1018 "Number of elements per register should be same for load and mma.sync");
1021 std::array<int64_t, 2> strides = {1,
1023 std::array<int64_t, 2> sliceShape = {
1024 mmaSyncFragmentInfo->numRegistersPerFragment,
1025 mmaSyncFragmentInfo->elementsPerRegister};
1026 auto it = valueMapping.find(transferReadOp);
1027 if (it == valueMapping.end())
1029 auto sourceVector = it->second;
1042 std::array<int64_t, 2> sliceOffset = {0, 0};
1044 if (offsets[0] && offsets[1])
1045 return op->emitError() <<
"Slicing fragments in 2D is not supported. ";
1047 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1048 else if (offsets[1])
1049 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1051 Value newOp = vector::ExtractStridedSliceOp::create(
1052 rewriter, loc, sourceVector, sliceOffset, sliceShape, strides);
1054 valueMapping[op] = newOp;
1064 auto itA = valueMapping.find(op.getLhs());
1065 auto itB = valueMapping.find(op.getRhs());
1066 auto itC = valueMapping.find(op.getAcc());
1067 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1068 itC == valueMapping.end())
1070 Value opA = itA->second, opB = itB->second, opC = itC->second;
1071 Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(),
1075 valueMapping[op.getResult()] = matmul;
1085 auto itA = valueMapping.find(op.getLhs());
1086 auto itB = valueMapping.find(op.getRhs());
1087 auto itC = valueMapping.find(op.getAcc());
1088 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1089 itC == valueMapping.end())
1091 Value opA = itA->second, opB = itB->second, opC = itC->second;
1092 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1093 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1094 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1095 Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
1097 valueMapping[op.getResult()] = matmul;
1111 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1112 auto scalarConstant =
1113 arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat);
1115 auto vecType = cast<VectorType>(op.getType());
1117 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1118 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1119 type, scalarConstant);
1120 valueMapping[op.getResult()] = matrix;
1134 auto vecType = op.getResultVectorType();
1136 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1137 auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1138 type, op.getSource());
1139 valueMapping[op.getResult()] = matrix;
1153 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1154 llvm::append_range(operands, newInitArgs);
1155 scf::ForOp newLoop =
1156 scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
1157 loop.getUpperBound(), loop.getStep(), operands);
1160 newLoop.getRegion().getBlocks().splice(
1161 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1162 for (
Value operand : newInitArgs)
1163 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1165 for (
auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1166 loop.getNumResults())))
1169 LDBG() <<
"newLoop now: " << newLoop;
1170 LDBG() <<
"stripped scf.for: " << loop;
1171 LDBG() <<
"erase: " << loop;
1184 for (
const auto &operand : llvm::enumerate(op.getInitArgs())) {
1185 auto it = valueMapping.find(operand.value());
1186 if (it == valueMapping.end()) {
1187 LDBG() <<
"no value mapping for: " << operand.value();
1190 argMapping.push_back(std::make_pair(
1191 operand.index(), op.getInitArgs().size() + newOperands.size()));
1192 newOperands.push_back(it->second);
1196 Block &loopBody = *newForOp.getBody();
1197 for (
auto mapping : argMapping) {
1198 valueMapping[newForOp.getResult(mapping.first)] =
1199 newForOp.getResult(mapping.second);
1201 newForOp.getNumInductionVars())] =
1202 loopBody.
getArgument(mapping.second + newForOp.getNumInductionVars());
1205 LDBG() <<
"scf.for to: " << newForOp;
1215 auto loop = cast<scf::ForOp>(op->getParentOp());
1216 auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1217 for (
const auto &operand : llvm::enumerate(op.getOperands())) {
1218 auto it = valueMapping.find(operand.value());
1219 if (it == valueMapping.end())
1223 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1224 yieldOperands.push_back(it->second);
1226 scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
1228 LDBG() <<
"erase: " << op;
1236 gpu::MMAElementwiseOp opType,
1243 auto it = valueMapping.find(operand);
1244 if (it == valueMapping.end())
1246 matrixOperands.push_back(it->second);
1248 auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].
getType());
1249 if (opType == gpu::MMAElementwiseOp::EXTF) {
1253 vectorType.getElementType(),
1254 resultType.getOperand());
1257 Value newOp = gpu::SubgroupMmaElementwiseOp::create(
1258 rewriter, op->
getLoc(), resultType, matrixOperands, opType);
1266 patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1279 auto globalRes = LogicalResult::success();
1281 LDBG() <<
"Process op: " << *op;
1283 auto res = LogicalResult::success();
1284 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1286 }
else if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1288 }
else if (
auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1290 }
else if (
auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1292 }
else if (
auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1294 }
else if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
1296 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1302 globalRes = failure();
1313 .Case([&](vector::TransferReadOp transferReadOp) {
1317 .Case([&](vector::TransferWriteOp transferWriteOp) {
1321 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1325 .Case([&](vector::ContractionOp contractionOp) {
1329 .Case([&](scf::ForOp forOp) {
1332 .Case([&](scf::YieldOp yieldOp) {
1335 .Case([&](arith::ConstantOp constOp) {
1339 return op->
emitError() <<
"unhandled vector to mma type: " << *op;
1343 <<
"failed to convert op during vector-to-nvgpu conversion";
1351struct ConvertVectorToGPUPass
1354 explicit ConvertVectorToGPUPass(
bool useNvGpu_) {
1355 useNvGpu.setValue(useNvGpu_);
1358 void runOnOperation()
override {
1362 return signalPassFailure();
1368 return signalPassFailure();
1378 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static const char * inferFragType(Operation *op)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info)
Returns the vector type which represents a matrix fragment.
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool isFirstResultLastMapDimension(AffineMap permutationMap)
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
IntegerAttr getIndexAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
user_iterator user_begin()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
bool canLowerToWarpMatrixOperation(vector::TransferWriteOp op)
Returns the number of bits in a single tile row.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contract a, b, c with row-major matmul semantics to a contraction with M...
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
llvm::SetVector< T, Vector, Set, N > SetVector
SliceOptions ForwardSliceOptions
const FrozenRewritePatternSet & patterns
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This trait tags element-wise ops on vectors or tensors.