MLIR 22.0.0git
EmulateAtomics.cpp
Go to the documentation of this file.
1//===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19
20namespace mlir::amdgpu {
21#define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
22#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
23} // namespace mlir::amdgpu
24
25using namespace mlir;
26using namespace mlir::amdgpu;
27
28namespace {
29struct AmdgpuEmulateAtomicsPass
31 AmdgpuEmulateAtomicsPass> {
33 AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
34 void runOnOperation() override;
35};
37template <typename AtomicOp, typename ArithOp>
38struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
39 using OpConversionPattern<AtomicOp>::OpConversionPattern;
40 using Adaptor = typename AtomicOp::Adaptor;
42 LogicalResult
43 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
44 ConversionPatternRewriter &rewriter) const override;
45};
46} // namespace
47
48namespace {
49enum class DataArgAction : unsigned char {
50 Duplicate,
51 Drop,
52};
53} // namespace
54
55// Fix up the fact that, when we're migrating from a general bugffer atomic
56// to a load or to a CAS, the number of openrands, and thus the number of
57// entries needed in operandSegmentSizes, needs to change. We use this method
58// because we'd like to preserve unknown attributes on the atomic instead of
59// discarding them.
62 DataArgAction action) {
63 newAttrs.reserve(attrs.size());
64 for (NamedAttribute attr : attrs) {
65 if (attr.getName().getValue() != "operandSegmentSizes") {
66 newAttrs.push_back(attr);
67 continue;
68 }
69 auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
70 MLIRContext *context = segmentAttr.getContext();
71 DenseI32ArrayAttr newSegments;
72 switch (action) {
73 case DataArgAction::Drop:
74 newSegments = DenseI32ArrayAttr::get(
75 context, segmentAttr.asArrayRef().drop_front());
76 break;
77 case DataArgAction::Duplicate: {
79 ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
80 newVals.push_back(oldVals[0]);
81 newVals.append(oldVals.begin(), oldVals.end());
82 newSegments = DenseI32ArrayAttr::get(context, newVals);
83 break;
84 }
85 }
86 newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));
87 }
88}
89
90// A helper function to flatten a vector value to a scalar containing its bits,
91// returning the value itself if othetwise.
92static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
93 Value val) {
94 auto vectorType = dyn_cast<VectorType>(val.getType());
95 if (!vectorType)
96 return val;
97
98 int64_t bitwidth =
99 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
100 Type allBitsType = rewriter.getIntegerType(bitwidth);
101 auto allBitsVecType = VectorType::get({1}, allBitsType);
102 Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val);
103 Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0);
104 return scalar;
105}
106
107template <typename AtomicOp, typename ArithOp>
108LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
109 AtomicOp atomicOp, Adaptor adaptor,
110 ConversionPatternRewriter &rewriter) const {
111 Location loc = atomicOp.getLoc();
112
113 ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
114 ValueRange operands = adaptor.getOperands();
115 Value data = operands.take_front()[0];
116 ValueRange invariantArgs = operands.drop_front();
117 Type dataType = data.getType();
118
120 patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
121 Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType,
122 invariantArgs, loadAttrs);
123 Block *currentBlock = rewriter.getInsertionBlock();
124 Block *afterAtomic =
125 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
126 Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
127
128 rewriter.setInsertionPointToEnd(currentBlock);
129 cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad);
130
131 rewriter.setInsertionPointToEnd(loopBlock);
132 Value prevLoad = loopBlock->getArgument(0);
133 Value operated = ArithOp::create(rewriter, loc, data, prevLoad);
134 dataType = operated.getType();
135
136 SmallVector<NamedAttribute> cmpswapAttrs;
137 patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
138 SmallVector<Value> cmpswapArgs = {operated, prevLoad};
139 cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
140 Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType,
141 cmpswapArgs, cmpswapAttrs);
142
143 // We care about exact bitwise equality here, so do some bitcasts.
144 // These will fold away during lowering to the ROCDL dialect, where
145 // an int->float bitcast is introduced to account for the fact that cmpswap
146 // only takes integer arguments.
147
148 Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad);
149 Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes);
150 if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
151 Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
152 prevLoadForCompare =
153 arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad);
154 atomicResForCompare =
155 arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes);
156 }
157 Value canLeave =
158 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
159 atomicResForCompare, prevLoadForCompare);
160 cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic, ValueRange{},
161 loopBlock, atomicRes);
162 rewriter.eraseOp(atomicOp);
163 return success();
164}
165
168 PatternBenefit benefit) {
169 // gfx10 has no atomic adds.
170 if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
171 target.addIllegalOp<RawBufferAtomicFaddOp>();
172 }
173 // gfx11 has no fp16 atomics
174 if (chipset.majorVersion == 11) {
175 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
176 [](RawBufferAtomicFaddOp op) -> bool {
177 Type elemType = getElementTypeOrSelf(op.getValue().getType());
178 return !isa<Float16Type, BFloat16Type>(elemType);
179 });
180 }
181 // gfx9 has no to a very limited support for floating-point min and max.
182 if (chipset.majorVersion == 9) {
183 if (chipset >= Chipset(9, 0, 0xa)) {
184 // gfx90a supports f64 max (and min, but we don't have a min wrapper right
185 // now) but all other types need to be emulated.
186 target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
187 [](RawBufferAtomicFmaxOp op) -> bool {
188 return op.getValue().getType().isF64();
189 });
190 } else {
191 target.addIllegalOp<RawBufferAtomicFmaxOp>();
192 }
193 // TODO(https://github.com/llvm/llvm-project/issues/129206): Refactor
194 // this to avoid hardcoding ISA version: gfx950 has bf16 atomics.
195 if (chipset < Chipset(9, 5, 0)) {
196 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
197 [](RawBufferAtomicFaddOp op) -> bool {
198 Type elemType = getElementTypeOrSelf(op.getValue().getType());
199 return !isa<BFloat16Type>(elemType);
200 });
201 }
202 }
203 patterns.add<
204 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
205 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
206 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
207 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
208 patterns.getContext(), benefit);
209}
210
211void AmdgpuEmulateAtomicsPass::runOnOperation() {
212 Operation *op = getOperation();
213 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
214 if (failed(maybeChipset)) {
215 emitError(op->getLoc(), "Invalid chipset name: " + chipset);
216 return signalPassFailure();
217 }
218
219 MLIRContext &ctx = getContext();
222 target.markUnknownOpDynamicallyLegal(
223 [](Operation *op) -> bool { return true; });
224
226 if (failed(applyPartialConversion(op, target, std::move(patterns))))
227 return signalPassFailure();
228}
return success()
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, Value val)
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
b getContext())
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset, PatternBenefit benefit=1)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14