15 #include <type_traits>
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
56 template <
typename TransferOpType>
60 indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
62 unsigned offsetsIdx = 0;
63 for (
auto expr : xferOp.getPermutationMap().getResults()) {
64 if (
auto dim = dyn_cast<AffineDimExpr>(expr)) {
65 Value prevIdx = indices[dim.getPosition()];
67 dims.push_back(prevIdx);
70 rewriter, loc, d0 + offsetMap.
getResult(offsetsIdx++), dims);
83 auto iteratorTypes =
contract.getIteratorTypes().getValue();
92 contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95 contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
118 return permutationMap ==
AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
125 auto memrefType = dyn_cast<MemRefType>(type);
129 if (memrefType.getRank() < 2)
136 int64_t stride = strides[strides.size() - 2];
137 if (stride == ShapedType::kDynamic)
144 if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
145 readOp.getVectorType().getRank() != 2)
151 if (readOp.getVectorType().getElementType().isInteger(8))
152 if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
153 !isa<arith::ExtUIOp>(*readOp->user_begin())))
160 auto broadcastInnerDim =
170 if (writeOp.getTransferRank() == 0)
173 if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
174 writeOp.getVectorType().getRank() != 2)
179 if (!writeOp.getPermutationMap().isMinorIdentity())
187 auto vecType = dyn_cast<VectorType>(constantOp.getType());
188 if (!vecType || vecType.getRank() != 2)
190 return isa<SplatElementsAttr>(constantOp.getValue());
195 return broadcastOp.getResultVectorType().getRank() == 2;
199 template <
typename ExtOpTy>
201 if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
203 return llvm::all_of(extOp->getUsers(), [](
Operation *user) {
204 return isa<vector::ContractionOp>(user);
212 static std::optional<gpu::MMAElementwiseOp>
214 if (isa<arith::AddFOp>(op))
215 return gpu::MMAElementwiseOp::ADDF;
216 if (isa<arith::MulFOp>(op))
217 return gpu::MMAElementwiseOp::MULF;
218 if (isa<arith::SubFOp>(op))
219 return gpu::MMAElementwiseOp::SUBF;
220 if (isa<arith::MaximumFOp>(op))
221 return gpu::MMAElementwiseOp::MAXF;
222 if (isa<arith::MinimumFOp>(op))
223 return gpu::MMAElementwiseOp::MINF;
224 if (isa<arith::DivFOp>(op))
225 return gpu::MMAElementwiseOp::DIVF;
226 if (isa<arith::AddIOp>(op))
228 if (isa<arith::MulIOp>(op))
230 if (isa<arith::SubIOp>(op))
232 if (isa<arith::DivSIOp>(op))
233 return gpu::MMAElementwiseOp::DIVS;
234 if (isa<arith::DivUIOp>(op))
235 return gpu::MMAElementwiseOp::DIVU;
236 if (isa<arith::NegFOp>(op))
237 return gpu::MMAElementwiseOp::NEGATEF;
238 if (isa<arith::ExtFOp>(op))
239 return gpu::MMAElementwiseOp::EXTF;
255 if (
failed(warpMatrixInfo))
267 cast<VectorType>((*contractOp).getRhs().getType()));
270 cast<VectorType>((*contractOp).getAcc().getType()));
276 if (isa<scf::ForOp, scf::YieldOp>(op))
278 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op))
281 if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
284 if (
auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
287 if (
auto contract = dyn_cast<vector::ContractionOp>(op))
289 if (
auto constant = dyn_cast<arith::ConstantOp>(op))
291 if (
auto broadcast = dyn_cast<vector::BroadcastOp>(op))
293 if (
auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
294 return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
295 if (
auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
296 return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
297 if (
auto fpExtend = dyn_cast<arith::ExtFOp>(op))
310 unsigned currentIndex = 0;
313 while (currentIndex != slice.size()) {
314 auto *currentOp = (slice)[currentIndex];
316 backwardSlice.clear();
318 slice.insert(backwardSlice.begin(), backwardSlice.end());
321 forwardSlice.clear();
326 if (
auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
327 for (
Value forOpResult : forOp.getResults())
334 slice.insert(forwardSlice.begin(), forwardSlice.end());
346 [](
Type t) { return isa<VectorType>(t); });
349 backwardSliceOptions.
filter = hasVectorDest;
353 [](
Type t) { return isa<VectorType>(t); });
356 forwardSliceOptions.
filter = hasVectorSrc;
360 if (opToConvert.contains(
contract.getOperation()))
367 if (llvm::any_of(dependentOps, [useNvGpu](
Operation *op) {
369 LLVM_DEBUG(
DBGS() <<
"cannot convert op: " << *op <<
"\n");
376 opToConvert.insert(dependentOps.begin(), dependentOps.end());
385 struct PrepareContractToGPUMMA
392 Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
399 static constexpr std::array<int64_t, 2> perm = {1, 0};
400 auto iteratorTypes = op.getIteratorTypes().getValue();
410 if (maps == infer({{m, k}, {k, n}, {m, n}}))
412 if (maps == infer({{m, k}, {n, k}, {m, n}})) {
413 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
414 }
else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
415 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
416 }
else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
417 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
418 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
419 }
else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
421 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
422 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
423 }
else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
425 rhs = rewriter.
create<vector::TransposeOp>(loc, rhs, perm);
426 }
else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
428 lhs = rewriter.
create<vector::TransposeOp>(loc, lhs, perm);
429 }
else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
438 op.getIteratorTypes());
447 struct CombineTransferReadOpTranspose final
454 Value source = op.getVector();
455 Type resultType = op.getType();
462 cast<VectorType>(source.
getType()).getElementType());
465 auto transferReadOp = source.
getDefiningOp<vector::TransferReadOp>();
470 if (transferReadOp.getTransferRank() == 0)
473 if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
479 permutationMap.
compose(transferReadOp.getPermutationMap());
484 .
create<vector::TransferReadOp>(
485 loc, resultType, transferReadOp.getSource(),
487 transferReadOp.getPadding(), transferReadOp.getMask(),
488 transferReadOp.getInBoundsAttr())
493 if (isa<arith::ExtSIOp>(extOp))
494 result = rewriter.
create<arith::ExtSIOp>(loc, op.getType(), result)
497 result = rewriter.
create<arith::ExtUIOp>(loc, op.getType(), result)
514 auto contract = dyn_cast<vector::ContractionOp>(users);
532 assert(op.getTransferRank() > 0 &&
"unexpected 0-d transfer");
534 "expected convertible operation");
536 std::optional<int64_t> stride =
538 if (!stride.has_value()) {
539 LLVM_DEBUG(
DBGS() <<
"no stride\n");
547 if (
auto cstExpr = dyn_cast<AffineConstantExpr>(map.
getResult(isTranspose))) {
548 assert(cstExpr.getValue() == 0);
553 auto elType = op.getVectorType().getElementType();
558 bool isSignedExtend = isa<arith::ExtSIOp>(user);
559 if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
561 op.
getContext(), cast<IntegerType>(elType).getWidth(),
563 mappingResult = user->getResult(0);
569 Value load = rewriter.
create<gpu::SubgroupMmaLoadMatrixOp>(
570 op.
getLoc(), type, op.getSource(), op.getIndices(),
572 isTranspose ? rewriter.
getUnitAttr() : UnitAttr());
573 valueMapping[mappingResult] = load;
575 LLVM_DEBUG(
DBGS() <<
"transfer read to: " << load <<
"\n");
586 std::optional<int64_t> stride =
588 if (!stride.has_value()) {
589 LLVM_DEBUG(
DBGS() <<
"no stride\n");
593 auto it = valueMapping.find(op.getVector());
594 if (it == valueMapping.end()) {
595 LLVM_DEBUG(
DBGS() <<
"no mapping\n");
599 Value matrix = it->second;
600 auto store = rewriter.
create<gpu::SubgroupMmaStoreMatrixOp>(
601 op.
getLoc(), matrix, op.getSource(), op.getIndices(),
605 LLVM_DEBUG(
DBGS() <<
"transfer write to: " << store <<
"\n");
607 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
618 if (
auto vecType = dyn_cast<VectorType>(elType))
619 elType = vecType.getElementType();
632 if (
failed(warpMatrixInfo)) {
633 LLVM_DEBUG(
DBGS() <<
"no warpMatrixInfo\n");
640 LLVM_DEBUG(
DBGS() <<
"not mma sync reg info\n");
645 auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
647 LLVM_DEBUG(
DBGS() <<
"not a splat\n");
669 LLVM_DEBUG(
DBGS() <<
"Failed because the result of `vector.transfer_read` "
670 "is not a 2d operand\n");
679 auto exprM = dyn_cast<AffineDimExpr>(dM);
680 auto exprN = dyn_cast<AffineDimExpr>(dN);
682 if (!exprM || !exprN) {
683 LLVM_DEBUG(
DBGS() <<
"Failed because expressions are not affine dim "
684 "expressions, then transpose cannot be determined.\n");
688 return exprM.getPosition() > exprN.getPosition();
700 if (
failed(warpMatrixInfo)) {
701 LLVM_DEBUG(
DBGS() <<
"no warpMatrixInfo\n");
708 LLVM_DEBUG(
DBGS() <<
"not mma sync reg info\n");
714 LLVM_DEBUG(
DBGS() <<
"failed to determine the transpose\n");
716 op,
"Op should likely not be converted to a nvgpu.ldmatrix call.");
725 <<
"failed to convert vector.transfer_read to ldmatrix. "
726 <<
"Op should likely not be converted to a nvgpu.ldmatrix call.\n");
728 op,
"failed to convert vector.transfer_read to ldmatrix; this op "
729 "likely should not be converted to a nvgpu.ldmatrix call.");
733 auto laneId = rewriter.
create<gpu::LaneIdOp>(loc);
737 LLVM_DEBUG(
DBGS() <<
"no offsets\n");
744 getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
747 nvgpu::LdMatrixOp newOp = rewriter.
create<nvgpu::LdMatrixOp>(
748 loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
749 valueMapping[op] = newOp->getResult(0);
762 if (
failed(warpMatrixInfo))
768 op,
"Failed to deduce register fragment type during "
769 "conversion to distributed non-ldmatrix compatible load");
772 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc);
776 Type loadedElType = regInfo->registerLLVMType;
780 op.
getLoc(), vectorType.getElementType(),
781 rewriter.
getZeroAttr(vectorType.getElementType()));
783 rewriter.
create<vector::SplatOp>(op.
getLoc(), fill, vectorType);
785 bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
789 if (!isTransposeLoad) {
790 if (!isa<VectorType>(loadedElType)) {
794 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
796 rewriter, op.
getLoc(), *warpMatrixInfo);
800 Value logicalValueId = rewriter.
create<arith::ConstantOp>(
802 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
804 getXferIndices<vector::TransferReadOp>(
805 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
807 Value el = rewriter.
create<vector::LoadOp>(loc, loadedElType,
808 op.getSource(), newIndices);
809 result = rewriter.
create<vector::InsertOp>(loc, el, result, i);
812 if (
auto vecType = dyn_cast<VectorType>(loadedElType)) {
813 loadedElType = vecType.getElementType();
815 for (
int i = 0; i < vectorType.getShape()[0]; i++) {
816 for (
unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
819 Value logicalValueId = rewriter.
create<arith::ConstantOp>(
821 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
823 rewriter, op.
getLoc(), *warpMatrixInfo);
828 getXferIndices<vector::TransferReadOp>(
829 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
831 op.getSource(), newIndices);
832 result = rewriter.
create<vector::InsertOp>(
845 dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
847 addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace())
863 if (
failed(warpMatrixInfo))
866 bool isLdMatrixCompatible =
870 VectorType vecTy = op.getVectorType();
871 int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
876 if (!op.getPermutationMap().isMinorIdentity() &&
877 (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
878 vecTy.getDimSize(0) * bitWidth < 128))
879 isLdMatrixCompatible =
false;
881 if (!isLdMatrixCompatible)
894 auto it = valueMapping.find(op.getVector());
895 if (it == valueMapping.end())
897 Value matrix = it->second;
901 if (
failed(warpMatrixInfo))
909 Value laneId = rewriter.
create<gpu::LaneIdOp>(loc);
911 for (
unsigned i = 0; i < vectorType.getShape()[0]; i++) {
912 Value logicalValueId = rewriter.
create<arith::ConstantOp>(
914 rewriter.
getIndexAttr(i * regInfo->elementsPerRegister));
916 rewriter, op.
getLoc(), *warpMatrixInfo);
923 getXferIndices<vector::TransferWriteOp>(
924 rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
925 rewriter.
create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
928 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
935 for (
auto attr : arrayAttr)
936 results.push_back(cast<IntegerAttr>(attr).getInt());
941 vector::ExtractStridedSliceOp op,
950 if (
failed(warpMatrixInfo))
955 if (
failed(mmaSyncFragmentInfo))
959 auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
964 if (
failed(warpMatrixInfo))
969 if (
failed(ldFragmentInfo))
973 (mmaSyncFragmentInfo->elementsPerRegister ==
974 ldFragmentInfo->elementsPerRegister) &&
975 "Number of elements per register should be same for load and mma.sync");
978 std::array<int64_t, 2> strides = {1,
980 std::array<int64_t, 2> sliceShape = {
981 mmaSyncFragmentInfo->numRegistersPerFragment,
982 mmaSyncFragmentInfo->elementsPerRegister};
983 auto it = valueMapping.find(transferReadOp);
984 if (it == valueMapping.end())
986 auto sourceVector = it->second;
999 std::array<int64_t, 2> sliceOffset = {0, 0};
1001 if (offsets[0] && offsets[1])
1002 return op->
emitError() <<
"Slicing fragments in 2D is not supported. ";
1004 sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1005 else if (offsets[1])
1006 sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1008 Value newOp = rewriter.
create<vector::ExtractStridedSliceOp>(
1009 loc, sourceVector, sliceOffset, sliceShape, strides);
1011 valueMapping[op] = newOp;
1021 auto itA = valueMapping.find(op.getLhs());
1022 auto itB = valueMapping.find(op.getRhs());
1023 auto itC = valueMapping.find(op.getAcc());
1024 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1025 itC == valueMapping.end())
1027 Value opA = itA->second, opB = itB->second, opC = itC->second;
1028 Value matmul = rewriter.
create<gpu::SubgroupMmaComputeOp>(
1029 op.
getLoc(), opC.getType(), opA, opB, opC, UnitAttr(),
1041 auto itA = valueMapping.find(op.getLhs());
1042 auto itB = valueMapping.find(op.getRhs());
1043 auto itC = valueMapping.find(op.getAcc());
1044 if (itA == valueMapping.end() || itB == valueMapping.end() ||
1045 itC == valueMapping.end())
1047 Value opA = itA->second, opB = itB->second, opC = itC->second;
1048 int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1049 int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1050 int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1067 cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1068 auto scalarConstant =
1069 rewriter.
create<arith::ConstantOp>(op.
getLoc(), splat.getType(), splat);
1071 auto vecType = cast<VectorType>(op.getType());
1073 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1074 auto matrix = rewriter.
create<gpu::SubgroupMmaConstantMatrixOp>(
1075 op.
getLoc(), type, scalarConstant);
1090 auto vecType = op.getResultVectorType();
1092 vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1093 auto matrix = rewriter.
create<gpu::SubgroupMmaConstantMatrixOp>(
1094 op.
getLoc(), type, op.getSource());
1109 auto operands = llvm::to_vector<4>(loop.getInitArgs());
1110 llvm::append_range(operands, newInitArgs);
1111 scf::ForOp newLoop = rewriter.
create<scf::ForOp>(
1112 loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1114 newLoop.getBody()->
erase();
1116 newLoop.getRegion().getBlocks().splice(
1117 newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1118 for (
Value operand : newInitArgs)
1119 newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1121 for (
auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1122 loop.getNumResults())))
1125 LLVM_DEBUG(
DBGS() <<
"newLoop now: " << newLoop <<
"\n");
1126 LLVM_DEBUG(
DBGS() <<
"stripped scf.for: " << loop <<
"\n");
1127 LLVM_DEBUG(
DBGS() <<
"erase: " << loop);
1141 auto it = valueMapping.find(operand.value());
1142 if (it == valueMapping.end()) {
1143 LLVM_DEBUG(
DBGS() <<
"no value mapping for: " << operand.value() <<
"\n");
1146 argMapping.push_back(std::make_pair(
1147 operand.index(), op.getInitArgs().size() + newOperands.size()));
1148 newOperands.push_back(it->second);
1152 Block &loopBody = *newForOp.getBody();
1153 for (
auto mapping : argMapping) {
1154 valueMapping[newForOp.getResult(mapping.first)] =
1155 newForOp.getResult(mapping.second);
1157 newForOp.getNumInductionVars())] =
1158 loopBody.
getArgument(mapping.second + newForOp.getNumInductionVars());
1161 LLVM_DEBUG(
DBGS() <<
"scf.for to: " << newForOp <<
"\n");
1172 auto yieldOperands = llvm::to_vector<4>(op.
getOperands());
1174 auto it = valueMapping.find(operand.value());
1175 if (it == valueMapping.end())
1179 yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1180 yieldOperands.push_back(it->second);
1182 rewriter.
create<scf::YieldOp>(op.
getLoc(), yieldOperands);
1184 LLVM_DEBUG(
DBGS() <<
"erase: " << op <<
"\n");
1192 gpu::MMAElementwiseOp opType,
1199 auto it = valueMapping.find(operand);
1200 if (it == valueMapping.end())
1202 matrixOperands.push_back(it->second);
1205 if (opType == gpu::MMAElementwiseOp::EXTF) {
1209 vectorType.getElementType(),
1210 resultType.getOperand());
1213 Value newOp = rewriter.
create<gpu::SubgroupMmaElementwiseOp>(
1214 op->
getLoc(), resultType, matrixOperands, opType);
1222 patterns.
add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1227 patterns.
add<CombineTransferReadOpTranspose>(patterns.
getContext());
1237 LLVM_DEBUG(
DBGS() <<
"Process op: " << *op <<
"\n");
1240 if (
auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1242 }
else if (
auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1244 }
else if (
auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1246 }
else if (
auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1248 }
else if (
auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1250 }
else if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
1252 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1269 .Case([&](vector::TransferReadOp transferReadOp) {
1273 .Case([&](vector::TransferWriteOp transferWriteOp) {
1277 .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1281 .Case([&](vector::ContractionOp contractionOp) {
1285 .Case([&](scf::ForOp forOp) {
1288 .Case([&](scf::YieldOp yieldOp) {
1291 .Case([&](arith::ConstantOp constOp) {
1295 return op->
emitError() <<
"unhandled vector to mma type: " << *op;
1299 <<
"failed to convert op during vector-to-nvgpu conversion";
1307 struct ConvertVectorToGPUPass
1308 :
public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1310 explicit ConvertVectorToGPUPass(
bool useNvGpu_) {
1311 useNvGpu.setValue(useNvGpu_);
1314 void runOnOperation()
override {
1319 return signalPassFailure();
1325 return signalPassFailure();
1335 return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static SetVector< Operation * > getSliceContract(Operation *op, BackwardSliceOptions backwardSliceOptions, ForwardSliceOptions forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
MLIRContext * getContext() const
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineConstantExpr(int64_t constant)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasOneUse()
Returns true if this operation has exactly one use.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
user_iterator user_begin()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams ¶ms)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Multi-root DAG topological sort.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
This class represents an efficient way to signal success or failure.
static LogicalResult success(bool isSuccess=true)
If isSuccess is true a success result is generated, otherwise a 'failure' result is generated.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
int64_t elementsPerRegister
int64_t numRegistersPerFragment