30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/TypeSwitch.h"
42 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
53 void AMDGPUDialect::initialize() {
56 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
59 #define GET_ATTRDEF_LIST
60 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
62 addInterfaces<AMDGPUInlinerInterface>();
69 if (getExisting() && getExisting().
getType() != getResult().
getType())
70 return emitOpError(
"existing values must have same type as result");
75 if (getExisting() && getExisting().
getType() != getResult().
getType())
76 return emitOpError(
"existing values must have same type as result");
84 if (getExisting() && getExisting().
getType() != getResult().
getType())
85 return emitOpError(
"existing values must have same type as result");
103 MemRefLayoutAttrInterface layout = source.getLayout();
104 if (resetOffset && !layout.isIdentity()) {
105 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
108 MemRefLayoutAttrInterface newLayout =
114 if (source.hasStaticShape()) {
116 }
else if (source.getRank() <= 1) {
119 if (stridesIfIdentity == stridedLayout.getStrides()) {
125 return (MemRefType)(mb);
128 LogicalResult FatRawBufferCastOp::inferReturnTypes(
132 Adaptor adaptor(operands, attributes, properties, regions);
134 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
137 FailureOr<MemRefType> resultType =
146 FailureOr<MemRefType> expectedResultType =
148 if (
failed(expectedResultType))
149 return emitOpError(
"source type ")
150 << getSource().getType() <<
" can't have its offset reset";
151 if (getResult().
getType() != *expectedResultType)
152 return emitOpError(
"expected result type to be ")
153 << *expectedResultType <<
" but got " << getResult().getType();
160 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
161 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
162 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
163 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
170 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
171 return intMemorySpace.getInt() == 3;
172 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
173 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
180 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
181 return intMemorySpace.getInt() == 7;
182 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
183 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
190 template <
typename T>
192 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
196 return op.emitOpError(
197 "Buffer ops must operate on a memref in global memory");
198 if (!bufferType.hasRank())
199 return op.emitOpError(
200 "Cannot meaningfully buffer_store to an unranked memref");
201 if (
static_cast<int64_t
>(op.getIndices().size()) != bufferType.getRank())
202 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
203 " indices to memref");
236 return cst.getZExtValue();
240 template <
typename OpType>
242 if (!op.getBoundsCheck())
244 MemRefType bufferType = op.getMemref().getType();
245 if (!bufferType.hasStaticShape())
249 if (
failed(bufferType.getStridesAndOffset(strides, offset)))
251 int64_t result = offset + op.getIndexOffset().value_or(0);
252 if (op.getSgprOffset()) {
256 result += *sgprOffset;
258 if (strides.size() != op.getIndices().size())
260 int64_t indexVal = 0;
261 for (
auto pair : llvm::zip(strides, op.getIndices())) {
262 int64_t stride = std::get<0>(pair);
263 Value idx = std::get<1>(pair);
267 indexVal += stride * *idxVal;
273 return result >= bufferType.getNumElements();
277 template <
typename OpType>
278 struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
281 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
284 Type loadType = op.getResult().getType();
291 template <
typename OpType>
292 struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
295 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
307 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
312 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
315 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
317 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
320 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
322 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
325 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
327 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
330 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
332 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
335 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
337 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
345 Type sourceAType = getSourceA().getType();
346 Type sourceBType = getSourceB().getType();
347 Type destType = getDestC().getType();
349 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
350 VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
351 VectorType destVectorType = dyn_cast<VectorType>(destType);
353 Type sourceAElemType = sourceVectorAType.getElementType();
354 Type sourceBElemType = sourceVectorBType.getElementType();
355 Type destElemType = destVectorType.getElementType();
357 if (sourceVectorAType.getNumElements() !=
358 sourceVectorBType.getNumElements()) {
359 return emitOpError(
"source vectors have different lengths: ")
360 << sourceVectorAType <<
" vs. " << sourceVectorBType;
363 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
365 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
368 if (isDestFloat && !isSrcFloat) {
369 return emitOpError(
"Expected float sources with float destination");
372 if (!isDestFloat && isSrcFloat) {
373 return emitOpError(
"Expected int sources with int destination");
376 if (sourceAElemType != sourceBElemType &&
377 !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
378 isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
380 "source element types much match (except for fp8) but have ")
381 << sourceAType <<
" and " << sourceBType;
390 constexpr uint32_t waveSize = 64;
393 Type sourceType = getSourceA().getType();
394 Type destType = getDestC().getType();
396 Type sourceElem = sourceType, destElem = destType;
397 uint32_t sourceLen = 1, destLen = 1;
398 if (
auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
399 sourceLen = sourceVector.getNumElements();
400 sourceElem = sourceVector.getElementType();
402 if (
auto destVector = llvm::dyn_cast<VectorType>(destType)) {
403 destLen = destVector.getNumElements();
404 destElem = destVector.getElementType();
407 Type sourceBType = getSourceB().getType();
409 int64_t sourceBLen = 1;
410 Type sourceBElem = sourceBType;
411 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
412 sourceBLen = sourceBVector.getNumElements();
413 sourceBElem = sourceBVector.getElementType();
417 return emitOpError(
"expected both source operands to have small-float "
418 "elements if one does");
419 if (sourceLen != sourceBLen)
421 "expected both small-float source vectors to have the same length");
423 if (sourceType != sourceBType)
424 return emitOpError(
"expected both non-small-float source operand types "
430 sourceElem = b.getI8Type();
434 sourceElem = b.getI8Type();
437 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
438 if (sourceLen != numSourceElems)
439 return emitOpError(
"expected " + Twine(numSourceElems) +
440 " source values for this operation but got " +
443 int64_t numDestElems = (
getM() *
getN() * getBlocks()) / waveSize;
444 if (destLen != numDestElems)
445 return emitOpError(
"expected " + Twine(numDestElems) +
446 " result values for this operation but got " +
449 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
451 "double-precision ops do not support permuting lanes of B");
452 if (destElem.isF64() && getCbsz() != 0)
454 "double-precision ops do not support permuting lanes of A");
455 if (getAbid() >= (1u << getCbsz()))
457 "block ID for permuting A (abid) must be below 2 ** cbsz");
459 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
461 "negation flags only available for double-precision operations");
470 Type srcType = getSrc().getType();
472 return emitOpError(
"integer and floating point types larger than 64 bits "
473 "are not supported");
476 DPPPerm
kind = getKind();
481 case DPPPerm::quad_perm: {
482 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
483 if (!quadPermAttr || quadPermAttr.size() != 4) {
484 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
486 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
487 int32_t num = elem.getInt();
488 if (num < 0 || num > 3) {
490 "Each element of quad_perm must be in the range [0, 3]");
495 case DPPPerm::row_shl:
496 case DPPPerm::row_shr:
497 case DPPPerm::row_ror: {
499 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(
kind)) +
500 "' value not specified");
502 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
503 uint32_t attrValue = intAttr.getInt();
504 if (attrValue < 1 || attrValue > 15) {
505 return emitOpError(
"Attribute value must be between 1 and 15");
510 case DPPPerm::wave_shl:
511 case DPPPerm::wave_shr:
512 case DPPPerm::wave_rol:
513 case DPPPerm::wave_ror:
514 case DPPPerm::row_mirror:
515 case DPPPerm::row_half_mirror:
516 case DPPPerm::row_bcast_15:
517 case DPPPerm::row_bcast_31: {
518 if (permArgument && !isa<UnitAttr>(permArgument)) {
519 return emitOpError(
"Expected unit attribute for permArgument, but found "
520 "non-trivial argument");
532 unsigned rowLength = getRowLength();
534 if (rowLength != 16 && rowLength != 32)
535 return emitOpError(
"row_length attribute must either be 16 or 32.");
545 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
546 MemRefType dstType = cast<MemRefType>(getDst().
getType());
548 if (!dstType.areTrailingDimsContiguous(1))
549 return emitOpError(
"destination type inner most dim must be contiguous");
551 auto elemType = srcType.getElementType();
553 if (elemType != dstType.getElementType())
554 return emitOpError(
"source and destination element types must match");
557 auto transferType = getTransferType();
559 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
560 transferSize = vectorTransfer.getNumElements() *
561 vectorTransfer.getElementTypeBitWidth();
563 transferSize = transferType.getIntOrFloatBitWidth();
565 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
567 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
572 "source memory address space must be global or fat raw buffer");
575 return emitOpError(
"destination memory address space must be Workgroup");
586 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
588 bool modified =
false;
589 auto foldCast = [&](
OpOperand &operand) {
590 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
593 [&] { operand.assign(castOp.getSource()); });
599 foldCast(gatherOp.getSrcMutable());
600 foldCast(gatherOp.getDstMutable());
602 return success(modified);
609 results.
add<FoldGatherToLDSOfCast>(context);
617 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
620 return emitOpError(
"source memory address space must be Workgroup");
622 auto transferType = cast<VectorType>(
getType());
623 size_t numElements = transferType.getNumElements();
624 size_t elementTypeSize =
625 transferType.getElementType().getIntOrFloatBitWidth();
628 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
635 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
636 if (validNumElems == kValidLoadSizeMap.end()) {
637 return emitOpError(
"Unsupported element type size for transpose load: ")
638 << elementTypeSize <<
" bits";
640 if (numElements != validNumElems->second) {
642 "Transferring type size mismatch: expected num of elements: ")
643 << validNumElems->second;
659 LogicalResult matchAndRewrite(ScaledMFMAOp op,
662 auto setOpsel = [&op](
unsigned idx, int64_t val) {
665 op.setScalesIdxA(val);
668 op.setScalesIdxB(val);
692 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
693 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
696 "defining op not a vector.insert");
699 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
701 op,
"scaled mfma operand already packed");
705 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
708 "defining op not a vector.extract");
711 Value scaleSrc = extractOp.getOperand(0);
712 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
719 if (!scaleSrcType.hasStaticShape()) {
721 "dynamic dims not yet supported");
724 int64_t numElements = scaleSrcType.getNumElements();
725 if (numElements <= 4) {
727 op,
"no packing if # of scales less than four");
731 auto extractedPos = llvm::to_vector_of<int64_t>(
732 llvm::reverse(extractOp.getStaticPosition()));
734 int64_t scaleSrcRank = scaleSrcType.getRank();
736 for (int64_t i = 1; i < scaleSrcRank; ++i) {
737 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
739 int64_t idx =
linearize(extractedPos, extractSizes);
751 int64_t offset = idx - (idx % 4);
752 int64_t opsel = idx - offset;
755 if (numElements - offset < size) {
756 opsel = size - (numElements - idx);
757 offset = numElements - 4l;
759 Type scaleSrcElemType = scaleSrcType.getElementType();
763 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
764 auto extract = vector::ExtractStridedSliceOp::create(
768 op->setOperand(opIdx, extract);
769 setOpsel(opIdx, opsel);
779 results.
add<PackScales>(context);
782 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
784 #define GET_ATTRDEF_CLASSES
785 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
787 #define GET_OP_CLASSES
788 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
union mlir::linalg::@1247::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setMemorySpace(Attribute newMemorySpace)
This class represents an operand of an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...