32 #include "llvm/ADT/ArrayRef.h"
40 #define DEBUG_TYPE "nvgpu-transforms"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
42 #define DBGSNL() (llvm::dbgs() << "\n")
43 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
49 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
56 llvmTypeConverter, [](gpu::AddressSpace space) ->
unsigned {
58 case gpu::AddressSpace::Global:
59 return static_cast<unsigned>(
61 case gpu::AddressSpace::Workgroup:
62 return static_cast<unsigned>(
64 case gpu::AddressSpace::Private:
67 llvm_unreachable(
"unknown address space enum value");
70 llvmTypeConverter.addConversion(
71 [&](nvgpu::DeviceAsyncTokenType type) ->
Type {
72 return llvmTypeConverter.convertType(
75 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
76 return llvmTypeConverter.convertType(
79 llvmTypeConverter.addConversion(
80 [&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
81 Type elemType = type.getFragmented().getElementType();
82 int64_t sizeM = type.getFragmented().getDimSize(0);
83 int64_t sizeN = type.getFragmented().getDimSize(1);
87 numMembers = sizeN / 2;
88 else if (elemType.
isF16())
89 numMembers = sizeN / 4;
91 llvm_unreachable(
"unsupported type for warpgroup accumulator");
94 for (
unsigned i = 0; i < numMembers; i++)
95 innerStructBody.push_back(elemType);
96 auto innerStructType = LLVM::LLVMStructType::getLiteral(
97 type.getContext(), innerStructBody);
101 structBody.push_back(innerStructType);
104 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
105 return llvmTypeConverter.convertType(convertedType);
107 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
108 return llvmTypeConverter.convertType(
111 llvmTypeConverter.addConversion(
112 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
113 return llvmTypeConverter.convertType(
116 llvmTypeConverter.addConversion(
117 [&](nvgpu::TensorMapDescriptorType type) ->
Type {
124 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
125 transform::TypeConverterBuilderOpInterface builder) {
126 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
127 return emitOpError(
"expected LLVMTypeConverter");
135 void transform::CreateAsyncGroupsOp::getEffects(
162 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.
getMemorySpace());
164 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
171 auto load = dyn_cast<vector::TransferReadOp>(op);
175 auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
184 auto store = dyn_cast<vector::TransferWriteOp>(op);
185 if (!store || store.getVector() != v)
188 auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
218 if (op.getNumRegions() > 0)
221 if (isa<gpu::BarrierOp>(op)) {
222 barriers.insert(&op);
226 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
228 ops.insert(std::make_move_iterator(barriers.begin()),
229 std::make_move_iterator(barriers.end()));
230 assert(barriers.empty() &&
231 "expected to have moved the barriers into another set");
251 unsigned iteration,
unsigned depth) {
254 auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
255 if (!waitOp || waitOp.getNumGroups())
258 int numGroupInFlight = 0;
261 numGroupInFlight = depth - 1;
268 numGroupInFlight = depth - 1 - iteration;
270 waitOp.setNumGroups(numGroupInFlight);
285 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
289 return visited->
getBlock() == forOp.getBody();
292 for (
Operation &op : forOp.getBody()->getOperations()) {
293 if (stage0Ops.contains(&op)) {
295 assert(result.succeeded() &&
"expected a backward slice");
300 for (
Operation &op : forOp.getBody()->getOperations()) {
301 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
302 opsWithPipelineStages.emplace_back(&op, depth);
304 for (
Operation &op : forOp.getBody()->getOperations()) {
305 if (dependencies.contains(&op))
306 opsWithPipelineStages.emplace_back(&op, 0);
320 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
321 nvgpu::DeviceAsyncWaitOp>(op)) {
326 auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
335 Location loc = asyncCopyOp->getLoc();
337 rewriter.
create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
338 Value originalSrcElement =
339 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
340 Value c0Index = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
341 auto srcElements = rewriter.
create<arith::SelectOp>(
342 loc, predicate, originalSrcElement, c0Index);
343 auto asyncCopyZeroFillOp = rewriter.
create<nvgpu::DeviceAsyncCopyOp>(
345 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
346 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
348 rewriter.
replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
349 return asyncCopyZeroFillOp;
357 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
359 bool epiloguePeeling) {
362 return std::make_tuple(
366 if (stage0Ops.empty()) {
367 return std::make_tuple(
372 unsigned maxDepth = depth;
375 unsigned iteration) {
379 [&](scf::ForOp schedulingFor,
380 std::vector<std::pair<Operation *, unsigned>> &ops) {
381 if (schedulingFor != forOp)
385 options.annotateFn = setAnnotation;
386 if (!epiloguePeeling) {
394 FailureOr<scf::ForOp> maybePipelined =
396 if (succeeded(maybePipelined)) {
400 return std::make_tuple(
411 rewriter, forOp,
static_cast<int64_t
>(getDepth()), getPeelEpilogue());
412 if (
diag.succeeded()) {
416 if (
diag.isDefiniteFailure()) {
418 if (!getPeelEpilogue()) {
419 diag.attachNote(forOp->getLoc()) <<
"couldn't predicate?";
420 diag.attachNote(getLoc()) <<
"try setting " << getPeelEpilogueAttrName();
425 return std::move(
diag);
441 void print(llvm::raw_ostream &os)
const {
442 os <<
"- indexing: " << first <<
", " << second;
451 : b(b), loc(loc), laneId(laneId) {}
454 std::function<SmallVector<RowColIndexing>(
MLIRContext *)>;
458 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
462 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
612 const IndexCalculator &indexFn);
622 IndexCalculator indexFn,
631 const IndexCalculator &indexFn);
654 template <
typename ApplyFn,
typename ReduceFn>
657 VectorType vectorType = cast<VectorType>(vector.
getType());
660 for (int64_t idx = 0, e =
vectorShape[0] * strides[0]; idx < e; ++idx) {
662 reduceFn(applyFn(vector, idx, indices), idx, indices);
669 const IndexCalculator &indexFn) {
675 for (
auto indexing : indexings) {
684 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
687 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
691 Value res = b.
create<vector::SplatOp>(loc, vt, loads[0]);
696 return loads[linearIdx];
700 res = b.
create<vector::InsertOp>(loc, v, res, indices);
708 Value memref,
const IndexCalculator &indexFn) {
713 for (
auto [indexing, val] :
714 llvm::zip_equal(indexFn(b.
getContext()), toStore)) {
719 res.push_back(store);
733 return b.
create<vector::ExtractOp>(loc, vectorToStore, indices);
737 toStore.push_back(v);
739 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
749 return std::make_tuple(vlhs, vrhs, vres);
752 FailureOr<MmaSyncBuilder::MmaSyncInfo>
759 elementalTypes ==
TypeRange{f32, f32, f32}) {
760 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
761 &MmaSyncBuilder::m16n8k4tf32Rhs,
762 &MmaSyncBuilder::m16n8k4tf32Res),
771 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
772 &MmaSyncBuilder::m16n8k16f16Rhs,
773 &MmaSyncBuilder::m16n8k16f16Res),
782 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
783 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
784 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
785 assert(cast<MemRefType>(lhsMemRef.
getType()).getRank() == 2 &&
786 "expected lhs to be a 2D memref");
787 assert(cast<MemRefType>(rhsMemRef.
getType()).getRank() == 2 &&
788 "expected rhs to be a 2D memref");
789 assert(cast<MemRefType>(resMemRef.
getType()).getRank() == 2 &&
790 "expected res to be a 2D memref");
792 int64_t m = cast<MemRefType>(lhsMemRef.
getType()).getShape()[0];
793 int64_t n = cast<MemRefType>(rhsMemRef.
getType()).getShape()[1];
794 int64_t k = cast<MemRefType>(lhsMemRef.
getType()).getShape()[1];
799 FailureOr<MmaSyncInfo> maybeInfo =
800 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
801 if (failed(maybeInfo))
804 MmaSyncInfo info = *maybeInfo;
805 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
806 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
807 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
808 lhsIndexFn, lhsShape);
809 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
810 rhsIndexFn, rhsShape);
811 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
812 resIndexFn, resShape);
813 res = b.
create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
815 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
826 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
829 if (linalgOp.hasUserDefinedMaps()) {
830 return emitSilenceableError()
831 <<
"only matmul ops with non-extended semantics are supported";
837 if (succeeded(
MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
843 <<
"unsupported target op: " << linalgOp;
844 diag.attachNote(linalgOp->getLoc()) <<
"target op";
860 : rewriter(rewriter), loc(loc) {}
863 buildAndInitBarrierInSharedMemory(
OpFoldResult numThreads);
869 gpu::LaunchOp launchOp);
900 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
901 Value tidx = rewriter.
create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
903 rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
905 rewriter.
create<scf::IfOp>(
911 sizes.reserve(globalDescriptors.size());
912 for (
auto [desc, shmem] : llvm::zip_equal(
913 globalDescriptors, sharedMemBuffers)) {
914 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
919 buildBarrierArriveTx(barrier, sizes);
920 rewriter.
create<scf::YieldOp>(loc);
927 rewriter.
create<scf::YieldOp>(loc);
935 b.
getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
942 Value barrier = rewriter.
create<nvgpu::MBarrierCreateOp>(
945 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
946 rewriter.
create<nvgpu::MBarrierInitOp>(
949 rewriter.
create<gpu::BarrierOp>(loc);
950 return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
955 gpu::LaunchOp launchOp) {
958 Value unrankedMemRef = rewriter.
create<memref::CastOp>(
961 memref.getType().getMemorySpace()),
969 Value desc = rewriter.
create<nvgpu::TmaCreateDescriptorOp>(
975 TensorMapSwizzleKind::SWIZZLE_NONE,
976 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
977 TensorMapInterleaveKind::INTERLEAVE_NONE),
978 unrankedMemRef, sizes);
979 return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
986 SmallVectorImpl<Operation *> &loadOps) {
988 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
990 loc, sharedMemref, barrier, globalDesc,
ValueRange{zero, zero}, zero,
992 loadOps.push_back(loadOp);
998 (sharedMemref.getType().getElementTypeBitWidth() / 8);
1000 prodExprInBytes, mixedSizes);
1006 ArrayRef<OpFoldResult> mixedSizes) {
1007 assert(!mixedSizes.empty() &&
"expecte non-empty sizes");
1015 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1016 rewriter.
create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
1023 Value parity = rewriter.
create<LLVM::ConstantOp>(loc, i1, 0);
1027 Value ticksBeforeRetry =
1028 rewriter.
create<arith::ConstantIndexOp>(loc, 10000000);
1029 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1030 rewriter.
create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1031 ticksBeforeRetry, zero);
1048 if (copyOps.empty())
1051 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1052 assert(launchOp &&
"expected launch op");
1061 rewriter, loc, prod,
1062 ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1063 launchOp.getBlockSizeZ()});
1066 buildAndInitBarrierInSharedMemory(numThreads);
1071 auto copyOp = cast<linalg::CopyOp>(op);
1073 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1074 assert(inMemRef.getType().getRank() == 2 &&
1075 "expected in to be a 2D memref");
1079 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1080 globalDescs.push_back(globalDesc);
1084 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1085 shmems.push_back(shmem);
1092 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1095 buildTryWaitParity(barrier);
1108 auto payloadOps = state.getPayloadOps(getTarget());
1109 gpu::LaunchOp commonLaunchOp;
1111 if (llvm::any_of(payloadOps, [&](
Operation *op) {
1112 if (!commonLaunchOp) {
1118 !isa<linalg::CopyOp>(op);
1124 emitSilenceableError()
1125 <<
"target ops must be linalg::CopyOp nested under a common "
1126 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1127 "be created on the host.\nBut got: "
1128 << *firstOp <<
"\nand " << *failingOp;
1143 class NVGPUTransformDialectExtension
1145 NVGPUTransformDialectExtension> {
1149 NVGPUTransformDialectExtension() {
1150 declareGeneratedDialect<arith::ArithDialect>();
1151 declareGeneratedDialect<affine::AffineDialect>();
1152 declareGeneratedDialect<nvgpu::NVGPUDialect>();
1153 declareGeneratedDialect<NVVM::NVVMDialect>();
1154 declareGeneratedDialect<vector::VectorDialect>();
1155 registerTransformOps<
1157 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1163 #define GET_OP_CLASSES
1164 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
@ kGlobalMemorySpace
Global memory space identifier.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
Options to dictate how loops should be pipelined.