15 #include <type_traits>
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
56 template <
typename TransferOpType>
60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
62 unsigned offsetsIdx = 0;
63 for (
auto expr : xferOp.getPermutationMap().getResults()) {
64 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
65 Value prevIdx = indices[dim.getPosition()];
67 dims.push_back(prevIdx);
70 rewriter, loc, d0 + offsetMap.
getResult(offsetsIdx++), dims);
80 auto infer = [&](MapList m) {
85 auto iteratorTypes =
contract.getIteratorTypes().getValue();
94 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
97 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
120 return permutationMap ==
AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
127 auto memrefType = dyn_cast<MemRefType>(type);
131 if (memrefType.getRank() < 2)
138 int64_t stride = strides[strides.size() - 2];
139 if (stride == ShapedType::kDynamic)
146 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147 readOp.getVectorType().getRank() != 2)
153 if (readOp.getVectorType().getElementType().isInteger(8))
154 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155 !isa<arith::ExtUIOp>(*readOp->user_begin())))
162 auto broadcastInnerDim =
172 if (writeOp.getTransferRank() == 0)
175 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176 writeOp.getVectorType().getRank() != 2)
181 if (!writeOp.getPermutationMap().isMinorIdentity())
189 auto vecType = dyn_cast<VectorType>(constantOp.getType());
190 if (!vecType || vecType.getRank() != 2)
192 return isa<SplatElementsAttr>(constantOp.getValue());
197 return broadcastOp.getResultVectorType().getRank() == 2;
201 template <
typename ExtOpTy>
203 if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
205 return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
212 static std::optional<gpu::MMAElementwiseOp>
214 if (isa<arith::AddFOp>(op))
215 return gpu::MMAElementwiseOp::ADDF;
216 if (isa<arith::MulFOp>(op))
217 return gpu::MMAElementwiseOp::MULF;
218 if (isa<arith::SubFOp>(op))
219 return gpu::MMAElementwiseOp::SUBF;
220 if (isa<arith::MaximumFOp>(op))
221 return gpu::MMAElementwiseOp::MAXF;
222 if (isa<arith::MinimumFOp>(op))
223 return gpu::MMAElementwiseOp::MINF;
224 if (isa<arith::DivFOp>(op))
225 return gpu::MMAElementwiseOp::DIVF;
226 if (isa<arith::AddIOp>(op))
228 if (isa<arith::MulIOp>(op))
230 if (isa<arith::SubIOp>(op))
232 if (isa<arith::DivSIOp>(op))
233 return gpu::MMAElementwiseOp::DIVS;
234 if (isa<arith::DivUIOp>(op))
235 return gpu::MMAElementwiseOp::DIVU;
236 if (isa<arith::NegFOp>(op))
237 return gpu::MMAElementwiseOp::NEGATEF;
238 if (isa<arith::ExtFOp>(op))
239 return gpu::MMAElementwiseOp::EXTF;
255 if (
failed(warpMatrixInfo))
267 cast<VectorType>((*contractOp).getRhs().getType()));
270 cast<VectorType>((*contractOp).getAcc().getType()));
276 if (isa<scf::ForOp, scf::YieldOp>(op))
278 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op))
281 if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
284 if (
auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
287 if (
auto contract = dyn_cast<vector::ContractionOp>(op))
289 if (
auto constant = dyn_cast<arith::ConstantOp>(op))
291 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
293 if (
auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
294 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
295 if (
auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
296 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
297 if (
auto fpExtend = dyn_cast<arith::ExtFOp>(op))
311 unsigned currentIndex = 0;
314 while (currentIndex != slice.size()) {
315 auto *currentOp = (slice)[currentIndex];
317 backwardSlice.clear();
319 slice.insert(backwardSlice.begin(), backwardSlice.end());
322 forwardSlice.clear();
327 if (
auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
328 for (
Value forOpResult : forOp.getResults())
335 slice.insert(forwardSlice.begin(), forwardSlice.end());
346 return llvm::any_of(op->
getResultTypes(), llvm::IsaPred<VectorType>);
349 backwardSliceOptions.
filter = hasVectorDest;
355 forwardSliceOptions.
filter = hasVectorSrc;
359 if (opToConvert.contains(
contract.getOperation()))
366 if (llvm::any_of(dependentOps, [useNvGpu](
Operation *op) {
368 LLVM_DEBUG(
DBGS() <<
"cannot convert op: " << *op <<
"\n");
375 opToConvert.insert(dependentOps.begin(), dependentOps.end());
384 struct PrepareContractToGPUMMA
391 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
395 auto infer = [&](MapList m) {
400 static constexpr std::array<int64_t, 2> perm = {1, 0};
401 auto iteratorTypes = op.getIteratorTypes().getValue();
411 if (maps == infer({{m, k}, {k, n}, {m, n}}))
413 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
414 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
415 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
416 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
417 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
418 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
419 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
420 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
422 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
423 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
424 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
426 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
427 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
429 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
430 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
439 op.getIteratorTypes());
448 struct CombineTransferReadOpTranspose final
455 Value source = op.getVector();
456 Type resultType = op.getType();
464 cast<VectorType>(source.
getType()).getElementType());
467 auto transferReadOp = source.
getDefiningOp<vector::TransferReadOp>();
472 if (transferReadOp.getTransferRank() == 0)
475 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
481 permutationMap.
compose(transferReadOp.getPermutationMap());
486 .
create<vector::TransferReadOp>(
487 loc, resultType, transferReadOp.getSource(),
489 transferReadOp.getPadding(), transferReadOp.getMask(),
490 transferReadOp.getInBoundsAttr())
495 if (isa<arith::ExtSIOp>(extOp))
496 result = rewriter.
create<arith::ExtSIOp>(loc, op.getType(), result)
498 else if (isa<arith::ExtUIOp>(extOp))
499 result = rewriter.
create<arith::ExtUIOp>(loc, op.getType(), result)
502 result = rewriter.
create<arith::ExtFOp>(loc, op.getType(), result)
519 auto contract = dyn_cast<vector::ContractionOp>(users);
537 assert(op.getTransferRank() > 0 &&
"unexpected 0-d transfer");
539 "expected convertible operation");
541 std::optional<int64_t> stride =
543 if (!stride.has_value()) {
544 LLVM_DEBUG(
DBGS() <<
"no stride\n");
552 if (
auto cstExpr = dyn_cast<AffineConstantExpr>(map.
getResult(isTranspose))) {
553 assert(cstExpr.getValue() == 0);
558 auto elType = op.getVectorType().getElementType();
563 bool isSignedExtend = isa<arith::ExtSIOp>(user);
564 if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
566 op.
getContext(), cast<IntegerType>(elType).getWidth(),
568 mappingResult = user->getResult(0);
574 Value load = rewriter.
create<gpu::SubgroupMmaLoadMatrixOp>(
575 op.
getLoc(), type, op.getSource(), op.getIndices(),
577 isTranspose ? rewriter.
getUnitAttr() : UnitAttr());
578 valueMapping[mappingResult] = load;
580 LLVM_DEBUG(
DBGS() <<
"transfer read to: " << load <<
"\n");
591 std::optional<int64_t> stride =
593 if (!stride.has_value()) {
594 LLVM_DEBUG(
DBGS() <<
"no stride\n");
598 auto it = valueMapping.find(op.getVector());
599 if (it == valueMapping.end()) {
600 LLVM_DEBUG(
DBGS() <<
"no mapping\n");
604 Value matrix = it->second;
605 auto store = rewriter.
create<gpu::SubgroupMmaStoreMatrixOp>(
606 op.
getLoc(), matrix, op.getSource(), op.getIndices(),
610 LLVM_DEBUG(
DBGS() <<
"transfer write to: " << store <<
"\n");
612 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
623 if (
auto vecType = dyn_cast<VectorType>(elType))
624 elType = vecType.getElementType();
637 if (
failed(warpMatrixInfo)) {
638 LLVM_DEBUG(
DBGS() <<
"no warpMatrixInfo\n");
645 LLVM_DEBUG(
DBGS() <<
"not mma sync reg info\n");
650 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
652 LLVM_DEBUG(
DBGS() <<
"not a splat\n");
674 LLVM_DEBUG(
DBGS() <<
"Failed because the result of `vector.transfer_read` "
675 "is not a 2d operand\n");
684 auto exprM = dyn_cast<AffineDimExpr>(dM);
685 auto exprN = dyn_cast<AffineDimExpr>(dN);
687 if (!exprM || !exprN) {
688 LLVM_DEBUG(
DBGS() <<
"Failed because expressions are not affine dim "
689 "expressions, then transpose cannot be determined.\n");
693 return exprM.getPosition() > exprN.getPosition();
705 if (
failed(warpMatrixInfo)) {
706 LLVM_DEBUG(
DBGS() <<
"no warpMatrixInfo\n");
713 LLVM_DEBUG(
DBGS() <<
"not mma sync reg info\n");
719 LLVM_DEBUG(
DBGS() <<
"failed to determine the transpose\n");
721 op,
"Op should likely not be converted to a nvgpu.ldmatrix call.");
730 <<
"failed to convert vector.transfer_read to ldmatrix. "
731 <<
"Op should likely not be converted to a nvgpu.ldmatrix call.\n");
733 op,
"failed to convert vector.transfer_read to ldmatrix; this op "
734 "likely should not be converted to a nvgpu.ldmatrix call.");
738 auto laneId = rewriter.
create<gpu::LaneIdOp>(loc);
742 LLVM_DEBUG(
DBGS() <<
"no offsets\n");
749 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
752 nvgpu::LdMatrixOp newOp = rewriter.
create<nvgpu::LdMatrixOp>(
753 loc, vectorType, op.getSource(), indices, *
transpose, params->numTiles);
754 valueMapping[op] = newOp->getResult(0);
767 if (
failed(warpMatrixInfo))
773 op,
"Failed to deduce register fragment type during "
774 "conversion to distributed non-ldmatrix compatible load");
777 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc);
781 Type loadedElType = regInfo->registerLLVMType;
785 op.
getLoc(), vectorType.getElementType(),
786 rewriter.
getZeroAttr(vectorType.getElementType()));
788 rewriter.
create<vector::SplatOp>(op.
getLoc(), fill, vectorType);
790 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
794 if (!isTransposeLoad) {
795 if (!isa<VectorType>(loadedElType)) {
799 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
801 rewriter, op.
getLoc(), *warpMatrixInfo);
805 Value logicalValueId = rewriter.
create<arith::ConstantOp>(
807 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
809 getXferIndices<vector::TransferReadOp>(
810 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
812 Value el = rewriter.
create<vector::LoadOp>(loc, loadedElType,
813 op.getSource(), newIndices);
814 result = rewriter.
create<vector::InsertOp>(loc, el, result, i);
817 if (
auto vecType = dyn_cast<VectorType>(loadedElType)) {
818 loadedElType = vecType.getElementType();
820 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
821 for (
unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
824 Value logicalValueId = rewriter.
create<arith::ConstantOp>(
826 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
828 rewriter, op.
getLoc(), *warpMatrixInfo);
833 getXferIndices<vector::TransferReadOp>(
834 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
836 op.getSource(), newIndices);
837 result = rewriter.
create<vector::InsertOp>(
850 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
851 return addressSpace &&
852 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
866 if (
failed(warpMatrixInfo))
869 bool isLdMatrixCompatible =
873 VectorType vecTy = op.getVectorType();
874 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
879 if (!op.getPermutationMap().isMinorIdentity() &&
880 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
881 vecTy.getDimSize(0) * bitWidth < 128))
882 isLdMatrixCompatible =
false;
884 if (!isLdMatrixCompatible)
897 auto it = valueMapping.find(op.getVector());
898 if (it == valueMapping.end())
900 Value matrix = it->second;
904 if (
failed(warpMatrixInfo))
912 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc);
914 for (
unsigned i = 0; i < vectorType.getShape()[0]; i++) {
915 Value logicalValueId = rewriter.
create<arith::ConstantOp>(
917 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
919 rewriter, op.
getLoc(), *warpMatrixInfo);
926 getXferIndices<vector::TransferWriteOp>(
927 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
928 rewriter.
create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
931 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
938 for (
auto attr : arrayAttr)
939 results.push_back(cast<IntegerAttr>(attr).getInt());
944 vector::ExtractStridedSliceOp op,
953 if (
failed(warpMatrixInfo))
958 if (
failed(mmaSyncFragmentInfo))
962 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
967 if (
failed(warpMatrixInfo))
972 if (
failed(ldFragmentInfo))
976 (mmaSyncFragmentInfo->elementsPerRegister ==
977 ldFragmentInfo->elementsPerRegister) &&
978 "Number of elements per register should be same for load and mma.sync");
981 std::array<int64_t, 2> strides = {1,
983 std::array<int64_t, 2> sliceShape = {
984 mmaSyncFragmentInfo->numRegistersPerFragment,
985 mmaSyncFragmentInfo->elementsPerRegister};
986 auto it = valueMapping.find(transferReadOp);
987 if (it == valueMapping.end())
989 auto sourceVector = it->second;
1002 std::array<int64_t, 2> sliceOffset = {0, 0};
1004 if (offsets[0] && offsets[1])
1005 return op->
emitError() <<
"Slicing fragments in 2D is not supported. ";
1007 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1008 else if (offsets[1])
1009 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1011 Value newOp = rewriter.
create<vector::ExtractStridedSliceOp>(
1012 loc, sourceVector, sliceOffset, sliceShape, strides);
1014 valueMapping[op] = newOp;
1024 auto itA = valueMapping.find(op.getLhs());
1025 auto itB = valueMapping.find(op.getRhs());
1026 auto itC = valueMapping.find(op.getAcc());
1027 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1028 itC == valueMapping.end())
1030 Value opA = itA->second, opB = itB->second, opC = itC->second;
1031 Value matmul = rewriter.
create<gpu::SubgroupMmaComputeOp>(
1032 op.
getLoc(), opC.getType(), opA, opB, opC, UnitAttr(),
1044 auto itA = valueMapping.find(op.getLhs());
1045 auto itB = valueMapping.find(op.getRhs());
1046 auto itC = valueMapping.find(op.getAcc());
1047 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1048 itC == valueMapping.end())
1050 Value opA = itA->second, opB = itB->second, opC = itC->second;
1051 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1052 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1053 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1070 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1071 auto scalarConstant =
1072 rewriter.
create<arith::ConstantOp>(op.
getLoc(), splat.getType(), splat);
1074 auto vecType = cast<VectorType>(op.getType());
1076 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1077 auto matrix = rewriter.
create<gpu::SubgroupMmaConstantMatrixOp>(
1078 op.
getLoc(), type, scalarConstant);
1093 auto vecType = op.getResultVectorType();
1095 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1096 auto matrix = rewriter.
create<gpu::SubgroupMmaConstantMatrixOp>(
1097 op.
getLoc(), type, op.getSource());
1112 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1113 llvm::append_range(operands, newInitArgs);
1114 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
1115 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1119 newLoop.getRegion().getBlocks().splice(
1120 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1121 for (
Value operand : newInitArgs)
1122 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1124 for (
auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1125 loop.getNumResults())))
1128 LLVM_DEBUG(
DBGS() <<
"newLoop now: " << newLoop <<
"\n");
1129 LLVM_DEBUG(
DBGS() <<
"stripped scf.for: " << loop <<
"\n");
1130 LLVM_DEBUG(
DBGS() <<
"erase: " << loop);
1144 auto it = valueMapping.find(operand.value());
1145 if (it == valueMapping.end()) {
1146 LLVM_DEBUG(
DBGS() <<
"no value mapping for: " << operand.value() <<
"\n");
1149 argMapping.push_back(std::make_pair(
1150 operand.index(), op.getInitArgs().size() + newOperands.size()));
1151 newOperands.push_back(it->second);
1155 Block &loopBody = *newForOp.getBody();
1156 for (
auto mapping : argMapping) {
1157 valueMapping[newForOp.getResult(mapping.first)] =
1158 newForOp.getResult(mapping.second);
1160 newForOp.getNumInductionVars())] =
1161 loopBody.
getArgument(mapping.second + newForOp.getNumInductionVars());
1164 LLVM_DEBUG(
DBGS() <<
"scf.for to: " << newForOp <<
"\n");
1175 auto yieldOperands = llvm::to_vector<4>(op.
getOperands());
1177 auto it = valueMapping.find(operand.value());
1178 if (it == valueMapping.end())
1182 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1183 yieldOperands.push_back(it->second);
1185 rewriter.
create<scf::YieldOp>(op.
getLoc(), yieldOperands);
1187 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
1195 gpu::MMAElementwiseOp opType,
1202 auto it = valueMapping.find(operand);
1203 if (it == valueMapping.end())
1205 matrixOperands.push_back(it->second);
1207 auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1208 if (opType == gpu::MMAElementwiseOp::EXTF) {
1212 vectorType.getElementType(),
1213 resultType.getOperand());
1216 Value newOp = rewriter.
create<gpu::SubgroupMmaElementwiseOp>(
1217 op->
getLoc(), resultType, matrixOperands, opType);
1225 patterns.
add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1230 patterns.
add<CombineTransferReadOpTranspose>(patterns.
getContext());
1240 LLVM_DEBUG(
DBGS() <<
"Process op: " << *op <<
"\n");
1243 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1245 }
else if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1247 }
else if (
auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1249 }
else if (
auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1251 }
else if (
auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1253 }
else if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
1255 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1272 .Case([&](vector::TransferReadOp transferReadOp) {
1276 .Case([&](vector::TransferWriteOp transferWriteOp) {
1280 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1284 .Case([&](vector::ContractionOp contractionOp) {
1288 .Case([&](scf::ForOp forOp) {
1291 .Case([&](scf::YieldOp yieldOp) {
1294 .Case([&](arith::ConstantOp constOp) {
1298 return op->
emitError() <<
"unhandled vector to mma type: " << *op;
1302 <<
"failed to convert op during vector-to-nvgpu conversion";
1310 struct ConvertVectorToGPUPass
1311 :
public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1313 explicit ConvertVectorToGPUPass(
bool useNvGpu_) {
1314 useNvGpu.setValue(useNvGpu_);
1317 void runOnOperation()
override {
1322 return signalPassFailure();
1328 return signalPassFailure();
1338 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
user_iterator user_begin()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Multi-root DAG topological sort.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
This class represents an efficient way to signal success or failure.
static LogicalResult success(bool isSuccess=true)
If isSuccess is true a success result is generated, otherwise a 'failure' result is generated.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
int64_t elementsPerRegister
int64_t numRegistersPerFragment