MLIR  21.0.0git
EmulateAtomics.cpp
Go to the documentation of this file.
1 //===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/TypeUtilities.h"
20 
21 namespace mlir::amdgpu {
22 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
23 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
24 } // namespace mlir::amdgpu
25 
26 using namespace mlir;
27 using namespace mlir::amdgpu;
28 
29 namespace {
30 struct AmdgpuEmulateAtomicsPass
31  : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
32  AmdgpuEmulateAtomicsPass> {
33  using AmdgpuEmulateAtomicsPassBase<
34  AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
35  void runOnOperation() override;
36 };
37 
38 template <typename AtomicOp, typename ArithOp>
39 struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
41  using Adaptor = typename AtomicOp::Adaptor;
42 
43  LogicalResult
44  matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
45  ConversionPatternRewriter &rewriter) const override;
46 };
47 } // namespace
48 
49 namespace {
50 enum class DataArgAction : unsigned char {
51  Duplicate,
52  Drop,
53 };
54 } // namespace
55 
56 // Fix up the fact that, when we're migrating from a general bugffer atomic
57 // to a load or to a CAS, the number of openrands, and thus the number of
58 // entries needed in operandSegmentSizes, needs to change. We use this method
59 // because we'd like to preserve unknown attributes on the atomic instead of
60 // discarding them.
63  DataArgAction action) {
64  newAttrs.reserve(attrs.size());
65  for (NamedAttribute attr : attrs) {
66  if (attr.getName().getValue() != "operandSegmentSizes") {
67  newAttrs.push_back(attr);
68  continue;
69  }
70  auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
71  MLIRContext *context = segmentAttr.getContext();
72  DenseI32ArrayAttr newSegments;
73  switch (action) {
74  case DataArgAction::Drop:
75  newSegments = DenseI32ArrayAttr::get(
76  context, segmentAttr.asArrayRef().drop_front());
77  break;
78  case DataArgAction::Duplicate: {
79  SmallVector<int32_t> newVals;
80  ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
81  newVals.push_back(oldVals[0]);
82  newVals.append(oldVals.begin(), oldVals.end());
83  newSegments = DenseI32ArrayAttr::get(context, newVals);
84  break;
85  }
86  }
87  newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));
88  }
89 }
90 
91 // A helper function to flatten a vector value to a scalar containing its bits,
92 // returning the value itself if othetwise.
94  Value val) {
95  auto vectorType = dyn_cast<VectorType>(val.getType());
96  if (!vectorType)
97  return val;
98 
99  int64_t bitwidth =
100  vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
101  Type allBitsType = rewriter.getIntegerType(bitwidth);
102  auto allBitsVecType = VectorType::get({1}, allBitsType);
103  Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
104  Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
105  return scalar;
106 }
107 
108 template <typename AtomicOp, typename ArithOp>
109 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
110  AtomicOp atomicOp, Adaptor adaptor,
111  ConversionPatternRewriter &rewriter) const {
112  Location loc = atomicOp.getLoc();
113 
114  ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
115  ValueRange operands = adaptor.getOperands();
116  Value data = operands.take_front()[0];
117  ValueRange invariantArgs = operands.drop_front();
118  Type dataType = data.getType();
119 
120  SmallVector<NamedAttribute> loadAttrs;
121  patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
122  Value initialLoad =
123  rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
124  Block *currentBlock = rewriter.getInsertionBlock();
125  Block *afterAtomic =
126  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
127  Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
128 
129  rewriter.setInsertionPointToEnd(currentBlock);
130  rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad);
131 
132  rewriter.setInsertionPointToEnd(loopBlock);
133  Value prevLoad = loopBlock->getArgument(0);
134  Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
135  dataType = operated.getType();
136 
137  SmallVector<NamedAttribute> cmpswapAttrs;
138  patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
139  SmallVector<Value> cmpswapArgs = {operated, prevLoad};
140  cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
141  Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
142  loc, dataType, cmpswapArgs, cmpswapAttrs);
143 
144  // We care about exact bitwise equality here, so do some bitcasts.
145  // These will fold away during lowering to the ROCDL dialect, where
146  // an int->float bitcast is introduced to account for the fact that cmpswap
147  // only takes integer arguments.
148 
149  Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad);
150  Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes);
151  if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
152  Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
153  prevLoadForCompare =
154  rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
155  atomicResForCompare =
156  rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes);
157  }
158  Value canLeave = rewriter.create<arith::CmpIOp>(
159  loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
160  rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
161  loopBlock, atomicRes);
162  rewriter.eraseOp(atomicOp);
163  return success();
164 }
165 
168  PatternBenefit benefit) {
169  // gfx10 has no atomic adds.
170  if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
171  target.addIllegalOp<RawBufferAtomicFaddOp>();
172  }
173  // gfx11 has no fp16 atomics
174  if (chipset.majorVersion == 11) {
175  target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
176  [](RawBufferAtomicFaddOp op) -> bool {
177  Type elemType = getElementTypeOrSelf(op.getValue().getType());
178  return !isa<Float16Type, BFloat16Type>(elemType);
179  });
180  }
181  // gfx9 has no to a very limited support for floating-point min and max.
182  if (chipset.majorVersion == 9) {
183  if (chipset >= Chipset(9, 0, 0xa)) {
184  // gfx90a supports f64 max (and min, but we don't have a min wrapper right
185  // now) but all other types need to be emulated.
186  target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
187  [](RawBufferAtomicFmaxOp op) -> bool {
188  return op.getValue().getType().isF64();
189  });
190  } else {
191  target.addIllegalOp<RawBufferAtomicFmaxOp>();
192  }
193  // TODO(https://github.com/llvm/llvm-project/issues/129206): Refactor
194  // this to avoid hardcoding ISA version: gfx950 has bf16 atomics.
195  if (chipset < Chipset(9, 5, 0)) {
196  target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
197  [](RawBufferAtomicFaddOp op) -> bool {
198  Type elemType = getElementTypeOrSelf(op.getValue().getType());
199  return !isa<BFloat16Type>(elemType);
200  });
201  }
202  }
203  patterns.add<
204  RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
205  RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
206  RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
207  RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
208  patterns.getContext(), benefit);
209 }
210 
211 void AmdgpuEmulateAtomicsPass::runOnOperation() {
212  Operation *op = getOperation();
213  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
214  if (failed(maybeChipset)) {
215  emitError(op->getLoc(), "Invalid chipset name: " + chipset);
216  return signalPassFailure();
217  }
218 
219  MLIRContext &ctx = getContext();
220  ConversionTarget target(ctx);
222  target.markUnknownOpDynamicallyLegal(
223  [](Operation *op) -> bool { return true; });
224 
225  populateAmdgpuEmulateAtomicsPatterns(target, patterns, *maybeChipset);
226  if (failed(applyPartialConversion(op, target, std::move(patterns))))
227  return signalPassFailure();
228 }
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, Value val)
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:434
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset, PatternBenefit benefit=1)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14