38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/ErrorHandling.h"
43 #include <type_traits>
50 #define DEBUG_TYPE "gpu-transforms"
51 #define DEBUG_TYPE_ALIAS "gpu-transforms-alias"
53 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
54 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
55 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
61 void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
71 llvmTypeConverter, [](AddressSpace space) ->
unsigned {
73 case AddressSpace::Global:
74 return static_cast<unsigned>(
76 case AddressSpace::Workgroup:
77 return static_cast<unsigned>(
79 case AddressSpace::Private:
82 llvm_unreachable(
"unknown address space enum value");
87 llvmTypeConverter.addConversion(
93 transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
94 transform::TypeConverterBuilderOpInterface builder) {
95 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
96 return emitOpError(
"expected LLVMTypeConverter");
100 void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
107 transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
108 transform::TypeConverterBuilderOpInterface builder) {
109 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
110 return emitOpError(
"expected LLVMTypeConverter");
114 void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
121 LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
122 verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
123 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
124 return emitOpError(
"expected LLVMTypeConverter");
142 static std::optional<SmallVector<int64_t>>
148 order.push_back(index);
152 llvm::SmallDenseSet<int64_t> dims;
159 order.push_back(index);
165 order.push_back(index);
173 static std::optional<SmallVector<int64_t>>
175 if (
auto contract = dyn_cast<vector::ContractionOp>(op)) {
176 int64_t contractRank =
contract.getIteratorTypes().size();
177 if (contractRank < 3)
180 nativeSize.append({m, n, k});
183 if (
auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
184 int64_t writeRank = writeOp.getVectorType().getRank();
188 nativeSize.append({m, n});
191 if (
auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
194 VectorType sliceType;
196 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
199 auto vecType = extract.getResult().getType().cast<VectorType>();
200 if (sliceType && sliceType != vecType)
204 return llvm::to_vector(sliceType.getShape());
207 if (
auto vecType = op->
getResultTypes()[0].dyn_cast<VectorType>()) {
210 if (vecType.getRank() < 2)
217 VectorType sliceType;
219 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
222 auto vecType = extract.getResult().getType().cast<VectorType>();
223 if (sliceType && sliceType != vecType)
228 return llvm::to_vector(sliceType.getShape());
233 nativeSize.append({m, n});
240 void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
243 auto contract = dyn_cast<vector::ContractionOp>(op);
258 .setNativeShapeFn(nativeShapeFn)
259 .setUnrollTraversalOrderFn(unrollOrder));
273 return isa<memref::AssumeAlignmentOp>(op);
279 if (op->
hasAttr(
"__parallel_region_boundary_for_test"))
282 return isa<GPUFuncOp, LaunchOp>(op);
293 return isa<scf::IfOp, memref::AllocaScopeOp>(op);
300 return isa_and_nonnull<memref::AllocOp, memref::AllocaOp>(op);
307 effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Read>());
308 effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Write>());
309 effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Allocate>());
310 effects.emplace_back(MemoryEffects::Effect::get<MemoryEffects::Free>());
320 bool ignoreBarriers =
true) {
323 if (ignoreBarriers && isa<BarrierOp>(op))
334 if (
auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
336 iface.getEffects(localEffects);
337 llvm::append_range(effects, localEffects);
342 for (
auto &block : region) {
343 for (
auto &innerOp : block)
366 bool stopAtBarrier) {
372 if (region && !llvm::hasSingleElement(region->
getBlocks())) {
379 for (
Operation *it = op->getPrevNode(); it !=
nullptr;
380 it = it->getPrevNode()) {
381 if (isa<BarrierOp>(it)) {
422 bool conservative =
false;
434 return !conservative;
446 bool stopAtBarrier) {
452 if (region && !llvm::hasSingleElement(region->
getBlocks())) {
459 for (
Operation *it = op->getNextNode(); it !=
nullptr;
460 it = it->getNextNode()) {
461 if (isa<BarrierOp>(it)) {
504 bool conservative =
false;
516 return !conservative;
526 bool shouldContinue =
528 .Case<memref::CastOp, memref::SubViewOp, memref::ViewOp>(
533 .Case<memref::TransposeOp>([&](
auto op) {
537 .Case<memref::CollapseShapeOp, memref::ExpandShapeOp>([&](
auto op) {
541 .Default([](
Operation *) {
return false; });
550 auto arg = dyn_cast<BlockArgument>(v);
551 return arg && isa<FunctionOpInterface>(arg.getOwner()->getParentOp());
560 [](ViewLikeOpInterface viewLike) {
return viewLike.getViewSource(); })
561 .Case([](CastOpInterface castLike) {
return castLike->getOperand(0); })
562 .Case([](memref::TransposeOp transpose) {
return transpose.getIn(); })
563 .Case<memref::ExpandShapeOp, memref::CollapseShapeOp>(
564 [](
auto op) {
return op.getSrc(); })
575 .Case<memref::StoreOp, vector::TransferWriteOp>(
576 [&](
auto op) {
return op.getValue() == v; })
577 .Case<vector::StoreOp, vector::MaskedStoreOp>(
578 [&](
auto op) {
return op.getValueToStore() == v; })
580 .Case([](memref::DeallocOp) {
return false; })
582 .Default([](
Operation *) {
return std::nullopt; });
591 while (!todo.empty()) {
592 Value v = todo.pop_back_val();
595 auto iface = dyn_cast<MemoryEffectOpInterface>(user);
598 iface.getEffects(effects);
599 if (llvm::all_of(effects,
601 return isa<MemoryEffects::Read>(effect.
getEffect());
615 if (!knownCaptureStatus || *knownCaptureStatus)
652 if (first == second) {
658 if (
auto globFirst = first.
getDefiningOp<memref::GetGlobalOp>()) {
659 if (
auto globSecond = second.
getDefiningOp<memref::GetGlobalOp>()) {
660 return globFirst.getNameAttr() == globSecond.getNameAttr();
665 auto isNoaliasFuncArgument = [](
Value value) {
666 auto bbArg = dyn_cast<BlockArgument>(value);
669 auto iface = dyn_cast<FunctionOpInterface>(bbArg.getOwner()->getParentOp());
673 return iface.getArgAttr(bbArg.getArgNumber(),
"llvm.noalias") !=
nullptr;
675 if (isNoaliasFuncArgument(first) && isNoaliasFuncArgument(second))
680 bool isGlobal[] = {first.
getDefiningOp<memref::GetGlobalOp>() !=
nullptr,
686 if ((isDistinct[0] || isGlobal[0]) && (isDistinct[1] || isGlobal[1]))
692 if ((isDistinct[0] && isArg[1]) || (isDistinct[1] && isArg[0]))
750 if (isa<MemoryEffects::Read>(before.getEffect()) &&
751 isa<MemoryEffects::Read>(after.getEffect())) {
759 if (isa<MemoryEffects::Allocate>(before.getEffect()) ||
760 isa<MemoryEffects::Allocate>(after.getEffect())) {
772 if (isa<MemoryEffects::Free>(before.getEffect()))
777 DBGS() <<
"found a conflict between (before): " << before.getValue()
778 <<
" read:" << isa<MemoryEffects::Read>(before.getEffect())
779 <<
" write:" << isa<MemoryEffects::Write>(before.getEffect())
781 << isa<MemoryEffects::Allocate>(before.getEffect()) <<
" free:"
782 << isa<MemoryEffects::Free>(before.getEffect()) <<
"\n");
784 DBGS() <<
"and (after): " << after.getValue()
785 <<
" read:" << isa<MemoryEffects::Read>(after.getEffect())
786 <<
" write:" << isa<MemoryEffects::Write>(after.getEffect())
787 <<
" alloc:" << isa<MemoryEffects::Allocate>(after.getEffect())
788 <<
" free:" << isa<MemoryEffects::Free>(after.getEffect())
810 LLVM_DEBUG(
DBGS() <<
"checking the necessity of: " << barrier <<
" "
811 << barrier.getLoc() <<
"\n");
820 LLVM_DEBUG(
DBGS() <<
"the surrounding barriers are sufficient, removing "
826 LLVM_DEBUG(
DBGS() <<
"barrier is necessary: " << barrier <<
" "
827 << barrier.getLoc() <<
"\n");
843 struct MappingKind {};
844 struct BlockMappingKind : MappingKind {};
845 struct ThreadMappingKind : MappingKind {};
850 Operation *target,
const Twine &message) {
851 if (transformOp.has_value())
852 return transformOp->emitDefiniteFailure() << message;
857 template <
typename MappingKindType>
860 scf::ForallOp forallOp) {
861 if (!forallOp.getMapping().has_value()) {
863 "scf.forall op requires a mapping attribute");
866 bool hasBlockMapping =
867 llvm::any_of(forallOp.getMapping().value(), [](
Attribute attr) {
868 return isa<GPUBlockMappingAttr>(attr);
870 bool hasWarpgroupMapping =
871 llvm::any_of(forallOp.getMapping().value(), [](
Attribute attr) {
872 return isa<GPUWarpgroupMappingAttr>(attr);
874 bool hasWarpMapping =
875 llvm::any_of(forallOp.getMapping().value(), [](
Attribute attr) {
876 return isa<GPUWarpMappingAttr>(attr);
878 bool hasThreadMapping =
879 llvm::any_of(forallOp.getMapping().value(), [](
Attribute attr) {
880 return isa<GPUThreadMappingAttr>(attr);
882 int64_t countMappingTypes = 0;
883 countMappingTypes += hasBlockMapping ? 1 : 0;
884 countMappingTypes += hasWarpgroupMapping ? 1 : 0;
885 countMappingTypes += hasWarpMapping ? 1 : 0;
886 countMappingTypes += hasThreadMapping ? 1 : 0;
887 if (countMappingTypes > 1) {
889 transformOp, forallOp,
890 "cannot mix different mapping types, use nesting");
892 if (std::is_same<MappingKindType, BlockMappingKind>::value &&
895 transformOp, forallOp,
896 "scf.forall op requires a mapping attribute of kind 'block'");
898 if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
899 !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) {
901 "scf.forall op requires a mapping attribute "
902 "of kind 'thread' or 'warp'");
906 for (
Attribute map : forallOp.getMapping()->getValue()) {
907 if (seen.contains(map)) {
909 transformOp, forallOp,
910 "duplicate attribute, cannot map different loops "
911 "to the same mapping id");
917 return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
919 if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
920 !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
922 transformOp, forallOp,
923 "cannot mix linear and non-linear mapping modes");
929 template <
typename MappingKindType>
932 scf::ForallOp forallOp) {
935 checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
940 if (!forallOp.isNormalized())
942 "unsupported non-normalized loops");
943 if (forallOp.getNumResults() > 0)
945 "only bufferized scf.forall can be mapped");
946 bool useLinearMapping = cast<DeviceMappingAttrInterface>(
947 forallOp.getMapping()->getValue().front())
951 int64_t maxNumMappingsSupported =
952 useLinearMapping ? (getMaxEnumValForMappingId() -
953 static_cast<uint64_t
>(MappingId::DimZ))
955 if (forallOp.getRank() > maxNumMappingsSupported) {
957 "scf.forall with rank > ")
958 << maxNumMappingsSupported
959 <<
" does not lower for the specified mapping attribute type";
961 auto numParallelIterations =
963 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
965 transformOp, forallOp,
966 "requires statically sized, normalized forall op");
978 template <
typename OpTy,
typename OperationOrBlock>
981 OperationOrBlock *parent,
Value replacement,
983 parent->walk([&](OpTy idOp) {
984 if (availableMappingSizes[
static_cast<int64_t
>(idOp.getDimension())] == 1)
990 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
993 LDBG(
"--start rewriteOneForallCommonImpl");
996 auto numParallelIterations =
998 assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
999 "requires statically sized, normalized forall op");
1002 forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(),
1003 forallOp.getMapping()->getValue().end());
1005 return cast<DeviceMappingAttrInterface>(a).getMappingId() <
1006 cast<DeviceMappingAttrInterface>(b).getMappingId();
1012 DeviceMappingAttrInterface maxMapping =
1013 cast<DeviceMappingAttrInterface>(*std::max_element(
1014 forallMappingAttrs.begin(), forallMappingAttrs.end(), comparator));
1015 DeviceMappingAttrInterface maxLinearMapping;
1016 if (maxMapping.isLinearMapping())
1017 maxLinearMapping = maxMapping;
1020 if (maxLinearMapping && comparator(maxLinearMapping, attr))
1023 if (!forallMappingAttrs.insert(attr))
1026 tmpMappingSizes.push_back(1);
1029 llvm::interleaveComma(
1031 DBGS() <<
"----tmpMappingSizes extracted from scf.forall op: ");
1032 llvm::dbgs() <<
"\n");
1036 forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
1037 LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
1038 DBGS() <<
"----forallMappingSizes: ");
1039 llvm::dbgs() <<
"\n"; llvm::interleaveComma(
1040 forallMappingAttrs,
DBGS() <<
"----forallMappingAttrs: ");
1041 llvm::dbgs() <<
"\n");
1048 bool originalBasisWasProvided = !originalBasis.empty();
1049 if (!originalBasisWasProvided) {
1050 originalBasis = forallMappingSizes;
1051 while (originalBasis.size() < 3)
1052 originalBasis.push_back(1);
1056 gpuIdBuilder.
idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
1062 for (
auto [iv, dim] : llvm::zip_equal(
1063 forallOp.getInductionVars(),
1064 forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
1065 auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
1066 Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
1067 bvm.
map(iv, peIdOp);
1074 if (originalBasisWasProvided) {
1081 llvm::interleaveComma(
1082 activeMappingSizes,
DBGS() <<
"----activeMappingSizes: ");
1083 llvm::dbgs() <<
"\n";
1084 llvm::interleaveComma(
1085 availableMappingSizes,
DBGS() <<
"----availableMappingSizes: ");
1086 llvm::dbgs() <<
"\n";
1087 llvm::interleaveComma(activeIdOps,
DBGS() <<
"----activeIdOps: ");
1088 llvm::dbgs() <<
"\n");
1090 for (
auto [activeId, activeMappingSize, availableMappingSize] :
1091 llvm::zip_equal(activeIdOps, activeMappingSizes,
1092 availableMappingSizes)) {
1093 if (activeMappingSize > availableMappingSize) {
1095 transformOp, forallOp,
1096 "Trying to map to fewer GPU threads than loop iterations but "
1097 "overprovisioning is not yet supported. "
1098 "Try additional tiling of the before mapping or map to more "
1101 if (activeMappingSize == availableMappingSize)
1104 rewriter.
create<arith::ConstantIndexOp>(loc, activeMappingSize);
1105 Value tmpPredicate = rewriter.
create<arith::CmpIOp>(
1106 loc, arith::CmpIPredicate::ult, activeId, idx);
1107 LDBG(
"----predicate: " << tmpPredicate);
1108 predicate = predicate ? rewriter.
create<arith::AndIOp>(loc, predicate,
1116 rewriter.
eraseOp(forallOp.getTerminator());
1121 auto ifOp = rewriter.
create<scf::IfOp>(loc, predicate,
1123 targetBlock = ifOp.thenBlock();
1124 insertionPoint = ifOp.thenBlock()->
begin();
1128 targetBlock = forallOp->getBlock();
1131 Block &sourceBlock = forallOp.getRegion().
front();
1136 for (
Value loopIndex : forallOp.getInductionVars()) {
1144 LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
1145 DBGS() <<
"----result forallMappingSizes: ");
1146 llvm::dbgs() <<
"\n"; llvm::interleaveComma(
1147 mappingIdOps,
DBGS() <<
"----result mappingIdOps: ");
1148 llvm::dbgs() <<
"\n");
1159 RewriterBase &rewriter, TransformOpInterface transformOp,
1162 LDBG(
"Start mapForallToBlocksImpl");
1169 verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
1170 if (!
diag.succeeded())
1175 Block *parentBlock = forallOp->getBlock();
1187 rewriter, transformOp, forallOp,
1188 gridDims, rewriteResult, gpuIdBuilder);
1192 if (!
diag.succeeded())
1196 if (gridDims.empty()) {
1198 while (gridDims.size() < 3)
1199 gridDims.push_back(1);
1201 assert(gridDims.size() == 3 &&
"Need 3-D gridDims");
1205 replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
1213 scf::ForallOp &topLevelForallOp,
1214 TransformOpInterface transformOp) {
1215 auto walkResult = target->
walk([&](scf::ForallOp forallOp) {
1216 if (forallOp->getParentOfType<scf::ForallOp>())
1218 if (topLevelForallOp)
1221 topLevelForallOp = forallOp;
1225 if (walkResult.wasInterrupted() || !topLevelForallOp)
1226 return transformOp.emitSilenceableError()
1227 <<
"could not find a unique topLevel scf.forall";
1234 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
1235 auto transformOp = cast<TransformOpInterface>(getOperation());
1237 if (!getGenerateGpuLaunch() && !gpuLaunch) {
1239 emitSilenceableError()
1240 <<
"Given target is not gpu.launch, set `generate_gpu_launch` "
1242 diag.attachNote(target->
getLoc()) <<
"when applied to this payload op";
1246 scf::ForallOp topLevelForallOp;
1248 target, topLevelForallOp, transformOp);
1249 if (!
diag.succeeded()) {
1250 diag.attachNote(target->
getLoc()) <<
"when applied to this payload op";
1253 assert(topLevelForallOp &&
"expect an scf.forall");
1256 if (!getGenerateGpuLaunch() && gridDims.size() != 3)
1257 return transformOp.emitDefiniteFailure(
"transform require size-3 mapping");
1263 if (getGenerateGpuLaunch()) {
1266 if (!
diag.succeeded())
1271 rewriter.
eraseOp(topLevelForallOp);
1272 topLevelForallOp = cast<scf::ForallOp>(newForallOp);
1276 bool useLinearMapping =
false;
1277 if (topLevelForallOp.getMapping()) {
1278 auto mappingAttr = cast<DeviceMappingAttrInterface>(
1279 topLevelForallOp.getMapping()->getValue().front());
1280 useLinearMapping = mappingAttr.isLinearMapping();
1285 rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
1286 if (!
diag.succeeded())
1292 cast<TransformOpInterface>(getOperation()), gridDims[0],
1293 gridDims[1], gridDims[2]);
1300 if (!getGridDims().empty() && getGridDims().size() != 3) {
1301 return emitOpError() <<
"transform requires empty or size-3 grid_dims";
1311 std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
1313 int factor,
bool useLinearMapping =
false) {
1314 if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
1316 transformOp, forallOp,
1317 Twine(
"3-D mapping: size of threadIdx.x must be a multiple of ") +
1318 std::to_string(factor));
1324 transformOp, forallOp,
1325 Twine(
"the number of required parallel resources (blocks or "
1328 std::string(
" overflows the number of available resources ") +
1339 auto mappingAttr = cast<DeviceMappingAttrInterface>(
1340 forallOp.getMapping()->getValue().front());
1341 bool useLinearMapping = mappingAttr.isLinearMapping();
1344 auto numParallelIterations =
1346 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
1348 transformOp, forallOp,
1349 "requires statically sized, normalized forall op");
1352 if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
1353 factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
1354 }
else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
1359 blockSizes, factor, useLinearMapping);
1360 if (!
diag.succeeded())
1367 .Case([&](GPUWarpgroupMappingAttr) {
1370 .Case([&](GPUWarpMappingAttr) {
1373 .Case([&](GPUThreadMappingAttr) {
1376 .Default([&](DeviceMappingAttrInterface) ->
GpuIdBuilder {
1377 llvm_unreachable(
"unknown mapping attribute");
1383 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
1385 bool syncAfterDistribute) {
1392 verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
1393 if (!
diag.succeeded())
1401 transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
1402 if (!
diag.succeeded())
1412 rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
1413 if (!
diag.succeeded())
1416 if (syncAfterDistribute)
1417 rewriter.
create<BarrierOp>(loc);
1423 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
1425 bool syncAfterDistribute) {
1426 LDBG(
"Start mapNestedForallToThreadsImpl");
1427 if (blockDims.size() != 3) {
1429 "requires size-3 thread mapping");
1436 WalkResult walkResult = target->
walk([&](scf::ForallOp forallOp) {
1438 rewriter, transformOp, forallOp, blockDims, warpSize,
1439 syncAfterDistribute);
1440 if (
diag.isDefiniteFailure())
1442 if (
diag.succeeded())
1451 replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
1460 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
1461 auto transformOp = cast<TransformOpInterface>(getOperation());
1465 return emitSilenceableError() <<
"Given target is not a gpu.launch";
1470 checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
1471 blockDims[0], blockDims[1], blockDims[2]);
1472 if (
diag.isSilenceableFailure()) {
1473 diag.attachNote(getLoc()) << getBlockDimsAttrName() <<
" is too large";
1480 std::nullopt, std::nullopt, blockDims[0], blockDims[1],
1486 getWarpSize(), getSyncAfterDistribute());
1488 results.
push_back(gpuLaunch.getOperation());
1499 class GPUTransformDialectExtension
1501 GPUTransformDialectExtension> {
1503 GPUTransformDialectExtension() {
1504 declareGeneratedDialect<scf::SCFDialect>();
1505 declareGeneratedDialect<arith::ArithDialect>();
1506 declareGeneratedDialect<GPUDialect>();
1507 registerTransformOps<
1509 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
1515 #define GET_OP_CLASSES
1516 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
@ kGlobalMemorySpace
Global memory space identifier.
@ kSharedMemorySpace
Shared memory space identifier.
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Block * getBlock()
Returns the operation block that contains this operation.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
user_range getUsers()
Returns a range of all users.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a specific instance of an effect.
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
Value getValue() const
Return the value the effect is applied on, or nullptr if there isn't a known value being affected.
TypeID getResourceID() const
Return the unique identifier for the base resource class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerTransformDialectExtension(DialectRegistry ®istry)
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
Struct to return the result of the rewrite of a forall operation.
SmallVector< Value > mappingIds
SmallVector< int64_t > mappingSizes
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.