30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
42#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
45struct AMDGPUInlinerInterface final : DialectInlinerInterface {
46 using DialectInlinerInterface::DialectInlinerInterface;
47 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
53void AMDGPUDialect::initialize() {
56#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
59#define GET_TYPEDEF_LIST
60#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
63#define GET_ATTRDEF_LIST
64#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
66 addInterfaces<AMDGPUInlinerInterface>();
72LogicalResult PackedTrunc2xFp8Op::verify() {
73 if (getExisting() && getExisting().
getType() != getResult().
getType())
74 return emitOpError(
"existing values must have same type as result");
78LogicalResult PackedStochRoundFp8Op::verify() {
79 if (getExisting() && getExisting().
getType() != getResult().
getType())
80 return emitOpError(
"existing values must have same type as result");
87LogicalResult PackedScaledTruncOp::verify() {
88 if (getExisting() && getExisting().
getType() != getResult().
getType())
89 return emitOpError(
"existing values must have same type as result");
106 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
107 MemRefLayoutAttrInterface layout = source.getLayout();
108 if (resetOffset && !layout.isIdentity()) {
109 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
112 MemRefLayoutAttrInterface newLayout =
113 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
118 if (source.hasStaticShape()) {
120 }
else if (source.getRank() <= 1) {
123 if (stridesIfIdentity == stridedLayout.getStrides()) {
124 newLayout = AffineMapAttr::get(
129 return (MemRefType)(mb);
132LogicalResult FatRawBufferCastOp::inferReturnTypes(
136 Adaptor adaptor(operands, attributes, properties, regions);
138 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
141 FailureOr<MemRefType> resultType =
149LogicalResult FatRawBufferCastOp::verify() {
150 FailureOr<MemRefType> expectedResultType =
152 if (
failed(expectedResultType))
154 << getSource().getType() <<
" can't have its offset reset";
155 if (getResult().
getType() != *expectedResultType)
157 << *expectedResultType <<
" but got " << getResult().getType();
164 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
165 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
166 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
167 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
174 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
175 return intMemorySpace.getInt() == 3;
176 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
177 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
184 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
185 return intMemorySpace.getInt() == 7;
186 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
187 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
196 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
200 return op.emitOpError(
201 "Buffer ops must operate on a memref in global memory");
202 if (!bufferType.hasRank())
203 return op.emitOpError(
204 "Cannot meaningfully buffer_store to an unranked memref");
205 if (
static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
206 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
207 " indices to memref");
215LogicalResult RawBufferAtomicFaddOp::verify() {
219LogicalResult RawBufferAtomicFmaxOp::verify() {
223LogicalResult RawBufferAtomicSmaxOp::verify() {
227LogicalResult RawBufferAtomicUminOp::verify() {
231LogicalResult RawBufferAtomicCmpswapOp::verify() {
240 return cst.getZExtValue();
244template <
typename OpType>
246 if (!op.getBoundsCheck())
248 MemRefType bufferType = op.getMemref().getType();
249 if (!bufferType.hasStaticShape())
253 if (failed(bufferType.getStridesAndOffset(strides, offset)))
256 if (op.getSgprOffset()) {
262 if (strides.size() != op.getIndices().size())
265 for (
auto pair : llvm::zip(strides, op.getIndices())) {
266 int64_t stride = std::get<0>(pair);
267 Value idx = std::get<1>(pair);
271 indexVal += stride * *idxVal;
274 if (
result > std::numeric_limits<uint32_t>::max())
277 return result >= bufferType.getNumElements();
281template <
typename OpType>
282struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
283 using OpRewritePattern<OpType>::OpRewritePattern;
285 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
288 Type loadType = op.getResult().getType();
295template <
typename OpType>
296struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
297 using OpRewritePattern<OpType>::OpRewritePattern;
299 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
311 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
316 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
319void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
321 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
324void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
326 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
329void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
331 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
334void RawBufferAtomicUminOp::getCanonicalizationPatterns(
336 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
339void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
341 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
348LogicalResult ScaledExtPackedMatrixOp::verify() {
350 assert(llvm::is_contained({16, 32}, blockSize) &&
"invalid block size");
352 int firstScaleByte = getFirstScaleByte();
353 int firstScaleLane = getFirstScaleLane();
354 auto sourceType = cast<VectorType>(getSource().
getType());
355 Type elementType = sourceType.getElementType();
356 auto floatType = cast<FloatType>(elementType);
357 unsigned bitWidth = floatType.getWidth();
361 const bool is_fp8 = bitWidth == 8;
362 const bool is_block_16 = blockSize == 16;
366 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
367 return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 "
368 "or 1 for f4 and f6.");
371 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
372 return emitOpError(
"blockSize of 32 can only have firstScaleByte be 0 "
373 "or 2 for f4 and f6.");
378 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
379 ((firstScaleLane == 16) && (firstScaleByte == 2));
381 return emitOpError(
"blockSize of 16 can only have (firstScaleLane, "
382 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
395 IntegerAttr &m, IntegerAttr &n,
400 if (dimensions.size() != 3)
402 <<
"expected 3 dimensions in MNK dimension list";
410LogicalResult WMMAOp::verify() {
411 auto sourceAType = cast<VectorType>(getSourceA().
getType());
412 auto sourceBType = cast<VectorType>(getSourceB().
getType());
413 auto destType = cast<VectorType>(getDestC().
getType());
415 Type sourceAElemType = sourceAType.getElementType();
416 Type sourceBElemType = sourceBType.getElementType();
417 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
418 return emitOpError(
"source vectors have different lengths: ")
419 << sourceAType <<
" vs. " << sourceBType;
422 bool isDestFloat = destType.getElementType().
isFloat();
423 bool isSrcFloat = sourceAElemType.
isFloat();
425 if (isDestFloat && !isSrcFloat)
426 return emitOpError(
"expected float sources with float destination");
427 if (!isDestFloat && isSrcFloat)
428 return emitOpError(
"expected int sources with int destination");
430 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
432 "source element types must match (except for fp8/bf8) but have ")
433 << sourceAType <<
" and " << sourceBType;
438 return emitOpError(
"clamp flag is not supported for float types");
439 if (getUnsignedA() || getUnsignedB())
440 return emitOpError(
"unsigned flags are not supported for float types");
448LogicalResult MFMAOp::verify() {
449 constexpr uint32_t waveSize = 64;
452 Type sourceType = getSourceA().getType();
453 Type destType = getDestC().getType();
455 Type sourceElem = sourceType, destElem = destType;
456 uint32_t sourceLen = 1, destLen = 1;
457 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
458 sourceLen = sourceVector.getNumElements();
459 sourceElem = sourceVector.getElementType();
461 if (
auto destVector = dyn_cast<VectorType>(destType)) {
462 destLen = destVector.getNumElements();
463 destElem = destVector.getElementType();
466 Type sourceBType = getSourceB().getType();
469 Type sourceBElem = sourceBType;
470 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
471 sourceBLen = sourceBVector.getNumElements();
472 sourceBElem = sourceBVector.getElementType();
476 return emitOpError(
"expected both source operands to have small-float "
477 "elements if one does");
478 if (sourceLen != sourceBLen)
480 "expected both small-float source vectors to have the same length");
482 if (sourceType != sourceBType)
483 return emitOpError(
"expected both non-small-float source operand types "
489 sourceElem =
b.getI8Type();
493 sourceElem =
b.getI8Type();
496 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
497 if (sourceLen != numSourceElems)
498 return emitOpError(
"expected " + Twine(numSourceElems) +
499 " source values for this operation but got " +
503 if (destLen != numDestElems)
504 return emitOpError(
"expected " + Twine(numDestElems) +
505 " result values for this operation but got " +
508 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
510 "double-precision ops do not support permuting lanes of B");
511 if (destElem.isF64() && getCbsz() != 0)
513 "double-precision ops do not support permuting lanes of A");
514 if (getAbid() >= (1u << getCbsz()))
516 "block ID for permuting A (abid) must be below 2 ** cbsz");
518 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
520 "negation flags only available for double-precision operations");
528LogicalResult DPPOp::verify() {
529 Type srcType = getSrc().getType();
531 return emitOpError(
"integer and floating point types larger than 64 bits "
532 "are not supported");
535 DPPPerm kind = getKind();
540 case DPPPerm::quad_perm: {
541 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
542 if (!quadPermAttr || quadPermAttr.size() != 4) {
543 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
545 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
546 int32_t num = elem.getInt();
547 if (num < 0 || num > 3) {
549 "Each element of quad_perm must be in the range [0, 3]");
554 case DPPPerm::row_shl:
555 case DPPPerm::row_shr:
556 case DPPPerm::row_ror: {
558 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
559 "' value not specified");
561 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
562 uint32_t attrValue = intAttr.getInt();
563 if (attrValue < 1 || attrValue > 15) {
564 return emitOpError(
"Attribute value must be between 1 and 15");
569 case DPPPerm::wave_shl:
570 case DPPPerm::wave_shr:
571 case DPPPerm::wave_rol:
572 case DPPPerm::wave_ror:
573 case DPPPerm::row_mirror:
574 case DPPPerm::row_half_mirror:
575 case DPPPerm::row_bcast_15:
576 case DPPPerm::row_bcast_31: {
577 if (permArgument && !isa<UnitAttr>(permArgument)) {
578 return emitOpError(
"Expected unit attribute for permArgument, but found "
579 "non-trivial argument");
590LogicalResult PermlaneSwapOp::verify() {
591 unsigned rowLength = getRowLength();
593 if (rowLength != 16 && rowLength != 32)
594 return emitOpError(
"row_length attribute must either be 16 or 32.");
606struct FuseMemoryCounterWaitOp final :
OpRewritePattern<MemoryCounterWaitOp> {
609 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
610 PatternRewriter &rewriter)
const override {
611 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
615 auto setters = {&MemoryCounterWaitOp::setLoad,
616 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
617 &MemoryCounterWaitOp::setExp,
618 &MemoryCounterWaitOp::setTensor};
619 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
621 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
622 next.getExp(), next.getTensor()};
624 for (
auto [setter,
lhs,
rhs] :
625 llvm::zip_equal(setters, lhsVals, rhsVals)) {
627 (op.*setter)(std::min(*
lhs, *
rhs));
641void MemoryCounterWaitOp::getCanonicalizationPatterns(
643 results.
add<FuseMemoryCounterWaitOp>(context);
650LogicalResult GatherToLDSOp::verify() {
651 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
652 MemRefType dstType = cast<MemRefType>(getDst().
getType());
654 if (!dstType.areTrailingDimsContiguous(1))
655 return emitOpError(
"destination type inner most dim must be contiguous");
657 auto elemType = srcType.getElementType();
659 if (elemType != dstType.getElementType())
660 return emitOpError(
"source and destination element types must match");
663 auto transferType = getTransferType();
665 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
666 transferSize = vectorTransfer.getNumElements() *
667 vectorTransfer.getElementTypeBitWidth();
669 transferSize = transferType.getIntOrFloatBitWidth();
671 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
673 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
678 "source memory address space must be global or fat raw buffer");
681 return emitOpError(
"destination memory address space must be Workgroup");
692 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
693 PatternRewriter &rewriter)
const override {
694 bool modified =
false;
695 auto foldCast = [&](OpOperand &operand) {
696 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
697 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
699 [&] { operand.assign(castOp.getSource()); });
705 foldCast(gatherOp.getSrcMutable());
706 foldCast(gatherOp.getDstMutable());
715 results.
add<FoldGatherToLDSOfCast>(context);
722LogicalResult TransposeLoadOp::verify() {
723 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
726 return emitOpError(
"source memory address space must be Workgroup");
728 auto transferType = cast<VectorType>(
getType());
729 size_t numElements = transferType.getNumElements();
730 size_t elementTypeSize =
731 transferType.getElementType().getIntOrFloatBitWidth();
734 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
741 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
742 if (validNumElems == kValidLoadSizeMap.end())
743 return emitOpError(
"Unsupported element type size for transpose load: ")
744 << elementTypeSize <<
" bits";
746 if (numElements != validNumElems->second)
748 "Transferring type size mismatch: expected num of elements: ")
749 << validNumElems->second;
758LogicalResult MakeDmaBaseOp::verify() {
760 auto ldsType = cast<MemRefType>(getLds().
getType());
761 auto globalType = cast<MemRefType>(getGlobal().
getType());
764 "lds memref must have workgroup address space attribute.");
767 "global memref must have global address space attribute.");
769 Type elementType = ldsType.getElementType();
772 if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, width))
774 "element type must be 1, 2, 4, or 8 bytes long but type was ")
775 << width <<
" bits long.";
784LogicalResult MakeDmaDescriptorOp::verify() {
787 if (globalStaticStrides.empty())
789 if (globalStaticStrides.back() != 1)
790 return emitOpError(
"strides for the innermost dimension must be 1.");
793 size_t rank = globalStaticSizes.size();
795 return emitOpError(
"tensor and tile must be at most of rank 5.");
796 if (rank != globalStaticStrides.size())
797 return emitOpError(
"strides and sizes must have same rank.");
800 if (rank != sharedStaticSizes.size())
801 return emitOpError(
"tensor must have same rank as tile.");
803 unsigned elementTypeWidth = getElementTypeWidth();
804 if (!llvm::is_contained<unsigned>({8, 16, 32, 64}, elementTypeWidth))
806 "element type width must be 1, 2, 4 or 8 bytes, but was ")
807 << elementTypeWidth <<
" bits long";
809 if (
Value atomicBarrierAddress = getAtomicBarrierAddress()) {
810 auto atomicBarrierAddressType =
811 cast<MemRefType>(atomicBarrierAddress.getType());
815 return emitOpError(
"atomic barrier address must be in LDS.");
818 if (getEarlyTimeout() && !getWorkgroupMask())
820 "early timeout does not apply when workgroup_mask is not set.");
824OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
844 setGlobalStaticSizes(staticGlobalSizes);
845 getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
848 staticGlobalStrides);
849 setGlobalStaticStrides(staticGlobalStrides);
850 getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
854 setSharedStaticSizes(staticSharedSizes);
855 getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
869 LogicalResult matchAndRewrite(ScaledMFMAOp op,
870 PatternRewriter &rewriter)
const override {
871 Location loc = op.getLoc();
872 auto setOpsel = [&op](
unsigned idx, int64_t val) {
875 op.setScalesIdxA(val);
878 op.setScalesIdxB(val);
902 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
903 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
906 "defining op not a vector.insert");
909 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
911 op,
"scaled mfma operand already packed");
915 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
918 "defining op not a vector.extract");
921 Value scaleSrc = extractOp.getOperand(0);
922 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
929 if (!scaleSrcType.hasStaticShape()) {
931 "dynamic dims not yet supported");
934 int64_t numElements = scaleSrcType.getNumElements();
935 if (numElements <= 4) {
937 op,
"no packing if # of scales less than four");
941 auto extractedPos = llvm::to_vector_of<int64_t>(
942 llvm::reverse(extractOp.getStaticPosition()));
943 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
944 int64_t scaleSrcRank = scaleSrcType.getRank();
945 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
946 for (int64_t i = 1; i < scaleSrcRank; ++i) {
947 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
949 int64_t idx =
linearize(extractedPos, extractSizes);
961 int64_t offset = idx - (idx % 4);
962 int64_t opsel = idx - offset;
965 if (numElements - offset < size) {
966 opsel = size - (numElements - idx);
967 offset = numElements - 4l;
969 Type scaleSrcElemType = scaleSrcType.getElementType();
971 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
973 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
974 auto extract = vector::ExtractStridedSliceOp::create(
975 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
976 ArrayRef{int64_t(1)});
978 op->setOperand(opIdx, extract);
979 setOpsel(opIdx, opsel);
989 results.
add<PackScales>(context);
992#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
994#define GET_ATTRDEF_CLASSES
995#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
997#define GET_TYPEDEF_CLASSES
998#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
1000#define GET_OP_CLASSES
1001#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...