31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/TypeSwitch.h"
43#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
46struct AMDGPUInlinerInterface final : DialectInlinerInterface {
47 using DialectInlinerInterface::DialectInlinerInterface;
48 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
54void AMDGPUDialect::initialize() {
57#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
60#define GET_TYPEDEF_LIST
61#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
64#define GET_ATTRDEF_LIST
65#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
67 addInterfaces<AMDGPUInlinerInterface>();
73LogicalResult PackedTrunc2xFp8Op::verify() {
74 if (getExisting() && getExisting().
getType() != getResult().
getType())
75 return emitOpError(
"existing values must have same type as result");
79LogicalResult PackedStochRoundFp8Op::verify() {
80 if (getExisting() && getExisting().
getType() != getResult().
getType())
81 return emitOpError(
"existing values must have same type as result");
88LogicalResult PackedScaledTruncOp::verify() {
89 if (getExisting() && getExisting().
getType() != getResult().
getType())
90 return emitOpError(
"existing values must have same type as result");
107 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
108 MemRefLayoutAttrInterface layout = source.getLayout();
109 if (resetOffset && !layout.isIdentity()) {
110 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
113 MemRefLayoutAttrInterface newLayout =
114 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
119 if (source.hasStaticShape()) {
121 }
else if (source.getRank() <= 1) {
124 if (stridesIfIdentity == stridedLayout.getStrides()) {
125 newLayout = AffineMapAttr::get(
130 return (MemRefType)(mb);
133LogicalResult FatRawBufferCastOp::inferReturnTypes(
137 Adaptor adaptor(operands, attributes, properties, regions);
139 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
142 FailureOr<MemRefType> resultType =
150FailureOr<OpFoldResult> FatRawBufferCastOp::reifyDimOfResult(
OpBuilder &builder,
153 assert(resultIndex == 0 &&
"FatRawBufferCastOp has a single result");
154 Value source = getSource();
155 auto sourceType = cast<MemRefType>(source.
getType());
156 if (sourceType.isDynamicDim(dim))
158 builder.
createOrFold<memref::DimOp>(getLoc(), source, dim));
162LogicalResult FatRawBufferCastOp::verify() {
163 FailureOr<MemRefType> expectedResultType =
165 if (
failed(expectedResultType))
167 << getSource().getType() <<
" can't have its offset reset";
168 if (getResult().
getType() != *expectedResultType)
170 << *expectedResultType <<
" but got " << getResult().getType();
177 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
178 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
179 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
180 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
187 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
188 return intMemorySpace.getInt() == 3;
189 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
190 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
197 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
198 return intMemorySpace.getInt() == 7;
199 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
200 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
209 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
213 return op.emitOpError(
214 "Buffer ops must operate on a memref in global memory");
215 if (!bufferType.hasRank())
216 return op.emitOpError(
217 "Cannot meaningfully buffer_store to an unranked memref");
218 if (
static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
219 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
220 " indices to memref");
228LogicalResult RawBufferAtomicFaddOp::verify() {
232LogicalResult RawBufferAtomicFmaxOp::verify() {
236LogicalResult RawBufferAtomicSmaxOp::verify() {
240LogicalResult RawBufferAtomicUminOp::verify() {
244LogicalResult RawBufferAtomicCmpswapOp::verify() {
253 return cst.getZExtValue();
257template <
typename OpType>
259 if (!op.getBoundsCheck())
261 MemRefType bufferType = op.getMemref().getType();
262 if (!bufferType.hasStaticShape())
266 if (failed(bufferType.getStridesAndOffset(strides, offset)))
269 if (op.getSgprOffset()) {
275 if (strides.size() != op.getIndices().size())
278 for (
auto pair : llvm::zip(strides, op.getIndices())) {
279 int64_t stride = std::get<0>(pair);
280 Value idx = std::get<1>(pair);
284 indexVal += stride * *idxVal;
287 if (
result > std::numeric_limits<uint32_t>::max())
290 return result >= bufferType.getNumElements();
294template <
typename OpType>
295struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
296 using OpRewritePattern<OpType>::OpRewritePattern;
298 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
301 Type loadType = op.getResult().getType();
308template <
typename OpType>
309struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
310 using OpRewritePattern<OpType>::OpRewritePattern;
312 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
324 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
329 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
332void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
334 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
337void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
339 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
342void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
344 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
347void RawBufferAtomicUminOp::getCanonicalizationPatterns(
349 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
352void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
354 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
361LogicalResult ScaledExtPackedMatrixOp::verify() {
363 assert(llvm::is_contained({16, 32}, blockSize) &&
"invalid block size");
365 int firstScaleByte = getFirstScaleByte();
366 int firstScaleLane = getFirstScaleLane();
367 auto sourceType = cast<VectorType>(getSource().
getType());
368 Type elementType = sourceType.getElementType();
369 auto floatType = cast<FloatType>(elementType);
370 unsigned bitWidth = floatType.getWidth();
374 const bool is_fp8 = bitWidth == 8;
375 const bool is_block_16 = blockSize == 16;
379 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
380 return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 "
381 "or 1 for f4 and f6.");
384 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
385 return emitOpError(
"blockSize of 32 can only have firstScaleByte be 0 "
386 "or 2 for f4 and f6.");
391 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
392 ((firstScaleLane == 16) && (firstScaleByte == 2));
394 return emitOpError(
"blockSize of 16 can only have (firstScaleLane, "
395 "firstScaleByte) be (0, 0) or (16, 2) for f8.");
408 IntegerAttr &m, IntegerAttr &n,
413 if (dimensions.size() != 3)
415 <<
"expected 3 dimensions in MNK dimension list";
423LogicalResult WMMAOp::verify() {
424 auto sourceAType = cast<VectorType>(getSourceA().
getType());
425 auto sourceBType = cast<VectorType>(getSourceB().
getType());
426 auto destType = cast<VectorType>(getDestC().
getType());
428 Type sourceAElemType = sourceAType.getElementType();
429 Type sourceBElemType = sourceBType.getElementType();
430 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
431 return emitOpError(
"source vectors have different lengths: ")
432 << sourceAType <<
" vs. " << sourceBType;
435 bool isDestFloat = destType.getElementType().
isFloat();
436 bool isSrcFloat = sourceAElemType.
isFloat();
438 if (isDestFloat && !isSrcFloat)
439 return emitOpError(
"expected float sources with float destination");
440 if (!isDestFloat && isSrcFloat)
441 return emitOpError(
"expected int sources with int destination");
443 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
445 "source element types must match (except for fp8/bf8) but have ")
446 << sourceAType <<
" and " << sourceBType;
451 return emitOpError(
"clamp flag is not supported for float types");
452 if (getUnsignedA() || getUnsignedB())
453 return emitOpError(
"unsigned flags are not supported for float types");
461LogicalResult MFMAOp::verify() {
462 constexpr uint32_t waveSize = 64;
465 Type sourceType = getSourceA().getType();
466 Type destType = getDestC().getType();
468 Type sourceElem = sourceType, destElem = destType;
469 uint32_t sourceLen = 1, destLen = 1;
470 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
471 sourceLen = sourceVector.getNumElements();
472 sourceElem = sourceVector.getElementType();
474 if (
auto destVector = dyn_cast<VectorType>(destType)) {
475 destLen = destVector.getNumElements();
476 destElem = destVector.getElementType();
479 Type sourceBType = getSourceB().getType();
482 Type sourceBElem = sourceBType;
483 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
484 sourceBLen = sourceBVector.getNumElements();
485 sourceBElem = sourceBVector.getElementType();
489 return emitOpError(
"expected both source operands to have small-float "
490 "elements if one does");
491 if (sourceLen != sourceBLen)
493 "expected both small-float source vectors to have the same length");
495 if (sourceType != sourceBType)
496 return emitOpError(
"expected both non-small-float source operand types "
502 sourceElem =
b.getI8Type();
506 sourceElem =
b.getI8Type();
509 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
510 if (sourceLen != numSourceElems)
511 return emitOpError(
"expected " + Twine(numSourceElems) +
512 " source values for this operation but got " +
516 if (destLen != numDestElems)
517 return emitOpError(
"expected " + Twine(numDestElems) +
518 " result values for this operation but got " +
521 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
523 "double-precision ops do not support permuting lanes of B");
524 if (destElem.isF64() && getCbsz() != 0)
526 "double-precision ops do not support permuting lanes of A");
527 if (getAbid() >= (1u << getCbsz()))
529 "block ID for permuting A (abid) must be below 2 ** cbsz");
531 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
533 "negation flags only available for double-precision operations");
541LogicalResult DPPOp::verify() {
542 Type srcType = getSrc().getType();
544 return emitOpError(
"integer and floating point types larger than 64 bits "
545 "are not supported");
548 DPPPerm kind = getKind();
553 case DPPPerm::quad_perm: {
554 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
555 if (!quadPermAttr || quadPermAttr.size() != 4) {
556 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
558 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
559 int32_t num = elem.getInt();
560 if (num < 0 || num > 3) {
562 "Each element of quad_perm must be in the range [0, 3]");
567 case DPPPerm::row_shl:
568 case DPPPerm::row_shr:
569 case DPPPerm::row_ror: {
571 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
572 "' value not specified");
574 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
575 uint32_t attrValue = intAttr.getInt();
576 if (attrValue < 1 || attrValue > 15) {
577 return emitOpError(
"Attribute value must be between 1 and 15");
582 case DPPPerm::wave_shl:
583 case DPPPerm::wave_shr:
584 case DPPPerm::wave_rol:
585 case DPPPerm::wave_ror:
586 case DPPPerm::row_mirror:
587 case DPPPerm::row_half_mirror:
588 case DPPPerm::row_bcast_15:
589 case DPPPerm::row_bcast_31: {
590 if (permArgument && !isa<UnitAttr>(permArgument)) {
591 return emitOpError(
"Expected unit attribute for permArgument, but found "
592 "non-trivial argument");
603LogicalResult PermlaneSwapOp::verify() {
604 unsigned rowLength = getRowLength();
606 if (rowLength != 16 && rowLength != 32)
607 return emitOpError(
"row_length attribute must either be 16 or 32.");
619struct FuseMemoryCounterWaitOp final :
OpRewritePattern<MemoryCounterWaitOp> {
622 LogicalResult matchAndRewrite(MemoryCounterWaitOp op,
623 PatternRewriter &rewriter)
const override {
624 auto next = dyn_cast<MemoryCounterWaitOp>(op->getNextNode());
628 auto setters = {&MemoryCounterWaitOp::setLoad,
629 &MemoryCounterWaitOp::setStore, &MemoryCounterWaitOp::setDs,
630 &MemoryCounterWaitOp::setExp,
631 &MemoryCounterWaitOp::setTensor};
632 auto lhsVals = {op.getLoad(), op.getStore(), op.getDs(), op.getExp(),
634 auto rhsVals = {next.getLoad(), next.getStore(), next.getDs(),
635 next.getExp(), next.getTensor()};
637 for (
auto [setter,
lhs,
rhs] :
638 llvm::zip_equal(setters, lhsVals, rhsVals)) {
640 (op.*setter)(std::min(*
lhs, *
rhs));
654void MemoryCounterWaitOp::getCanonicalizationPatterns(
656 results.
add<FuseMemoryCounterWaitOp>(context);
663LogicalResult GatherToLDSOp::verify() {
664 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
665 MemRefType dstType = cast<MemRefType>(getDst().
getType());
667 if (!dstType.areTrailingDimsContiguous(1))
668 return emitOpError(
"destination type inner most dim must be contiguous");
670 auto elemType = srcType.getElementType();
672 if (elemType != dstType.getElementType())
673 return emitOpError(
"source and destination element types must match");
676 auto transferType = getTransferType();
678 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
679 transferSize = vectorTransfer.getNumElements() *
680 vectorTransfer.getElementTypeBitWidth();
682 transferSize = transferType.getIntOrFloatBitWidth();
684 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
686 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
691 "source memory address space must be global or fat raw buffer");
694 return emitOpError(
"destination memory address space must be Workgroup");
705 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
706 PatternRewriter &rewriter)
const override {
707 bool modified =
false;
708 auto foldCast = [&](OpOperand &operand) {
709 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
710 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
712 [&] { operand.assign(castOp.getSource()); });
718 foldCast(gatherOp.getSrcMutable());
719 foldCast(gatherOp.getDstMutable());
728 results.
add<FoldGatherToLDSOfCast>(context);
735LogicalResult TransposeLoadOp::verify() {
736 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
739 return emitOpError(
"source memory address space must be Workgroup");
741 auto transferType = cast<VectorType>(
getType());
742 size_t numElements = transferType.getNumElements();
743 size_t elementTypeSize =
744 transferType.getElementType().getIntOrFloatBitWidth();
747 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
754 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
755 if (validNumElems == kValidLoadSizeMap.end())
756 return emitOpError(
"Unsupported element type size for transpose load: ")
757 << elementTypeSize <<
" bits";
759 if (numElements != validNumElems->second)
761 "Transferring type size mismatch: expected num of elements: ")
762 << validNumElems->second;
771template <
typename BaseOp>
773 auto ldsType = cast<MemRefType>(op.getLds().getType());
774 auto globalType = cast<MemRefType>(op.getGlobal().getType());
776 return op.emitOpError(
777 "lds memref must have workgroup address space attribute.");
779 return op.emitOpError(
780 "global memref must have global address space attribute.");
782 Type elementType = ldsType.getElementType();
785 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
786 return op.emitOpError(
787 "element type must be 1, 2, 4, or 8 bytes long but type was ")
788 << width <<
" bits long.";
792LogicalResult MakeDmaBaseOp::verify() {
return verifyBase(*
this); }
802 if (!llvm::is_contained({8u, 16u, 32u, 64u}, width))
804 <<
"element type must be 1, 2, 4, or 8 bytes wide but type "
805 << elementType <<
" is " << width / 8 <<
" bytes wide.";
807 Type i16 = IntegerType::get(ctx, 32);
808 Type i32 = IntegerType::get(ctx, 16);
809 if (!llvm::is_contained({i16, i32}, indexType))
810 return emitError() <<
"index type must be i16 or i32 but index type is "
815LogicalResult MakeGatherDmaBaseOp::verify() {
return verifyBase(*
this); }
821LogicalResult MakeDmaDescriptorOp::verify() {
824 if (globalStaticStrides.empty())
826 if (globalStaticStrides.back() != 1)
827 return emitOpError(
"strides for the innermost dimension must be 1.");
830 size_t rank = globalStaticSizes.size();
832 return emitOpError(
"tensor and tile must be at most of rank 5.");
833 if (rank != globalStaticStrides.size())
834 return emitOpError(
"strides and sizes must have same rank.");
837 if (rank != sharedStaticSizes.size())
838 return emitOpError(
"tensor must have same rank as tile.");
840 unsigned elementTypeWidth = getElementTypeWidth();
841 if (!llvm::is_contained({8u, 16u, 32u, 64u}, elementTypeWidth))
843 "element type width must be 1, 2, 4 or 8 bytes, but was ")
844 << elementTypeWidth <<
" bits long";
846 if (
Value atomicBarrierAddress = getAtomicBarrierAddress()) {
847 auto atomicBarrierAddressType =
848 cast<MemRefType>(atomicBarrierAddress.getType());
852 return emitOpError(
"atomic barrier address must be in LDS.");
855 if (getEarlyTimeout() && !getWorkgroupMask())
857 "early timeout does not apply when workgroup_mask is not set.");
861OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
881 setGlobalStaticSizes(staticGlobalSizes);
882 getGlobalDynamicSizesMutable().assign(dynamicGlobalSizes);
885 staticGlobalStrides);
886 setGlobalStaticStrides(staticGlobalStrides);
887 getGlobalDynamicStridesMutable().assign(dynamicGlobalStrides);
891 setSharedStaticSizes(staticSharedSizes);
892 getSharedDynamicSizesMutable().assign(dynamicSharedSizes);
906 LogicalResult matchAndRewrite(ScaledMFMAOp op,
907 PatternRewriter &rewriter)
const override {
908 Location loc = op.getLoc();
909 auto setOpsel = [&op](
unsigned idx, int64_t val) {
912 op.setScalesIdxA(val);
915 op.setScalesIdxB(val);
939 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
940 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
943 "defining op not a vector.insert");
946 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
948 op,
"scaled mfma operand already packed");
952 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
955 "defining op not a vector.extract");
958 Value scaleSrc = extractOp.getOperand(0);
959 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
966 if (!scaleSrcType.hasStaticShape()) {
968 "dynamic dims not yet supported");
971 int64_t numElements = scaleSrcType.getNumElements();
972 if (numElements <= 4) {
974 op,
"no packing if # of scales less than four");
978 auto extractedPos = llvm::to_vector_of<int64_t>(
979 llvm::reverse(extractOp.getStaticPosition()));
980 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
981 int64_t scaleSrcRank = scaleSrcType.getRank();
982 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
983 for (int64_t i = 1; i < scaleSrcRank; ++i) {
984 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
986 int64_t idx =
linearize(extractedPos, extractSizes);
998 int64_t offset = idx - (idx % 4);
999 int64_t opsel = idx - offset;
1002 if (numElements - offset < size) {
1003 opsel = size - (numElements - idx);
1004 offset = numElements - 4l;
1006 Type scaleSrcElemType = scaleSrcType.getElementType();
1008 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
1010 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
1011 auto extract = vector::ExtractStridedSliceOp::create(
1012 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
1013 ArrayRef{int64_t(1)});
1015 op->setOperand(opIdx, extract);
1016 setOpsel(opIdx, opsel);
1026 results.
add<PackScales>(context);
1029#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
1031#define GET_ATTRDEF_CLASSES
1032#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
1034#define GET_TYPEDEF_CLASSES
1035#include "mlir/Dialect/AMDGPU/IR/AMDGPUTypes.cpp.inc"
1037#define GET_OP_CLASSES
1038#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static LogicalResult verifyBase(BaseOp op)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
llvm::function_ref< Fn > function_ref
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs, bool onlyNonNegative=false, bool onlyNonZero=false)
Returns "success" when any of the elements in ofrs is a constant value.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...