38 #include "llvm/ADT/STLExtras.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/ErrorHandling.h"
43 #include <type_traits>
50 #define DEBUG_TYPE "gpu-transforms"
51 #define DEBUG_TYPE_ALIAS "gpu-transforms-alias"
53 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
54 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
55 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
61 void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
71 llvmTypeConverter, [](AddressSpace space) ->
unsigned {
73 case AddressSpace::Global:
74 return static_cast<unsigned>(
76 case AddressSpace::Workgroup:
77 return static_cast<unsigned>(
79 case AddressSpace::Private:
82 llvm_unreachable(
"unknown address space enum value");
87 llvmTypeConverter.addConversion(
93 transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
94 transform::TypeConverterBuilderOpInterface builder) {
95 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
96 return emitOpError(
"expected LLVMTypeConverter");
100 void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
107 transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
108 transform::TypeConverterBuilderOpInterface builder) {
109 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
110 return emitOpError(
"expected LLVMTypeConverter");
114 void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
121 LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
122 verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
123 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
124 return emitOpError(
"expected LLVMTypeConverter");
142 static std::optional<SmallVector<int64_t>>
148 order.push_back(index);
152 llvm::SmallDenseSet<int64_t> dims;
154 dims.insert(cast<AffineDimExpr>(expr).getPosition());
159 order.push_back(index);
165 order.push_back(index);
173 static std::optional<SmallVector<int64_t>>
175 if (
auto contract = dyn_cast<vector::ContractionOp>(op)) {
176 int64_t contractRank =
contract.getIteratorTypes().size();
177 if (contractRank < 3)
180 nativeSize.append({m, n, k});
183 if (
auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
184 int64_t writeRank = writeOp.getVectorType().getRank();
188 nativeSize.append({m, n});
191 if (
auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
194 VectorType sliceType;
196 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
199 auto vecType = cast<VectorType>(extract.getResult().getType());
200 if (sliceType && sliceType != vecType)
204 return llvm::to_vector(sliceType.getShape());
207 if (
auto vecType = dyn_cast<VectorType>(op->
getResultTypes()[0])) {
210 if (vecType.getRank() < 2)
217 VectorType sliceType;
219 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
222 auto vecType = cast<VectorType>(extract.getResult().getType());
223 if (sliceType && sliceType != vecType)
228 return llvm::to_vector(sliceType.getShape());
233 nativeSize.append({m, n});
240 void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
243 auto contract = dyn_cast<vector::ContractionOp>(op);
258 .setNativeShapeFn(nativeShapeFn)
259 .setUnrollTraversalOrderFn(unrollOrder));
276 struct MappingKind {};
277 struct BlockMappingKind : MappingKind {};
278 struct ThreadMappingKind : MappingKind {};
283 Operation *target,
const Twine &message) {
284 if (transformOp.has_value())
285 return transformOp->emitDefiniteFailure() << message;
290 template <
typename MappingKindType>
293 scf::ForallOp forallOp) {
294 if (!forallOp.getMapping().has_value()) {
296 "scf.forall op requires a mapping attribute");
299 bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
300 llvm::IsaPred<GPUBlockMappingAttr>);
301 bool hasWarpgroupMapping = llvm::any_of(
302 forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
303 bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
304 llvm::IsaPred<GPUWarpMappingAttr>);
305 bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
306 llvm::IsaPred<GPUThreadMappingAttr>);
307 int64_t countMappingTypes = 0;
308 countMappingTypes += hasBlockMapping ? 1 : 0;
309 countMappingTypes += hasWarpgroupMapping ? 1 : 0;
310 countMappingTypes += hasWarpMapping ? 1 : 0;
311 countMappingTypes += hasThreadMapping ? 1 : 0;
312 if (countMappingTypes > 1) {
314 transformOp, forallOp,
315 "cannot mix different mapping types, use nesting");
317 if (std::is_same<MappingKindType, BlockMappingKind>::value &&
320 transformOp, forallOp,
321 "scf.forall op requires a mapping attribute of kind 'block'");
323 if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
324 !hasThreadMapping && !hasWarpMapping && !hasWarpgroupMapping) {
326 "scf.forall op requires a mapping attribute "
327 "of kind 'thread' or 'warp'");
331 for (
Attribute map : forallOp.getMapping()->getValue()) {
332 if (seen.contains(map)) {
334 transformOp, forallOp,
335 "duplicate attribute, cannot map different loops "
336 "to the same mapping id");
342 return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
344 if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
345 !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
347 transformOp, forallOp,
348 "cannot mix linear and non-linear mapping modes");
354 template <
typename MappingKindType>
357 scf::ForallOp forallOp) {
360 checkMappingAttributeTypes<MappingKindType>(transformOp, forallOp);
365 if (!forallOp.isNormalized())
367 "unsupported non-normalized loops");
368 if (forallOp.getNumResults() > 0)
370 "only bufferized scf.forall can be mapped");
371 bool useLinearMapping = cast<DeviceMappingAttrInterface>(
372 forallOp.getMapping()->getValue().front())
376 int64_t maxNumMappingsSupported =
377 useLinearMapping ? (getMaxEnumValForMappingId() -
378 static_cast<uint64_t
>(MappingId::DimZ))
380 if (forallOp.getRank() > maxNumMappingsSupported) {
382 "scf.forall with rank > ")
383 << maxNumMappingsSupported
384 <<
" does not lower for the specified mapping attribute type";
386 auto numParallelIterations =
388 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
390 transformOp, forallOp,
391 "requires statically sized, normalized forall op");
403 template <
typename OpTy,
typename OperationOrBlock>
406 OperationOrBlock *parent,
Value replacement,
408 parent->walk([&](OpTy idOp) {
409 if (availableMappingSizes[
static_cast<int64_t
>(idOp.getDimension())] == 1)
415 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
418 LDBG(
"--start rewriteOneForallCommonImpl");
421 auto numParallelIterations =
423 assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
424 "requires statically sized, normalized forall op");
427 forallMappingAttrs.insert(forallOp.getMapping()->getValue().begin(),
428 forallOp.getMapping()->getValue().end());
430 return cast<DeviceMappingAttrInterface>(a).getMappingId() <
431 cast<DeviceMappingAttrInterface>(b).getMappingId();
437 DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>(
438 *llvm::max_element(forallMappingAttrs, comparator));
439 DeviceMappingAttrInterface maxLinearMapping;
440 if (maxMapping.isLinearMapping())
441 maxLinearMapping = maxMapping;
444 if (maxLinearMapping && comparator(maxLinearMapping, attr))
447 if (!forallMappingAttrs.insert(attr))
450 tmpMappingSizes.push_back(1);
453 llvm::interleaveComma(
455 DBGS() <<
"----tmpMappingSizes extracted from scf.forall op: ");
456 llvm::dbgs() <<
"\n");
460 forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
461 LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
462 DBGS() <<
"----forallMappingSizes: ");
463 llvm::dbgs() <<
"\n"; llvm::interleaveComma(
464 forallMappingAttrs,
DBGS() <<
"----forallMappingAttrs: ");
465 llvm::dbgs() <<
"\n");
472 bool originalBasisWasProvided = !originalBasis.empty();
473 if (!originalBasisWasProvided) {
474 originalBasis = forallMappingSizes;
475 while (originalBasis.size() < 3)
476 originalBasis.push_back(1);
480 gpuIdBuilder.
idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
486 for (
auto [iv, dim] : llvm::zip_equal(
487 forallOp.getInductionVars(),
488 forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
489 auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
490 Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
498 if (originalBasisWasProvided) {
505 llvm::interleaveComma(
506 activeMappingSizes,
DBGS() <<
"----activeMappingSizes: ");
507 llvm::dbgs() <<
"\n";
508 llvm::interleaveComma(
509 availableMappingSizes,
DBGS() <<
"----availableMappingSizes: ");
510 llvm::dbgs() <<
"\n";
511 llvm::interleaveComma(activeIdOps,
DBGS() <<
"----activeIdOps: ");
512 llvm::dbgs() <<
"\n");
514 for (
auto [activeId, activeMappingSize, availableMappingSize] :
515 llvm::zip_equal(activeIdOps, activeMappingSizes,
516 availableMappingSizes)) {
517 if (activeMappingSize > availableMappingSize) {
519 transformOp, forallOp,
520 "Trying to map to fewer GPU threads than loop iterations but "
521 "overprovisioning is not yet supported. "
522 "Try additional tiling of the before mapping or map to more "
525 if (activeMappingSize == availableMappingSize)
528 rewriter.
create<arith::ConstantIndexOp>(loc, activeMappingSize);
529 Value tmpPredicate = rewriter.
create<arith::CmpIOp>(
530 loc, arith::CmpIPredicate::ult, activeId, idx);
531 LDBG(
"----predicate: " << tmpPredicate);
532 predicate = predicate ? rewriter.
create<arith::AndIOp>(loc, predicate,
540 rewriter.
eraseOp(forallOp.getTerminator());
545 auto ifOp = rewriter.
create<scf::IfOp>(loc, predicate,
547 targetBlock = ifOp.thenBlock();
548 insertionPoint = ifOp.thenBlock()->
begin();
552 targetBlock = forallOp->getBlock();
555 Block &sourceBlock = forallOp.getRegion().
front();
560 for (
Value loopIndex : forallOp.getInductionVars()) {
568 LLVM_DEBUG(llvm::interleaveComma(forallMappingSizes,
569 DBGS() <<
"----result forallMappingSizes: ");
570 llvm::dbgs() <<
"\n"; llvm::interleaveComma(
571 mappingIdOps,
DBGS() <<
"----result mappingIdOps: ");
572 llvm::dbgs() <<
"\n");
583 RewriterBase &rewriter, TransformOpInterface transformOp,
586 LDBG(
"Start mapForallToBlocksImpl");
593 verifyGpuMapping<BlockMappingKind>(transformOp, forallOp);
594 if (!
diag.succeeded())
599 Block *parentBlock = forallOp->getBlock();
611 rewriter, transformOp, forallOp,
612 gridDims, rewriteResult, gpuIdBuilder);
616 if (!
diag.succeeded())
620 if (gridDims.empty()) {
622 while (gridDims.size() < 3)
623 gridDims.push_back(1);
625 assert(gridDims.size() == 3 &&
"Need 3-D gridDims");
629 replaceUnitMappingIdsHelper<BlockDimOp>(rewriter, loc, parentBlock, zero,
637 scf::ForallOp &topLevelForallOp,
638 TransformOpInterface transformOp) {
639 auto walkResult = target->
walk([&](scf::ForallOp forallOp) {
640 if (forallOp->getParentOfType<scf::ForallOp>())
642 if (topLevelForallOp)
645 topLevelForallOp = forallOp;
649 if (walkResult.wasInterrupted() || !topLevelForallOp)
650 return transformOp.emitSilenceableError()
651 <<
"could not find a unique topLevel scf.forall";
658 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
659 auto transformOp = cast<TransformOpInterface>(getOperation());
661 if (!getGenerateGpuLaunch() && !gpuLaunch) {
663 emitSilenceableError()
664 <<
"Given target is not gpu.launch, set `generate_gpu_launch` "
666 diag.attachNote(target->
getLoc()) <<
"when applied to this payload op";
670 scf::ForallOp topLevelForallOp;
672 target, topLevelForallOp, transformOp);
673 if (!
diag.succeeded()) {
674 diag.attachNote(target->
getLoc()) <<
"when applied to this payload op";
677 assert(topLevelForallOp &&
"expect an scf.forall");
680 if (!getGenerateGpuLaunch() && gridDims.size() != 3)
681 return transformOp.emitDefiniteFailure(
"transform require size-3 mapping");
687 if (getGenerateGpuLaunch()) {
690 if (!
diag.succeeded())
695 rewriter.
eraseOp(topLevelForallOp);
696 topLevelForallOp = cast<scf::ForallOp>(newForallOp);
700 bool useLinearMapping =
false;
701 if (topLevelForallOp.getMapping()) {
702 auto mappingAttr = cast<DeviceMappingAttrInterface>(
703 topLevelForallOp.getMapping()->getValue().front());
704 useLinearMapping = mappingAttr.isLinearMapping();
709 rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
710 if (!
diag.succeeded())
716 cast<TransformOpInterface>(getOperation()), gridDims[0],
717 gridDims[1], gridDims[2]);
724 if (!getGridDims().empty() && getGridDims().size() != 3) {
725 return emitOpError() <<
"transform requires empty or size-3 grid_dims";
735 std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
737 int factor,
bool useLinearMapping =
false) {
738 if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
740 transformOp, forallOp,
741 Twine(
"3-D mapping: size of threadIdx.x must be a multiple of ") +
742 std::to_string(factor));
748 transformOp, forallOp,
749 Twine(
"the number of required parallel resources (blocks or "
752 std::string(
" overflows the number of available resources ") +
763 auto mappingAttr = cast<DeviceMappingAttrInterface>(
764 forallOp.getMapping()->getValue().front());
765 bool useLinearMapping = mappingAttr.isLinearMapping();
768 auto numParallelIterations =
770 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
772 transformOp, forallOp,
773 "requires statically sized, normalized forall op");
776 if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
777 factor = GpuWarpgroupIdBuilder::kNumWarpsPerGroup * warpSize;
778 }
else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
783 blockSizes, factor, useLinearMapping);
784 if (!
diag.succeeded())
791 .Case([&](GPUWarpgroupMappingAttr) {
794 .Case([&](GPUWarpMappingAttr) {
797 .Case([&](GPUThreadMappingAttr) {
800 .Default([&](DeviceMappingAttrInterface) ->
GpuIdBuilder {
801 llvm_unreachable(
"unknown mapping attribute");
807 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
809 bool syncAfterDistribute) {
816 verifyGpuMapping<ThreadMappingKind>(transformOp, forallOp);
817 if (!
diag.succeeded())
825 transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
826 if (!
diag.succeeded())
836 rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
837 if (!
diag.succeeded())
840 if (syncAfterDistribute)
841 rewriter.
create<BarrierOp>(loc);
847 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
849 bool syncAfterDistribute) {
850 LDBG(
"Start mapNestedForallToThreadsImpl");
851 if (blockDims.size() != 3) {
853 "requires size-3 thread mapping");
860 WalkResult walkResult = target->
walk([&](scf::ForallOp forallOp) {
862 rewriter, transformOp, forallOp, blockDims, warpSize,
863 syncAfterDistribute);
864 if (
diag.isDefiniteFailure())
866 if (
diag.succeeded())
875 replaceUnitMappingIdsHelper<ThreadIdOp>(rewriter, loc, target, zero,
884 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(target);
885 auto transformOp = cast<TransformOpInterface>(getOperation());
889 return emitSilenceableError() <<
"Given target is not a gpu.launch";
894 checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
895 blockDims[0], blockDims[1], blockDims[2]);
896 if (
diag.isSilenceableFailure()) {
897 diag.attachNote(getLoc()) << getBlockDimsAttrName() <<
" is too large";
904 std::nullopt, std::nullopt, blockDims[0], blockDims[1],
910 getWarpSize(), getSyncAfterDistribute());
912 results.
push_back(gpuLaunch.getOperation());
923 class GPUTransformDialectExtension
925 GPUTransformDialectExtension> {
929 GPUTransformDialectExtension() {
930 declareGeneratedDialect<scf::SCFDialect>();
931 declareGeneratedDialect<arith::ArithDialect>();
932 declareGeneratedDialect<GPUDialect>();
933 registerTransformOps<
935 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
941 #define GET_OP_CLASSES
942 #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
OpListType & getOperations()
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
user_range getUsers()
Returns a range of all users.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
Specialization of arith.constant op that returns an integer of index type.
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
@ kGlobalMemorySpace
Global memory space identifier.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void registerTransformDialectExtension(DialectRegistry ®istry)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
Include the generated interface declarations.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert from the GPU dialect to NVVM.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type)
Return the LLVMStructureType corresponding to the MMAMatrixType type.
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns)
Erase barriers that do not enforce conflicting memory side effects.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Populate GpuSubgroupReduce pattern to NVVM.
Struct to return the result of the rewrite of a forall operation.
SmallVector< Value > mappingIds
SmallVector< int64_t > mappingSizes
Options that control the vector unrolling.