32 #include "llvm/ADT/ArrayRef.h"
40 #define DEBUG_TYPE "nvgpu-transforms"
46 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
53 llvmTypeConverter, [](gpu::AddressSpace space) ->
unsigned {
55 case gpu::AddressSpace::Global:
56 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
57 case gpu::AddressSpace::Workgroup:
58 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
59 case gpu::AddressSpace::Private:
62 llvm_unreachable(
"unknown address space enum value");
63 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
65 llvmTypeConverter.addConversion(
66 [&](nvgpu::DeviceAsyncTokenType type) ->
Type {
67 return llvmTypeConverter.convertType(
70 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
71 return llvmTypeConverter.convertType(
74 llvmTypeConverter.addConversion(
75 [&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
76 Type elemType = type.getFragmented().getElementType();
77 int64_t sizeM = type.getFragmented().getDimSize(0);
78 int64_t sizeN = type.getFragmented().getDimSize(1);
82 numMembers = sizeN / 2;
83 else if (elemType.
isF16())
84 numMembers = sizeN / 4;
86 llvm_unreachable(
"unsupported type for warpgroup accumulator");
89 for (
unsigned i = 0; i < numMembers; i++)
90 innerStructBody.push_back(elemType);
91 auto innerStructType = LLVM::LLVMStructType::getLiteral(
92 type.getContext(), innerStructBody);
96 structBody.push_back(innerStructType);
99 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
100 return llvmTypeConverter.convertType(convertedType);
102 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
103 return llvmTypeConverter.convertType(
106 llvmTypeConverter.addConversion(
107 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
108 return llvmTypeConverter.convertType(
111 llvmTypeConverter.addConversion(
112 [&](nvgpu::TensorMapDescriptorType type) ->
Type {
119 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
120 transform::TypeConverterBuilderOpInterface builder) {
121 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
122 return emitOpError(
"expected LLVMTypeConverter");
130 void transform::CreateAsyncGroupsOp::getEffects(
157 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.
getMemorySpace());
159 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
166 auto load = dyn_cast<vector::TransferReadOp>(op);
170 auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
179 auto store = dyn_cast<vector::TransferWriteOp>(op);
180 if (!store || store.getVector() != v)
183 auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
213 if (op.getNumRegions() > 0)
216 if (isa<gpu::BarrierOp>(op)) {
217 barriers.insert(&op);
221 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
223 ops.insert(std::make_move_iterator(barriers.begin()),
224 std::make_move_iterator(barriers.end()));
225 assert(barriers.empty() &&
226 "expected to have moved the barriers into another set");
246 unsigned iteration,
unsigned depth) {
249 auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
250 if (!waitOp || waitOp.getNumGroups())
253 int numGroupInFlight = 0;
256 numGroupInFlight = depth - 1;
263 numGroupInFlight = depth - 1 - iteration;
265 waitOp.setNumGroups(numGroupInFlight);
280 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
284 return visited->
getBlock() == forOp.getBody();
287 for (
Operation &op : forOp.getBody()->getOperations()) {
288 if (stage0Ops.contains(&op)) {
290 assert(result.succeeded() &&
"expected a backward slice");
295 for (
Operation &op : forOp.getBody()->getOperations()) {
296 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
297 opsWithPipelineStages.emplace_back(&op, depth);
299 for (
Operation &op : forOp.getBody()->getOperations()) {
300 if (dependencies.contains(&op))
301 opsWithPipelineStages.emplace_back(&op, 0);
315 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
316 nvgpu::DeviceAsyncWaitOp>(op)) {
321 auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
330 Location loc = asyncCopyOp->getLoc();
331 Value dstElements = arith::ConstantOp::create(
332 rewriter, loc, asyncCopyOp.getDstElementsAttr());
333 Value originalSrcElement =
334 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
336 auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
337 originalSrcElement, c0Index);
338 auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create(
340 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
341 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
343 rewriter.
replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
344 return asyncCopyZeroFillOp;
352 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
354 bool epiloguePeeling) {
357 return std::make_tuple(
361 if (stage0Ops.empty()) {
362 return std::make_tuple(
367 unsigned maxDepth = depth;
370 unsigned iteration) {
374 [&](scf::ForOp schedulingFor,
375 std::vector<std::pair<Operation *, unsigned>> &ops) {
376 if (schedulingFor != forOp)
380 options.annotateFn = setAnnotation;
381 if (!epiloguePeeling) {
389 FailureOr<scf::ForOp> maybePipelined =
391 if (succeeded(maybePipelined)) {
395 return std::make_tuple(
406 rewriter, forOp,
static_cast<int64_t
>(getDepth()), getPeelEpilogue());
407 if (
diag.succeeded()) {
411 if (
diag.isDefiniteFailure()) {
413 if (!getPeelEpilogue()) {
414 diag.attachNote(forOp->getLoc()) <<
"couldn't predicate?";
415 diag.attachNote(getLoc()) <<
"try setting " << getPeelEpilogueAttrName();
420 return std::move(
diag);
436 void print(llvm::raw_ostream &os)
const {
437 os <<
"- indexing: " << first <<
", " << second;
446 : b(b), loc(loc), laneId(laneId) {}
449 std::function<SmallVector<RowColIndexing>(
MLIRContext *)>;
453 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
457 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
607 const IndexCalculator &indexFn);
617 IndexCalculator indexFn,
626 const IndexCalculator &indexFn);
649 template <
typename ApplyFn,
typename ReduceFn>
652 VectorType vectorType = cast<VectorType>(vector.
getType());
655 for (int64_t idx = 0, e =
vectorShape[0] * strides[0]; idx < e; ++idx) {
657 reduceFn(applyFn(vector, idx, indices), idx, indices);
664 const IndexCalculator &indexFn) {
670 for (
auto indexing : indexings) {
673 auto load = memref::LoadOp::create(b, loc, memref,
ValueRange{row, col});
679 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
682 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
686 Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
691 return loads[linearIdx];
695 res = vector::InsertOp::create(b, loc, v, res, indices);
703 Value memref,
const IndexCalculator &indexFn) {
708 for (
auto [indexing, val] :
709 llvm::zip_equal(indexFn(b.
getContext()), toStore)) {
713 memref::StoreOp::create(b, loc, val, memref,
ValueRange{row, col});
714 res.push_back(store);
728 return vector::ExtractOp::create(b, loc, vectorToStore, indices);
732 toStore.push_back(v);
734 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
744 return std::make_tuple(vlhs, vrhs, vres);
747 FailureOr<MmaSyncBuilder::MmaSyncInfo>
754 elementalTypes ==
TypeRange{f32, f32, f32}) {
755 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
756 &MmaSyncBuilder::m16n8k4tf32Rhs,
757 &MmaSyncBuilder::m16n8k4tf32Res),
766 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
767 &MmaSyncBuilder::m16n8k16f16Rhs,
768 &MmaSyncBuilder::m16n8k16f16Res),
777 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
778 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
779 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
780 assert(cast<MemRefType>(lhsMemRef.
getType()).getRank() == 2 &&
781 "expected lhs to be a 2D memref");
782 assert(cast<MemRefType>(rhsMemRef.
getType()).getRank() == 2 &&
783 "expected rhs to be a 2D memref");
784 assert(cast<MemRefType>(resMemRef.
getType()).getRank() == 2 &&
785 "expected res to be a 2D memref");
787 int64_t m = cast<MemRefType>(lhsMemRef.
getType()).getShape()[0];
788 int64_t n = cast<MemRefType>(rhsMemRef.
getType()).getShape()[1];
789 int64_t k = cast<MemRefType>(lhsMemRef.
getType()).getShape()[1];
794 FailureOr<MmaSyncInfo> maybeInfo =
795 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
799 MmaSyncInfo info = *maybeInfo;
800 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
801 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
802 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
803 lhsIndexFn, lhsShape);
804 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
805 rhsIndexFn, rhsShape);
806 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
807 resIndexFn, resShape);
808 res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape,
810 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
821 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
824 if (linalgOp.hasUserDefinedMaps()) {
825 return emitSilenceableError()
826 <<
"only matmul ops with non-extended semantics are supported";
830 Value laneId = gpu::ThreadIdOp::create(
831 rewriter, loc, rewriter.
getIndexType(), gpu::Dimension::x);
832 if (succeeded(
MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
838 <<
"unsupported target op: " << linalgOp;
839 diag.attachNote(linalgOp->getLoc()) <<
"target op";
855 : rewriter(rewriter), loc(loc) {}
858 buildAndInitBarrierInSharedMemory(
OpFoldResult numThreads);
864 gpu::LaunchOp launchOp);
896 Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
897 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
900 scf::IfOp::create(rewriter,
906 sizes.reserve(globalDescriptors.size());
907 for (
auto [desc, shmem] : llvm::zip_equal(
908 globalDescriptors, sharedMemBuffers)) {
909 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
914 buildBarrierArriveTx(barrier, sizes);
915 scf::YieldOp::create(rewriter, loc);
922 scf::YieldOp::create(rewriter, loc);
930 b.
getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
937 Value barrier = nvgpu::MBarrierCreateOp::create(
941 nvgpu::MBarrierInitOp::create(
942 rewriter, loc, barrier,
945 gpu::BarrierOp::create(rewriter, loc);
946 return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
951 gpu::LaunchOp launchOp) {
954 Value unrankedMemRef = memref::CastOp::create(
957 memref.getType().getMemorySpace()),
965 Value desc = nvgpu::TmaCreateDescriptorOp::create(
971 TensorMapSwizzleKind::SWIZZLE_NONE,
972 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
973 TensorMapInterleaveKind::INTERLEAVE_NONE),
974 unrankedMemRef, sizes);
975 return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
982 SmallVectorImpl<Operation *> &loadOps) {
985 Operation *loadOp = nvgpu::TmaAsyncLoadOp::create(
986 rewriter, loc, sharedMemref, barrier, globalDesc,
ValueRange{zero, zero},
988 loadOps.push_back(loadOp);
994 (sharedMemref.getType().getElementTypeBitWidth() / 8);
996 prodExprInBytes, mixedSizes);
1002 ArrayRef<OpFoldResult> mixedSizes) {
1003 assert(!mixedSizes.empty() &&
"expecte non-empty sizes");
1012 nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,
1019 Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
1023 Value ticksBeforeRetry =
1026 nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,
1027 ticksBeforeRetry, zero);
1044 if (copyOps.empty())
1047 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1048 assert(launchOp &&
"expected launch op");
1057 rewriter, loc, prod,
1058 ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1059 launchOp.getBlockSizeZ()});
1062 buildAndInitBarrierInSharedMemory(numThreads);
1067 auto copyOp = cast<linalg::CopyOp>(op);
1069 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1070 assert(inMemRef.getType().getRank() == 2 &&
1071 "expected in to be a 2D memref");
1075 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1076 globalDescs.push_back(globalDesc);
1080 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1081 shmems.push_back(shmem);
1088 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1091 buildTryWaitParity(barrier);
1104 auto payloadOps = state.getPayloadOps(getTarget());
1105 gpu::LaunchOp commonLaunchOp;
1107 if (llvm::any_of(payloadOps, [&](
Operation *op) {
1108 if (!commonLaunchOp) {
1114 !isa<linalg::CopyOp>(op);
1120 emitSilenceableError()
1121 <<
"target ops must be linalg::CopyOp nested under a common "
1122 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1123 "be created on the host.\nBut got: "
1124 << *firstOp <<
"\nand " << *failingOp;
1139 class NVGPUTransformDialectExtension
1141 NVGPUTransformDialectExtension> {
1145 NVGPUTransformDialectExtension() {
1146 declareGeneratedDialect<arith::ArithDialect>();
1147 declareGeneratedDialect<affine::AffineDialect>();
1148 declareGeneratedDialect<nvgpu::NVGPUDialect>();
1149 declareGeneratedDialect<NVVM::NVVMDialect>();
1150 declareGeneratedDialect<vector::VectorDialect>();
1151 registerTransformOps<
1153 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1159 #define GET_OP_CLASSES
1160 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
Options to dictate how loops should be pipelined.