22 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
23 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
30 struct AmdgpuEmulateAtomicsPass
31 :
public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
32 AmdgpuEmulateAtomicsPass> {
33 using AmdgpuEmulateAtomicsPassBase<
34 AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
35 void runOnOperation()
override;
38 template <
typename AtomicOp,
typename ArithOp>
41 using Adaptor =
typename AtomicOp::Adaptor;
44 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
50 enum class DataArgAction : unsigned char {
63 DataArgAction action) {
64 newAttrs.reserve(attrs.size());
66 if (attr.getName().getValue() !=
"operandSegmentSizes") {
67 newAttrs.push_back(attr);
70 auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
74 case DataArgAction::Drop:
76 context, segmentAttr.asArrayRef().drop_front());
78 case DataArgAction::Duplicate: {
81 newVals.push_back(oldVals[0]);
82 newVals.append(oldVals.begin(), oldVals.end());
95 auto vectorType = dyn_cast<VectorType>(val.
getType());
100 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
103 Value bitcast = rewriter.
create<vector::BitCastOp>(loc, allBitsVecType, val);
104 Value scalar = rewriter.
create<vector::ExtractOp>(loc, bitcast, 0);
108 template <
typename AtomicOp,
typename ArithOp>
109 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
110 AtomicOp atomicOp, Adaptor adaptor,
116 Value data = operands.take_front()[0];
117 ValueRange invariantArgs = operands.drop_front();
123 rewriter.
create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
130 rewriter.
create<cf::BranchOp>(loc, loopBlock, initialLoad);
133 Value prevLoad = loopBlock->getArgument(0);
134 Value operated = rewriter.
create<ArithOp>(loc, data, prevLoad);
140 cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
141 Value atomicRes = rewriter.
create<RawBufferAtomicCmpswapOp>(
142 loc, dataType, cmpswapArgs, cmpswapAttrs);
151 if (
auto floatDataTy = dyn_cast<FloatType>(dataType)) {
154 rewriter.
create<arith::BitcastOp>(loc, equivInt, prevLoad);
155 atomicResForCompare =
156 rewriter.
create<arith::BitcastOp>(loc, equivInt, atomicRes);
159 loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
161 loopBlock, atomicRes);
175 [](RawBufferAtomicFaddOp op) ->
bool {
177 return !isa<Float16Type, BFloat16Type>(elemType);
182 if (chipset >=
Chipset(9, 0, 0xa) && chipset !=
Chipset(9, 4, 1)) {
186 [](RawBufferAtomicFmaxOp op) ->
bool {
187 return op.getValue().getType().isF64();
192 if (chipset ==
Chipset(9, 4, 1)) {
195 target.
addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp,
196 RawBufferAtomicSmaxOp, RawBufferAtomicUminOp>();
200 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
201 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
202 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
203 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
207 void AmdgpuEmulateAtomicsPass::runOnOperation() {
210 if (failed(maybeChipset)) {
212 return signalPassFailure();
218 target.markUnknownOpDynamicallyLegal(
219 [](
Operation *op) ->
bool {
return true; });
223 return signalPassFailure();
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, Value val)
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.