20 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
21 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
28 struct AmdgpuEmulateAtomicsPass
29 :
public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
30 AmdgpuEmulateAtomicsPass> {
31 using AmdgpuEmulateAtomicsPassBase<
32 AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
33 void runOnOperation()
override;
36 template <
typename AtomicOp,
typename ArithOp>
39 using Adaptor =
typename AtomicOp::Adaptor;
42 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
48 enum class DataArgAction : unsigned char {
61 DataArgAction action) {
62 newAttrs.reserve(attrs.size());
64 if (attr.getName().getValue() !=
"operandSegmentSizes") {
65 newAttrs.push_back(attr);
68 auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
72 case DataArgAction::Drop:
74 context, segmentAttr.asArrayRef().drop_front());
76 case DataArgAction::Duplicate: {
79 newVals.push_back(oldVals[0]);
80 newVals.append(oldVals.begin(), oldVals.end());
89 template <
typename AtomicOp,
typename ArithOp>
90 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
91 AtomicOp atomicOp, Adaptor adaptor,
97 Value data = operands.take_front()[0];
98 ValueRange invariantArgs = operands.drop_front();
104 rewriter.
create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
111 rewriter.
create<cf::BranchOp>(loc, loopBlock, initialLoad);
114 Value prevLoad = loopBlock->getArgument(0);
115 Value operated = rewriter.
create<ArithOp>(loc, data, prevLoad);
120 cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
121 Value atomicRes = rewriter.
create<RawBufferAtomicCmpswapOp>(
122 loc, dataType, cmpswapArgs, cmpswapAttrs);
129 Value prevLoadForCompare = prevLoad;
130 Value atomicResForCompare = atomicRes;
131 if (
auto floatDataTy = dyn_cast<FloatType>(dataType)) {
134 rewriter.
create<arith::BitcastOp>(loc, equivInt, prevLoad);
135 atomicResForCompare =
136 rewriter.
create<arith::BitcastOp>(loc, equivInt, atomicRes);
139 loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
141 loopBlock, atomicRes);
159 [](RawBufferAtomicFmaxOp op) ->
bool {
160 return op.getValue().getType().isF64();
168 target.
addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp,
169 RawBufferAtomicSmaxOp, RawBufferAtomicUminOp>();
173 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
174 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
175 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
176 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
180 void AmdgpuEmulateAtomicsPass::runOnOperation() {
183 if (
failed(maybeChipset)) {
185 return signalPassFailure();
191 target.markUnknownOpDynamicallyLegal(
192 [](
Operation *op) ->
bool {
return true; });
196 return signalPassFailure();
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
IntegerType getIntegerType(unsigned width)
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
static FailureOr< Chipset > parse(StringRef name)