29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/TypeSwitch.h"
41 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
43 void AMDGPUDialect::initialize() {
46 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
49 #define GET_ATTRDEF_LIST
50 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
58 if (getExisting() && getExisting().
getType() != getResult().
getType())
59 return emitOpError(
"existing values must have same type as result");
64 if (getExisting() && getExisting().
getType() != getResult().
getType())
65 return emitOpError(
"existing values must have same type as result");
73 if (getExisting() && getExisting().
getType() != getResult().
getType())
74 return emitOpError(
"existing values must have same type as result");
92 MemRefLayoutAttrInterface layout = source.getLayout();
93 if (resetOffset && !layout.isIdentity()) {
94 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
97 MemRefLayoutAttrInterface newLayout =
103 if (source.hasStaticShape()) {
105 }
else if (source.getRank() <= 1) {
108 if (stridesIfIdentity == stridedLayout.getStrides()) {
114 return (MemRefType)(mb);
117 LogicalResult FatRawBufferCastOp::inferReturnTypes(
121 Adaptor adaptor(operands, attributes, properties, regions);
123 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
126 FailureOr<MemRefType> resultType =
135 FailureOr<MemRefType> expectedResultType =
137 if (
failed(expectedResultType))
138 return emitOpError(
"source type ")
139 << getSource().getType() <<
" can't have its offset reset";
140 if (getResult().
getType() != *expectedResultType)
141 return emitOpError(
"expected result type to be ")
142 << *expectedResultType <<
" but got " << getResult().getType();
149 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
150 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
151 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
152 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
159 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
160 return intMemorySpace.getInt() == 3;
161 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
162 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
169 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
170 return intMemorySpace.getInt() == 7;
171 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
172 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
179 template <
typename T>
181 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
185 return op.emitOpError(
186 "Buffer ops must operate on a memref in global memory");
187 if (!bufferType.hasRank())
188 return op.emitOpError(
189 "Cannot meaningfully buffer_store to an unranked memref");
190 if (
static_cast<int64_t
>(op.getIndices().size()) != bufferType.getRank())
191 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
192 " indices to memref");
225 return cst.getZExtValue();
229 template <
typename OpType>
231 if (!op.getBoundsCheck())
233 MemRefType bufferType = op.getMemref().getType();
234 if (!bufferType.hasStaticShape())
238 if (
failed(bufferType.getStridesAndOffset(strides, offset)))
240 int64_t result = offset + op.getIndexOffset().value_or(0);
241 if (op.getSgprOffset()) {
245 result += *sgprOffset;
247 if (strides.size() != op.getIndices().size())
249 int64_t indexVal = 0;
250 for (
auto pair : llvm::zip(strides, op.getIndices())) {
251 int64_t stride = std::get<0>(pair);
252 Value idx = std::get<1>(pair);
256 indexVal += stride * *idxVal;
262 return result >= bufferType.getNumElements();
266 template <
typename OpType>
267 struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
270 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
273 Type loadType = op.getResult().getType();
280 template <
typename OpType>
281 struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
284 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
296 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
301 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
304 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
306 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
309 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
311 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
314 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
316 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
319 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
321 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
324 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
326 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
334 Type sourceAType = getSourceA().getType();
335 Type sourceBType = getSourceB().getType();
336 Type destType = getDestC().getType();
338 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
339 VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
340 VectorType destVectorType = dyn_cast<VectorType>(destType);
342 Type sourceAElemType = sourceVectorAType.getElementType();
343 Type sourceBElemType = sourceVectorBType.getElementType();
344 Type destElemType = destVectorType.getElementType();
346 if (sourceVectorAType.getNumElements() !=
347 sourceVectorBType.getNumElements()) {
348 return emitOpError(
"source vectors have different lengths: ")
349 << sourceVectorAType <<
" vs. " << sourceVectorBType;
352 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
354 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
357 if (isDestFloat && !isSrcFloat) {
358 return emitOpError(
"Expected float sources with float destination");
361 if (!isDestFloat && isSrcFloat) {
362 return emitOpError(
"Expected int sources with int destination");
365 if (sourceAElemType != sourceBElemType &&
366 !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
367 isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
369 "source element types much match (except for fp8) but have ")
370 << sourceAType <<
" and " << sourceBType;
379 constexpr uint32_t waveSize = 64;
382 Type sourceType = getSourceA().getType();
383 Type destType = getDestC().getType();
385 Type sourceElem = sourceType, destElem = destType;
386 uint32_t sourceLen = 1, destLen = 1;
387 if (
auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
388 sourceLen = sourceVector.getNumElements();
389 sourceElem = sourceVector.getElementType();
391 if (
auto destVector = llvm::dyn_cast<VectorType>(destType)) {
392 destLen = destVector.getNumElements();
393 destElem = destVector.getElementType();
396 Type sourceBType = getSourceB().getType();
398 int64_t sourceBLen = 1;
399 Type sourceBElem = sourceBType;
400 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
401 sourceBLen = sourceBVector.getNumElements();
402 sourceBElem = sourceBVector.getElementType();
406 return emitOpError(
"expected both source operands to have small-float "
407 "elements if one does");
408 if (sourceLen != sourceBLen)
410 "expected both small-float source vectors to have the same length");
412 if (sourceType != sourceBType)
413 return emitOpError(
"expected both non-small-float source operand types "
419 sourceElem = b.getI8Type();
423 sourceElem = b.getI8Type();
426 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
427 if (sourceLen != numSourceElems)
428 return emitOpError(
"expected " + Twine(numSourceElems) +
429 " source values for this operation but got " +
432 int64_t numDestElems = (
getM() *
getN() * getBlocks()) / waveSize;
433 if (destLen != numDestElems)
434 return emitOpError(
"expected " + Twine(numDestElems) +
435 " result values for this operation but got " +
438 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
440 "double-precision ops do not support permuting lanes of B");
441 if (destElem.isF64() && getCbsz() != 0)
443 "double-precision ops do not support permuting lanes of A");
444 if (getAbid() >= (1u << getCbsz()))
446 "block ID for permuting A (abid) must be below 2 ** cbsz");
448 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
450 "negation flags only available for double-precision operations");
459 Type srcType = getSrc().getType();
461 return emitOpError(
"integer and floating point types larger than 64 bits "
462 "are not supported");
465 DPPPerm
kind = getKind();
470 case DPPPerm::quad_perm: {
471 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
472 if (!quadPermAttr || quadPermAttr.size() != 4) {
473 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
475 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
476 int32_t num = elem.getInt();
477 if (num < 0 || num > 3) {
479 "Each element of quad_perm must be in the range [0, 3]");
484 case DPPPerm::row_shl:
485 case DPPPerm::row_shr:
486 case DPPPerm::row_ror: {
488 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(
kind)) +
489 "' value not specified");
491 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
492 uint32_t attrValue = intAttr.getInt();
493 if (attrValue < 1 || attrValue > 15) {
494 return emitOpError(
"Attribute value must be between 1 and 15");
499 case DPPPerm::wave_shl:
500 case DPPPerm::wave_shr:
501 case DPPPerm::wave_rol:
502 case DPPPerm::wave_ror:
503 case DPPPerm::row_mirror:
504 case DPPPerm::row_half_mirror:
505 case DPPPerm::row_bcast_15:
506 case DPPPerm::row_bcast_31: {
507 if (permArgument && !isa<UnitAttr>(permArgument)) {
508 return emitOpError(
"Expected unit attribute for permArgument, but found "
509 "non-trivial argument");
521 unsigned rowLength = getRowLength();
523 if (rowLength != 16 && rowLength != 32)
524 return emitOpError(
"row_length attribute must either be 16 or 32.");
534 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
535 MemRefType dstType = cast<MemRefType>(getDst().
getType());
537 if (!dstType.areTrailingDimsContiguous(1))
538 return emitOpError(
"destination type inner most dim must be contiguous");
540 auto elemType = srcType.getElementType();
542 if (elemType != dstType.getElementType())
543 return emitOpError(
"source and destination element types must match");
546 auto transferType = getTransferType();
548 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
549 transferSize = vectorTransfer.getNumElements() *
550 vectorTransfer.getElementTypeBitWidth();
552 transferSize = transferType.getIntOrFloatBitWidth();
554 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
556 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
561 "source memory address space must be global or fat raw buffer");
564 return emitOpError(
"destination memory address space must be Workgroup");
575 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
577 bool modified =
false;
578 auto foldCast = [&](
OpOperand &operand) {
579 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
582 [&] { operand.assign(castOp.getSource()); });
588 foldCast(gatherOp.getSrcMutable());
589 foldCast(gatherOp.getDstMutable());
591 return success(modified);
598 results.
add<FoldGatherToLDSOfCast>(context);
606 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
609 return emitOpError(
"source memory address space must be Workgroup");
611 auto transferType = cast<VectorType>(
getType());
612 size_t numElements = transferType.getNumElements();
613 size_t elementTypeSize =
614 transferType.getElementType().getIntOrFloatBitWidth();
617 const llvm::SmallDenseMap<size_t, size_t> kValidLoadSizeMap = {
624 auto validNumElems = kValidLoadSizeMap.find(elementTypeSize);
625 if (validNumElems == kValidLoadSizeMap.end()) {
626 return emitOpError(
"Unsupported element type size for transpose load: ")
627 << elementTypeSize <<
" bits";
629 if (numElements != validNumElems->second) {
631 "Transferring type size mismatch: expected num of elements: ")
632 << validNumElems->second;
648 LogicalResult matchAndRewrite(ScaledMFMAOp op,
651 auto setOpsel = [&op](
unsigned idx, int64_t val) {
654 op.setScalesIdxA(val);
657 op.setScalesIdxB(val);
681 for (
auto opIdx : std::array<int64_t, 2>({3, 4})) {
682 auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
685 "defining op not a vector.insert");
688 if (isa<VectorType>(insertOp.getValueToStore().getType())) {
690 op,
"scaled mfma operand already packed");
694 insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
697 "defining op not a vector.extract");
700 Value scaleSrc = extractOp.getOperand(0);
701 auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.
getType());
708 if (!scaleSrcType.hasStaticShape()) {
710 "dynamic dims not yet supported");
713 int64_t numElements = scaleSrcType.getNumElements();
714 if (numElements <= 4) {
716 op,
"no packing if # of scales less than four");
720 auto extractedPos = llvm::to_vector_of<int64_t>(
721 llvm::reverse(extractOp.getStaticPosition()));
723 int64_t scaleSrcRank = scaleSrcType.getRank();
725 for (int64_t i = 1; i < scaleSrcRank; ++i) {
726 extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
728 int64_t idx =
linearize(extractedPos, extractSizes);
740 int64_t offset = idx - (idx % 4);
741 int64_t opsel = idx - offset;
744 if (numElements - offset < size) {
745 opsel = size - (numElements - idx);
746 offset = numElements - 4l;
748 Type scaleSrcElemType = scaleSrcType.getElementType();
752 vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
753 auto extract = vector::ExtractStridedSliceOp::create(
757 op->setOperand(opIdx, extract);
758 setOpsel(opIdx, opsel);
768 results.
add<PackScales>(context);
771 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
773 #define GET_ATTRDEF_CLASSES
774 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
776 #define GET_OP_CLASSES
777 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1243::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setMemorySpace(Attribute newMemorySpace)
This class represents an operand of an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
int64_t linearize(ArrayRef< int64_t > offsets, ArrayRef< int64_t > basis)
Return the linearized index of 'offsets' w.r.t.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...