28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/TypeSwitch.h"
37 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
39 void AMDGPUDialect::initialize() {
42 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
45 #define GET_ATTRDEF_LIST
46 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
54 if (getExisting() && getExisting().
getType() != getResult().
getType())
55 return emitOpError(
"existing values must have same type as result");
60 if (getExisting() && getExisting().
getType() != getResult().
getType())
61 return emitOpError(
"existing values must have same type as result");
69 if (getExisting() && getExisting().
getType() != getResult().
getType())
70 return emitOpError(
"existing values must have same type as result");
88 MemRefLayoutAttrInterface layout = source.getLayout();
89 if (resetOffset && !layout.isIdentity()) {
90 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
93 MemRefLayoutAttrInterface newLayout =
99 if (source.hasStaticShape()) {
101 }
else if (source.getRank() <= 1) {
104 if (stridesIfIdentity == stridedLayout.getStrides()) {
110 return (MemRefType)(mb);
113 LogicalResult FatRawBufferCastOp::inferReturnTypes(
117 Adaptor adaptor(operands, attributes, properties, regions);
119 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
122 FailureOr<MemRefType> resultType =
124 if (failed(resultType))
131 FailureOr<MemRefType> expectedResultType =
133 if (failed(expectedResultType))
134 return emitOpError(
"source type ")
135 << getSource().getType() <<
" can't have its offset reset";
136 if (getResult().
getType() != *expectedResultType)
137 return emitOpError(
"expected result type to be ")
138 << *expectedResultType <<
" but got " << getResult().getType();
145 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
146 return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
147 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
148 return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
155 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
156 return intMemorySpace.getInt() == 3;
157 if (
auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
158 return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
165 if (
auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
166 return intMemorySpace.getInt() == 7;
167 if (
auto gpuMemorySpace = dyn_cast<amdgpu::AddressSpaceAttr>(memorySpace))
168 return gpuMemorySpace.getValue() == amdgpu::AddressSpace::FatRawBuffer;
175 template <
typename T>
177 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
181 return op.emitOpError(
182 "Buffer ops must operate on a memref in global memory");
183 if (!bufferType.hasRank())
184 return op.emitOpError(
185 "Cannot meaningfully buffer_store to an unranked memref");
186 if (
static_cast<int64_t
>(op.getIndices().size()) != bufferType.getRank())
187 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
188 " indices to memref");
221 return cst.getZExtValue();
225 template <
typename OpType>
227 if (!op.getBoundsCheck())
229 MemRefType bufferType = op.getMemref().getType();
230 if (!bufferType.hasStaticShape())
234 if (failed(bufferType.getStridesAndOffset(strides, offset)))
236 int64_t result = offset + op.getIndexOffset().value_or(0);
237 if (op.getSgprOffset()) {
241 result += *sgprOffset;
243 if (strides.size() != op.getIndices().size())
245 int64_t indexVal = 0;
246 for (
auto pair : llvm::zip(strides, op.getIndices())) {
247 int64_t stride = std::get<0>(pair);
248 Value idx = std::get<1>(pair);
252 indexVal += stride * *idxVal;
258 return result >= bufferType.getNumElements();
262 template <
typename OpType>
263 struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
266 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
269 Type loadType = op.getResult().getType();
276 template <
typename OpType>
277 struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
280 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
292 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
297 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
300 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
302 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
305 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
307 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
310 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
312 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
315 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
317 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
320 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
322 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
330 Type sourceAType = getSourceA().getType();
331 Type sourceBType = getSourceB().getType();
332 Type destType = getDestC().getType();
334 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
335 VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
336 VectorType destVectorType = dyn_cast<VectorType>(destType);
338 Type sourceAElemType = sourceVectorAType.getElementType();
339 Type sourceBElemType = sourceVectorBType.getElementType();
340 Type destElemType = destVectorType.getElementType();
342 if (sourceVectorAType.getNumElements() !=
343 sourceVectorBType.getNumElements()) {
344 return emitOpError(
"source vectors have different lengths: ")
345 << sourceVectorAType <<
" vs. " << sourceVectorBType;
348 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
350 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
353 if (isDestFloat && !isSrcFloat) {
354 return emitOpError(
"Expected float sources with float destination");
357 if (!isDestFloat && isSrcFloat) {
358 return emitOpError(
"Expected int sources with int destination");
361 if (sourceAElemType != sourceBElemType &&
362 !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
363 isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
365 "source element types much match (except for fp8) but have ")
366 << sourceAType <<
" and " << sourceBType;
375 constexpr uint32_t waveSize = 64;
378 Type sourceType = getSourceA().getType();
379 Type destType = getDestC().getType();
381 Type sourceElem = sourceType, destElem = destType;
382 uint32_t sourceLen = 1, destLen = 1;
383 if (
auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
384 sourceLen = sourceVector.getNumElements();
385 sourceElem = sourceVector.getElementType();
387 if (
auto destVector = llvm::dyn_cast<VectorType>(destType)) {
388 destLen = destVector.getNumElements();
389 destElem = destVector.getElementType();
392 Type sourceBType = getSourceB().getType();
394 int64_t sourceBLen = 1;
395 Type sourceBElem = sourceBType;
396 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
397 sourceBLen = sourceBVector.getNumElements();
398 sourceBElem = sourceBVector.getElementType();
402 return emitOpError(
"expected both source operands to have small-float "
403 "elements if one does");
404 if (sourceLen != sourceBLen)
406 "expected both small-float source vectors to have the same length");
408 if (sourceType != sourceBType)
409 return emitOpError(
"expected both non-small-float source operand types "
415 sourceElem = b.getI8Type();
419 sourceElem = b.getI8Type();
422 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
423 if (sourceLen != numSourceElems)
424 return emitOpError(
"expected " + Twine(numSourceElems) +
425 " source values for this operation but got " +
428 int64_t numDestElems = (
getM() *
getN() * getBlocks()) / waveSize;
429 if (destLen != numDestElems)
430 return emitOpError(
"expected " + Twine(numDestElems) +
431 " result values for this operation but got " +
434 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
436 "double-precision ops do not support permuting lanes of B");
437 if (destElem.isF64() && getCbsz() != 0)
439 "double-precision ops do not support permuting lanes of A");
440 if (getAbid() >= (1u << getCbsz()))
442 "block ID for permuting A (abid) must be below 2 ** cbsz");
444 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
446 "negation flags only available for double-precision operations");
455 Type srcType = getSrc().getType();
457 return emitOpError(
"integer and floating point types larger than 64 bits "
458 "are not supported");
461 DPPPerm
kind = getKind();
466 case DPPPerm::quad_perm: {
467 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
468 if (!quadPermAttr || quadPermAttr.size() != 4) {
469 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
471 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
472 int32_t num = elem.getInt();
473 if (num < 0 || num > 3) {
475 "Each element of quad_perm must be in the range [0, 3]");
480 case DPPPerm::row_shl:
481 case DPPPerm::row_shr:
482 case DPPPerm::row_ror: {
484 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(
kind)) +
485 "' value not specified");
487 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
488 uint32_t attrValue = intAttr.getInt();
489 if (attrValue < 1 || attrValue > 15) {
490 return emitOpError(
"Attribute value must be between 1 and 15");
495 case DPPPerm::wave_shl:
496 case DPPPerm::wave_shr:
497 case DPPPerm::wave_rol:
498 case DPPPerm::wave_ror:
499 case DPPPerm::row_mirror:
500 case DPPPerm::row_half_mirror:
501 case DPPPerm::row_bcast_15:
502 case DPPPerm::row_bcast_31: {
503 if (permArgument && !isa<UnitAttr>(permArgument)) {
504 return emitOpError(
"Expected unit attribute for permArgument, but found "
505 "non-trivial argument");
518 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
519 MemRefType dstType = cast<MemRefType>(getDst().
getType());
521 if (!dstType.areTrailingDimsContiguous(dstType.getRank()))
522 return emitOpError(
"destination types must be contiguous");
524 auto elemType = srcType.getElementType();
526 if (elemType != dstType.getElementType())
527 return emitOpError(
"source and destination element types must match");
530 auto transferType = getTransferType();
532 if (
auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
533 transferSize = vectorTransfer.getNumElements() *
534 vectorTransfer.getElementTypeBitWidth();
536 transferSize = transferType.getIntOrFloatBitWidth();
538 if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
540 "Transfering type size must be 8, 16, 32, 96 or 128 bits");
545 "source memory address space must be global or fat raw buffer");
548 return emitOpError(
"destination memory address space must be Workgroup");
559 LogicalResult matchAndRewrite(GatherToLDSOp gatherOp,
561 bool modified =
false;
562 auto foldCast = [&](
OpOperand &operand) {
563 if (
auto castOp = operand.get().getDefiningOp<memref::CastOp>()) {
566 [&] { operand.assign(castOp.getSource()); });
572 foldCast(gatherOp.getSrcMutable());
573 foldCast(gatherOp.getDstMutable());
575 return success(modified);
582 results.
add<FoldGatherToLDSOfCast>(context);
590 MemRefType srcType = cast<MemRefType>(getSrc().
getType());
593 return emitOpError(
"source memory address space must be Workgroup");
595 auto transferType = cast<VectorType>(
getType());
596 size_t numElements = transferType.getNumElements();
597 size_t elementTypeSize =
598 transferType.getElementType().getIntOrFloatBitWidth();
601 const llvm::SmallDenseMap<size_t, size_t> KValidLoadSizeMap = {
608 auto validNumElems = KValidLoadSizeMap.find(elementTypeSize);
609 if (validNumElems == KValidLoadSizeMap.end()) {
610 return emitOpError(
"Unsupported element type size for transpose load: ")
611 << elementTypeSize <<
" bits";
613 if (numElements != validNumElems->second) {
615 "Transferring type size mismatch: expected num of elements: ")
616 << validNumElems->second;
622 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
624 #define GET_ATTRDEF_CLASSES
625 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
627 #define GET_OP_CLASSES
628 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static bool hasGlobalMemorySpace(Attribute memorySpace)
static bool hasWorkgroupMemorySpace(Attribute memorySpace)
static std::optional< uint32_t > getConstantUint32(Value v)
static bool hasFatRawBufferMemorySpace(Attribute memorySpace)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1225::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setMemorySpace(Attribute newMemorySpace)
This class represents an operand of an operation.
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
SmallVector< int64_t > computeSuffixProduct(ArrayRef< int64_t > sizes)
Given a set of sizes, return the suffix product.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...