MLIR  22.0.0git
EmulateAtomics.cpp
Go to the documentation of this file.
1 //===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/TypeUtilities.h"
19 
20 namespace mlir::amdgpu {
21 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
22 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
23 } // namespace mlir::amdgpu
24 
25 using namespace mlir;
26 using namespace mlir::amdgpu;
27 
28 namespace {
29 struct AmdgpuEmulateAtomicsPass
30  : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
31  AmdgpuEmulateAtomicsPass> {
32  using AmdgpuEmulateAtomicsPassBase<
33  AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
34  void runOnOperation() override;
35 };
36 
37 template <typename AtomicOp, typename ArithOp>
38 struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
40  using Adaptor = typename AtomicOp::Adaptor;
41 
42  LogicalResult
43  matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
44  ConversionPatternRewriter &rewriter) const override;
45 };
46 } // namespace
47 
48 namespace {
49 enum class DataArgAction : unsigned char {
50  Duplicate,
51  Drop,
52 };
53 } // namespace
54 
55 // Fix up the fact that, when we're migrating from a general bugffer atomic
56 // to a load or to a CAS, the number of openrands, and thus the number of
57 // entries needed in operandSegmentSizes, needs to change. We use this method
58 // because we'd like to preserve unknown attributes on the atomic instead of
59 // discarding them.
62  DataArgAction action) {
63  newAttrs.reserve(attrs.size());
64  for (NamedAttribute attr : attrs) {
65  if (attr.getName().getValue() != "operandSegmentSizes") {
66  newAttrs.push_back(attr);
67  continue;
68  }
69  auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
70  MLIRContext *context = segmentAttr.getContext();
71  DenseI32ArrayAttr newSegments;
72  switch (action) {
73  case DataArgAction::Drop:
74  newSegments = DenseI32ArrayAttr::get(
75  context, segmentAttr.asArrayRef().drop_front());
76  break;
77  case DataArgAction::Duplicate: {
78  SmallVector<int32_t> newVals;
79  ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
80  newVals.push_back(oldVals[0]);
81  newVals.append(oldVals.begin(), oldVals.end());
82  newSegments = DenseI32ArrayAttr::get(context, newVals);
83  break;
84  }
85  }
86  newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));
87  }
88 }
89 
90 // A helper function to flatten a vector value to a scalar containing its bits,
91 // returning the value itself if othetwise.
93  Value val) {
94  auto vectorType = dyn_cast<VectorType>(val.getType());
95  if (!vectorType)
96  return val;
97 
98  int64_t bitwidth =
99  vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
100  Type allBitsType = rewriter.getIntegerType(bitwidth);
101  auto allBitsVecType = VectorType::get({1}, allBitsType);
102  Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val);
103  Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0);
104  return scalar;
105 }
106 
107 template <typename AtomicOp, typename ArithOp>
108 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
109  AtomicOp atomicOp, Adaptor adaptor,
110  ConversionPatternRewriter &rewriter) const {
111  Location loc = atomicOp.getLoc();
112 
113  ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
114  ValueRange operands = adaptor.getOperands();
115  Value data = operands.take_front()[0];
116  ValueRange invariantArgs = operands.drop_front();
117  Type dataType = data.getType();
118 
119  SmallVector<NamedAttribute> loadAttrs;
120  patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
121  Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType,
122  invariantArgs, loadAttrs);
123  Block *currentBlock = rewriter.getInsertionBlock();
124  Block *afterAtomic =
125  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
126  Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
127 
128  rewriter.setInsertionPointToEnd(currentBlock);
129  cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad);
130 
131  rewriter.setInsertionPointToEnd(loopBlock);
132  Value prevLoad = loopBlock->getArgument(0);
133  Value operated = ArithOp::create(rewriter, loc, data, prevLoad);
134  dataType = operated.getType();
135 
136  SmallVector<NamedAttribute> cmpswapAttrs;
137  patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
138  SmallVector<Value> cmpswapArgs = {operated, prevLoad};
139  cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
140  Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType,
141  cmpswapArgs, cmpswapAttrs);
142 
143  // We care about exact bitwise equality here, so do some bitcasts.
144  // These will fold away during lowering to the ROCDL dialect, where
145  // an int->float bitcast is introduced to account for the fact that cmpswap
146  // only takes integer arguments.
147 
148  Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad);
149  Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes);
150  if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
151  Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
152  prevLoadForCompare =
153  arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad);
154  atomicResForCompare =
155  arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes);
156  }
157  Value canLeave =
158  arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
159  atomicResForCompare, prevLoadForCompare);
160  cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic, ValueRange{},
161  loopBlock, atomicRes);
162  rewriter.eraseOp(atomicOp);
163  return success();
164 }
165 
168  PatternBenefit benefit) {
169  // gfx10 has no atomic adds.
170  if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
171  target.addIllegalOp<RawBufferAtomicFaddOp>();
172  }
173  // gfx11 has no fp16 atomics
174  if (chipset.majorVersion == 11) {
175  target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
176  [](RawBufferAtomicFaddOp op) -> bool {
177  Type elemType = getElementTypeOrSelf(op.getValue().getType());
178  return !isa<Float16Type, BFloat16Type>(elemType);
179  });
180  }
181  // gfx9 has no to a very limited support for floating-point min and max.
182  if (chipset.majorVersion == 9) {
183  if (chipset >= Chipset(9, 0, 0xa)) {
184  // gfx90a supports f64 max (and min, but we don't have a min wrapper right
185  // now) but all other types need to be emulated.
186  target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
187  [](RawBufferAtomicFmaxOp op) -> bool {
188  return op.getValue().getType().isF64();
189  });
190  } else {
191  target.addIllegalOp<RawBufferAtomicFmaxOp>();
192  }
193  // TODO(https://github.com/llvm/llvm-project/issues/129206): Refactor
194  // this to avoid hardcoding ISA version: gfx950 has bf16 atomics.
195  if (chipset < Chipset(9, 5, 0)) {
196  target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
197  [](RawBufferAtomicFaddOp op) -> bool {
198  Type elemType = getElementTypeOrSelf(op.getValue().getType());
199  return !isa<BFloat16Type>(elemType);
200  });
201  }
202  }
203  patterns.add<
204  RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
205  RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
206  RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
207  RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
208  patterns.getContext(), benefit);
209 }
210 
211 void AmdgpuEmulateAtomicsPass::runOnOperation() {
212  Operation *op = getOperation();
213  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
214  if (failed(maybeChipset)) {
215  emitError(op->getLoc(), "Invalid chipset name: " + chipset);
216  return signalPassFailure();
217  }
218 
219  MLIRContext &ctx = getContext();
220  ConversionTarget target(ctx);
222  target.markUnknownOpDynamicallyLegal(
223  [](Operation *op) -> bool { return true; });
224 
225  populateAmdgpuEmulateAtomicsPatterns(target, patterns, *maybeChipset);
226  if (failed(applyPartialConversion(op, target, std::move(patterns))))
227  return signalPassFailure();
228 }
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, Value val)
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset, PatternBenefit benefit=1)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14