30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
42 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
53 void AMDGPUDialect::initialize() {
56 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
59 #define GET_ATTRDEF_LIST
60 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
62 addInterfaces<AMDGPUInlinerInterface>();
69 if (getExisting() && getExisting().
getType() != getResult().
getType())
70 return emitOpError(
"existing values must have same type as result");
75 if (getExisting() && getExisting().
getType() != getResult().
getType())
76 return emitOpError(
"existing values must have same type as result");
84 if (getExisting() && getExisting().
getType() != getResult().
getType())
85 return emitOpError(
"existing values must have same type as result");
103 MemRefLayoutAttrInterface layout = source.getLayout();
104 if (resetOffset && !layout.isIdentity()) {
105 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
108 MemRefLayoutAttrInterface newLayout =
114 if (source.hasStaticShape()) {
116 }
else if (source.getRank() <= 1) {
119 if (stridesIfIdentity == stridedLayout.getStrides()) {
125 return (MemRefType)(mb);
128 LogicalResult FatRawBufferCastOp::inferReturnTypes(
132 Adaptor adaptor(operands, attributes, properties, regions);
134 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
137 FailureOr<MemRefType> resultType =
146 FailureOr<MemRefType> expectedResultType =
148 if (
failed(expectedResultType))
149 return emitOpError(
"source type ")
150 << getSource().getType() <<
" can't have its offset reset";
151 if (getResult().
getType() != *expectedResultType)
152 return emitOpError(
"expected result type to be ")
153 << *expectedResultType <<
" but got " << getResult().getType();
160 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
161 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
162 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
163 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
170 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
171 return intMemorySpace.getInt() == 3;
172 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
173 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
180 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
181 return intMemorySpace.getInt() == 7;
182 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
183 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
190 template <
typename T>
192 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
196 return op.emitOpError(
197 "Buffer ops must operate on a memref in global memory");
198 if (!bufferType.hasRank())
199 return op.emitOpError(
200 "Cannot meaningfully buffer_store to an unranked memref");
201 if (
static_cast<int64_t
>(op.getIndices().size()) != bufferType.getRank())
202 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
203 " indices to memref");
236 return cst.getZExtValue();
240 template <
typename OpType>
242 if (!op.getBoundsCheck())
244 MemRefType bufferType = op.getMemref().getType();
245 if (!bufferType.hasStaticShape())
249 if (
failed(bufferType.getStridesAndOffset(strides, offset)))
251 int64_t result = offset + op.getIndexOffset().value_or(0);
252 if (op.getSgprOffset()) {
256 result += *sgprOffset;
258 if (strides.size() != op.getIndices().size())
260 int64_t indexVal = 0;
261 for (
auto pair : llvm::zip(strides, op.getIndices())) {
262 int64_t stride = std::get<0>(pair);
263 Value idx = std::get<1>(pair);
267 indexVal += stride * *idxVal;
273 return result >= bufferType.getNumElements();
277 template <
typename OpType>
278 struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
281 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
284 Type loadType = op.getResult().getType();
291 template <
typename OpType>
292 struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
295 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
307 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
312 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
315 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
317 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
320 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
322 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
325 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
327 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
330 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
332 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
335 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
337 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
346 assert((blockSize == 16 || blockSize == 32) &&
"invalid block size");
347 int firstScaleByte = getFirstScaleByte();
348 if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
350 "blockSize of 16 can only have firstScaleByte be 0 or 1.");
352 if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
354 "blockSize of 32 can only have firstScaleByte be 0 or 2.");
365 IntegerAttr &m, IntegerAttr &n,
370 if (dimensions.size() != 3)
372 <<
"expected 3 dimensions in MNK dimension list";
381 auto sourceAType = cast<VectorType>(getSourceA().
getType());
382 auto sourceBType = cast<VectorType>(getSourceB().
getType());
383 auto destType = cast<VectorType>(getDestC().
getType());
385 Type sourceAElemType = sourceAType.getElementType();
386 Type sourceBElemType = sourceBType.getElementType();
387 if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
388 return emitOpError(
"source vectors have different lengths: ")
389 << sourceAType <<
" vs. " << sourceBType;
392 bool isDestFloat = destType.getElementType().
isFloat();
393 bool isSrcFloat = sourceAElemType.
isFloat();
395 if (isDestFloat && !isSrcFloat)
396 return emitOpError(
"expected float sources with float destination");
397 if (!isDestFloat && isSrcFloat)
398 return emitOpError(
"expected int sources with int destination");
400 if (!sourceAElemType.
isFloat(8) && sourceAElemType != sourceBElemType) {
402 "source element types must match (except for fp8/bf8) but have ")
403 << sourceAType <<
" and " << sourceBType;
408 return emitOpError(
"clamp flag is not supported for float types");
409 if (getUnsignedA() || getUnsignedB())
410 return emitOpError(
"unsigned flags are not supported for float types");
419 constexpr uint32_t waveSize = 64;
422 Type sourceType = getSourceA().getType();
423 Type destType = getDestC().getType();
425 Type sourceElem = sourceType, destElem = destType;
426 uint32_t sourceLen = 1, destLen = 1;
427 if (
auto sourceVector = dyn_cast<VectorType>(sourceType)) {
428 sourceLen = sourceVector.getNumElements();
429 sourceElem = sourceVector.getElementType();
431 if (
auto destVector = dyn_cast<VectorType>(destType)) {
432 destLen = destVector.getNumElements();
433 destElem = destVector.getElementType();
436 Type sourceBType = getSourceB().getType();
438 int64_t sourceBLen = 1;
439 Type sourceBElem = sourceBType;
440 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
441 sourceBLen = sourceBVector.getNumElements();
442 sourceBElem = sourceBVector.getElementType();
446 return emitOpError(
"expected both source operands to have small-float "
447 "elements if one does");
448 if (sourceLen != sourceBLen)
450 "expected both small-float source vectors to have the same length");
452 if (sourceType != sourceBType)
453 return emitOpError(
"expected both non-small-float source operand types "
459 sourceElem = b.getI8Type();
463 sourceElem = b.getI8Type();
466 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
467 if (sourceLen != numSourceElems)
468 return emitOpError(
"expected " + Twine(numSourceElems) +
469 " source values for this operation but got " +
472 int64_t numDestElems = (
getM() *
getN() * getBlocks()) / waveSize;
473 if (destLen != numDestElems)
474 return emitOpError(
"expected " + Twine(numDestElems) +
475 " result values for this operation but got " +
478 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
480 "double-precision ops do not support permuting lanes of B");
481 if (destElem.isF64() && getCbsz() != 0)
483 "double-precision ops do not support permuting lanes of A");
484 if (getAbid() >= (1u << getCbsz()))
486 "block ID for permuting A (abid) must be below 2 ** cbsz");
488 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
490 "negation flags only available for double-precision operations");
499 Type srcType = getSrc().getType();
501 return emitOpError(
"integer and floating point types larger than 64 bits "
502 "are not supported");
505 DPPPerm
kind = getKind();
510 case DPPPerm::quad_perm: {
511 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
512 if (!quadPermAttr || quadPermAttr.size() != 4) {
513 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
515 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
516 int32_t num = elem.getInt();
517 if (num < 0 || num > 3) {
519 "Each element of quad_perm must be in the range [0, 3]");
524 case DPPPerm::row_shl:
525 case DPPPerm::row_shr:
526 case DPPPerm::row_ror: {
528 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(
kind)) +
529 "' value not specified");
531 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
532 uint32_t attrValue = intAttr.getInt();
533 if (attrValue < 1 || attrValue > 15) {
534 return emitOpError(
"Attribute value must be between 1 and 15");
539 case DPPPerm::wave_shl:
540 case DPPPerm::wave_shr:
541 case DPPPerm::wave_rol:
542 case DPPPerm::wave_ror:
543 case DPPPerm::row_mirror:
544 case DPPPerm::row_half_mirror:
545 case DPPPerm::row_bcast_15:
546 case DPPPerm::row_bcast_31: {
547 if (permArgument && !isa<UnitAttr>(permArgument)) {
548 return emitOpError(
"Expected unit attribute for permArgument, but found "
549 "non-trivial argument");
561 unsigned rowLength = getRowLength();
563 if (rowLength != 16 && rowLength != 32)
564 return emitOpError(
"row_length attribute must either be 16 or 32.");
574 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
575 MemRefType dstType = cast<MemRefType>(getDst().
getType());
577 if (!dstType.areTrailingDimsContiguous(1))
578 return emitOpError(
"destination type inner most dim must be contiguous");
580 auto elemType = srcType.getElementType();
582 if (elemType != dstType.getElementType())
583 return emitOpError(
"source and destination element types must match");
586 auto transferType = getTransferType();
588 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
589 transferSize = vectorTransfer.getNumElements() *
590 vectorTransfer.getElementTypeBitWidth();
592 transferSize = transferType.getIntOrFloatBitWidth();
594 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
596 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
601 "source memory address space must be global or fat raw buffer");
604 return emitOpError(
"destination memory address space must be Workgroup");
615 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
617 bool modified =
false;
618 auto foldCast = [&](
OpOperand &operand) {
619 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
622 [&] { operand.assign(castOp.getSource()); });
628 foldCast(gatherOp.getSrcMutable());
629 foldCast(gatherOp.getDstMutable());
631 return success(modified);
638 results.
add<FoldGatherToLDSOfCast>(context);
646 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
649 return emitOpError(
"source memory address space must be Workgroup");
651 auto transferType = cast<VectorType>(
getType());
652 size_t numElements = transferType.getNumElements();
653 size_t elementTypeSize =
654 transferType.getElementType().getIntOrFloatBitWidth();
657 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
664 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
665 if (validNumElems == kValidLoadSizeMap.end()) {
666 return emitOpError(
"Unsupported element type size for transpose load: ")
667 << elementTypeSize <<
" bits";
669 if (numElements != validNumElems->second) {
671 "Transferring type size mismatch: expected num of elements: ")
672 << validNumElems->second;
688 LogicalResult matchAndRewrite(ScaledMFMAOp op,
691 auto setOpsel = [&op](
unsigned idx, int64_t val) {
694 op.setScalesIdxA(val);
697 op.setScalesIdxB(val);
721 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
722 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
725 "defining op not a vector.insert");
728 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
730 op,
"scaled mfma operand already packed");
734 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
737 "defining op not a vector.extract");
740 Value scaleSrc = extractOp.getOperand(0);
741 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
748 if (!scaleSrcType.hasStaticShape()) {
750 "dynamic dims not yet supported");
753 int64_t numElements = scaleSrcType.getNumElements();
754 if (numElements <= 4) {
756 op,
"no packing if # of scales less than four");
760 auto extractedPos = llvm::to_vector_of<int64_t>(
761 llvm::reverse(extractOp.getStaticPosition()));
763 int64_t scaleSrcRank = scaleSrcType.getRank();
765 for (int64_t i = 1; i < scaleSrcRank; ++i) {
766 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
768 int64_t idx =
linearize(extractedPos, extractSizes);
780 int64_t offset = idx - (idx % 4);
781 int64_t opsel = idx - offset;
784 if (numElements - offset < size) {
785 opsel = size - (numElements - idx);
786 offset = numElements - 4l;
788 Type scaleSrcElemType = scaleSrcType.getElementType();
792 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
793 auto extract = vector::ExtractStridedSliceOp::create(
797 op->setOperand(opIdx, extract);
798 setOpsel(opIdx, opsel);
808 results.
add<PackScales>(context);
811 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
813 #define GET_ATTRDEF_CLASSES
814 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
816 #define GET_OP_CLASSES
817 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1257::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setMemorySpace(Attribute newMemorySpace)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This class represents an operand of an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, IntegerAttr &n, IntegerAttr &k)
Parser for the custom<MNKDimensionList> custom assembly format used by WMMAOp.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...