30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
42#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
47 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
53void AMDGPUDialect::initialize() {
56#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
59#define GET_ATTRDEF_LIST
60#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
62 addInterfaces<AMDGPUInlinerInterface>();
68LogicalResult PackedTrunc2xFp8Op::verify() {
69 if (getExisting() && getExisting().
getType() != getResult().
getType())
70 return emitOpError(
"existing values must have same type as result");
74LogicalResult PackedStochRoundFp8Op::verify() {
75 if (getExisting() && getExisting().
getType() != getResult().
getType())
76 return emitOpError(
"existing values must have same type as result");
83LogicalResult PackedScaledTruncOp::verify() {
84 if (getExisting() && getExisting().
getType() != getResult().
getType())
85 return emitOpError(
"existing values must have same type as result");
102 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
103 MemRefLayoutAttrInterface layout = source.getLayout();
104 if (resetOffset && !layout.isIdentity()) {
105 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
108 MemRefLayoutAttrInterface newLayout =
109 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
114 if (source.hasStaticShape()) {
116 }
else if (source.getRank() <= 1) {
119 if (stridesIfIdentity == stridedLayout.getStrides()) {
120 newLayout = AffineMapAttr::get(
125 return (MemRefType)(mb);
128LogicalResult FatRawBufferCastOp::inferReturnTypes(
132 Adaptor adaptor(operands, attributes, properties, regions);
134 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
137 FailureOr<MemRefType> resultType =
145LogicalResult FatRawBufferCastOp::verify() {
146 FailureOr<MemRefType> expectedResultType =
148 if (
failed(expectedResultType))
150 << getSource().getType() <<
" can't have its offset reset";
151 if (getResult().
getType() != *expectedResultType)
153 << *expectedResultType <<
" but got " << getResult().getType();
160 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
161 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
162 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
163 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
170 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
171 return intMemorySpace.getInt() == 3;
172 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
173 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
180 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
181 return intMemorySpace.getInt() == 7;
182 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
183 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
192 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
196 return op.emitOpError(
197 "Buffer ops must operate on a memref in global memory");
198 if (!bufferType.hasRank())
199 return op.emitOpError(
200 "Cannot meaningfully buffer_store to an unranked memref");
201 if (
static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
202 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
203 " indices to memref");
211LogicalResult RawBufferAtomicFaddOp::verify() {
215LogicalResult RawBufferAtomicFmaxOp::verify() {
219LogicalResult RawBufferAtomicSmaxOp::verify() {
223LogicalResult RawBufferAtomicUminOp::verify() {
227LogicalResult RawBufferAtomicCmpswapOp::verify() {
236 return cst.getZExtValue();
240template <
typename OpType>
242 if (!op.getBoundsCheck())
244 MemRefType bufferType = op.getMemref().getType();
245 if (!bufferType.hasStaticShape())
249 if (failed(bufferType.getStridesAndOffset(strides, offset)))
252 if (op.getSgprOffset()) {
258 if (strides.size() != op.getIndices().size())
261 for (
auto pair : llvm::zip(strides, op.getIndices())) {
262 int64_t stride = std::get<0>(pair);
263 Value idx = std::get<1>(pair);
267 indexVal += stride * *idxVal;
270 if (
result > std::numeric_limits<uint32_t>::max())
273 return result >= bufferType.getNumElements();
277template <
typename OpType>
278struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
279 using OpRewritePattern<OpType>::OpRewritePattern;
281 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
284 Type loadType = op.getResult().getType();
291template <
typename OpType>
292struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
293 using OpRewritePattern<OpType>::OpRewritePattern;
295 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
307 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
312 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
315void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
317 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
320void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
322 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
325void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
327 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
330void RawBufferAtomicUminOp::getCanonicalizationPatterns(
332 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
335void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
337 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
344LogicalResult ScaledExtPacked816Op::verify() {
346 assert((blockSize == 16 || blockSize == 32) &&
"invalid block size");
347 int firstScaleByte = getFirstScaleByte();
348 if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
350 "blockSize of 16 can only have firstScaleByte be 0 or 1.");
352 if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
354 "blockSize of 32 can only have firstScaleByte be 0 or 2.");
365 IntegerAttr &m, IntegerAttr &n,
370 if (dimensions.size() != 3)
372 <<
"expected 3 dimensions in MNK dimension list";
380LogicalResult WMMAOp::verify() {
381 auto sourceAType = cast<VectorType>(getSourceA().
getType());
382 auto sourceBType = cast<VectorType>(getSourceB().
getType());
383 auto destType = cast<VectorType>(getDestC().
getType());
385 Type sourceAElemType = sourceAType.getElementType();
386 Type sourceBElemType = sourceBType.getElementType();
387 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
388 return emitOpError(
"source vectors have different lengths: ")
389 << sourceAType <<
" vs. " << sourceBType;
392 bool isDestFloat = destType.getElementType().
isFloat();
393 bool isSrcFloat = sourceAElemType.
isFloat();
395 if (isDestFloat && !isSrcFloat)
396 return emitOpError(
"expected float sources with float destination");
397 if (!isDestFloat && isSrcFloat)
398 return emitOpError(
"expected int sources with int destination");
400 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
402 "source element types must match (except for fp8/bf8) but have ")
403 << sourceAType <<
" and " << sourceBType;
408 return emitOpError(
"clamp flag is not supported for float types");
409 if (getUnsignedA() || getUnsignedB())
410 return emitOpError(
"unsigned flags are not supported for float types");
418LogicalResult MFMAOp::verify() {
419 constexpr uint32_t waveSize = 64;
422 Type sourceType = getSourceA().getType();
423 Type destType = getDestC().getType();
425 Type sourceElem = sourceType, destElem = destType;
426 uint32_t sourceLen = 1, destLen = 1;
427 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
428 sourceLen = sourceVector.getNumElements();
429 sourceElem = sourceVector.getElementType();
431 if (
auto destVector = dyn_cast<VectorType>(destType)) {
432 destLen = destVector.getNumElements();
433 destElem = destVector.getElementType();
436 Type sourceBType = getSourceB().getType();
439 Type sourceBElem = sourceBType;
440 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
441 sourceBLen = sourceBVector.getNumElements();
442 sourceBElem = sourceBVector.getElementType();
446 return emitOpError(
"expected both source operands to have small-float "
447 "elements if one does");
448 if (sourceLen != sourceBLen)
450 "expected both small-float source vectors to have the same length");
452 if (sourceType != sourceBType)
453 return emitOpError(
"expected both non-small-float source operand types "
459 sourceElem =
b.getI8Type();
463 sourceElem =
b.getI8Type();
466 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
467 if (sourceLen != numSourceElems)
468 return emitOpError(
"expected " + Twine(numSourceElems) +
469 " source values for this operation but got " +
473 if (destLen != numDestElems)
474 return emitOpError(
"expected " + Twine(numDestElems) +
475 " result values for this operation but got " +
478 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
480 "double-precision ops do not support permuting lanes of B");
481 if (destElem.isF64() && getCbsz() != 0)
483 "double-precision ops do not support permuting lanes of A");
484 if (getAbid() >= (1u << getCbsz()))
486 "block ID for permuting A (abid) must be below 2 ** cbsz");
488 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
490 "negation flags only available for double-precision operations");
498LogicalResult DPPOp::verify() {
499 Type srcType = getSrc().getType();
501 return emitOpError(
"integer and floating point types larger than 64 bits "
502 "are not supported");
505 DPPPerm kind = getKind();
510 case DPPPerm::quad_perm: {
511 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
512 if (!quadPermAttr || quadPermAttr.size() != 4) {
513 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
515 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
516 int32_t num = elem.getInt();
517 if (num < 0 || num > 3) {
519 "Each element of quad_perm must be in the range [0, 3]");
524 case DPPPerm::row_shl:
525 case DPPPerm::row_shr:
526 case DPPPerm::row_ror: {
528 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
529 "' value not specified");
531 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
532 uint32_t attrValue = intAttr.getInt();
533 if (attrValue < 1 || attrValue > 15) {
534 return emitOpError(
"Attribute value must be between 1 and 15");
539 case DPPPerm::wave_shl:
540 case DPPPerm::wave_shr:
541 case DPPPerm::wave_rol:
542 case DPPPerm::wave_ror:
543 case DPPPerm::row_mirror:
544 case DPPPerm::row_half_mirror:
545 case DPPPerm::row_bcast_15:
546 case DPPPerm::row_bcast_31: {
547 if (permArgument && !isa<UnitAttr>(permArgument)) {
548 return emitOpError(
"Expected unit attribute for permArgument, but found "
549 "non-trivial argument");
560LogicalResult PermlaneSwapOp::verify() {
561 unsigned rowLength = getRowLength();
563 if (rowLength != 16 && rowLength != 32)
564 return emitOpError(
"row_length attribute must either be 16 or 32.");
573LogicalResult GatherToLDSOp::verify() {
574 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
575 MemRefType dstType = cast<MemRefType>(getDst().
getType());
577 if (!dstType.areTrailingDimsContiguous(1))
578 return emitOpError(
"destination type inner most dim must be contiguous");
580 auto elemType = srcType.getElementType();
582 if (elemType != dstType.getElementType())
583 return emitOpError(
"source and destination element types must match");
586 auto transferType = getTransferType();
588 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
589 transferSize = vectorTransfer.getNumElements() *
590 vectorTransfer.getElementTypeBitWidth();
592 transferSize = transferType.getIntOrFloatBitWidth();
594 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
596 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
601 "source memory address space must be global or fat raw buffer");
604 return emitOpError(
"destination memory address space must be Workgroup");
615 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
616 PatternRewriter &rewriter)
const override {
617 bool modified =
false;
618 auto foldCast = [&](OpOperand &operand) {
619 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
620 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
622 [&] { operand.assign(castOp.getSource()); });
628 foldCast(gatherOp.getSrcMutable());
629 foldCast(gatherOp.getDstMutable());
638 results.
add<FoldGatherToLDSOfCast>(context);
645LogicalResult TransposeLoadOp::verify() {
646 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
649 return emitOpError(
"source memory address space must be Workgroup");
651 auto transferType = cast<VectorType>(
getType());
652 size_t numElements = transferType.getNumElements();
653 size_t elementTypeSize =
654 transferType.getElementType().getIntOrFloatBitWidth();
657 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
664 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
665 if (validNumElems == kValidLoadSizeMap.end()) {
666 return emitOpError(
"Unsupported element type size for transpose load: ")
667 << elementTypeSize <<
" bits";
669 if (numElements != validNumElems->second) {
671 "Transferring type size mismatch: expected num of elements: ")
672 << validNumElems->second;
688 LogicalResult matchAndRewrite(ScaledMFMAOp op,
689 PatternRewriter &rewriter)
const override {
690 Location loc = op.getLoc();
691 auto setOpsel = [&op](
unsigned idx, int64_t val) {
694 op.setScalesIdxA(val);
697 op.setScalesIdxB(val);
721 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
722 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
725 "defining op not a vector.insert");
728 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
730 op,
"scaled mfma operand already packed");
734 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
737 "defining op not a vector.extract");
740 Value scaleSrc = extractOp.getOperand(0);
741 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
748 if (!scaleSrcType.hasStaticShape()) {
750 "dynamic dims not yet supported");
753 int64_t numElements = scaleSrcType.getNumElements();
754 if (numElements <= 4) {
756 op,
"no packing if # of scales less than four");
760 auto extractedPos = llvm::to_vector_of<int64_t>(
761 llvm::reverse(extractOp.getStaticPosition()));
762 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
763 int64_t scaleSrcRank = scaleSrcType.getRank();
764 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
765 for (int64_t i = 1; i < scaleSrcRank; ++i) {
766 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
768 int64_t idx =
linearize(extractedPos, extractSizes);
780 int64_t offset = idx - (idx % 4);
781 int64_t opsel = idx - offset;
784 if (numElements - offset < size) {
785 opsel = size - (numElements - idx);
786 offset = numElements - 4l;
788 Type scaleSrcElemType = scaleSrcType.getElementType();
790 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
792 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
793 auto extract = vector::ExtractStridedSliceOp::create(
794 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
795 ArrayRef{int64_t(1)});
797 op->setOperand(opIdx, extract);
798 setOpsel(opIdx, opsel);
808 results.
add<PackScales>(context);
811#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
813#define GET_ATTRDEF_CLASSES
814#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
816#define GET_OP_CLASSES
817#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...