32 #include "llvm/ADT/ArrayRef.h"
40 #define DEBUG_TYPE "nvgpu-transforms"
46 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
53 llvmTypeConverter, [](gpu::AddressSpace space) ->
unsigned {
55 case gpu::AddressSpace::Global:
56 return static_cast<unsigned>(
58 case gpu::AddressSpace::Workgroup:
59 return static_cast<unsigned>(
61 case gpu::AddressSpace::Private:
64 llvm_unreachable(
"unknown address space enum value");
67 llvmTypeConverter.addConversion(
68 [&](nvgpu::DeviceAsyncTokenType type) ->
Type {
69 return llvmTypeConverter.convertType(
72 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
73 return llvmTypeConverter.convertType(
76 llvmTypeConverter.addConversion(
77 [&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
78 Type elemType = type.getFragmented().getElementType();
79 int64_t sizeM = type.getFragmented().getDimSize(0);
80 int64_t sizeN = type.getFragmented().getDimSize(1);
84 numMembers = sizeN / 2;
85 else if (elemType.
isF16())
86 numMembers = sizeN / 4;
88 llvm_unreachable(
"unsupported type for warpgroup accumulator");
91 for (
unsigned i = 0; i < numMembers; i++)
92 innerStructBody.push_back(elemType);
93 auto innerStructType = LLVM::LLVMStructType::getLiteral(
94 type.getContext(), innerStructBody);
98 structBody.push_back(innerStructType);
101 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
102 return llvmTypeConverter.convertType(convertedType);
104 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
105 return llvmTypeConverter.convertType(
108 llvmTypeConverter.addConversion(
109 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
110 return llvmTypeConverter.convertType(
113 llvmTypeConverter.addConversion(
114 [&](nvgpu::TensorMapDescriptorType type) ->
Type {
121 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
122 transform::TypeConverterBuilderOpInterface builder) {
123 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
124 return emitOpError(
"expected LLVMTypeConverter");
132 void transform::CreateAsyncGroupsOp::getEffects(
159 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.
getMemorySpace());
161 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
168 auto load = dyn_cast<vector::TransferReadOp>(op);
172 auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
181 auto store = dyn_cast<vector::TransferWriteOp>(op);
182 if (!store || store.getVector() != v)
185 auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
215 if (op.getNumRegions() > 0)
218 if (isa<gpu::BarrierOp>(op)) {
219 barriers.insert(&op);
223 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
225 ops.insert(std::make_move_iterator(barriers.begin()),
226 std::make_move_iterator(barriers.end()));
227 assert(barriers.empty() &&
228 "expected to have moved the barriers into another set");
248 unsigned iteration,
unsigned depth) {
251 auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
252 if (!waitOp || waitOp.getNumGroups())
255 int numGroupInFlight = 0;
258 numGroupInFlight = depth - 1;
265 numGroupInFlight = depth - 1 - iteration;
267 waitOp.setNumGroups(numGroupInFlight);
282 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
286 return visited->
getBlock() == forOp.getBody();
289 for (
Operation &op : forOp.getBody()->getOperations()) {
290 if (stage0Ops.contains(&op)) {
292 assert(result.succeeded() &&
"expected a backward slice");
297 for (
Operation &op : forOp.getBody()->getOperations()) {
298 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
299 opsWithPipelineStages.emplace_back(&op, depth);
301 for (
Operation &op : forOp.getBody()->getOperations()) {
302 if (dependencies.contains(&op))
303 opsWithPipelineStages.emplace_back(&op, 0);
317 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
318 nvgpu::DeviceAsyncWaitOp>(op)) {
323 auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
332 Location loc = asyncCopyOp->getLoc();
333 Value dstElements = arith::ConstantOp::create(
334 rewriter, loc, asyncCopyOp.getDstElementsAttr());
335 Value originalSrcElement =
336 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
338 auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
339 originalSrcElement, c0Index);
340 auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create(
342 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
343 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
345 rewriter.
replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
346 return asyncCopyZeroFillOp;
354 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
356 bool epiloguePeeling) {
359 return std::make_tuple(
363 if (stage0Ops.empty()) {
364 return std::make_tuple(
369 unsigned maxDepth = depth;
372 unsigned iteration) {
376 [&](scf::ForOp schedulingFor,
377 std::vector<std::pair<Operation *, unsigned>> &ops) {
378 if (schedulingFor != forOp)
382 options.annotateFn = setAnnotation;
383 if (!epiloguePeeling) {
391 FailureOr<scf::ForOp> maybePipelined =
393 if (succeeded(maybePipelined)) {
397 return std::make_tuple(
408 rewriter, forOp,
static_cast<int64_t
>(getDepth()), getPeelEpilogue());
409 if (
diag.succeeded()) {
413 if (
diag.isDefiniteFailure()) {
415 if (!getPeelEpilogue()) {
416 diag.attachNote(forOp->getLoc()) <<
"couldn't predicate?";
417 diag.attachNote(getLoc()) <<
"try setting " << getPeelEpilogueAttrName();
422 return std::move(
diag);
438 void print(llvm::raw_ostream &os)
const {
439 os <<
"- indexing: " << first <<
", " << second;
448 : b(b), loc(loc), laneId(laneId) {}
451 std::function<SmallVector<RowColIndexing>(
MLIRContext *)>;
455 FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
459 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
609 const IndexCalculator &indexFn);
619 IndexCalculator indexFn,
628 const IndexCalculator &indexFn);
651 template <
typename ApplyFn,
typename ReduceFn>
654 VectorType vectorType = cast<VectorType>(vector.
getType());
657 for (int64_t idx = 0, e =
vectorShape[0] * strides[0]; idx < e; ++idx) {
659 reduceFn(applyFn(vector, idx, indices), idx, indices);
666 const IndexCalculator &indexFn) {
672 for (
auto indexing : indexings) {
675 auto load = memref::LoadOp::create(b, loc, memref,
ValueRange{row, col});
681 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
684 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
688 Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
693 return loads[linearIdx];
697 res = vector::InsertOp::create(b, loc, v, res, indices);
705 Value memref,
const IndexCalculator &indexFn) {
710 for (
auto [indexing, val] :
711 llvm::zip_equal(indexFn(b.
getContext()), toStore)) {
715 memref::StoreOp::create(b, loc, val, memref,
ValueRange{row, col});
716 res.push_back(store);
730 return vector::ExtractOp::create(b, loc, vectorToStore, indices);
734 toStore.push_back(v);
736 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
746 return std::make_tuple(vlhs, vrhs, vres);
749 FailureOr<MmaSyncBuilder::MmaSyncInfo>
756 elementalTypes ==
TypeRange{f32, f32, f32}) {
757 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
758 &MmaSyncBuilder::m16n8k4tf32Rhs,
759 &MmaSyncBuilder::m16n8k4tf32Res),
768 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
769 &MmaSyncBuilder::m16n8k16f16Rhs,
770 &MmaSyncBuilder::m16n8k16f16Res),
779 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
780 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
781 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
782 assert(cast<MemRefType>(lhsMemRef.
getType()).getRank() == 2 &&
783 "expected lhs to be a 2D memref");
784 assert(cast<MemRefType>(rhsMemRef.
getType()).getRank() == 2 &&
785 "expected rhs to be a 2D memref");
786 assert(cast<MemRefType>(resMemRef.
getType()).getRank() == 2 &&
787 "expected res to be a 2D memref");
789 int64_t m = cast<MemRefType>(lhsMemRef.
getType()).getShape()[0];
790 int64_t n = cast<MemRefType>(rhsMemRef.
getType()).getShape()[1];
791 int64_t k = cast<MemRefType>(lhsMemRef.
getType()).getShape()[1];
796 FailureOr<MmaSyncInfo> maybeInfo =
797 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
801 MmaSyncInfo info = *maybeInfo;
802 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
803 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
804 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
805 lhsIndexFn, lhsShape);
806 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
807 rhsIndexFn, rhsShape);
808 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
809 resIndexFn, resShape);
810 res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape,
812 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
823 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
826 if (linalgOp.hasUserDefinedMaps()) {
827 return emitSilenceableError()
828 <<
"only matmul ops with non-extended semantics are supported";
832 Value laneId = gpu::ThreadIdOp::create(
833 rewriter, loc, rewriter.
getIndexType(), gpu::Dimension::x);
834 if (succeeded(
MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
840 <<
"unsupported target op: " << linalgOp;
841 diag.attachNote(linalgOp->getLoc()) <<
"target op";
857 : rewriter(rewriter), loc(loc) {}
860 buildAndInitBarrierInSharedMemory(
OpFoldResult numThreads);
866 gpu::LaunchOp launchOp);
898 Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
899 Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
902 scf::IfOp::create(rewriter,
908 sizes.reserve(globalDescriptors.size());
909 for (
auto [desc, shmem] : llvm::zip_equal(
910 globalDescriptors, sharedMemBuffers)) {
911 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
916 buildBarrierArriveTx(barrier, sizes);
917 scf::YieldOp::create(rewriter, loc);
924 scf::YieldOp::create(rewriter, loc);
932 b.
getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
939 Value barrier = nvgpu::MBarrierCreateOp::create(
943 nvgpu::MBarrierInitOp::create(
944 rewriter, loc, barrier,
947 gpu::BarrierOp::create(rewriter, loc);
948 return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
953 gpu::LaunchOp launchOp) {
956 Value unrankedMemRef = memref::CastOp::create(
959 memref.getType().getMemorySpace()),
967 Value desc = nvgpu::TmaCreateDescriptorOp::create(
973 TensorMapSwizzleKind::SWIZZLE_NONE,
974 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
975 TensorMapInterleaveKind::INTERLEAVE_NONE),
976 unrankedMemRef, sizes);
977 return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
984 SmallVectorImpl<Operation *> &loadOps) {
987 Operation *loadOp = nvgpu::TmaAsyncLoadOp::create(
988 rewriter, loc, sharedMemref, barrier, globalDesc,
ValueRange{zero, zero},
990 loadOps.push_back(loadOp);
996 (sharedMemref.getType().getElementTypeBitWidth() / 8);
998 prodExprInBytes, mixedSizes);
1004 ArrayRef<OpFoldResult> mixedSizes) {
1005 assert(!mixedSizes.empty() &&
"expecte non-empty sizes");
1014 nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,
1021 Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
1025 Value ticksBeforeRetry =
1028 nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,
1029 ticksBeforeRetry, zero);
1046 if (copyOps.empty())
1049 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1050 assert(launchOp &&
"expected launch op");
1059 rewriter, loc, prod,
1060 ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1061 launchOp.getBlockSizeZ()});
1064 buildAndInitBarrierInSharedMemory(numThreads);
1069 auto copyOp = cast<linalg::CopyOp>(op);
1071 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1072 assert(inMemRef.getType().getRank() == 2 &&
1073 "expected in to be a 2D memref");
1077 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1078 globalDescs.push_back(globalDesc);
1082 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1083 shmems.push_back(shmem);
1090 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1093 buildTryWaitParity(barrier);
1106 auto payloadOps = state.getPayloadOps(getTarget());
1107 gpu::LaunchOp commonLaunchOp;
1109 if (llvm::any_of(payloadOps, [&](
Operation *op) {
1110 if (!commonLaunchOp) {
1116 !isa<linalg::CopyOp>(op);
1122 emitSilenceableError()
1123 <<
"target ops must be linalg::CopyOp nested under a common "
1124 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1125 "be created on the host.\nBut got: "
1126 << *firstOp <<
"\nand " << *failingOp;
1141 class NVGPUTransformDialectExtension
1143 NVGPUTransformDialectExtension> {
1147 NVGPUTransformDialectExtension() {
1148 declareGeneratedDialect<arith::ArithDialect>();
1149 declareGeneratedDialect<affine::AffineDialect>();
1150 declareGeneratedDialect<nvgpu::NVGPUDialect>();
1151 declareGeneratedDialect<NVVM::NVVMDialect>();
1152 declareGeneratedDialect<vector::VectorDialect>();
1153 registerTransformOps<
1155 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1161 #define GET_OP_CLASSES
1162 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
static constexpr int64_t kSharedMemorySpace
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
@ kGlobalMemorySpace
Global memory space identifier.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
Options to dictate how loops should be pipelined.