32 #include "llvm/ADT/ArrayRef.h"
40 #define DEBUG_TYPE "nvgpu-transforms"
41 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
42 #define DBGSNL() (llvm::dbgs() << "\n")
43 #define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
49 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
56 llvmTypeConverter, [](gpu::AddressSpace space) ->
unsigned {
58 case gpu::AddressSpace::Global:
59 return static_cast<unsigned>(
61 case gpu::AddressSpace::Workgroup:
62 return static_cast<unsigned>(
64 case gpu::AddressSpace::Private:
67 llvm_unreachable(
"unknown address space enum value");
70 llvmTypeConverter.addConversion(
71 [&](nvgpu::DeviceAsyncTokenType type) ->
Type {
72 return llvmTypeConverter.convertType(
75 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
76 return llvmTypeConverter.convertType(
79 llvmTypeConverter.addConversion(
80 [&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
81 Type elemType = type.getFragmented().getElementType();
82 int64_t sizeM = type.getFragmented().getDimSize(0);
83 int64_t sizeN = type.getFragmented().getDimSize(1);
87 numMembers = sizeN / 2;
88 else if (elemType.
isF16())
89 numMembers = sizeN / 4;
91 llvm_unreachable(
"unsupported type for warpgroup accumulator");
94 for (
unsigned i = 0; i < numMembers; i++)
95 innerStructBody.push_back(elemType);
96 auto innerStructType = LLVM::LLVMStructType::getLiteral(
97 type.getContext(), innerStructBody);
101 structBody.push_back(innerStructType);
104 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
105 return llvmTypeConverter.convertType(convertedType);
107 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
108 return llvmTypeConverter.convertType(
111 llvmTypeConverter.addConversion(
112 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
113 return llvmTypeConverter.convertType(
116 llvmTypeConverter.addConversion(
117 [&](nvgpu::TensorMapDescriptorType type) ->
Type {
124 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
125 transform::TypeConverterBuilderOpInterface builder) {
126 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
127 return emitOpError(
"expected LLVMTypeConverter");
135 void transform::CreateAsyncGroupsOp::getEffects(
162 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.
getMemorySpace());
164 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
171 auto load = dyn_cast<vector::TransferReadOp>(op);
175 auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
184 auto store = dyn_cast<vector::TransferWriteOp>(op);
185 if (!store || store.getVector() != v)
188 auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
218 if (op.getNumRegions() > 0)
221 if (isa<gpu::BarrierOp>(op)) {
222 barriers.insert(&op);
226 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
228 ops.insert(std::make_move_iterator(barriers.begin()),
229 std::make_move_iterator(barriers.end()));
230 assert(barriers.empty() &&
231 "expected to have moved the barriers into another set");
251 unsigned iteration,
unsigned depth) {
254 auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
255 if (!waitOp || waitOp.getNumGroups())
258 int numGroupInFlight = 0;
261 numGroupInFlight = depth - 1;
268 numGroupInFlight = depth - 1 - iteration;
270 waitOp.setNumGroups(numGroupInFlight);
285 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
289 return visited->
getBlock() == forOp.getBody();
292 for (
Operation &op : forOp.getBody()->getOperations()) {
293 if (stage0Ops.contains(&op))
297 for (
Operation &op : forOp.getBody()->getOperations()) {
298 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
299 opsWithPipelineStages.emplace_back(&op, depth);
301 for (
Operation &op : forOp.getBody()->getOperations()) {
302 if (dependencies.contains(&op))
303 opsWithPipelineStages.emplace_back(&op, 0);
317 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
318 nvgpu::DeviceAsyncWaitOp>(op)) {
323 auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
332 Location loc = asyncCopyOp->getLoc();
334 rewriter.
create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
335 Value originalSrcElement =
336 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
337 Value c0Index = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
338 auto srcElements = rewriter.
create<arith::SelectOp>(
339 loc, predicate, originalSrcElement, c0Index);
340 auto asyncCopyZeroFillOp = rewriter.
create<nvgpu::DeviceAsyncCopyOp>(
342 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
343 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
345 rewriter.
replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
346 return asyncCopyZeroFillOp;
354 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
356 bool epiloguePeeling) {
359 return std::make_tuple(
363 if (stage0Ops.empty()) {
364 return std::make_tuple(
369 unsigned maxDepth = depth;
372 unsigned iteration) {
376 [&](scf::ForOp schedulingFor,
377 std::vector<std::pair<Operation *, unsigned>> &ops) {
378 if (schedulingFor != forOp)
382 options.annotateFn = setAnnotation;
383 if (!epiloguePeeling) {
391 FailureOr<scf::ForOp> maybePipelined =
393 if (succeeded(maybePipelined)) {
397 return std::make_tuple(
408 rewriter, forOp,
static_cast<int64_t
>(getDepth()), getPeelEpilogue());
409 if (
diag.succeeded()) {
413 if (
diag.isDefiniteFailure()) {
415 if (!getPeelEpilogue()) {
416 diag.attachNote(forOp->getLoc()) <<
"couldn't predicate?";
417 diag.attachNote(getLoc()) <<
"try setting " << getPeelEpilogueAttrName();
422 return std::move(
diag);
438 void print(llvm::raw_ostream &os)
const {
439 os <<
"- indexing: " << first <<
", " << second;
448 : b(b), loc(loc), laneId(laneId) {}
451 std::function<SmallVector<RowColIndexing>(
MLIRContext *)>;
455 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
459 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
609 const IndexCalculator &indexFn);
619 IndexCalculator indexFn,
628 const IndexCalculator &indexFn);
651 template <
typename ApplyFn,
typename ReduceFn>
654 VectorType vectorType = cast<VectorType>(vector.
getType());
657 for (int64_t idx = 0, e =
vectorShape[0] * strides[0]; idx < e; ++idx) {
659 reduceFn(applyFn(vector, idx, indices), idx, indices);
666 const IndexCalculator &indexFn) {
672 for (
auto indexing : indexings) {
681 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
684 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
688 Value res = b.
create<vector::SplatOp>(loc, vt, loads[0]);
693 return loads[linearIdx];
697 res = b.
create<vector::InsertOp>(loc, v, res, indices);
705 Value memref,
const IndexCalculator &indexFn) {
710 for (
auto [indexing, val] :
711 llvm::zip_equal(indexFn(b.
getContext()), toStore)) {
716 res.push_back(store);
730 return b.
create<vector::ExtractOp>(loc, vectorToStore, indices);
734 toStore.push_back(v);
736 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
746 return std::make_tuple(vlhs, vrhs, vres);
749 FailureOr<MmaSyncBuilder::MmaSyncInfo>
756 elementalTypes ==
TypeRange{f32, f32, f32}) {
757 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
758 &MmaSyncBuilder::m16n8k4tf32Rhs,
759 &MmaSyncBuilder::m16n8k4tf32Res),
768 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
769 &MmaSyncBuilder::m16n8k16f16Rhs,
770 &MmaSyncBuilder::m16n8k16f16Res),
779 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
780 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
781 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
782 assert(cast<MemRefType>(lhsMemRef.
getType()).getRank() == 2 &&
783 "expected lhs to be a 2D memref");
784 assert(cast<MemRefType>(rhsMemRef.
getType()).getRank() == 2 &&
785 "expected rhs to be a 2D memref");
786 assert(cast<MemRefType>(resMemRef.
getType()).getRank() == 2 &&
787 "expected res to be a 2D memref");
789 int64_t m = cast<MemRefType>(lhsMemRef.
getType()).getShape()[0];
790 int64_t n = cast<MemRefType>(rhsMemRef.
getType()).getShape()[1];
791 int64_t k = cast<MemRefType>(lhsMemRef.
getType()).getShape()[1];
796 FailureOr<MmaSyncInfo> maybeInfo =
797 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
798 if (failed(maybeInfo))
801 MmaSyncInfo info = *maybeInfo;
802 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
803 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
804 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
805 lhsIndexFn, lhsShape);
806 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
807 rhsIndexFn, rhsShape);
808 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
809 resIndexFn, resShape);
810 res = b.
create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
812 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
823 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
826 if (linalgOp.hasUserDefinedMaps()) {
827 return emitSilenceableError()
828 <<
"only matmul ops with non-extended semantics are supported";
834 if (succeeded(
MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
840 <<
"unsupported target op: " << linalgOp;
841 diag.attachNote(linalgOp->getLoc()) <<
"target op";
857 : rewriter(rewriter), loc(loc) {}
860 buildAndInitBarrierInSharedMemory(
OpFoldResult numThreads);
866 gpu::LaunchOp launchOp);
897 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
898 Value tidx = rewriter.
create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
900 rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
902 rewriter.
create<scf::IfOp>(
908 sizes.reserve(globalDescriptors.size());
909 for (
auto [desc, shmem] : llvm::zip_equal(
910 globalDescriptors, sharedMemBuffers)) {
911 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
916 buildBarrierArriveTx(barrier, sizes);
917 rewriter.
create<scf::YieldOp>(loc);
924 rewriter.
create<scf::YieldOp>(loc);
932 b.
getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
939 Value barrier = rewriter.
create<nvgpu::MBarrierCreateOp>(
942 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
943 rewriter.
create<nvgpu::MBarrierInitOp>(
946 rewriter.
create<gpu::BarrierOp>(loc);
947 return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
952 gpu::LaunchOp launchOp) {
955 Value unrankedMemRef = rewriter.
create<memref::CastOp>(
958 memref.getType().getMemorySpace()),
966 Value desc = rewriter.
create<nvgpu::TmaCreateDescriptorOp>(
972 TensorMapSwizzleKind::SWIZZLE_NONE,
973 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
974 TensorMapInterleaveKind::INTERLEAVE_NONE),
975 unrankedMemRef, sizes);
976 return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
983 SmallVectorImpl<Operation *> &loadOps) {
985 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
987 loc, sharedMemref, barrier, globalDesc,
ValueRange{zero, zero}, zero,
989 loadOps.push_back(loadOp);
995 (sharedMemref.getType().getElementTypeBitWidth() / 8);
997 prodExprInBytes, mixedSizes);
1003 ArrayRef<OpFoldResult> mixedSizes) {
1004 assert(!mixedSizes.empty() &&
"expecte non-empty sizes");
1012 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1013 rewriter.
create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
1020 Value parity = rewriter.
create<LLVM::ConstantOp>(loc, i1, 0);
1024 Value ticksBeforeRetry =
1025 rewriter.
create<arith::ConstantIndexOp>(loc, 10000000);
1026 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1027 rewriter.
create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1028 ticksBeforeRetry, zero);
1045 if (copyOps.empty())
1048 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1049 assert(launchOp &&
"expected launch op");
1058 rewriter, loc, prod,
1059 ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1060 launchOp.getBlockSizeZ()});
1063 buildAndInitBarrierInSharedMemory(numThreads);
1068 auto copyOp = cast<linalg::CopyOp>(op);
1070 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1071 assert(inMemRef.getType().getRank() == 2 &&
1072 "expected in to be a 2D memref");
1076 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1077 globalDescs.push_back(globalDesc);
1081 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1082 shmems.push_back(shmem);
1089 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1092 buildTryWaitParity(barrier);
1105 auto payloadOps = state.getPayloadOps(getTarget());
1106 gpu::LaunchOp commonLaunchOp;
1108 if (llvm::any_of(payloadOps, [&](
Operation *op) {
1109 if (!commonLaunchOp) {
1115 !isa<linalg::CopyOp>(op);
1121 emitSilenceableError()
1122 <<
"target ops must be linalg::CopyOp nested under a common "
1123 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1124 "be created on the host.\nBut got: "
1125 << *firstOp <<
"\nand " << *failingOp;
1140 class NVGPUTransformDialectExtension
1142 NVGPUTransformDialectExtension> {
1146 NVGPUTransformDialectExtension() {
1147 declareGeneratedDialect<arith::ArithDialect>();
1148 declareGeneratedDialect<affine::AffineDialect>();
1149 declareGeneratedDialect<nvgpu::NVGPUDialect>();
1150 declareGeneratedDialect<NVVM::NVVMDialect>();
1151 declareGeneratedDialect<vector::VectorDialect>();
1152 registerTransformOps<
1154 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1160 #define GET_OP_CLASSES
1161 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
@ kGlobalMemorySpace
Global memory space identifier.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
Options to dictate how loops should be pipelined.