30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TypeSwitch.h"
42#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
47 bool isLegalToInline(Operation *, Region *,
bool, IRMapping &)
const final {
53void AMDGPUDialect::initialize() {
56#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
59#define GET_ATTRDEF_LIST
60#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
62 addInterfaces<AMDGPUInlinerInterface>();
68LogicalResult PackedTrunc2xFp8Op::verify() {
69 if (getExisting() && getExisting().
getType() != getResult().
getType())
70 return emitOpError(
"existing values must have same type as result");
74LogicalResult PackedStochRoundFp8Op::verify() {
75 if (getExisting() && getExisting().
getType() != getResult().
getType())
76 return emitOpError(
"existing values must have same type as result");
83LogicalResult PackedScaledTruncOp::verify() {
84 if (getExisting() && getExisting().
getType() != getResult().
getType())
85 return emitOpError(
"existing values must have same type as result");
102 amdgpu::AddressSpaceAttr::get(ctx, amdgpu::AddressSpace::FatRawBuffer));
103 MemRefLayoutAttrInterface layout = source.getLayout();
104 if (resetOffset && !layout.isIdentity()) {
105 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
108 MemRefLayoutAttrInterface newLayout =
109 StridedLayoutAttr::get(ctx, 0, stridedLayout.getStrides());
114 if (source.hasStaticShape()) {
116 }
else if (source.getRank() <= 1) {
119 if (stridesIfIdentity == stridedLayout.getStrides()) {
120 newLayout = AffineMapAttr::get(
125 return (MemRefType)(mb);
128LogicalResult FatRawBufferCastOp::inferReturnTypes(
132 Adaptor adaptor(operands, attributes, properties, regions);
134 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
137 FailureOr<MemRefType> resultType =
145LogicalResult FatRawBufferCastOp::verify() {
146 FailureOr<MemRefType> expectedResultType =
148 if (
failed(expectedResultType))
150 << getSource().getType() <<
" can't have its offset reset";
151 if (getResult().
getType() != *expectedResultType)
153 << *expectedResultType <<
" but got " << getResult().getType();
160 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
161 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
162 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
163 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
170 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
171 return intMemorySpace.getInt() == 3;
172 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
173 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
180 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
181 return intMemorySpace.getInt() == 7;
182 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
183 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
192 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
196 return op.emitOpError(
197 "Buffer ops must operate on a memref in global memory");
198 if (!bufferType.hasRank())
199 return op.emitOpError(
200 "Cannot meaningfully buffer_store to an unranked memref");
201 if (
static_cast<int64_t>(op.getIndices().size()) != bufferType.getRank())
202 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
203 " indices to memref");
211LogicalResult RawBufferAtomicFaddOp::verify() {
215LogicalResult RawBufferAtomicFmaxOp::verify() {
219LogicalResult RawBufferAtomicSmaxOp::verify() {
223LogicalResult RawBufferAtomicUminOp::verify() {
227LogicalResult RawBufferAtomicCmpswapOp::verify() {
236 return cst.getZExtValue();
240template <
typename OpType>
242 if (!op.getBoundsCheck())
244 MemRefType bufferType = op.getMemref().getType();
245 if (!bufferType.hasStaticShape())
249 if (failed(bufferType.getStridesAndOffset(strides, offset)))
252 if (op.getSgprOffset()) {
258 if (strides.size() != op.getIndices().size())
261 for (
auto pair : llvm::zip(strides, op.getIndices())) {
262 int64_t stride = std::get<0>(pair);
263 Value idx = std::get<1>(pair);
267 indexVal += stride * *idxVal;
270 if (
result > std::numeric_limits<uint32_t>::max())
273 return result >= bufferType.getNumElements();
277template <
typename OpType>
278struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
279 using OpRewritePattern<OpType>::OpRewritePattern;
281 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
284 Type loadType = op.getResult().getType();
291template <
typename OpType>
292struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
293 using OpRewritePattern<OpType>::OpRewritePattern;
295 LogicalResult matchAndRewrite(OpType op, PatternRewriter &rw)
const override {
307 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
312 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
315void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
317 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
320void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
322 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
325void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
327 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
330void RawBufferAtomicUminOp::getCanonicalizationPatterns(
332 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
335void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
337 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
344LogicalResult ScaledExtPacked816Op::verify() {
346 assert(llvm::is_contained({16, 32}, blockSize) &&
"invalid block size");
348 int firstScaleByte = getFirstScaleByte();
349 int firstScaleLane = getFirstScaleLane();
350 auto sourceType = cast<VectorType>(getSource().
getType());
351 Type elementType = sourceType.getElementType();
352 auto floatType = cast<FloatType>(elementType);
353 unsigned bitWidth = floatType.getWidth();
357 const bool is_fp8 = bitWidth == 8;
358 const bool is_block_16 = blockSize == 16;
362 if (!llvm::is_contained({0, 1}, firstScaleByte)) {
363 return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 "
364 "or 1 for f4 and f6.");
367 if (!llvm::is_contained({0, 2}, firstScaleByte)) {
368 return emitOpError(
"blockSize of 32 can only have firstScaleByte be 0 "
369 "or 2 for f4 and f6.");
374 bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
375 ((firstScaleLane == 1) && (firstScaleByte == 2));
377 return emitOpError(
"blockSize of 16 can only have (firstScaleLane, "
378 "firstScaleByte) be (0, 0) or (1, 2) for f8.");
391 IntegerAttr &m, IntegerAttr &n,
396 if (dimensions.size() != 3)
398 <<
"expected 3 dimensions in MNK dimension list";
406LogicalResult WMMAOp::verify() {
407 auto sourceAType = cast<VectorType>(getSourceA().
getType());
408 auto sourceBType = cast<VectorType>(getSourceB().
getType());
409 auto destType = cast<VectorType>(getDestC().
getType());
411 Type sourceAElemType = sourceAType.getElementType();
412 Type sourceBElemType = sourceBType.getElementType();
413 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
414 return emitOpError(
"source vectors have different lengths: ")
415 << sourceAType <<
" vs. " << sourceBType;
418 bool isDestFloat = destType.getElementType().
isFloat();
419 bool isSrcFloat = sourceAElemType.
isFloat();
421 if (isDestFloat && !isSrcFloat)
422 return emitOpError(
"expected float sources with float destination");
423 if (!isDestFloat && isSrcFloat)
424 return emitOpError(
"expected int sources with int destination");
426 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
428 "source element types must match (except for fp8/bf8) but have ")
429 << sourceAType <<
" and " << sourceBType;
434 return emitOpError(
"clamp flag is not supported for float types");
435 if (getUnsignedA() || getUnsignedB())
436 return emitOpError(
"unsigned flags are not supported for float types");
444LogicalResult MFMAOp::verify() {
445 constexpr uint32_t waveSize = 64;
448 Type sourceType = getSourceA().getType();
449 Type destType = getDestC().getType();
451 Type sourceElem = sourceType, destElem = destType;
452 uint32_t sourceLen = 1, destLen = 1;
453 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
454 sourceLen = sourceVector.getNumElements();
455 sourceElem = sourceVector.getElementType();
457 if (
auto destVector = dyn_cast<VectorType>(destType)) {
458 destLen = destVector.getNumElements();
459 destElem = destVector.getElementType();
462 Type sourceBType = getSourceB().getType();
465 Type sourceBElem = sourceBType;
466 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
467 sourceBLen = sourceBVector.getNumElements();
468 sourceBElem = sourceBVector.getElementType();
472 return emitOpError(
"expected both source operands to have small-float "
473 "elements if one does");
474 if (sourceLen != sourceBLen)
476 "expected both small-float source vectors to have the same length");
478 if (sourceType != sourceBType)
479 return emitOpError(
"expected both non-small-float source operand types "
485 sourceElem =
b.getI8Type();
489 sourceElem =
b.getI8Type();
492 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
493 if (sourceLen != numSourceElems)
494 return emitOpError(
"expected " + Twine(numSourceElems) +
495 " source values for this operation but got " +
499 if (destLen != numDestElems)
500 return emitOpError(
"expected " + Twine(numDestElems) +
501 " result values for this operation but got " +
504 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
506 "double-precision ops do not support permuting lanes of B");
507 if (destElem.isF64() && getCbsz() != 0)
509 "double-precision ops do not support permuting lanes of A");
510 if (getAbid() >= (1u << getCbsz()))
512 "block ID for permuting A (abid) must be below 2 ** cbsz");
514 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
516 "negation flags only available for double-precision operations");
524LogicalResult DPPOp::verify() {
525 Type srcType = getSrc().getType();
527 return emitOpError(
"integer and floating point types larger than 64 bits "
528 "are not supported");
531 DPPPerm kind = getKind();
536 case DPPPerm::quad_perm: {
537 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
538 if (!quadPermAttr || quadPermAttr.size() != 4) {
539 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
541 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
542 int32_t num = elem.getInt();
543 if (num < 0 || num > 3) {
545 "Each element of quad_perm must be in the range [0, 3]");
550 case DPPPerm::row_shl:
551 case DPPPerm::row_shr:
552 case DPPPerm::row_ror: {
554 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(kind)) +
555 "' value not specified");
557 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
558 uint32_t attrValue = intAttr.getInt();
559 if (attrValue < 1 || attrValue > 15) {
560 return emitOpError(
"Attribute value must be between 1 and 15");
565 case DPPPerm::wave_shl:
566 case DPPPerm::wave_shr:
567 case DPPPerm::wave_rol:
568 case DPPPerm::wave_ror:
569 case DPPPerm::row_mirror:
570 case DPPPerm::row_half_mirror:
571 case DPPPerm::row_bcast_15:
572 case DPPPerm::row_bcast_31: {
573 if (permArgument && !isa<UnitAttr>(permArgument)) {
574 return emitOpError(
"Expected unit attribute for permArgument, but found "
575 "non-trivial argument");
586LogicalResult PermlaneSwapOp::verify() {
587 unsigned rowLength = getRowLength();
589 if (rowLength != 16 && rowLength != 32)
590 return emitOpError(
"row_length attribute must either be 16 or 32.");
599LogicalResult GatherToLDSOp::verify() {
600 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
601 MemRefType dstType = cast<MemRefType>(getDst().
getType());
603 if (!dstType.areTrailingDimsContiguous(1))
604 return emitOpError(
"destination type inner most dim must be contiguous");
606 auto elemType = srcType.getElementType();
608 if (elemType != dstType.getElementType())
609 return emitOpError(
"source and destination element types must match");
612 auto transferType = getTransferType();
614 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
615 transferSize = vectorTransfer.getNumElements() *
616 vectorTransfer.getElementTypeBitWidth();
618 transferSize = transferType.getIntOrFloatBitWidth();
620 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
622 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
627 "source memory address space must be global or fat raw buffer");
630 return emitOpError(
"destination memory address space must be Workgroup");
641 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
642 PatternRewriter &rewriter)
const override {
643 bool modified =
false;
644 auto foldCast = [&](OpOperand &operand) {
645 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
646 if (memref::CastOp::canFoldIntoConsumerOp(castOp)) {
648 [&] { operand.assign(castOp.getSource()); });
654 foldCast(gatherOp.getSrcMutable());
655 foldCast(gatherOp.getDstMutable());
664 results.
add<FoldGatherToLDSOfCast>(context);
671LogicalResult TransposeLoadOp::verify() {
672 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
675 return emitOpError(
"source memory address space must be Workgroup");
677 auto transferType = cast<VectorType>(
getType());
678 size_t numElements = transferType.getNumElements();
679 size_t elementTypeSize =
680 transferType.getElementType().getIntOrFloatBitWidth();
683 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
690 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
691 if (validNumElems == kValidLoadSizeMap.end()) {
692 return emitOpError(
"Unsupported element type size for transpose load: ")
693 << elementTypeSize <<
" bits";
695 if (numElements != validNumElems->second) {
697 "Transferring type size mismatch: expected num of elements: ")
698 << validNumElems->second;
714 LogicalResult matchAndRewrite(ScaledMFMAOp op,
715 PatternRewriter &rewriter)
const override {
716 Location loc = op.getLoc();
717 auto setOpsel = [&op](
unsigned idx, int64_t val) {
720 op.setScalesIdxA(val);
723 op.setScalesIdxB(val);
747 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
748 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
751 "defining op not a vector.insert");
754 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
756 op,
"scaled mfma operand already packed");
760 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
763 "defining op not a vector.extract");
766 Value scaleSrc = extractOp.getOperand(0);
767 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
774 if (!scaleSrcType.hasStaticShape()) {
776 "dynamic dims not yet supported");
779 int64_t numElements = scaleSrcType.getNumElements();
780 if (numElements <= 4) {
782 op,
"no packing if # of scales less than four");
786 auto extractedPos = llvm::to_vector_of<int64_t>(
787 llvm::reverse(extractOp.getStaticPosition()));
788 ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
789 int64_t scaleSrcRank = scaleSrcType.getRank();
790 SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
791 for (int64_t i = 1; i < scaleSrcRank; ++i) {
792 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
794 int64_t idx =
linearize(extractedPos, extractSizes);
806 int64_t offset = idx - (idx % 4);
807 int64_t opsel = idx - offset;
810 if (numElements - offset < size) {
811 opsel = size - (numElements - idx);
812 offset = numElements - 4l;
814 Type scaleSrcElemType = scaleSrcType.getElementType();
816 VectorType::get(ArrayRef{numElements}, scaleSrcElemType);
818 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
819 auto extract = vector::ExtractStridedSliceOp::create(
820 rewriter, loc, newScaleSrc, ArrayRef{offset}, ArrayRef{size},
821 ArrayRef{int64_t(1)});
823 op->setOperand(opIdx, extract);
824 setOpsel(opIdx, opsel);
834 results.
add<PackScales>(context);
837#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
839#define GET_ATTRDEF_CLASSES
840#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
842#define GET_OP_CLASSES
843#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static std::optional< uint32_t > getConstantUint32(Value v)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setMemorySpace(Attribute newMemorySpace)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...