MLIR  19.0.0git
EmulateAtomics.cpp
Go to the documentation of this file.
1 //===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 
19 namespace mlir::amdgpu {
20 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
21 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
22 } // namespace mlir::amdgpu
23 
24 using namespace mlir;
25 using namespace mlir::amdgpu;
26 
27 namespace {
28 struct AmdgpuEmulateAtomicsPass
29  : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
30  AmdgpuEmulateAtomicsPass> {
31  using AmdgpuEmulateAtomicsPassBase<
32  AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
33  void runOnOperation() override;
34 };
35 
36 template <typename AtomicOp, typename ArithOp>
37 struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
39  using Adaptor = typename AtomicOp::Adaptor;
40 
42  matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
43  ConversionPatternRewriter &rewriter) const override;
44 };
45 } // namespace
46 
47 namespace {
48 enum class DataArgAction : unsigned char {
49  Duplicate,
50  Drop,
51 };
52 } // namespace
53 
54 // Fix up the fact that, when we're migrating from a general bugffer atomic
55 // to a load or to a CAS, the number of openrands, and thus the number of
56 // entries needed in operandSegmentSizes, needs to change. We use this method
57 // because we'd like to preserve unknown attributes on the atomic instead of
58 // discarding them.
61  DataArgAction action) {
62  newAttrs.reserve(attrs.size());
63  for (NamedAttribute attr : attrs) {
64  if (attr.getName().getValue() != "operandSegmentSizes") {
65  newAttrs.push_back(attr);
66  continue;
67  }
68  auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
69  MLIRContext *context = segmentAttr.getContext();
70  DenseI32ArrayAttr newSegments;
71  switch (action) {
72  case DataArgAction::Drop:
73  newSegments = DenseI32ArrayAttr::get(
74  context, segmentAttr.asArrayRef().drop_front());
75  break;
76  case DataArgAction::Duplicate: {
77  SmallVector<int32_t> newVals;
78  ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
79  newVals.push_back(oldVals[0]);
80  newVals.append(oldVals.begin(), oldVals.end());
81  newSegments = DenseI32ArrayAttr::get(context, newVals);
82  break;
83  }
84  }
85  newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));
86  }
87 }
88 
89 template <typename AtomicOp, typename ArithOp>
90 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
91  AtomicOp atomicOp, Adaptor adaptor,
92  ConversionPatternRewriter &rewriter) const {
93  Location loc = atomicOp.getLoc();
94 
95  ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
96  ValueRange operands = adaptor.getOperands();
97  Value data = operands.take_front()[0];
98  ValueRange invariantArgs = operands.drop_front();
99  Type dataType = data.getType();
100 
101  SmallVector<NamedAttribute> loadAttrs;
102  patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
103  Value initialLoad =
104  rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
105  Block *currentBlock = rewriter.getInsertionBlock();
106  Block *afterAtomic =
107  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
108  Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
109 
110  rewriter.setInsertionPointToEnd(currentBlock);
111  rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad);
112 
113  rewriter.setInsertionPointToEnd(loopBlock);
114  Value prevLoad = loopBlock->getArgument(0);
115  Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
116 
117  SmallVector<NamedAttribute> cmpswapAttrs;
118  patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
119  SmallVector<Value> cmpswapArgs = {operated, prevLoad};
120  cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
121  Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
122  loc, dataType, cmpswapArgs, cmpswapAttrs);
123 
124  // We care about exact bitwise equality here, so do some bitcasts.
125  // These will fold away during lowering to the ROCDL dialect, where
126  // an int->float bitcast is introduced to account for the fact that cmpswap
127  // only takes integer arguments.
128 
129  Value prevLoadForCompare = prevLoad;
130  Value atomicResForCompare = atomicRes;
131  if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
132  Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
133  prevLoadForCompare =
134  rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
135  atomicResForCompare =
136  rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes);
137  }
138  Value canLeave = rewriter.create<arith::CmpIOp>(
139  loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
140  rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
141  loopBlock, atomicRes);
142  rewriter.eraseOp(atomicOp);
143  return success();
144 }
145 
147  ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
148  // gfx10 has no atomic adds.
149  if (chipset.majorVersion == 10 || chipset.majorVersion < 9 ||
150  (chipset.majorVersion == 9 && chipset.minorVersion < 0x08)) {
151  target.addIllegalOp<RawBufferAtomicFaddOp>();
152  }
153  // gfx9 has no to a very limited support for floating-point min and max.
154  if (chipset.majorVersion == 9) {
155  if (chipset.minorVersion >= 0x0a && chipset.minorVersion != 0x41) {
156  // gfx90a supports f64 max (and min, but we don't have a min wrapper right
157  // now) but all other types need to be emulated.
158  target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
159  [](RawBufferAtomicFmaxOp op) -> bool {
160  return op.getValue().getType().isF64();
161  });
162  } else {
163  target.addIllegalOp<RawBufferAtomicFmaxOp>();
164  }
165  if (chipset.minorVersion == 0x41) {
166  // gfx941 requires non-CAS atomics to be implemented with CAS loops.
167  // The workaround here mirrors HIP and OpenMP.
168  target.addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp,
169  RawBufferAtomicSmaxOp, RawBufferAtomicUminOp>();
170  }
171  }
172  patterns.add<
173  RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
174  RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
175  RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
176  RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
177  patterns.getContext());
178 }
179 
180 void AmdgpuEmulateAtomicsPass::runOnOperation() {
181  Operation *op = getOperation();
182  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
183  if (failed(maybeChipset)) {
184  emitError(op->getLoc(), "Invalid chipset name: " + chipset);
185  return signalPassFailure();
186  }
187 
188  MLIRContext &ctx = getContext();
189  ConversionTarget target(ctx);
190  RewritePatternSet patterns(&ctx);
191  target.markUnknownOpDynamicallyLegal(
192  [](Operation *op) -> bool { return true; });
193 
194  populateAmdgpuEmulateAtomicsPatterns(target, patterns, *maybeChipset);
195  if (failed(applyPartialConversion(op, target, std::move(patterns))))
196  return signalPassFailure();
197 }
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:30
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:438
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static FailureOr< Chipset > parse(StringRef name)
Definition: Chipset.cpp:16
unsigned majorVersion
Definition: Chipset.h:21
unsigned minorVersion
Definition: Chipset.h:22