37template <
typename AtomicOp,
typename ArithOp>
38struct RawBufferAtomicByCasPattern :
public OpConversionPattern<AtomicOp> {
39 using OpConversionPattern<AtomicOp>::OpConversionPattern;
40 using Adaptor =
typename AtomicOp::Adaptor;
43 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
44 ConversionPatternRewriter &rewriter)
const override;
49enum class DataArgAction :
unsigned char {
62 DataArgAction action) {
63 newAttrs.reserve(attrs.size());
65 if (attr.getName().getValue() !=
"operandSegmentSizes") {
66 newAttrs.push_back(attr);
69 auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
73 case DataArgAction::Drop:
75 context, segmentAttr.asArrayRef().drop_front());
77 case DataArgAction::Duplicate: {
80 newVals.push_back(oldVals[0]);
81 newVals.append(oldVals.begin(), oldVals.end());
94 auto vectorType = dyn_cast<VectorType>(val.
getType());
99 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
100 Type allBitsType = rewriter.getIntegerType(bitwidth);
101 auto allBitsVecType = VectorType::get({1}, allBitsType);
102 Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val);
103 Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0);
107template <
typename AtomicOp,
typename ArithOp>
108LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
109 AtomicOp atomicOp, Adaptor adaptor,
110 ConversionPatternRewriter &rewriter)
const {
115 Value data = operands.take_front()[0];
116 ValueRange invariantArgs = operands.drop_front();
121 Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType,
122 invariantArgs, loadAttrs);
123 Block *currentBlock = rewriter.getInsertionBlock();
125 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
127 Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
129 rewriter.setInsertionPointToEnd(currentBlock);
130 cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad);
132 rewriter.setInsertionPointToEnd(loopBlock);
134 Value operated = ArithOp::create(rewriter, loc, data, prevLoad);
140 cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
141 Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType,
142 cmpswapArgs, cmpswapAttrs);
151 if (
auto floatDataTy = dyn_cast<FloatType>(dataType)) {
152 Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
154 arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad);
155 atomicResForCompare =
156 arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes);
159 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
160 atomicResForCompare, prevLoadForCompare);
161 cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic,
172 target.addIllegalOp<RawBufferAtomicFaddOp>();
176 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
177 [](RawBufferAtomicFaddOp op) ->
bool {
179 return !isa<Float16Type, BFloat16Type>(elemType);
184 if (chipset >=
Chipset(9, 0, 0xa)) {
187 target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
188 [](RawBufferAtomicFmaxOp op) ->
bool {
189 return op.getValue().getType().isF64();
192 target.addIllegalOp<RawBufferAtomicFmaxOp>();
196 if (chipset <
Chipset(9, 5, 0)) {
197 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
198 [](RawBufferAtomicFaddOp op) ->
bool {
200 return !isa<BFloat16Type>(elemType);
205 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
206 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
207 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
208 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
212void AmdgpuEmulateAtomicsPass::runOnOperation() {
215 if (failed(maybeChipset)) {
217 return signalPassFailure();
223 target.markUnknownOpDynamicallyLegal(
224 [](
Operation *op) ->
bool {
return true; });
227 if (
failed(applyPartialConversion(op,
target, std::move(patterns))))
228 return signalPassFailure();
MLIRContext is the top-level object for a collection of MLIR operations.