32 #include "llvm/ADT/ArrayRef.h"
40 #define DEBUG_TYPE "nvgpu-transforms"
46 void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
53 llvmTypeConverter, [](gpu::AddressSpace space) ->
unsigned {
55 case gpu::AddressSpace::Global:
56 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
57 case gpu::AddressSpace::Workgroup:
58 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
59 case gpu::AddressSpace::Private:
62 llvm_unreachable(
"unknown address space enum value");
63 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
65 llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) ->
Type {
66 return llvmTypeConverter.convertType(
69 llvmTypeConverter.addConversion([&](MBarrierTokenType type) ->
Type {
70 return llvmTypeConverter.convertType(
73 llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) ->
Type {
74 Type elemType = type.getFragmented().getElementType();
75 int64_t sizeM = type.getFragmented().getDimSize(0);
76 int64_t sizeN = type.getFragmented().getDimSize(1);
80 numMembers = sizeN / 2;
81 else if (elemType.
isF16())
82 numMembers = sizeN / 4;
84 llvm_unreachable(
"unsupported type for warpgroup accumulator");
87 for (
unsigned i = 0; i < numMembers; i++)
88 innerStructBody.push_back(elemType);
89 auto innerStructType =
90 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
94 structBody.push_back(innerStructType);
97 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
98 return llvmTypeConverter.convertType(convertedType);
100 llvmTypeConverter.addConversion([&](MBarrierGroupType type) ->
Type {
101 return llvmTypeConverter.convertType(
104 llvmTypeConverter.addConversion(
105 [&](WarpgroupMatrixDescriptorType type) ->
Type {
106 return llvmTypeConverter.convertType(
109 llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) ->
Type {
115 LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
116 TypeConverterBuilderOpInterface builder) {
117 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
118 return emitOpError(
"expected LLVMTypeConverter");
126 void CreateAsyncGroupsOp::getEffects(
154 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.
getMemorySpace());
156 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
163 auto load = dyn_cast<vector::TransferReadOp>(op);
167 auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
176 auto store = dyn_cast<vector::TransferWriteOp>(op);
177 if (!store || store.getVector() != v)
180 auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
210 if (op.getNumRegions() > 0)
213 if (isa<gpu::BarrierOp>(op)) {
214 barriers.insert(&op);
218 if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) {
220 ops.insert(std::make_move_iterator(barriers.begin()),
221 std::make_move_iterator(barriers.end()));
222 assert(barriers.empty() &&
223 "expected to have moved the barriers into another set");
243 unsigned iteration,
unsigned depth) {
246 auto waitOp = dyn_cast<DeviceAsyncWaitOp>(op);
247 if (!waitOp || waitOp.getNumGroups())
250 int numGroupInFlight = 0;
253 numGroupInFlight = depth - 1;
260 numGroupInFlight = depth - 1 - iteration;
262 waitOp.setNumGroups(numGroupInFlight);
277 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
281 return visited->
getBlock() == forOp.getBody();
284 for (
Operation &op : forOp.getBody()->getOperations()) {
285 if (stage0Ops.contains(&op)) {
287 assert(result.succeeded() &&
"expected a backward slice");
292 for (
Operation &op : forOp.getBody()->getOperations()) {
293 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
294 opsWithPipelineStages.emplace_back(&op, depth);
296 for (
Operation &op : forOp.getBody()->getOperations()) {
297 if (dependencies.contains(&op))
298 opsWithPipelineStages.emplace_back(&op, 0);
312 isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) {
317 auto asyncCopyOp = dyn_cast<DeviceAsyncCopyOp>(op);
326 Location loc = asyncCopyOp->getLoc();
327 Value dstElements = arith::ConstantOp::create(
328 rewriter, loc, asyncCopyOp.getDstElementsAttr());
329 Value originalSrcElement =
330 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
332 auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
333 originalSrcElement, c0Index);
334 auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create(
336 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
337 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
339 rewriter.
replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
340 return asyncCopyZeroFillOp;
348 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
350 bool epiloguePeeling) {
353 return std::make_tuple(
357 if (stage0Ops.empty()) {
358 return std::make_tuple(
363 unsigned maxDepth = depth;
366 unsigned iteration) {
370 [&](scf::ForOp schedulingFor,
371 std::vector<std::pair<Operation *, unsigned>> &ops) {
372 if (schedulingFor != forOp)
376 options.annotateFn = setAnnotation;
377 if (!epiloguePeeling) {
385 FailureOr<scf::ForOp> maybePipelined =
387 if (succeeded(maybePipelined)) {
391 return std::make_tuple(
402 rewriter, forOp,
static_cast<int64_t
>(getDepth()), getPeelEpilogue());
403 if (
diag.succeeded()) {
407 if (
diag.isDefiniteFailure()) {
409 if (!getPeelEpilogue()) {
410 diag.attachNote(forOp->getLoc()) <<
"couldn't predicate?";
411 diag.attachNote(getLoc()) <<
"try setting " << getPeelEpilogueAttrName();
416 return std::move(
diag);
432 void print(llvm::raw_ostream &os)
const {
433 os <<
"- indexing: " << first <<
", " << second;
442 : b(b), loc(loc), laneId(laneId) {}
445 std::function<SmallVector<RowColIndexing>(
MLIRContext *)>;
449 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
453 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
603 const IndexCalculator &indexFn);
613 IndexCalculator indexFn,
622 const IndexCalculator &indexFn);
645 template <
typename ApplyFn,
typename ReduceFn>
648 VectorType vectorType = cast<VectorType>(vector.
getType());
651 for (int64_t idx = 0, e =
vectorShape[0] * strides[0]; idx < e; ++idx) {
653 reduceFn(applyFn(vector, idx, indices), idx, indices);
660 const IndexCalculator &indexFn) {
666 for (
auto indexing : indexings) {
669 auto load = memref::LoadOp::create(b, loc, memref,
ValueRange{row, col});
675 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
678 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
682 Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
687 return loads[linearIdx];
691 res = vector::InsertOp::create(b, loc, v, res, indices);
699 Value memref,
const IndexCalculator &indexFn) {
704 for (
auto [indexing, val] :
705 llvm::zip_equal(indexFn(b.
getContext()), toStore)) {
709 memref::StoreOp::create(b, loc, val, memref,
ValueRange{row, col});
710 res.push_back(store);
724 return vector::ExtractOp::create(b, loc, vectorToStore, indices);
728 toStore.push_back(v);
730 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
740 return std::make_tuple(vlhs, vrhs, vres);
743 FailureOr<MmaSyncBuilder::MmaSyncInfo>
750 elementalTypes ==
TypeRange{f32, f32, f32}) {
751 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
752 &MmaSyncBuilder::m16n8k4tf32Rhs,
753 &MmaSyncBuilder::m16n8k4tf32Res),
762 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
763 &MmaSyncBuilder::m16n8k16f16Rhs,
764 &MmaSyncBuilder::m16n8k16f16Res),
773 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
774 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
775 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
776 assert(cast<MemRefType>(lhsMemRef.
getType()).getRank() == 2 &&
777 "expected lhs to be a 2D memref");
778 assert(cast<MemRefType>(rhsMemRef.
getType()).getRank() == 2 &&
779 "expected rhs to be a 2D memref");
780 assert(cast<MemRefType>(resMemRef.
getType()).getRank() == 2 &&
781 "expected res to be a 2D memref");
783 int64_t m = cast<MemRefType>(lhsMemRef.
getType()).getShape()[0];
784 int64_t n = cast<MemRefType>(rhsMemRef.
getType()).getShape()[1];
785 int64_t k = cast<MemRefType>(lhsMemRef.
getType()).getShape()[1];
790 FailureOr<MmaSyncInfo> maybeInfo =
791 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
795 MmaSyncInfo info = *maybeInfo;
796 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
797 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
798 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
799 lhsIndexFn, lhsShape);
800 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
801 rhsIndexFn, rhsShape);
802 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
803 resIndexFn, resShape);
805 MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled);
806 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
816 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
819 if (linalgOp.hasUserDefinedMaps()) {
820 return emitSilenceableError()
821 <<
"only matmul ops with non-extended semantics are supported";
825 Value laneId = gpu::ThreadIdOp::create(
826 rewriter, loc, rewriter.
getIndexType(), gpu::Dimension::x);
827 if (succeeded(
MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
833 <<
"unsupported target op: " << linalgOp;
834 diag.attachNote(linalgOp->getLoc()) <<
"target op";
850 : rewriter(rewriter), loc(loc) {}
853 buildAndInitBarrierInSharedMemory(
OpFoldResult numThreads);
859 gpu::LaunchOp launchOp);
890 Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
891 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
894 scf::IfOp::create(rewriter,
900 sizes.reserve(globalDescriptors.size());
901 for (
auto [desc, shmem] : llvm::zip_equal(
902 globalDescriptors, sharedMemBuffers)) {
903 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
908 buildBarrierArriveTx(barrier, sizes);
909 scf::YieldOp::create(rewriter, loc);
916 scf::YieldOp::create(rewriter, loc);
924 b.
getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
931 Value barrier = MBarrierCreateOp::create(
935 nvgpu::MBarrierInitOp::create(
936 rewriter, loc, barrier,
939 gpu::BarrierOp::create(rewriter, loc);
940 return cast<TypedValue<MBarrierGroupType>>(barrier);
945 gpu::LaunchOp launchOp) {
948 Value unrankedMemRef = memref::CastOp::create(
951 memref.getType().getMemorySpace()),
959 Value desc = TmaCreateDescriptorOp::create(
964 TensorMapSwizzleKind::SWIZZLE_NONE,
965 TensorMapL2PromoKind::L2PROMO_NONE,
966 TensorMapOOBKind::OOB_ZERO,
967 TensorMapInterleaveKind::INTERLEAVE_NONE),
968 unrankedMemRef, sizes);
969 return cast<TypedValue<TensorMapDescriptorType>>(desc);
976 SmallVectorImpl<Operation *> &loadOps) {
980 TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc,
982 loadOps.push_back(loadOp);
988 (sharedMemref.getType().getElementTypeBitWidth() / 8);
990 prodExprInBytes, mixedSizes);
995 ArrayRef<OpFoldResult> mixedSizes) {
996 assert(!mixedSizes.empty() &&
"expecte non-empty sizes");
1005 nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,
1011 Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
1015 Value ticksBeforeRetry =
1018 nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,
1019 ticksBeforeRetry, zero);
1036 if (copyOps.empty())
1039 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1040 assert(launchOp &&
"expected launch op");
1049 rewriter, loc, prod,
1050 ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1051 launchOp.getBlockSizeZ()});
1054 buildAndInitBarrierInSharedMemory(numThreads);
1059 auto copyOp = cast<linalg::CopyOp>(op);
1061 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1062 assert(inMemRef.getType().getRank() == 2 &&
1063 "expected in to be a 2D memref");
1067 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1068 globalDescs.push_back(globalDesc);
1072 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1073 shmems.push_back(shmem);
1080 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1083 buildTryWaitParity(barrier);
1095 auto payloadOps = state.getPayloadOps(getTarget());
1096 gpu::LaunchOp commonLaunchOp;
1098 if (llvm::any_of(payloadOps, [&](
Operation *op) {
1099 if (!commonLaunchOp) {
1105 !isa<linalg::CopyOp>(op);
1111 emitSilenceableError()
1112 <<
"target ops must be linalg::CopyOp nested under a common "
1113 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1114 "be created on the host.\nBut got: "
1115 << *firstOp <<
"\nand " << *failingOp;
1130 class NVGPUTransformDialectExtension
1135 NVGPUTransformDialectExtension() {
1136 declareGeneratedDialect<arith::ArithDialect>();
1137 declareGeneratedDialect<affine::AffineDialect>();
1138 declareGeneratedDialect<NVGPUDialect>();
1139 declareGeneratedDialect<NVVM::NVVMDialect>();
1140 declareGeneratedDialect<vector::VectorDialect>();
1141 registerTransformOps<
1143 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1149 #define GET_OP_CLASSES
1150 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
void buildBarrierArriveTx(TypedValue< MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
OpFoldResult buildTmaAsyncLoad(TypedValue< TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
void buildTryWaitParity(TypedValue< MBarrierGroupType > barrier)
TypedValue< MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
Options to dictate how loops should be pipelined.