18 #include "mlir/IR/TypeUtilities.h"
21 namespace mlir::amdgpu {
23 #include "mlir/Dialect/AMDGPU/Transforms/"
24 } // namespace mlir::amdgpu
26 using namespace mlir;
27 using namespace mlir::amdgpu;
29 namespace {
30 struct AmdgpuEmulateAtomicsPass
31  : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
32  AmdgpuEmulateAtomicsPass> {
33  using AmdgpuEmulateAtomicsPassBase<
34  AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
35  void runOnOperation() override;
36 };
38 template <typename AtomicOp, typename ArithOp>
39 struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
41  using Adaptor = typename AtomicOp::Adaptor;
43  LogicalResult
44  matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
45  ConversionPatternRewriter &rewriter) const override;
46 };
47 } // namespace
49 namespace {
50 enum class DataArgAction : unsigned char {
51  Duplicate,
52  Drop,
53 };
54 } // namespace
56 // Fix up the fact that, when we're migrating from a general bugffer atomic
57 // to a load or to a CAS, the number of openrands, and thus the number of
58 // entries needed in operandSegmentSizes, needs to change. We use this method
59 // because we'd like to preserve unknown attributes on the atomic instead of
60 // discarding them.
63  DataArgAction action) {
64  newAttrs.reserve(attrs.size());
65  for (NamedAttribute attr : attrs) {
66  if (attr.getName().getValue() != "operandSegmentSizes") {
67  newAttrs.push_back(attr);
68  continue;
69  }
70  auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
71  MLIRContext *context = segmentAttr.getContext();
72  DenseI32ArrayAttr newSegments;
73  switch (action) {
74  case DataArgAction::Drop:
75  newSegments = DenseI32ArrayAttr::get(
76  context, segmentAttr.asArrayRef().drop_front());
77  break;
78  case DataArgAction::Duplicate: {
79  SmallVector<int32_t> newVals;
80  ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
81  newVals.push_back(oldVals[0]);
82  newVals.append(oldVals.begin(), oldVals.end());
83  newSegments = DenseI32ArrayAttr::get(context, newVals);
84  break;
85  }
86  }
87  newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));
88  }
89 }
91 // A helper function to flatten a vector value to a scalar containing its bits,
92 // returning the value itself if othetwise.
94  Value val) {
95  auto vectorType = dyn_cast<VectorType>(val.getType());
96  if (!vectorType)
97  return val;
99  int64_t bitwidth =
100  vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
101  Type allBitsType = rewriter.getIntegerType(bitwidth);
102  auto allBitsVecType = VectorType::get({1}, allBitsType);
103  Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
104  Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
105  return scalar;
106 }
108 template <typename AtomicOp, typename ArithOp>
109 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
110  AtomicOp atomicOp, Adaptor adaptor,
111  ConversionPatternRewriter &rewriter) const {
112  Location loc = atomicOp.getLoc();
114  ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
115  ValueRange operands = adaptor.getOperands();
116  Value data = operands.take_front()[0];
117  ValueRange invariantArgs = operands.drop_front();
118  Type dataType = data.getType();
120  SmallVector<NamedAttribute> loadAttrs;
121  patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
122  Value initialLoad =
123  rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
124  Block *currentBlock = rewriter.getInsertionBlock();
125  Block *afterAtomic =
126  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
127  Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
129  rewriter.setInsertionPointToEnd(currentBlock);
130  rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad);
132  rewriter.setInsertionPointToEnd(loopBlock);
133  Value prevLoad = loopBlock->getArgument(0);
134  Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
135  dataType = operated.getType();
137  SmallVector<NamedAttribute> cmpswapAttrs;
138  patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
139  SmallVector<Value> cmpswapArgs = {operated, prevLoad};
140  cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
141  Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
142  loc, dataType, cmpswapArgs, cmpswapAttrs);
144  // We care about exact bitwise equality here, so do some bitcasts.
145  // These will fold away during lowering to the ROCDL dialect, where
146  // an int->float bitcast is introduced to account for the fact that cmpswap
147  // only takes integer arguments.
149  Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad);
150  Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes);
151  if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
152  Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
153  prevLoadForCompare =
154  rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
155  atomicResForCompare =
156  rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes);
157  }
158  Value canLeave = rewriter.create<arith::CmpIOp>(
159  loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
160  rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
161  loopBlock, atomicRes);
162  rewriter.eraseOp(atomicOp);
163  return success();
164 }
167  ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
168  // gfx10 has no atomic adds.
169  if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
170  target.addIllegalOp<RawBufferAtomicFaddOp>();
171  }
172  // gfx11 has no fp16 atomics
173  if (chipset.majorVersion == 11) {
174  target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
175  [](RawBufferAtomicFaddOp op) -> bool {
176  Type elemType = getElementTypeOrSelf(op.getValue().getType());
177  return !isa<Float16Type, BFloat16Type>(elemType);
178  });
179  }
180  // gfx9 has no to a very limited support for floating-point min and max.
181  if (chipset.majorVersion == 9) {
182  if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) {
183  // gfx90a supports f64 max (and min, but we don't have a min wrapper right
184  // now) but all other types need to be emulated.
185  target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
186  [](RawBufferAtomicFmaxOp op) -> bool {
187  return op.getValue().getType().isF64();
188  });
189  } else {
190  target.addIllegalOp<RawBufferAtomicFmaxOp>();
191  }
192  if (chipset == Chipset(9, 4, 1)) {
193  // gfx941 requires non-CAS atomics to be implemented with CAS loops.
194  // The workaround here mirrors HIP and OpenMP.
195  target.addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp,
196  RawBufferAtomicSmaxOp, RawBufferAtomicUminOp>();
197  }
198  }
199  patterns.add<
200  RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
201  RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
202  RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
203  RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
204  patterns.getContext());
205 }
207 void AmdgpuEmulateAtomicsPass::runOnOperation() {
208  Operation *op = getOperation();
209  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
210  if (failed(maybeChipset)) {
211  emitError(op->getLoc(), "Invalid chipset name: " + chipset);
212  return signalPassFailure();
213  }
215  MLIRContext &ctx = getContext();
216  ConversionTarget target(ctx);
217  RewritePatternSet patterns(&ctx);
218  target.markUnknownOpDynamicallyLegal(
219  [](Operation *op) -> bool { return true; });
221  populateAmdgpuEmulateAtomicsPatterns(target, patterns, *maybeChipset);
222  if (failed(applyPartialConversion(op, target, std::move(patterns))))
223  return signalPassFailure();
224 }
