32 #include "llvm/ADT/ArrayRef.h"
40 #define DEBUG_TYPE "nvgpu-transforms"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
42 #define DBGSNL() (llvm::dbgs() << "\n")
43 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
49 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
56 llvmTypeConverter, [](gpu::AddressSpace space) ->
unsigned {
58 case gpu::AddressSpace::Global:
59 return static_cast<unsigned>(
61 case gpu::AddressSpace::Workgroup:
62 return static_cast<unsigned>(
64 case gpu::AddressSpace::Private:
67 llvm_unreachable(
"unknown address space enum value");
70 llvmTypeConverter.addConversion(
71 [&](nvgpu::DeviceAsyncTokenType type) ->
Type {
72 return llvmTypeConverter.convertType(
75 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
76 return llvmTypeConverter.convertType(
79 llvmTypeConverter.addConversion(
80 [&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
81 Type elemType = type.getFragmented().getElementType();
82 int64_t sizeM = type.getFragmented().getDimSize(0);
83 int64_t sizeN = type.getFragmented().getDimSize(1);
87 numMembers = sizeN / 2;
88 else if (elemType.
isF16())
89 numMembers = sizeN / 4;
91 llvm_unreachable(
"unsupported type for warpgroup accumulator");
94 for (
unsigned i = 0; i < numMembers; i++)
95 innerStructBody.push_back(elemType);
97 type.getContext(), innerStructBody);
101 structBody.push_back(innerStructType);
105 return llvmTypeConverter.convertType(convertedType);
107 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
108 return llvmTypeConverter.convertType(
111 llvmTypeConverter.addConversion(
112 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
113 return llvmTypeConverter.convertType(
116 llvmTypeConverter.addConversion(
117 [&](nvgpu::TensorMapDescriptorType type) ->
Type {
124 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
125 transform::TypeConverterBuilderOpInterface builder) {
126 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
127 return emitOpError(
"expected LLVMTypeConverter");
135 void transform::CreateAsyncGroupsOp::getEffects(
162 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.
getMemorySpace());
164 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
171 auto load = dyn_cast<vector::TransferReadOp>(op);
175 auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
184 auto store = dyn_cast<vector::TransferWriteOp>(op);
185 if (!store || store.getVector() != v)
188 auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
221 if (isa<gpu::BarrierOp>(op)) {
222 barriers.insert(&op);
226 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
228 ops.insert(std::make_move_iterator(barriers.begin()),
229 std::make_move_iterator(barriers.end()));
230 assert(barriers.empty() &&
231 "expected to have moved the barriers into another set");
251 unsigned iteration,
unsigned depth) {
254 auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
255 if (!waitOp || waitOp.getNumGroups())
258 int numGroupInFlight = 0;
261 numGroupInFlight = depth - 1;
268 numGroupInFlight = depth - 1 - iteration;
270 waitOp.setNumGroups(numGroupInFlight);
285 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
289 return visited->
getBlock() == forOp.getBody();
292 for (
Operation &op : forOp.getBody()->getOperations()) {
293 if (stage0Ops.contains(&op))
297 for (
Operation &op : forOp.getBody()->getOperations()) {
298 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
299 opsWithPipelineStages.emplace_back(&op, depth);
301 for (
Operation &op : forOp.getBody()->getOperations()) {
302 if (dependencies.contains(&op))
303 opsWithPipelineStages.emplace_back(&op, 0);
317 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
318 nvgpu::DeviceAsyncWaitOp>(op)) {
323 auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
332 Location loc = asyncCopyOp->getLoc();
334 rewriter.
create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
335 Value originalSrcElement =
336 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
337 Value c0Index = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
338 auto srcElements = rewriter.
create<arith::SelectOp>(
339 loc, predicate, originalSrcElement, c0Index);
340 auto asyncCopyZeroFillOp = rewriter.
create<nvgpu::DeviceAsyncCopyOp>(
342 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
343 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
345 rewriter.
replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
346 return asyncCopyZeroFillOp;
354 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
356 bool epiloguePeeling) {
359 return std::make_tuple(
363 if (stage0Ops.empty()) {
364 return std::make_tuple(
369 unsigned maxDepth = depth;
372 unsigned iteration) {
376 [&](scf::ForOp schedulingFor,
377 std::vector<std::pair<Operation *, unsigned>> &ops) {
378 if (schedulingFor != forOp)
382 options.annotateFn = setAnnotation;
383 if (!epiloguePeeling) {
391 FailureOr<scf::ForOp> maybePipelined =
393 if (succeeded(maybePipelined)) {
397 return std::make_tuple(
408 rewriter, forOp,
static_cast<int64_t
>(getDepth()), getPeelEpilogue());
409 if (
diag.succeeded()) {
413 if (
diag.isDefiniteFailure()) {
415 if (!getPeelEpilogue()) {
416 diag.attachNote(forOp->getLoc()) <<
"couldn't predicate?";
417 diag.attachNote(getLoc()) <<
"try setting " << getPeelEpilogueAttrName();
422 return std::move(
diag);
438 void print(llvm::raw_ostream &os)
const {
439 os <<
"- indexing: " << first <<
", " << second;
448 : b(b), loc(loc), laneId(laneId) {}
451 std::function<SmallVector<RowColIndexing>(
MLIRContext *)>;
455 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
459 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
609 const IndexCalculator &indexFn);
619 IndexCalculator indexFn,
628 const IndexCalculator &indexFn);
651 template <
typename ApplyFn,
typename ReduceFn>
654 VectorType vectorType = cast<VectorType>(vector.
getType());
657 for (int64_t idx = 0, e =
vectorShape[0] * strides[0]; idx < e; ++idx) {
659 reduceFn(applyFn(vector, idx, indices), idx, indices);
666 const IndexCalculator &indexFn) {
672 for (
auto indexing : indexings) {
681 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
684 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
688 Value res = b.
create<vector::SplatOp>(loc, vt, loads[0]);
693 return loads[linearIdx];
697 res = b.
create<vector::InsertOp>(loc, v, res, indices);
705 Value memref,
const IndexCalculator &indexFn) {
710 for (
auto [indexing, val] :
711 llvm::zip_equal(indexFn(b.
getContext()), toStore)) {
716 res.push_back(store);
730 return b.
create<vector::ExtractOp>(loc, vectorToStore, indices);
734 toStore.push_back(v);
736 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
746 return std::make_tuple(vlhs, vrhs, vres);
749 FailureOr<MmaSyncBuilder::MmaSyncInfo>
756 elementalTypes ==
TypeRange{f32, f32, f32}) {
757 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
758 &MmaSyncBuilder::m16n8k4tf32Rhs,
759 &MmaSyncBuilder::m16n8k4tf32Res),
768 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
769 &MmaSyncBuilder::m16n8k16f16Rhs,
770 &MmaSyncBuilder::m16n8k16f16Res),
779 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
780 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
781 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
782 assert(cast<MemRefType>(lhsMemRef.
getType()).getRank() == 2 &&
783 "expected lhs to be a 2D memref");
784 assert(cast<MemRefType>(rhsMemRef.
getType()).getRank() == 2 &&
785 "expected rhs to be a 2D memref");
786 assert(cast<MemRefType>(resMemRef.
getType()).getRank() == 2 &&
787 "expected res to be a 2D memref");
789 int64_t m = cast<MemRefType>(lhsMemRef.
getType()).getShape()[0];
790 int64_t n = cast<MemRefType>(rhsMemRef.
getType()).getShape()[1];
791 int64_t k = cast<MemRefType>(lhsMemRef.
getType()).getShape()[1];
796 FailureOr<MmaSyncInfo> maybeInfo =
797 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
798 if (failed(maybeInfo))
801 MmaSyncInfo info = *maybeInfo;
802 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
803 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
804 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
805 lhsIndexFn, lhsShape);
806 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
807 rhsIndexFn, rhsShape);
808 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
809 resIndexFn, resShape);
810 res = b.
create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
812 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
823 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
828 if (succeeded(
MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
834 <<
"unsupported target op: " << linalgOp;
835 diag.attachNote(linalgOp->getLoc()) <<
"target op";
851 : rewriter(rewriter), loc(loc) {}
854 buildAndInitBarrierInSharedMemory(
OpFoldResult numThreads);
860 gpu::LaunchOp launchOp);
891 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
892 Value tidx = rewriter.
create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
894 rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
896 rewriter.
create<scf::IfOp>(
902 sizes.reserve(globalDescriptors.size());
903 for (
auto [desc, shmem] : llvm::zip_equal(
904 globalDescriptors, sharedMemBuffers)) {
905 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
910 buildBarrierArriveTx(barrier, sizes);
911 rewriter.
create<scf::YieldOp>(loc);
918 rewriter.
create<scf::YieldOp>(loc);
926 b.
getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
933 Value barrier = rewriter.
create<nvgpu::MBarrierCreateOp>(
936 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
937 rewriter.
create<nvgpu::MBarrierInitOp>(
940 rewriter.
create<gpu::BarrierOp>(loc);
941 return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
946 gpu::LaunchOp launchOp) {
949 Value unrankedMemRef = rewriter.
create<memref::CastOp>(
952 memref.getType().getMemorySpace()),
960 Value desc = rewriter.
create<nvgpu::TmaCreateDescriptorOp>(
966 TensorMapSwizzleKind::SWIZZLE_NONE,
967 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
968 TensorMapInterleaveKind::INTERLEAVE_NONE),
969 unrankedMemRef, sizes);
970 return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
977 SmallVectorImpl<Operation *> &loadOps) {
979 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
981 loc, sharedMemref, barrier, globalDesc,
ValueRange{zero, zero}, zero,
983 loadOps.push_back(loadOp);
989 (sharedMemref.getType().getElementTypeBitWidth() / 8);
991 prodExprInBytes, mixedSizes);
997 ArrayRef<OpFoldResult> mixedSizes) {
998 assert(!mixedSizes.empty() &&
"expecte non-empty sizes");
1006 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1007 rewriter.
create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
1014 Value parity = rewriter.
create<LLVM::ConstantOp>(loc, i1, 0);
1018 Value ticksBeforeRetry =
1019 rewriter.
create<arith::ConstantIndexOp>(loc, 10000000);
1020 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1021 rewriter.
create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1022 ticksBeforeRetry, zero);
1039 if (copyOps.empty())
1042 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1043 assert(launchOp &&
"expected launch op");
1052 rewriter, loc, prod,
1053 ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1054 launchOp.getBlockSizeZ()});
1057 buildAndInitBarrierInSharedMemory(numThreads);
1062 auto copyOp = cast<linalg::CopyOp>(op);
1064 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1065 assert(inMemRef.getType().getRank() == 2 &&
1066 "expected in to be a 2D memref");
1070 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1071 globalDescs.push_back(globalDesc);
1075 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1076 shmems.push_back(shmem);
1083 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1086 buildTryWaitParity(barrier);
1099 auto payloadOps = state.getPayloadOps(getTarget());
1100 gpu::LaunchOp commonLaunchOp;
1102 if (llvm::any_of(payloadOps, [&](
Operation *op) {
1103 if (!commonLaunchOp) {
1109 !isa<linalg::CopyOp>(op);
1115 emitSilenceableError()
1116 <<
"target ops must be linalg::CopyOp nested under a common "
1117 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1118 "be created on the host.\nBut got: "
1119 << *firstOp <<
"\nand " << *failingOp;
1134 class NVGPUTransformDialectExtension
1136 NVGPUTransformDialectExtension> {
1140 NVGPUTransformDialectExtension() {
1141 declareGeneratedDialect<arith::ArithDialect>();
1142 declareGeneratedDialect<affine::AffineDialect>();
1143 declareGeneratedDialect<nvgpu::NVGPUDialect>();
1144 declareGeneratedDialect<NVVM::NVVMDialect>();
1145 declareGeneratedDialect<vector::VectorDialect>();
1146 registerTransformOps<
1148 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1154 #define GET_OP_CLASSES
1155 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static llvm::ManagedStatic< PassManagerOptions > options
static VectorShape vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
@ kGlobalMemorySpace
Global memory space identifier.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
Options to dictate how loops should be pipelined.