21#define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
22#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
29struct AmdgpuEmulateAtomicsPass
30 :
public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
31 AmdgpuEmulateAtomicsPass> {
32 using AmdgpuEmulateAtomicsPassBase<
33 AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
34 void runOnOperation()
override;
37template <
typename AtomicOp,
typename ArithOp>
38struct RawBufferAtomicByCasPattern :
public OpConversionPattern<AtomicOp> {
39 using OpConversionPattern<AtomicOp>::OpConversionPattern;
40 using Adaptor =
typename AtomicOp::Adaptor;
43 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
44 ConversionPatternRewriter &rewriter)
const override;
49enum class DataArgAction :
unsigned char {
62 DataArgAction action) {
63 newAttrs.reserve(attrs.size());
65 if (attr.getName().getValue() !=
"operandSegmentSizes") {
66 newAttrs.push_back(attr);
69 auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
73 case DataArgAction::Drop:
75 context, segmentAttr.asArrayRef().drop_front());
77 case DataArgAction::Duplicate: {
80 newVals.push_back(oldVals[0]);
81 newVals.append(oldVals.begin(), oldVals.end());
94 auto vectorType = dyn_cast<VectorType>(val.
getType());
99 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
100 Type allBitsType = rewriter.getIntegerType(bitwidth);
101 auto allBitsVecType = VectorType::get({1}, allBitsType);
102 Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val);
103 Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0);
107template <
typename AtomicOp,
typename ArithOp>
108LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
109 AtomicOp atomicOp, Adaptor adaptor,
110 ConversionPatternRewriter &rewriter)
const {
115 Value data = operands.take_front()[0];
116 ValueRange invariantArgs = operands.drop_front();
121 Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType,
122 invariantArgs, loadAttrs);
123 Block *currentBlock = rewriter.getInsertionBlock();
125 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
127 Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
129 rewriter.setInsertionPointToEnd(currentBlock);
130 cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad);
132 rewriter.setInsertionPointToEnd(loopBlock);
134 Value operated = ArithOp::create(rewriter, loc, data, prevLoad);
140 cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
141 Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType,
142 cmpswapArgs, cmpswapAttrs);
151 if (
auto floatDataTy = dyn_cast<FloatType>(dataType)) {
152 Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
154 arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad);
155 atomicResForCompare =
156 arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes);
159 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
160 atomicResForCompare, prevLoadForCompare);
161 cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic,
172 target.addIllegalOp<RawBufferAtomicFaddOp>();
176 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
177 [](RawBufferAtomicFaddOp op) ->
bool {
179 return !isa<Float16Type, BFloat16Type>(elemType);
184 if (chipset >=
Chipset(9, 0, 0xa)) {
187 target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
188 [](RawBufferAtomicFmaxOp op) ->
bool {
189 return op.getValue().getType().isF64();
192 target.addIllegalOp<RawBufferAtomicFmaxOp>();
196 if (chipset <
Chipset(9, 5, 0)) {
197 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
198 [](RawBufferAtomicFaddOp op) ->
bool {
200 return !isa<BFloat16Type>(elemType);
205 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
206 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
207 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
208 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
212void AmdgpuEmulateAtomicsPass::runOnOperation() {
215 if (failed(maybeChipset)) {
217 return signalPassFailure();
223 target.markUnknownOpDynamicallyLegal(
224 [](
Operation *op) ->
bool {
return true; });
227 if (
failed(applyPartialConversion(op,
target, std::move(patterns))))
228 return signalPassFailure();
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, Value val)
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset, PatternBenefit benefit=1)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.