26 #include "llvm/ADT/TypeSwitch.h"
34 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
36 void AMDGPUDialect::initialize() {
39 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
42 #define GET_ATTRDEF_LIST
43 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
51 if (getExisting() && getExisting().
getType() != getResult().
getType())
52 return emitOpError(
"existing values must have same type as result");
57 if (getExisting() && getExisting().
getType() != getResult().
getType())
58 return emitOpError(
"existing values must have same type as result");
76 MemRefLayoutAttrInterface layout = source.getLayout();
77 if (resetOffset && !layout.isIdentity()) {
78 auto stridedLayout = dyn_cast<StridedLayoutAttr>(layout);
83 return (MemRefType)(mb);
86 LogicalResult FatRawBufferCastOp::inferReturnTypes(
90 Adaptor adaptor(operands, attributes, properties, regions);
92 dyn_cast_if_present<MemRefType>(adaptor.getSource().getType());
95 FailureOr<MemRefType> resultType =
97 if (failed(resultType))
104 FailureOr<MemRefType> expectedResultType =
106 if (failed(expectedResultType))
107 return emitOpError(
"source type ")
108 << getSource().getType() <<
" can't have its offset reset";
109 if (getResult().
getType() != *expectedResultType)
110 return emitOpError(
"expected result type to be ")
111 << *expectedResultType <<
" but got " << getResult().getType();
118 template <
typename T>
120 MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
121 Attribute memorySpace = bufferType.getMemorySpace();
122 bool isGlobal =
false;
125 else if (
auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
126 isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
127 else if (
auto gpuMemorySpace =
128 llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
129 isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
132 return op.emitOpError(
133 "Buffer ops must operate on a memref in global memory");
134 if (!bufferType.hasRank())
135 return op.emitOpError(
136 "Cannot meaningfully buffer_store to an unranked memref");
137 if (
static_cast<int64_t
>(op.getIndices().size()) != bufferType.getRank())
138 return op.emitOpError(
"Expected " + Twine(bufferType.getRank()) +
139 " indices to memref");
172 return cst.getZExtValue();
176 template <
typename OpType>
178 if (!op.getBoundsCheck())
180 MemRefType bufferType = op.getMemref().getType();
181 if (!bufferType.hasStaticShape())
185 if (failed(bufferType.getStridesAndOffset(strides, offset)))
187 int64_t result = offset + op.getIndexOffset().value_or(0);
188 if (op.getSgprOffset()) {
192 result += *sgprOffset;
194 if (strides.size() != op.getIndices().size())
196 int64_t indexVal = 0;
197 for (
auto pair : llvm::zip(strides, op.getIndices())) {
198 int64_t stride = std::get<0>(pair);
199 Value idx = std::get<1>(pair);
203 indexVal += stride * *idxVal;
209 return result >= bufferType.getNumElements();
213 template <
typename OpType>
214 struct RemoveStaticallyOobBufferLoads final :
public OpRewritePattern<OpType> {
217 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
220 Type loadType = op.getResult().getType();
227 template <
typename OpType>
228 struct RemoveStaticallyOobBufferWrites final :
public OpRewritePattern<OpType> {
231 LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rw)
const override {
243 results.
add<RemoveStaticallyOobBufferLoads<RawBufferLoadOp>>(context);
248 results.
add<RemoveStaticallyOobBufferWrites<RawBufferStoreOp>>(context);
251 void RawBufferAtomicFaddOp::getCanonicalizationPatterns(
253 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFaddOp>>(context);
256 void RawBufferAtomicFmaxOp::getCanonicalizationPatterns(
258 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicFmaxOp>>(context);
261 void RawBufferAtomicSmaxOp::getCanonicalizationPatterns(
263 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicSmaxOp>>(context);
266 void RawBufferAtomicUminOp::getCanonicalizationPatterns(
268 results.
add<RemoveStaticallyOobBufferWrites<RawBufferAtomicUminOp>>(context);
271 void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
273 results.
add<RemoveStaticallyOobBufferLoads<RawBufferAtomicCmpswapOp>>(
281 Type sourceAType = getSourceA().getType();
282 Type sourceBType = getSourceB().getType();
283 Type destType = getDestC().getType();
285 VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
286 VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
287 VectorType destVectorType = dyn_cast<VectorType>(destType);
289 Type sourceAElemType = sourceVectorAType.getElementType();
290 Type sourceBElemType = sourceVectorBType.getElementType();
291 Type destElemType = destVectorType.getElementType();
293 if (sourceVectorAType.getNumElements() !=
294 sourceVectorBType.getNumElements()) {
295 return emitOpError(
"source vectors have different lengths: ")
296 << sourceVectorAType <<
" vs. " << sourceVectorBType;
299 bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
301 isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
304 if (isDestFloat && !isSrcFloat) {
305 return emitOpError(
"Expected float sources with float destination");
308 if (!isDestFloat && isSrcFloat) {
309 return emitOpError(
"Expected int sources with int destination");
312 if (sourceAElemType != sourceBElemType &&
313 !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
314 isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
316 "source element types much match (except for fp8) but have ")
317 << sourceAType <<
" and " << sourceBType;
326 constexpr uint32_t waveSize = 64;
329 Type sourceType = getSourceA().getType();
330 Type destType = getDestC().getType();
332 Type sourceElem = sourceType, destElem = destType;
333 uint32_t sourceLen = 1, destLen = 1;
334 if (
auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
335 sourceLen = sourceVector.getNumElements();
336 sourceElem = sourceVector.getElementType();
338 if (
auto destVector = llvm::dyn_cast<VectorType>(destType)) {
339 destLen = destVector.getNumElements();
340 destElem = destVector.getElementType();
343 Type sourceBType = getSourceB().getType();
345 int64_t sourceBLen = 1;
346 Type sourceBElem = sourceBType;
347 if (
auto sourceBVector = llvm::dyn_cast<VectorType>(sourceBType)) {
348 sourceBLen = sourceBVector.getNumElements();
349 sourceBElem = sourceBVector.getElementType();
352 return emitOpError(
"expected both source operands to have f8 elements");
353 if (sourceLen != sourceBLen)
355 "expected both f8 source vectors to have the same length");
357 if (sourceType != sourceBType)
359 "expected both non-f8 source operand types to match exactly");
364 sourceElem = b.getI8Type();
368 sourceElem = b.getI8Type();
371 int64_t numSourceElems = (
getM() * getK() * getBlocks()) / waveSize;
372 if (sourceLen != numSourceElems)
373 return emitOpError(
"expected " + Twine(numSourceElems) +
374 " source values for this operation but got " +
377 int64_t numDestElems = (
getM() *
getN() * getBlocks()) / waveSize;
378 if (destLen != numDestElems)
379 return emitOpError(
"expected " + Twine(numDestElems) +
380 " result values for this operation but got " +
383 if (destElem.isF64() && getBlgp() != MFMAPermB::none)
385 "double-precision ops do not support permuting lanes of B");
386 if (destElem.isF64() && getCbsz() != 0)
388 "double-precision ops do not support permuting lanes of A");
389 if (getAbid() >= (1u << getCbsz()))
391 "block ID for permuting A (abid) must be below 2 ** cbsz");
393 if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64())
395 "negation flags only available for double-precision operations");
404 Type srcType = getSrc().getType();
406 return emitOpError(
"integer and floating point types larger than 64 bits "
407 "are not supported");
410 DPPPerm
kind = getKind();
415 case DPPPerm::quad_perm: {
416 auto quadPermAttr = dyn_cast_or_null<ArrayAttr>(permArgument);
417 if (!quadPermAttr || quadPermAttr.size() != 4) {
418 return emitOpError(
"quad_perm attribute must have exactly 4 elements");
420 for (
auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
421 int32_t num = elem.getInt();
422 if (num < 0 || num > 3) {
424 "Each element of quad_perm must be in the range [0, 3]");
429 case DPPPerm::row_shl:
430 case DPPPerm::row_shr:
431 case DPPPerm::row_ror: {
433 return emitOpError(
"Attribute '" + Twine(stringifyDPPPerm(
kind)) +
434 "' value not specified");
436 if (
auto intAttr = dyn_cast<IntegerAttr>(permArgument)) {
437 uint32_t attrValue = intAttr.getInt();
438 if (attrValue < 1 || attrValue > 15) {
439 return emitOpError(
"Attribute value must be between 1 and 15");
444 case DPPPerm::wave_shl:
445 case DPPPerm::wave_shr:
446 case DPPPerm::wave_rol:
447 case DPPPerm::wave_ror:
448 case DPPPerm::row_mirror:
449 case DPPPerm::row_half_mirror:
450 case DPPPerm::row_bcast_15:
451 case DPPPerm::row_bcast_31: {
452 if (permArgument && !isa<UnitAttr>(permArgument)) {
453 return emitOpError(
"Expected unit attribute for permArgument, but found "
454 "non-trivial argument");
462 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
464 #define GET_ATTRDEF_CLASSES
465 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
467 #define GET_OP_CLASSES
468 #include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
static LogicalResult verifyRawBufferOp(T &op)
static FailureOr< MemRefType > getFatRawBufferTypeLike(MemRefType source, bool resetOffset)
Convert the type source to one with the same sizes and strides - and offset, unless stripOffset is tr...
static std::optional< uint32_t > getConstantUint32(Value v)
static bool staticallyOutOfBounds(OpType op)
static MLIRContext * getContext(OpFoldResult val)
union mlir::linalg::@1179::ArityGroupAndKind::Kind kind
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
This class is a general helper class for creating context-global objects like types,...
TypedAttr getZeroAttr(Type type)
MLIRContext is the top-level object for a collection of MLIR operations.
This is a builder type that keeps local references to arguments.
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
Builder & setMemorySpace(Attribute newMemorySpace)
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat() const
Return true if this is an float type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
uint64_t getN(LevelType lt)
uint64_t getM(LevelType lt)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...