32#include "llvm/ADT/ArrayRef.h"
40#define DEBUG_TYPE "nvgpu-transforms"
46void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
53 llvmTypeConverter, [](gpu::AddressSpace space) ->
unsigned {
55 case gpu::AddressSpace::Global:
56 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
57 case gpu::AddressSpace::Workgroup:
58 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
59 case gpu::AddressSpace::Private:
62 llvm_unreachable(
"unknown address space enum value");
63 return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
65 llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) ->
Type {
66 return llvmTypeConverter.convertType(
67 IntegerType::get(type.getContext(), 32));
69 llvmTypeConverter.addConversion([&](MBarrierTokenType type) ->
Type {
70 return llvmTypeConverter.convertType(
71 IntegerType::get(type.getContext(), 64));
73 llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) ->
Type {
74 Type elemType = type.getFragmented().getElementType();
75 int64_t sizeM = type.getFragmented().getDimSize(0);
76 int64_t sizeN = type.getFragmented().getDimSize(1);
80 numMembers = sizeN / 2;
81 else if (elemType.
isF16())
82 numMembers = sizeN / 4;
84 llvm_unreachable(
"unsupported type for warpgroup accumulator");
87 for (
unsigned i = 0; i < numMembers; i++)
88 innerStructBody.push_back(elemType);
89 auto innerStructType =
90 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
94 structBody.push_back(innerStructType);
97 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
98 return llvmTypeConverter.convertType(convertedType);
100 llvmTypeConverter.addConversion([&](MBarrierGroupType type) ->
Type {
101 return llvmTypeConverter.convertType(
104 llvmTypeConverter.addConversion(
105 [&](WarpgroupMatrixDescriptorType type) ->
Type {
106 return llvmTypeConverter.convertType(
107 IntegerType::get(type.getContext(), 64));
109 llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) ->
Type {
110 return LLVM::LLVMPointerType::get(type.getContext());
115LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
116 TypeConverterBuilderOpInterface builder) {
117 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
126void CreateAsyncGroupsOp::getEffects(
154 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.
getMemorySpace());
156 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
163 auto load = dyn_cast<vector::TransferReadOp>(op);
167 auto loadType = dyn_cast<MemRefType>(
load.getBase().getType());
176 auto store = dyn_cast<vector::TransferWriteOp>(op);
177 if (!store || store.getVector() != v)
180 auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
210 if (op.getNumRegions() > 0)
213 if (isa<gpu::BarrierOp>(op)) {
214 barriers.insert(&op);
218 if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) {
220 ops.insert(std::make_move_iterator(barriers.begin()),
221 std::make_move_iterator(barriers.end()));
222 assert(barriers.empty() &&
223 "expected to have moved the barriers into another set");
243 unsigned iteration,
unsigned depth) {
246 auto waitOp = dyn_cast<DeviceAsyncWaitOp>(op);
247 if (!waitOp || waitOp.getNumGroups())
250 int numGroupInFlight = 0;
253 numGroupInFlight = depth - 1;
260 numGroupInFlight = depth - 1 - iteration;
262 waitOp.setNumGroups(numGroupInFlight);
277 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
281 return visited->
getBlock() == forOp.getBody();
284 for (
Operation &op : forOp.getBody()->getOperations()) {
285 if (stage0Ops.contains(&op)) {
287 assert(
result.succeeded() &&
"expected a backward slice");
292 for (
Operation &op : forOp.getBody()->getOperations()) {
293 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
294 opsWithPipelineStages.emplace_back(&op, depth);
296 for (
Operation &op : forOp.getBody()->getOperations()) {
297 if (dependencies.contains(&op))
298 opsWithPipelineStages.emplace_back(&op, 0);
312 isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) {
317 auto asyncCopyOp = dyn_cast<DeviceAsyncCopyOp>(op);
326 Location loc = asyncCopyOp->getLoc();
327 Value dstElements = arith::ConstantOp::create(
328 rewriter, loc, asyncCopyOp.getDstElementsAttr());
329 Value originalSrcElement =
330 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
332 auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
333 originalSrcElement, c0Index);
334 auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create(
335 rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
336 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
337 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
339 rewriter.
replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
340 return asyncCopyZeroFillOp;
348static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
350 bool epiloguePeeling) {
353 return std::make_tuple(
357 if (stage0Ops.empty()) {
358 return std::make_tuple(
363 unsigned maxDepth = depth;
366 unsigned iteration) {
370 [&](scf::ForOp schedulingFor,
371 std::vector<std::pair<Operation *, unsigned>> &ops) {
372 if (schedulingFor != forOp)
376 options.annotateFn = setAnnotation;
377 if (!epiloguePeeling) {
385 FailureOr<scf::ForOp> maybePipelined =
386 pipelineForLoop(rewriter, forOp,
options, &modifiedIR);
387 if (succeeded(maybePipelined)) {
391 return std::make_tuple(
402 rewriter, forOp,
static_cast<int64_t>(getDepth()), getPeelEpilogue());
403 if (
diag.succeeded()) {
407 if (
diag.isDefiniteFailure()) {
409 if (!getPeelEpilogue()) {
410 diag.attachNote(forOp->getLoc()) <<
"couldn't predicate?";
411 diag.attachNote(getLoc()) <<
"try setting " << getPeelEpilogueAttrName();
416 return std::move(
diag);
432 void print(llvm::raw_ostream &os)
const {
433 os <<
"- indexing: " << first <<
", " << second;
442 : b(b), loc(loc), laneId(laneId) {}
445 std::function<SmallVector<RowColIndexing>(
MLIRContext *)>;
453 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
488 RowColIndexing{groupID + 8, threadIDInGroup}};
496 static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
498 AffineExpr groupID = dim.
floorDiv(4);
499 AffineExpr threadIDInGroup = dim % 4;
500 return {RowColIndexing{threadIDInGroup, groupID}};
509 static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
511 AffineExpr groupID = dim.
floorDiv(4);
512 AffineExpr threadIDInGroup = dim % 4;
513 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
514 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
515 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
516 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
531 static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
533 AffineExpr groupID = dim.
floorDiv(4);
534 AffineExpr threadIDInGroup = dim % 4;
537 RowColIndexing{groupID, threadIDInGroup * 2 + 0},
538 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
539 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
540 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1},
541 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8},
542 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8},
543 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8},
544 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8}
557 static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
559 AffineExpr groupID = dim.
floorDiv(4);
560 AffineExpr threadIDInGroup = dim % 4;
563 RowColIndexing{threadIDInGroup * 2 + 0, groupID},
564 RowColIndexing{threadIDInGroup * 2 + 1, groupID},
565 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID},
566 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID}
579 static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
581 AffineExpr groupID = dim.
floorDiv(4);
582 AffineExpr threadIDInGroup = dim % 4;
585 RowColIndexing{groupID, threadIDInGroup * 2 + 0},
586 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
587 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
588 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}
601 SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
602 OpFoldResult laneId, Value memref,
611 Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
612 OpFoldResult laneId, Value memref,
619 SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
621 OpFoldResult laneId, Value memref,
630 SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
631 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
645template <
typename ApplyFn,
typename ReduceFn>
648 VectorType vectorType = cast<VectorType>(
vector.getType());
660 const IndexCalculator &indexFn) {
661 auto aff = [&](AffineExpr e) {
664 SmallVector<Value> res;
665 SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
666 for (
auto indexing : indexings) {
669 auto load = memref::LoadOp::create(b, loc, memref,
ValueRange{row, col});
675Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
678 auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
681 auto vt = VectorType::get(
vectorShape, elementType);
682 Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
686 [&](Value v, int64_t linearIdx, ArrayRef<int64_t>
indices) {
687 return loads[linearIdx];
690 [&](Value v, int64_t linearIdx, ArrayRef<int64_t>
indices) {
691 res = vector::InsertOp::create(b, loc, v, res,
indices);
700 auto aff = [&](AffineExpr e) {
703 SmallVector<Operation *> res;
704 for (
auto [indexing, val] :
705 llvm::zip_equal(indexFn(b.getContext()), toStore)) {
709 memref::StoreOp::create(b, loc, val, memref,
ValueRange{row, col});
710 res.push_back(store);
718 SmallVector<Value> toStore;
723 [&](Value v, int64_t linearIdx, ArrayRef<int64_t>
indices) {
724 return vector::ExtractOp::create(b, loc, vectorToStore,
indices);
727 [&](Value v, int64_t linearIdx, ArrayRef<int64_t>
indices) {
728 toStore.push_back(v);
730 return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
740 return std::make_tuple(vlhs, vrhs, vres);
743FailureOr<MmaSyncBuilder::MmaSyncInfo>
747 Type f16 = b.getF16Type();
748 Type f32 = b.getF32Type();
749 if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
750 elementalTypes ==
TypeRange{f32, f32, f32}) {
751 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
752 &MmaSyncBuilder::m16n8k4tf32Rhs,
753 &MmaSyncBuilder::m16n8k4tf32Res),
755 SmallVector<int64_t>{opShape},
760 if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
761 elementalTypes ==
TypeRange{f16, f16, f16}) {
762 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
763 &MmaSyncBuilder::m16n8k16f16Rhs,
764 &MmaSyncBuilder::m16n8k16f16Res),
766 SmallVector<int64_t>{opShape},
773 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
774 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
775 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
776 assert(cast<MemRefType>(lhsMemRef.
getType()).getRank() == 2 &&
777 "expected lhs to be a 2D memref");
778 assert(cast<MemRefType>(rhsMemRef.
getType()).getRank() == 2 &&
779 "expected rhs to be a 2D memref");
780 assert(cast<MemRefType>(resMemRef.
getType()).getRank() == 2 &&
781 "expected res to be a 2D memref");
790 FailureOr<MmaSyncInfo> maybeInfo =
791 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
792 if (failed(maybeInfo))
795 MmaSyncInfo info = *maybeInfo;
796 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
797 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
798 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
799 lhsIndexFn, lhsShape);
800 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
801 rhsIndexFn, rhsShape);
802 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
803 resIndexFn, resShape);
805 MmaSyncOp::create(b, loc,
lhs,
rhs, res, info.mmaShape, info.tf32Enabled);
806 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
816 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
819 if (linalgOp.hasUserDefinedMaps()) {
820 return emitSilenceableError()
821 <<
"only matmul ops with non-extended semantics are supported";
825 Value laneId = gpu::ThreadIdOp::create(
826 rewriter, loc, rewriter.
getIndexType(), gpu::Dimension::x);
827 if (succeeded(
MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
833 <<
"unsupported target op: " << linalgOp;
834 diag.attachNote(linalgOp->getLoc()) <<
"target op";
859 gpu::LaunchOp launchOp);
900 sizes.reserve(globalDescriptors.size());
901 for (
auto [desc, shmem] : llvm::zip_equal(
902 globalDescriptors, sharedMemBuffers)) {
923 return gpu::AddressSpaceAttr::get(
924 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
931 Value barrier = MBarrierCreateOp::create(
933 MBarrierGroupType::get(
rewriter.getContext(), sharedMemorySpace));
935 nvgpu::MBarrierInitOp::create(
940 return cast<TypedValue<MBarrierGroupType>>(barrier);
945 gpu::LaunchOp launchOp) {
947 rewriter.setInsertionPoint(launchOp);
948 Value unrankedMemRef = memref::CastOp::create(
950 UnrankedMemRefType::get(
memref.getType().getElementType(),
951 memref.getType().getMemorySpace()),
959 Value desc = TmaCreateDescriptorOp::create(
961 TensorMapDescriptorType::get(
rewriter.getContext(),
964 TensorMapSwizzleKind::SWIZZLE_NONE,
965 TensorMapL2PromoKind::L2PROMO_NONE,
966 TensorMapOOBKind::OOB_ZERO,
967 TensorMapInterleaveKind::INTERLEAVE_NONE),
968 unrankedMemRef, sizes);
969 return cast<TypedValue<TensorMapDescriptorType>>(desc);
980 TmaAsyncLoadOp::create(
rewriter,
loc, sharedMemref, barrier, globalDesc,
982 loadOps.push_back(loadOp);
988 (sharedMemref.getType().getElementTypeBitWidth() / 8);
990 prodExprInBytes, mixedSizes);
996 assert(!mixedSizes.empty() &&
"expecte non-empty sizes");
1005 nvgpu::MBarrierArriveExpectTxOp::create(
rewriter,
loc, barrier, sizeVal, zero,
1015 Value ticksBeforeRetry =
1018 nvgpu::MBarrierTryWaitParityOp::create(
rewriter,
loc, barrier, parity,
1019 ticksBeforeRetry, zero);
1036 if (copyOps.empty())
1039 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1040 assert(launchOp &&
"expected launch op");
1044 rewriter.setInsertionPoint(copyOps.front());
1051 launchOp.getBlockSizeZ()});
1059 auto copyOp = cast<linalg::CopyOp>(op);
1061 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->
get());
1062 assert(inMemRef.getType().getRank() == 2 &&
1063 "expected in to be a 2D memref");
1068 globalDescs.push_back(globalDesc);
1072 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->
get());
1073 shmems.push_back(shmem);
1078 rewriter.setInsertionPoint(copyOps.front());
1096 gpu::LaunchOp commonLaunchOp;
1098 if (llvm::any_of(payloadOps, [&](
Operation *op) {
1099 if (!commonLaunchOp) {
1105 !isa<linalg::CopyOp>(op);
1111 emitSilenceableError()
1112 <<
"target ops must be linalg::CopyOp nested under a common "
1113 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1114 "be created on the host.\nBut got: "
1115 << *firstOp <<
"\nand " << *failingOp;
1130class NVGPUTransformDialectExtension
1135 NVGPUTransformDialectExtension() {
1136 declareGeneratedDialect<arith::ArithDialect>();
1137 declareGeneratedDialect<affine::AffineDialect>();
1138 declareGeneratedDialect<NVGPUDialect>();
1139 declareGeneratedDialect<NVVM::NVVMDialect>();
1140 declareGeneratedDialect<vector::VectorDialect>();
1141 registerTransformOps<
1143#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1149#define GET_OP_CLASSES
1150#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
void buildBarrierArriveTx(TypedValue< MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
OpFoldResult buildTmaAsyncLoad(TypedValue< TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
void buildTryWaitParity(TypedValue< MBarrierGroupType > barrier)
TypedValue< MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< TensorMapDescriptorType > > globalDescriptors, ArrayRef< TypedValue< MemRefType > > sharedMemBuffers, TypedValue< MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
Options to dictate how loops should be pipelined.