MLIR  20.0.0git
EmulateAtomics.cpp
Go to the documentation of this file.
1 //===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/TypeUtilities.h"
20 
21 namespace mlir::amdgpu {
22 #define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
23 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
24 } // namespace mlir::amdgpu
25 
26 using namespace mlir;
27 using namespace mlir::amdgpu;
28 
29 namespace {
30 struct AmdgpuEmulateAtomicsPass
31  : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
32  AmdgpuEmulateAtomicsPass> {
33  using AmdgpuEmulateAtomicsPassBase<
34  AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
35  void runOnOperation() override;
36 };
37 
38 template <typename AtomicOp, typename ArithOp>
39 struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
41  using Adaptor = typename AtomicOp::Adaptor;
42 
43  LogicalResult
44  matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
45  ConversionPatternRewriter &rewriter) const override;
46 };
47 } // namespace
48 
49 namespace {
50 enum class DataArgAction : unsigned char {
51  Duplicate,
52  Drop,
53 };
54 } // namespace
55 
56 // Fix up the fact that, when we're migrating from a general bugffer atomic
57 // to a load or to a CAS, the number of openrands, and thus the number of
58 // entries needed in operandSegmentSizes, needs to change. We use this method
59 // because we'd like to preserve unknown attributes on the atomic instead of
60 // discarding them.
63  DataArgAction action) {
64  newAttrs.reserve(attrs.size());
65  for (NamedAttribute attr : attrs) {
66  if (attr.getName().getValue() != "operandSegmentSizes") {
67  newAttrs.push_back(attr);
68  continue;
69  }
70  auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
71  MLIRContext *context = segmentAttr.getContext();
72  DenseI32ArrayAttr newSegments;
73  switch (action) {
74  case DataArgAction::Drop:
75  newSegments = DenseI32ArrayAttr::get(
76  context, segmentAttr.asArrayRef().drop_front());
77  break;
78  case DataArgAction::Duplicate: {
79  SmallVector<int32_t> newVals;
80  ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
81  newVals.push_back(oldVals[0]);
82  newVals.append(oldVals.begin(), oldVals.end());
83  newSegments = DenseI32ArrayAttr::get(context, newVals);
84  break;
85  }
86  }
87  newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));
88  }
89 }
90 
91 // A helper function to flatten a vector value to a scalar containing its bits,
92 // returning the value itself if othetwise.
94  Value val) {
95  auto vectorType = dyn_cast<VectorType>(val.getType());
96  if (!vectorType)
97  return val;
98 
99  int64_t bitwidth =
100  vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
101  Type allBitsType = rewriter.getIntegerType(bitwidth);
102  auto allBitsVecType = VectorType::get({1}, allBitsType);
103  Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
104  Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
105  return scalar;
106 }
107 
108 template <typename AtomicOp, typename ArithOp>
109 LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
110  AtomicOp atomicOp, Adaptor adaptor,
111  ConversionPatternRewriter &rewriter) const {
112  Location loc = atomicOp.getLoc();
113 
114  ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
115  ValueRange operands = adaptor.getOperands();
116  Value data = operands.take_front()[0];
117  ValueRange invariantArgs = operands.drop_front();
118  Type dataType = data.getType();
119 
120  SmallVector<NamedAttribute> loadAttrs;
121  patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
122  Value initialLoad =
123  rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
124  Block *currentBlock = rewriter.getInsertionBlock();
125  Block *afterAtomic =
126  rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
127  Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
128 
129  rewriter.setInsertionPointToEnd(currentBlock);
130  rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad);
131 
132  rewriter.setInsertionPointToEnd(loopBlock);
133  Value prevLoad = loopBlock->getArgument(0);
134  Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
135  dataType = operated.getType();
136 
137  SmallVector<NamedAttribute> cmpswapAttrs;
138  patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
139  SmallVector<Value> cmpswapArgs = {operated, prevLoad};
140  cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
141  Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
142  loc, dataType, cmpswapArgs, cmpswapAttrs);
143 
144  // We care about exact bitwise equality here, so do some bitcasts.
145  // These will fold away during lowering to the ROCDL dialect, where
146  // an int->float bitcast is introduced to account for the fact that cmpswap
147  // only takes integer arguments.
148 
149  Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad);
150  Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes);
151  if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
152  Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
153  prevLoadForCompare =
154  rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
155  atomicResForCompare =
156  rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes);
157  }
158  Value canLeave = rewriter.create<arith::CmpIOp>(
159  loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
160  rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
161  loopBlock, atomicRes);
162  rewriter.eraseOp(atomicOp);
163  return success();
164 }
165 
167  ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
168  // gfx10 has no atomic adds.
169  if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
170  target.addIllegalOp<RawBufferAtomicFaddOp>();
171  }
172  // gfx11 has no fp16 atomics
173  if (chipset.majorVersion == 11) {
174  target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
175  [](RawBufferAtomicFaddOp op) -> bool {
176  Type elemType = getElementTypeOrSelf(op.getValue().getType());
177  return !isa<Float16Type, BFloat16Type>(elemType);
178  });
179  }
180  // gfx9 has no to a very limited support for floating-point min and max.
181  if (chipset.majorVersion == 9) {
182  if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) {
183  // gfx90a supports f64 max (and min, but we don't have a min wrapper right
184  // now) but all other types need to be emulated.
185  target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
186  [](RawBufferAtomicFmaxOp op) -> bool {
187  return op.getValue().getType().isF64();
188  });
189  } else {
190  target.addIllegalOp<RawBufferAtomicFmaxOp>();
191  }
192  if (chipset == Chipset(9, 4, 1)) {
193  // gfx941 requires non-CAS atomics to be implemented with CAS loops.
194  // The workaround here mirrors HIP and OpenMP.
195  target.addIllegalOp<RawBufferAtomicFaddOp, RawBufferAtomicFmaxOp,
196  RawBufferAtomicSmaxOp, RawBufferAtomicUminOp>();
197  }
198  }
199  patterns.add<
200  RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
201  RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
202  RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
203  RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
204  patterns.getContext());
205 }
206 
207 void AmdgpuEmulateAtomicsPass::runOnOperation() {
208  Operation *op = getOperation();
209  FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
210  if (failed(maybeChipset)) {
211  emitError(op->getLoc(), "Invalid chipset name: " + chipset);
212  return signalPassFailure();
213  }
214 
215  MLIRContext &ctx = getContext();
216  ConversionTarget target(ctx);
217  RewritePatternSet patterns(&ctx);
218  target.markUnknownOpDynamicallyLegal(
219  [](Operation *op) -> bool { return true; });
220 
221  populateAmdgpuEmulateAtomicsPatterns(target, patterns, *maybeChipset);
222  if (failed(applyPartialConversion(op, target, std::move(patterns))))
223  return signalPassFailure();
224 }
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, Value val)
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:31
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
This class implements a pattern rewriter for use with ConversionPatterns.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:444
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:450
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
unsigned majorVersion
Definition: Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14