32#include "llvm/ADT/ArrayRef.h"
40#define DEBUG_TYPE "nvgpu-transforms"
46void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
53 llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) ->
Type {
54 return llvmTypeConverter.convertType(
55 IntegerType::get(type.getContext(), 32));
57 llvmTypeConverter.addConversion([&](MBarrierTokenType type) ->
Type {
58 return llvmTypeConverter.convertType(
59 IntegerType::get(type.getContext(), 64));
61 llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) ->
Type {
62 Type elemType = type.getFragmented().getElementType();
63 int64_t sizeM = type.getFragmented().getDimSize(0);
64 int64_t sizeN = type.getFragmented().getDimSize(1);
68 numMembers = sizeN / 2;
69 else if (elemType.
isF16())
70 numMembers = sizeN / 4;
72 llvm_unreachable(
"unsupported type for warpgroup accumulator");
75 for (
unsigned i = 0; i < numMembers; i++)
76 innerStructBody.push_back(elemType);
77 auto innerStructType =
78 LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
82 structBody.push_back(innerStructType);
85 LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
86 return llvmTypeConverter.convertType(convertedType);
88 llvmTypeConverter.addConversion([&](MBarrierGroupType type) ->
Type {
89 return llvmTypeConverter.convertType(
92 llvmTypeConverter.addConversion(
93 [&](WarpgroupMatrixDescriptorType type) ->
Type {
94 return llvmTypeConverter.convertType(
95 IntegerType::get(type.getContext(), 64));
97 llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) ->
Type {
98 return LLVM::LLVMPointerType::get(type.getContext());
103LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
104 TypeConverterBuilderOpInterface builder) {
105 if (builder.getTypeConverterType() !=
"LLVMTypeConverter")
114void CreateAsyncGroupsOp::getEffects(
142 dyn_cast_if_present<gpu::AddressSpaceAttr>(type.
getMemorySpace());
144 space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
151 auto load = dyn_cast<vector::TransferReadOp>(op);
155 auto loadType = dyn_cast<MemRefType>(
load.getBase().getType());
164 auto store = dyn_cast<vector::TransferWriteOp>(op);
165 if (!store || store.getVector() != v)
168 auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
198 if (op.getNumRegions() > 0)
201 if (isa<gpu::BarrierOp>(op)) {
202 barriers.insert(&op);
206 if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) {
208 ops.insert(std::make_move_iterator(barriers.begin()),
209 std::make_move_iterator(barriers.end()));
210 assert(barriers.empty() &&
211 "expected to have moved the barriers into another set");
231 unsigned iteration,
unsigned depth) {
234 auto waitOp = dyn_cast<DeviceAsyncWaitOp>(op);
235 if (!waitOp || waitOp.getNumGroups())
238 int numGroupInFlight = 0;
241 numGroupInFlight = depth - 1;
248 numGroupInFlight = depth - 1 - iteration;
250 waitOp.setNumGroups(numGroupInFlight);
265 std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
269 return visited->
getBlock() == forOp.getBody();
272 for (
Operation &op : forOp.getBody()->getOperations()) {
273 if (stage0Ops.contains(&op)) {
275 assert(
result.succeeded() &&
"expected a backward slice");
280 for (
Operation &op : forOp.getBody()->getOperations()) {
281 if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
282 opsWithPipelineStages.emplace_back(&op, depth);
284 for (
Operation &op : forOp.getBody()->getOperations()) {
285 if (dependencies.contains(&op))
286 opsWithPipelineStages.emplace_back(&op, 0);
300 isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) {
305 auto asyncCopyOp = dyn_cast<DeviceAsyncCopyOp>(op);
314 Location loc = asyncCopyOp->getLoc();
315 Value dstElements = arith::ConstantOp::create(
316 rewriter, loc, asyncCopyOp.getDstElementsAttr());
317 Value originalSrcElement =
318 asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
320 auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
321 originalSrcElement, c0Index);
322 auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create(
323 rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
324 asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
325 asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
327 rewriter.
replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
328 return asyncCopyZeroFillOp;
336static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
338 bool epiloguePeeling) {
341 return std::make_tuple(
345 if (stage0Ops.empty()) {
346 return std::make_tuple(
351 unsigned maxDepth = depth;
354 unsigned iteration) {
358 [&](scf::ForOp schedulingFor,
359 std::vector<std::pair<Operation *, unsigned>> &ops) {
360 if (schedulingFor != forOp)
364 options.annotateFn = setAnnotation;
365 if (!epiloguePeeling) {
373 FailureOr<scf::ForOp> maybePipelined =
374 pipelineForLoop(rewriter, forOp,
options, &modifiedIR);
375 if (succeeded(maybePipelined)) {
379 return std::make_tuple(
390 rewriter, forOp,
static_cast<int64_t>(getDepth()), getPeelEpilogue());
391 if (
diag.succeeded()) {
395 if (
diag.isDefiniteFailure()) {
397 if (!getPeelEpilogue()) {
398 diag.attachNote(forOp->getLoc()) <<
"couldn't predicate?";
399 diag.attachNote(getLoc()) <<
"try setting " << getPeelEpilogueAttrName();
404 return std::move(
diag);
420 void print(llvm::raw_ostream &os)
const {
421 os <<
"- indexing: " << first <<
", " << second;
430 : b(b), loc(loc), laneId(laneId) {}
433 std::function<SmallVector<RowColIndexing>(
MLIRContext *)>;
441 std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
476 RowColIndexing{groupID + 8, threadIDInGroup}};
484 static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
486 AffineExpr groupID = dim.
floorDiv(4);
487 AffineExpr threadIDInGroup = dim % 4;
488 return {RowColIndexing{threadIDInGroup, groupID}};
497 static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
499 AffineExpr groupID = dim.
floorDiv(4);
500 AffineExpr threadIDInGroup = dim % 4;
501 return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
502 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
503 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
504 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
519 static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
521 AffineExpr groupID = dim.
floorDiv(4);
522 AffineExpr threadIDInGroup = dim % 4;
525 RowColIndexing{groupID, threadIDInGroup * 2 + 0},
526 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
527 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
528 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1},
529 RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8},
530 RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8},
531 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8},
532 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8}
545 static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
547 AffineExpr groupID = dim.
floorDiv(4);
548 AffineExpr threadIDInGroup = dim % 4;
551 RowColIndexing{threadIDInGroup * 2 + 0, groupID},
552 RowColIndexing{threadIDInGroup * 2 + 1, groupID},
553 RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID},
554 RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID}
567 static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
569 AffineExpr groupID = dim.
floorDiv(4);
570 AffineExpr threadIDInGroup = dim % 4;
573 RowColIndexing{groupID, threadIDInGroup * 2 + 0},
574 RowColIndexing{groupID, threadIDInGroup * 2 + 1},
575 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
576 RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}
589 SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
590 OpFoldResult laneId, Value memref,
599 Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
600 OpFoldResult laneId, Value memref,
607 SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
609 OpFoldResult laneId, Value memref,
618 SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
619 OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
633template <
typename ApplyFn,
typename ReduceFn>
636 VectorType vectorType = cast<VectorType>(
vector.getType());
648 const IndexCalculator &indexFn) {
649 auto aff = [&](AffineExpr e) {
650 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
652 SmallVector<Value> res;
653 SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
654 for (
auto indexing : indexings) {
657 auto load = memref::LoadOp::create(b, loc, memref,
ValueRange{row, col});
663Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
666 auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);
669 auto vt = VectorType::get(
vectorShape, elementType);
670 Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
674 [&](Value v, int64_t linearIdx, ArrayRef<int64_t>
indices) {
675 return loads[linearIdx];
678 [&](Value v, int64_t linearIdx, ArrayRef<int64_t>
indices) {
679 res = vector::InsertOp::create(b, loc, v, res,
indices);
688 auto aff = [&](AffineExpr e) {
689 return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
691 SmallVector<Operation *> res;
692 for (
auto [indexing, val] :
693 llvm::zip_equal(indexFn(b.getContext()), toStore)) {
697 memref::StoreOp::create(b, loc, val, memref,
ValueRange{row, col});
698 res.push_back(store);
706 SmallVector<Value> toStore;
711 [&](Value v, int64_t linearIdx, ArrayRef<int64_t>
indices) {
712 return vector::ExtractOp::create(b, loc, vectorToStore,
indices);
715 [&](Value v, int64_t linearIdx, ArrayRef<int64_t>
indices) {
716 toStore.push_back(v);
718 return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);
728 return std::make_tuple(vlhs, vrhs, vres);
731FailureOr<MmaSyncBuilder::MmaSyncInfo>
735 Type f16 = b.getF16Type();
736 Type f32 = b.getF32Type();
737 if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
738 elementalTypes ==
TypeRange{f32, f32, f32}) {
739 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
740 &MmaSyncBuilder::m16n8k4tf32Rhs,
741 &MmaSyncBuilder::m16n8k4tf32Res),
743 SmallVector<int64_t>{opShape},
748 if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
749 elementalTypes ==
TypeRange{f16, f16, f16}) {
750 return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
751 &MmaSyncBuilder::m16n8k16f16Rhs,
752 &MmaSyncBuilder::m16n8k16f16Res),
754 SmallVector<int64_t>{opShape},
761 Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
762 Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
763 Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
764 assert(cast<MemRefType>(lhsMemRef.
getType()).getRank() == 2 &&
765 "expected lhs to be a 2D memref");
766 assert(cast<MemRefType>(rhsMemRef.
getType()).getRank() == 2 &&
767 "expected rhs to be a 2D memref");
768 assert(cast<MemRefType>(resMemRef.
getType()).getRank() == 2 &&
769 "expected res to be a 2D memref");
778 FailureOr<MmaSyncInfo> maybeInfo =
779 getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
780 if (failed(maybeInfo))
783 const MmaSyncInfo &info = *maybeInfo;
784 auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
785 auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
786 Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
787 lhsIndexFn, lhsShape);
788 Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
789 rhsIndexFn, rhsShape);
790 Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
791 resIndexFn, resShape);
793 MmaSyncOp::create(b, loc,
lhs,
rhs, res, info.mmaShape, info.tf32Enabled);
794 buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
804 if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
807 if (linalgOp.hasUserDefinedMaps()) {
808 return emitSilenceableError()
809 <<
"only matmul ops with non-extended semantics are supported";
813 Value laneId = gpu::ThreadIdOp::create(
814 rewriter, loc, rewriter.
getIndexType(), gpu::Dimension::x);
815 if (succeeded(
MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
821 <<
"unsupported target op: " << linalgOp;
822 diag.attachNote(linalgOp->getLoc()) <<
"target op";
847 gpu::LaunchOp launchOp);
888 sizes.reserve(globalDescriptors.size());
889 for (
auto [desc, shmem] : llvm::zip_equal(
890 globalDescriptors, sharedMemBuffers)) {
911 return gpu::AddressSpaceAttr::get(
912 b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
919 Value barrier = MBarrierCreateOp::create(
921 MBarrierGroupType::get(
rewriter.getContext(), sharedMemorySpace));
923 nvgpu::MBarrierInitOp::create(
928 return cast<TypedValue<MBarrierGroupType>>(barrier);
933 gpu::LaunchOp launchOp) {
935 rewriter.setInsertionPoint(launchOp);
936 Value unrankedMemRef = memref::CastOp::create(
938 UnrankedMemRefType::get(
memref.getType().getElementType(),
939 memref.getType().getMemorySpace()),
947 Value desc = TmaCreateDescriptorOp::create(
949 TensorMapDescriptorType::get(
rewriter.getContext(),
952 TensorMapSwizzleKind::SWIZZLE_NONE,
953 TensorMapL2PromoKind::L2PROMO_NONE,
954 TensorMapOOBKind::OOB_ZERO,
955 TensorMapInterleaveKind::INTERLEAVE_NONE),
956 unrankedMemRef, sizes);
957 return cast<TypedValue<TensorMapDescriptorType>>(desc);
968 TmaAsyncLoadOp::create(
rewriter,
loc, sharedMemref, barrier, globalDesc,
970 loadOps.push_back(loadOp);
976 (sharedMemref.getType().getElementTypeBitWidth() / 8);
978 prodExprInBytes, mixedSizes);
984 assert(!mixedSizes.empty() &&
"expecte non-empty sizes");
993 nvgpu::MBarrierArriveExpectTxOp::create(
rewriter,
loc, barrier, sizeVal, zero,
1003 Value ticksBeforeRetry =
1006 nvgpu::MBarrierTryWaitParityOp::create(
rewriter,
loc, barrier, parity,
1007 ticksBeforeRetry, zero);
1024 if (copyOps.empty())
1027 auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1028 assert(launchOp &&
"expected launch op");
1032 rewriter.setInsertionPoint(copyOps.front());
1039 launchOp.getBlockSizeZ()});
1047 auto copyOp = cast<linalg::CopyOp>(op);
1049 cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->
get());
1050 assert(inMemRef.getType().getRank() == 2 &&
1051 "expected in to be a 2D memref");
1056 globalDescs.push_back(globalDesc);
1060 cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->
get());
1061 shmems.push_back(shmem);
1066 rewriter.setInsertionPoint(copyOps.front());
1084 gpu::LaunchOp commonLaunchOp;
1086 if (llvm::any_of(payloadOps, [&](
Operation *op) {
1087 if (!commonLaunchOp) {
1093 !isa<linalg::CopyOp>(op);
1099 emitSilenceableError()
1100 <<
"target ops must be linalg::CopyOp nested under a common "
1101 "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1102 "be created on the host.\nBut got: "
1103 << *firstOp <<
"\nand " << *failingOp;
1118class NVGPUTransformDialectExtension
1123 NVGPUTransformDialectExtension() {
1124 declareGeneratedDialect<arith::ArithDialect>();
1125 declareGeneratedDialect<affine::AffineDialect>();
1126 declareGeneratedDialect<NVGPUDialect>();
1127 declareGeneratedDialect<NVVM::NVVMDialect>();
1128 declareGeneratedDialect<vector::VectorDialect>();
1129 registerTransformOps<
1131#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1137#define GET_OP_CLASSES
1138#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Base type for affine expression.
AffineExpr floorDiv(uint64_t v) const
Attributes are known-constant values of operations.
This class provides a shared interface for ranked and unranked memref types.
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Block * getBlock()
Returns the operation block that contains this operation.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger() const
Return true if this is an integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void populateCommonGPUTypeAndAttributeConversions(TypeConverter &typeConverter)
Remap common GPU memory spaces (Workgroup, Private, etc) to LLVM address spaces.
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry ®istry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
llvm::SetVector< T, Vector, Set, N > SetVector
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
void buildBarrierArriveTx(TypedValue< MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
OpFoldResult buildTmaAsyncLoad(TypedValue< TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
void buildTryWaitParity(TypedValue< MBarrierGroupType > barrier)
TypedValue< MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< TensorMapDescriptorType > > globalDescriptors, ArrayRef< TypedValue< MemRefType > > sharedMemBuffers, TypedValue< MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
Options to dictate how loops should be pipelined.