41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/SmallVector.h"
43#include "llvm/ADT/TypeSwitch.h"
44#include "llvm/Support/DebugLog.h"
45#include "llvm/Support/ErrorHandling.h"
46#include "llvm/Support/InterleavedRange.h"
47#include "llvm/Support/LogicalResult.h"
56#define DEBUG_TYPE "gpu-transforms"
62void transform::ApplyGPUToNVVMConversionPatternsOp::populatePatterns(
72transform::ApplyGPUToNVVMConversionPatternsOp::verifyTypeConverter(
73 transform::TypeConverterBuilderOpInterface builder) {
74 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
79void transform::ApplyGPUWwmaToNVVMConversionPatternsOp::populatePatterns(
86transform::ApplyGPUWwmaToNVVMConversionPatternsOp::verifyTypeConverter(
87 transform::TypeConverterBuilderOpInterface builder) {
88 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
93void transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
100LogicalResult transform::ApplyGPUSubgroupReduceToNVVMConversionPatternsOp::
101 verifyTypeConverter(transform::TypeConverterBuilderOpInterface builder) {
102 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
107void transform::ApplyGPUToROCDLConversionPatternsOp::populatePatterns(
111 FailureOr<amdgpu::Chipset> maybeChipset =
113 assert(llvm::succeeded(maybeChipset) &&
"expected valid chipset");
119transform::ApplyGPUToROCDLConversionPatternsOp::verifyTypeConverter(
120 transform::TypeConverterBuilderOpInterface builder) {
121 FailureOr<amdgpu::Chipset> maybeChipset =
123 if (
failed(maybeChipset)) {
124 return emitOpError(
"Invalid chipset name: " + getChipset());
126 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
139void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns(
141 std::optional<StringRef> chipsetName = getChipset();
142 std::optional<amdgpu::Chipset> maybeChipset;
144 FailureOr<amdgpu::Chipset> parsedChipset =
146 assert(llvm::succeeded(parsedChipset) &&
"expected valid chipset");
147 maybeChipset = parsedChipset;
159static std::optional<SmallVector<int64_t>>
163 for (
auto [
index, iter] : llvm::enumerate(
contract.getIteratorTypes())) {
165 order.push_back(
index);
169 llvm::SmallDenseSet<int64_t> dims;
171 dims.insert(cast<AffineDimExpr>(expr).getPosition());
174 for (
auto [
index, iter] : llvm::enumerate(
contract.getIteratorTypes())) {
176 order.push_back(
index);
180 for (
auto [
index, iter] : llvm::enumerate(
contract.getIteratorTypes())) {
182 order.push_back(
index);
190static std::optional<SmallVector<int64_t>>
192 if (
auto contract = dyn_cast<vector::ContractionOp>(op)) {
194 if (contractRank < 3)
197 nativeSize.append({m, n, k});
200 if (
auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
201 int64_t writeRank = writeOp.getVectorType().getRank();
205 nativeSize.append({m, n});
208 if (
auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
211 VectorType sliceType;
213 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
216 auto vecType = cast<VectorType>(extract.getResult().getType());
217 if (sliceType && sliceType != vecType)
221 return llvm::to_vector(sliceType.getShape());
224 if (
auto vecType = dyn_cast<VectorType>(op->
getResultTypes()[0])) {
227 if (vecType.getRank() < 2)
234 VectorType sliceType;
236 auto extract = dyn_cast<vector::ExtractStridedSliceOp>(users);
239 auto vecType = cast<VectorType>(extract.getResult().getType());
240 if (sliceType && sliceType != vecType)
245 return llvm::to_vector(sliceType.getShape());
250 nativeSize.append({m, n});
257void transform::ApplyUnrollVectorsSubgroupMmaOp::populatePatterns(
260 auto contract = dyn_cast<vector::ContractionOp>(op);
273 vector::populateVectorUnrollPatterns(
275 .setNativeShapeFn(nativeShapeFn)
276 .setUnrollTraversalOrderFn(unrollOrder));
293struct MappingKind {};
294struct BlockMappingKind : MappingKind {};
295struct ThreadMappingKind : MappingKind {};
301 if (transformOp.has_value())
302 return transformOp->emitDefiniteFailure() << message;
307template <
typename MappingKindType>
310 scf::ForallOp forallOp) {
311 if (!forallOp.getMapping().has_value()) {
313 "scf.forall op requires a mapping attribute");
316 bool hasBlockMapping = llvm::any_of(forallOp.getMapping().value(),
317 llvm::IsaPred<GPUBlockMappingAttr>);
318 bool hasWarpgroupMapping = llvm::any_of(
319 forallOp.getMapping().value(), llvm::IsaPred<GPUWarpgroupMappingAttr>);
320 bool hasWarpMapping = llvm::any_of(forallOp.getMapping().value(),
321 llvm::IsaPred<GPUWarpMappingAttr>);
322 bool hasThreadMapping = llvm::any_of(forallOp.getMapping().value(),
323 llvm::IsaPred<GPUThreadMappingAttr>);
324 bool hasLaneMapping = llvm::any_of(forallOp.getMapping().value(),
325 llvm::IsaPred<GPULaneMappingAttr>);
327 countMappingTypes += hasBlockMapping ? 1 : 0;
328 countMappingTypes += hasWarpgroupMapping ? 1 : 0;
329 countMappingTypes += hasWarpMapping ? 1 : 0;
330 countMappingTypes += hasThreadMapping ? 1 : 0;
331 countMappingTypes += hasLaneMapping ? 1 : 0;
332 if (countMappingTypes > 1) {
334 transformOp, forallOp,
335 "cannot mix different mapping types, use nesting");
337 if (std::is_same<MappingKindType, BlockMappingKind>::value &&
340 transformOp, forallOp,
341 "scf.forall op requires a mapping attribute of kind 'block'");
343 if (std::is_same<MappingKindType, ThreadMappingKind>::value &&
344 !hasLaneMapping && !hasThreadMapping && !hasWarpMapping &&
345 !hasWarpgroupMapping) {
347 "scf.forall op requires a mapping attribute "
348 "of kind 'thread' or 'warp'");
352 for (
Attribute map : forallOp.getMapping()->getValue()) {
353 if (seen.contains(map)) {
355 transformOp, forallOp,
356 "duplicate attribute, cannot map different loops "
357 "to the same mapping id");
362 auto isLinear = [](DeviceMappingAttrInterface attr) {
363 return attr.isLinearMapping();
365 if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) &&
366 !llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) {
368 transformOp, forallOp,
369 "cannot mix linear and non-linear mapping modes");
372 FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
373 forallOp.getDeviceMaskingAttr();
374 if (succeeded(maybeMaskingAttr) && *maybeMaskingAttr &&
375 !forallOp.usesLinearMapping()) {
377 transformOp, forallOp,
378 "device masking is only available in linear mapping mode");
384template <
typename MappingKindType>
387 scf::ForallOp forallOp) {
395 if (!forallOp.isNormalized())
397 "unsupported non-normalized loops");
398 if (forallOp.getNumResults() > 0)
400 "only bufferized scf.forall can be mapped");
401 bool useLinearMapping = forallOp.usesLinearMapping();
404 int64_t maxNumMappingsSupported =
405 useLinearMapping ? (getMaxEnumValForMappingId() -
406 static_cast<uint64_t
>(MappingId::DimZ))
408 if (forallOp.getRank() > maxNumMappingsSupported) {
410 "scf.forall with rank > ")
411 << maxNumMappingsSupported
412 <<
" does not lower for the specified mapping attribute type";
414 auto numParallelIterations =
416 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
418 transformOp, forallOp,
419 "requires statically sized, normalized forall op");
431template <
typename OpTy,
typename OperationOrBlock>
436 parent->walk([&](OpTy idOp) {
437 if (availableMappingSizes[
static_cast<int64_t>(idOp.getDimension())] == 1)
443 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
446 LDBG() <<
"--start rewriteOneForallCommonImpl";
449 auto numParallelIterations =
451 assert(forallOp.isNormalized() && numParallelIterations.has_value() &&
452 "requires statically sized, normalized forall op");
455 forallOp.getDeviceMappingAttrs();
457 forallMappingAttrs.insert_range(forallMappingAttrsVec);
459 return cast<DeviceMappingAttrInterface>(a).getMappingId() <
460 cast<DeviceMappingAttrInterface>(
b).getMappingId();
466 DeviceMappingAttrInterface maxMapping = cast<DeviceMappingAttrInterface>(
467 *llvm::max_element(forallMappingAttrs, comparator));
468 DeviceMappingAttrInterface maxLinearMapping;
469 if (maxMapping.isLinearMapping())
470 maxLinearMapping = maxMapping;
473 if (maxLinearMapping && comparator(maxLinearMapping, attr))
476 if (!forallMappingAttrs.insert(attr))
479 tmpMappingSizes.push_back(1);
481 LDBG() <<
"----tmpMappingSizes extracted from scf.forall op: "
482 << llvm::interleaved(tmpMappingSizes);
486 forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator);
487 LDBG() <<
"----forallMappingSizes: " << llvm::interleaved(forallMappingSizes);
488 LDBG() <<
"----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs);
495 bool originalBasisWasProvided = !originalBasis.empty();
496 if (!originalBasisWasProvided) {
497 LDBG() <<
"----originalBasis was not provided, deriving it and there will "
500 originalBasis = forallMappingSizes;
501 while (originalBasis.size() < 3)
502 originalBasis.push_back(1);
504 LDBG() <<
"----originalBasis was provided, using it, there will be "
507 LDBG() <<
"------originalBasis: " << llvm::interleaved(originalBasis);
510 gpuIdBuilder.
idBuilder(rewriter, loc, forallMappingSizes, originalBasis);
511 if (!builderResult.
errorMsg.empty())
514 LDBG() << builderResult;
520 for (
auto [iv, dim] : llvm::zip_equal(
521 forallOp.getInductionVars(),
522 forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) {
523 auto mappingAttr = cast<DeviceMappingAttrInterface>(dim);
524 Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()];
525 LDBG() <<
"----map: " << iv <<
" to " << peIdOp;
533 if (originalBasisWasProvided) {
535 predicate = predicate ? arith::AndIOp::create(rewriter, loc, predicate,
543 rewriter.
eraseOp(forallOp.getTerminator());
548 auto ifOp = scf::IfOp::create(rewriter, loc, predicate,
550 targetBlock = ifOp.thenBlock();
551 insertionPoint = ifOp.thenBlock()->
begin();
555 targetBlock = forallOp->getBlock();
558 Block &sourceBlock = forallOp.getRegion().
front();
563 for (
Value loopIndex : forallOp.getInductionVars()) {
571 LDBG() <<
"----result forallMappingSizes: "
572 << llvm::interleaved(forallMappingSizes);
573 LDBG() <<
"----result mappingIdOps: " << llvm::interleaved(mappingIdOps);
584 RewriterBase &rewriter, TransformOpInterface transformOp,
587 LDBG() <<
"Start mapForallToBlocksImpl";
595 if (!
diag.succeeded())
600 Block *parentBlock = forallOp->getBlock();
612 rewriter, transformOp, forallOp,
613 gridDims, rewriteResult, gpuIdBuilder);
617 if (!
diag.succeeded())
621 if (gridDims.empty()) {
623 while (gridDims.size() < 3)
624 gridDims.push_back(1);
626 assert(gridDims.size() == 3 &&
"Need 3-D gridDims");
638 scf::ForallOp &topLevelForallOp,
639 TransformOpInterface transformOp) {
640 auto walkResult =
target->walk([&](scf::ForallOp forallOp) {
641 if (forallOp->getParentOfType<scf::ForallOp>())
643 if (topLevelForallOp)
646 topLevelForallOp = forallOp;
650 if (walkResult.wasInterrupted() || !topLevelForallOp)
651 return transformOp.emitSilenceableError()
652 <<
"could not find a unique topLevel scf.forall";
659 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(
target);
660 auto transformOp = cast<TransformOpInterface>(getOperation());
662 if (!getGenerateGpuLaunch() && !gpuLaunch) {
664 emitSilenceableError()
665 <<
"Given target is not gpu.launch, set `generate_gpu_launch` "
667 diag.attachNote(
target->getLoc()) <<
"when applied to this payload op";
671 scf::ForallOp topLevelForallOp;
673 target, topLevelForallOp, transformOp);
674 if (!
diag.succeeded()) {
675 diag.attachNote(
target->getLoc()) <<
"when applied to this payload op";
678 assert(topLevelForallOp &&
"expect an scf.forall");
681 if (!getGenerateGpuLaunch() && gridDims.size() != 3)
682 return transformOp.emitDefiniteFailure(
"transform require size-3 mapping");
688 if (getGenerateGpuLaunch()) {
691 if (!
diag.succeeded())
696 rewriter.
eraseOp(topLevelForallOp);
697 topLevelForallOp = cast<scf::ForallOp>(newForallOp);
701 bool useLinearMapping =
false;
702 if (topLevelForallOp.getMapping())
703 useLinearMapping = topLevelForallOp.usesLinearMapping();
705 FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
706 topLevelForallOp.getDeviceMaskingAttr();
707 assert(succeeded(maybeMaskingAttr) &&
"unexpected failed maybeMaskingAttr");
708 assert((!*maybeMaskingAttr || useLinearMapping) &&
709 "masking requires linear mapping");
715 rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
716 if (!
diag.succeeded())
722 cast<TransformOpInterface>(getOperation()), gridDims[0],
723 gridDims[1], gridDims[2]);
729LogicalResult transform::MapForallToBlocks::verify() {
730 if (!getGridDims().empty() && getGridDims().size() != 3) {
731 return emitOpError() <<
"transform requires empty or size-3 grid_dims";
741 std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
743 int factor,
bool useLinearMapping =
false) {
744 if (llvm::any_of(blockOrGridSizes, [](
int64_t i) {
return i <= 0; })) {
746 "block/grid sizes must be strictly positive");
748 if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
750 transformOp, forallOp,
751 Twine(
"3-D mapping: size of threadIdx.x must be a multiple of ") +
755 bool hasZeroParallelIteration =
756 llvm::any_of(numParallelIterations, [](
int64_t i) {
return i == 0; });
760 int64_t requiredResourceCount =
761 hasZeroParallelIteration ? 0
764 if (requiredResourceCount > availableResourceCount) {
766 transformOp, forallOp,
767 Twine(
"the number of required parallel resources (blocks or "
769 Twine(requiredResourceCount) +
770 " overflows the number of available resources " +
771 Twine(availableResourceCount));
781 DeviceMappingAttrInterface mappingAttr =
782 forallOp.getDeviceMappingAttrs().front();
783 bool useLinearMapping = mappingAttr.isLinearMapping();
786 auto numParallelIterations =
788 if (!forallOp.isNormalized() || !numParallelIterations.has_value()) {
790 transformOp, forallOp,
791 "requires statically sized, normalized forall op");
794 if (isa<GPUWarpgroupMappingAttr>(mappingAttr)) {
796 }
else if (isa<GPUWarpMappingAttr>(mappingAttr)) {
801 blockSizes, factor, useLinearMapping);
802 if (!
diag.succeeded())
805 FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
806 forallOp.getDeviceMaskingAttr();
807 assert(succeeded(maybeMaskingAttr) &&
"unexpected failed maybeMaskingAttr");
808 assert((!*maybeMaskingAttr || useLinearMapping) &&
809 "masking requires linear mapping");
815 .Case([&](GPUWarpgroupMappingAttr) {
819 .Case([&](GPUWarpMappingAttr) {
823 .Case([&](GPUThreadMappingAttr) {
826 .Case([&](GPULaneMappingAttr) {
830 .DefaultUnreachable(
"unknown mapping attribute");
835 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
837 bool syncAfterDistribute) {
845 if (!
diag.succeeded())
853 transformOp, forallOp, blockSizes, warpSize, gpuIdBuilder);
854 if (!
diag.succeeded())
864 rewriter, transformOp, forallOp, blockSizes, rewriteResult, gpuIdBuilder);
865 if (!
diag.succeeded())
868 if (syncAfterDistribute)
869 BarrierOp::create(rewriter, loc);
875 RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp,
877 bool syncAfterDistribute) {
878 LDBG() <<
"Start mapNestedForallToThreadsImpl";
879 if (blockDims.size() != 3) {
881 "requires size-3 thread mapping");
890 rewriter, transformOp, forallOp, blockDims, warpSize,
891 syncAfterDistribute);
892 if (
diag.isDefiniteFailure())
894 if (
diag.succeeded())
912 LaunchOp gpuLaunch = dyn_cast<LaunchOp>(
target);
913 auto transformOp = cast<TransformOpInterface>(getOperation());
917 return emitSilenceableError() <<
"Given target is not a gpu.launch";
922 checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt,
923 blockDims[0], blockDims[1], blockDims[2]);
924 if (
diag.isSilenceableFailure()) {
925 diag.attachNote(getLoc()) << getBlockDimsAttrName() <<
" is too large";
932 std::nullopt, std::nullopt, blockDims[0], blockDims[1],
938 getWarpSize(), getSyncAfterDistribute());
940 results.
push_back(gpuLaunch.getOperation());
951class GPUTransformDialectExtension
953 GPUTransformDialectExtension> {
957 GPUTransformDialectExtension() {
958 declareGeneratedDialect<GPUDialect>();
959 declareGeneratedDialect<amdgpu::AMDGPUDialect>();
960 declareGeneratedDialect<arith::ArithDialect>();
961 declareGeneratedDialect<scf::SCFDialect>();
962 registerTransformOps<
964#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
970#define GET_OP_CLASSES
971#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static std::string diag(const llvm::Value &value)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
OpListType::iterator iterator
OpListType & getOperations()
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
bool succeeded() const
Returns true if this is a success.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
user_range getUsers()
Returns a range of all users.
unsigned getNumResults()
Return the number of results held by this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
bool wasInterrupted() const
Returns true if the walk was interrupted.
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
void registerTransformDialectExtension(DialectRegistry ®istry)
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Include the generated interface declarations.
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, gpu::amd::Runtime runtime, amdgpu::Chipset chipset)
Collect a set of patterns to convert from the GPU dialect to ROCDL.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
void populateGpuRewritePatterns(RewritePatternSet &patterns)
Collect all patterns to rewrite ops within the GPU dialect.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter)
Configure the LLVM type convert to convert types and address spaces from the GPU dialect to NVVM.
void populateGpuSubgroupReduceOpLoweringPattern(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate GpuSubgroupReduce pattern to NVVM.
void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert from the GPU dialect to NVVM.
llvm::TypeSwitch< T, ResultT > TypeSwitch
void populateGpuPromoteShuffleToAMDGPUPatterns(RewritePatternSet &patterns, std::optional< amdgpu::Chipset > maybeChipset)
Tries to promote gpu.shuffles to specialized AMDGPU intrinsics.
std::optional< SmallVector< int64_t > > getConstantIntValues(ArrayRef< OpFoldResult > ofrs)
If all ofrs are constant integers or IntegerAttrs, return the integers.
SmallVector< Value > getValuesSortedByKey(ArrayRef< Attribute > keys, ArrayRef< Value > values, llvm::function_ref< bool(Attribute, Attribute)> compare)
Helper to sort values according to matching keys.
void populateGpuEliminateBarriersPatterns(RewritePatternSet &patterns)
Erase barriers that do not enforce conflicting memory side effects.
void populateGpuWMMAToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM.
Struct to return the result of the rewrite of a forall operation.
SmallVector< Value > mappingIds
SmallVector< int64_t > mappingSizes
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Options that control the vector unrolling.