31 #include "llvm/ADT/ArrayRef.h"
39 #define DEBUG_TYPE "nvgpu-transforms"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
48 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
55 [&](nvgpu::DeviceAsyncTokenType type) ->
Type {
56 return llvmTypeConverter.convertType(
59 llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) ->
Type {
60 return llvmTypeConverter.convertType(
63 llvmTypeConverter.addConversion(
64 [&](nvgpu::WarpgroupAccumulatorType type) ->
Type {
65 Type elemType = type.getFragmented().getElementType();
66 int64_t sizeM = type.getFragmented().getDimSize(0);
67 int64_t sizeN = type.getFragmented().getDimSize(1);
71 numMembers = sizeN / 2;
72 else if (elemType.
isF16())
73 numMembers = sizeN / 4;
75 llvm_unreachable(
"unsupported type for warpgroup accumulator");
78 for (
unsigned i = 0; i < numMembers; i++)
79 innerStructBody.push_back(elemType);
81 type.getContext(), innerStructBody);
85 structBody.push_back(innerStructType);
89 return llvmTypeConverter.convertType(convertedType);
91 llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) ->
Type {
92 return llvmTypeConverter.convertType(
95 llvmTypeConverter.addConversion(
96 [&](nvgpu::WarpgroupMatrixDescriptorType type) ->
Type {
97 return llvmTypeConverter.convertType(
100 llvmTypeConverter.addConversion(
101 [&](nvgpu::TensorMapDescriptorType type) ->
Type {
108 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
109 transform::TypeConverterBuilderOpInterface builder) {
110 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
111 return emitOpError(
"expected LLVMTypeConverter");
119 void transform::CreateAsyncGroupsOp::getEffects(
146 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.
getMemorySpace());
148 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
155 auto load = dyn_cast<vector::TransferReadOp>(op);
159 auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
168 auto store = dyn_cast<vector::TransferWriteOp>(op);
169 if (!store || store.getVector() != v)
172 auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
205 if (isa<gpu::BarrierOp>(op)) {
206 barriers.insert(&op);
210 if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
212 ops.insert(std::make_move_iterator(barriers.begin()),
213 std::make_move_iterator(barriers.end()));
214 assert(barriers.empty() &&
215 "expected to have moved the barriers into another set");
235 unsigned iteration,
unsigned depth) {
238 auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
239 if (!waitOp || waitOp.getNumGroups())
242 int numGroupInFlight = 0;
245 numGroupInFlight = depth - 1;
252 numGroupInFlight = depth - 1 - iteration;
254 waitOp.setNumGroups(numGroupInFlight);
269 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
273 return visited->
getBlock() == forOp.getBody();
276 for (
Operation &op : forOp.getBody()->getOperations()) {
277 if (stage0Ops.contains(&op))
281 for (
Operation &op : forOp.getBody()->getOperations()) {
282 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
283 opsWithPipelineStages.emplace_back(&op, depth);
285 for (
Operation &op : forOp.getBody()->getOperations()) {
286 if (dependencies.contains(&op))
287 opsWithPipelineStages.emplace_back(&op, 0);
301 isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
302 nvgpu::DeviceAsyncWaitOp>(op)) {
307 auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
316 Location loc = asyncCopyOp->getLoc();
318 rewriter.
create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
319 Value originalSrcElement =
320 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
321 Value c0Index = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
322 auto srcElements = rewriter.
create<arith::SelectOp>(
323 loc, predicate, originalSrcElement, c0Index);
324 auto asyncCopyZeroFillOp = rewriter.
create<nvgpu::DeviceAsyncCopyOp>(
326 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
327 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
329 rewriter.
replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
330 return asyncCopyZeroFillOp;
338 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
340 bool epiloguePeeling) {
343 return std::make_tuple(
347 if (stage0Ops.empty()) {
348 return std::make_tuple(
353 unsigned maxDepth = depth;
356 unsigned iteration) {
360 [&](scf::ForOp schedulingFor,
361 std::vector<std::pair<Operation *, unsigned>> &ops) {
362 if (schedulingFor != forOp)
366 options.annotateFn = setAnnotation;
367 if (!epiloguePeeling) {
381 return std::make_tuple(
392 rewriter, forOp,
static_cast<int64_t
>(getDepth()), getPeelEpilogue());
393 if (
diag.succeeded()) {
397 if (
diag.isDefiniteFailure()) {
399 if (!getPeelEpilogue()) {
400 diag.attachNote(forOp->getLoc()) <<
"couldn't predicate?";
401 diag.attachNote(getLoc()) <<
"try setting " << getPeelEpilogueAttrName();
406 return std::move(
diag);
422 void print(llvm::raw_ostream &os)
const {
423 os <<
"- indexing: " << first <<
", " << second;
432 : b(b), loc(loc), laneId(laneId) {}
435 std::function<SmallVector<RowColIndexing>(
MLIRContext *)>;
443 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
593 IndexCalculator indexFn);
603 IndexCalculator indexFn,
612 IndexCalculator indexFn);
635 template <
typename ApplyFn,
typename ReduceFn>
638 VectorType vectorType = vector.
getType().
cast<VectorType>();
641 for (int64_t idx = 0, e =
vectorShape[0] * strides[0]; idx < e; ++idx) {
643 reduceFn(applyFn(vector, idx, indices), idx, indices);
650 IndexCalculator indexFn) {
656 for (
auto indexing : indexings) {
665 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
668 auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);
672 Value res = b.
create<vector::SplatOp>(loc, vt, loads[0]);
677 return loads[linearIdx];
681 res = b.
create<vector::InsertOp>(loc, v, res, indices);
690 Value memref, IndexCalculator indexFn) {
695 for (
auto [indexing, val] :
696 llvm::zip_equal(indexFn(b.
getContext()), toStore)) {
701 res.push_back(store);
715 return b.
create<vector::ExtractOp>(loc, vectorToStore, indices);
719 toStore.push_back(v);
721 return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);
731 return std::make_tuple(vlhs, vrhs, vres);
741 elementalTypes ==
TypeRange{f32, f32, f32}) {
742 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
743 &MmaSyncBuilder::m16n8k4tf32Rhs,
744 &MmaSyncBuilder::m16n8k4tf32Res),
753 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
754 &MmaSyncBuilder::m16n8k16f16Rhs,
755 &MmaSyncBuilder::m16n8k16f16Res),
764 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
765 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
766 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
767 assert(lhsMemRef.
getType().
cast<MemRefType>().getRank() == 2 &&
768 "expected lhs to be a 2D memref");
769 assert(rhsMemRef.
getType().
cast<MemRefType>().getRank() == 2 &&
770 "expected rhs to be a 2D memref");
771 assert(resMemRef.
getType().
cast<MemRefType>().getRank() == 2 &&
772 "expected res to be a 2D memref");
774 int64_t m = cast<MemRefType>(lhsMemRef.
getType()).getShape()[0];
775 int64_t n = cast<MemRefType>(rhsMemRef.
getType()).getShape()[1];
776 int64_t k = cast<MemRefType>(lhsMemRef.
getType()).getShape()[1];
782 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
786 MmaSyncInfo info = *maybeInfo;
787 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
788 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
789 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
790 lhsIndexFn, lhsShape);
791 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
792 rhsIndexFn, rhsShape);
793 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
794 resIndexFn, resShape);
795 res = b.
create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
797 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
808 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
819 <<
"unsupported target op: " << linalgOp;
820 diag.attachNote(linalgOp->getLoc()) <<
"target op";
836 : rewriter(rewriter), loc(loc) {}
839 buildAndInitBarrierInSharedMemory(
OpFoldResult numThreads);
845 gpu::LaunchOp launchOp);
876 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
877 Value tidx = rewriter.
create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
879 rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
881 rewriter.
create<scf::IfOp>(
887 sizes.reserve(globalDescriptors.size());
888 for (
auto [desc, shmem] : llvm::zip_equal(
889 globalDescriptors, sharedMemBuffers)) {
890 OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
895 buildBarrierArriveTx(barrier, sizes);
896 rewriter.
create<scf::YieldOp>(loc);
903 rewriter.
create<scf::YieldOp>(loc);
911 b.
getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
918 Value barrier = rewriter.
create<nvgpu::MBarrierCreateOp>(
921 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
922 rewriter.
create<nvgpu::MBarrierInitOp>(
925 rewriter.
create<gpu::BarrierOp>(loc);
926 return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
931 gpu::LaunchOp launchOp) {
934 Value unrankedMemRef = rewriter.
create<memref::CastOp>(
937 memref.getType().getMemorySpace()),
945 Value desc = rewriter.
create<nvgpu::TmaCreateDescriptorOp>(
951 TensorMapSwizzleKind::SWIZZLE_NONE,
952 TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
953 TensorMapInterleaveKind::INTERLEAVE_NONE),
954 unrankedMemRef, sizes);
955 return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
962 SmallVectorImpl<Operation *> &loadOps) {
964 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
966 loc, sharedMemref, barrier, globalDesc,
ValueRange{zero, zero}, zero,
968 loadOps.push_back(loadOp);
974 (sharedMemref.getType().getElementTypeBitWidth() / 8);
976 prodExprInBytes, mixedSizes);
982 ArrayRef<OpFoldResult> mixedSizes) {
983 assert(!mixedSizes.empty() &&
"expecte non-empty sizes");
991 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
992 rewriter.
create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
998 Value parity = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1002 Value ticksBeforeRetry =
1003 rewriter.
create<arith::ConstantIndexOp>(loc, 10000000);
1004 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
1005 rewriter.
create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1006 ticksBeforeRetry, zero);
1023 if (copyOps.empty())
1026 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1027 assert(launchOp &&
"expected launch op");
1036 rewriter, loc, prod,
1037 ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1038 launchOp.getBlockSizeZ()});
1041 buildAndInitBarrierInSharedMemory(numThreads);
1046 auto copyOp = cast<linalg::CopyOp>(op);
1048 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1049 assert(inMemRef.getType().getRank() == 2 &&
1050 "expected in to be a 2D memref");
1054 buildGlobalMemRefDescriptor(inMemRef, launchOp);
1055 globalDescs.push_back(globalDesc);
1059 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1060 shmems.push_back(shmem);
1067 buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1070 buildTryWaitParity(barrier);
1083 auto payloadOps = state.getPayloadOps(getTarget());
1084 gpu::LaunchOp commonLaunchOp;
1086 if (llvm::any_of(payloadOps, [&](
Operation *op) {
1087 if (!commonLaunchOp) {
1093 !isa<linalg::CopyOp>(op);
1099 emitSilenceableError()
1100 <<
"target ops must be linalg::CopyOp nested under a common "
1101 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1102 "be created on the host.\nBut got: "
1103 << *firstOp <<
"\nand " << *failingOp;
1118 class NVGPUTransformDialectExtension
1120 NVGPUTransformDialectExtension> {
1122 NVGPUTransformDialectExtension() {
1123 declareGeneratedDialect<arith::ArithDialect>();
1124 declareGeneratedDialect<affine::AffineDialect>();
1125 declareGeneratedDialect<nvgpu::NVGPUDialect>();
1126 declareGeneratedDialect<NVVM::NVVMDialect>();
1127 declareGeneratedDialect<vector::VectorDialect>();
1128 registerTransformOps<
1130 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1136 #define GET_OP_CLASSES
1137 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static llvm::ManagedStatic< PassManagerOptions > options
static ArrayRef< int64_t > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
MLIRContext * getContext() const
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Conversion from types to the LLVM IR dialect.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
This class represents an efficient way to signal success or failure.
Options to dictate how loops should be pipelined.