21#define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
22#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
29struct AmdgpuEmulateAtomicsPass
31 AmdgpuEmulateAtomicsPass> {
34 void runOnOperation()
override;
37template <
typename AtomicOp,
typename ArithOp>
38struct RawBufferAtomicByCasPattern :
public OpConversionPattern<AtomicOp> {
39 using OpConversionPattern<AtomicOp>::OpConversionPattern;
40 using Adaptor =
typename AtomicOp::Adaptor;
43 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
44 ConversionPatternRewriter &rewriter)
const override;
49enum class DataArgAction :
unsigned char {
62 DataArgAction action) {
63 newAttrs.reserve(attrs.size());
65 if (attr.getName().getValue() !=
"operandSegmentSizes") {
66 newAttrs.push_back(attr);
69 auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
73 case DataArgAction::Drop:
75 context, segmentAttr.asArrayRef().drop_front());
77 case DataArgAction::Duplicate: {
80 newVals.push_back(oldVals[0]);
81 newVals.append(oldVals.begin(), oldVals.end());
94 auto vectorType = dyn_cast<VectorType>(val.
getType());
99 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
100 Type allBitsType = rewriter.getIntegerType(bitwidth);
101 auto allBitsVecType = VectorType::get({1}, allBitsType);
102 Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val);
103 Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0);
107template <
typename AtomicOp,
typename ArithOp>
108LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
109 AtomicOp atomicOp, Adaptor adaptor,
110 ConversionPatternRewriter &rewriter)
const {
115 Value data = operands.take_front()[0];
116 ValueRange invariantArgs = operands.drop_front();
121 Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType,
122 invariantArgs, loadAttrs);
123 Block *currentBlock = rewriter.getInsertionBlock();
125 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
126 Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
128 rewriter.setInsertionPointToEnd(currentBlock);
129 cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad);
131 rewriter.setInsertionPointToEnd(loopBlock);
133 Value operated = ArithOp::create(rewriter, loc, data, prevLoad);
139 cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
140 Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType,
141 cmpswapArgs, cmpswapAttrs);
150 if (
auto floatDataTy = dyn_cast<FloatType>(dataType)) {
151 Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
153 arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad);
154 atomicResForCompare =
155 arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes);
158 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
159 atomicResForCompare, prevLoadForCompare);
160 cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic,
ValueRange{},
161 loopBlock, atomicRes);
162 rewriter.eraseOp(atomicOp);
171 target.addIllegalOp<RawBufferAtomicFaddOp>();
175 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
176 [](RawBufferAtomicFaddOp op) ->
bool {
178 return !isa<Float16Type, BFloat16Type>(elemType);
183 if (chipset >=
Chipset(9, 0, 0xa)) {
186 target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
187 [](RawBufferAtomicFmaxOp op) ->
bool {
188 return op.getValue().getType().isF64();
191 target.addIllegalOp<RawBufferAtomicFmaxOp>();
195 if (chipset <
Chipset(9, 5, 0)) {
196 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
197 [](RawBufferAtomicFaddOp op) ->
bool {
199 return !isa<BFloat16Type>(elemType);
204 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
205 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
206 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
207 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
211void AmdgpuEmulateAtomicsPass::runOnOperation() {
214 if (failed(maybeChipset)) {
216 return signalPassFailure();
222 target.markUnknownOpDynamicallyLegal(
223 [](
Operation *op) ->
bool {
return true; });
227 return signalPassFailure();
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, Value val)
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset, PatternBenefit benefit=1)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.