MLIR 23.0.0git
EmulateAtomics.cpp
Go to the documentation of this file.
1//===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
19
20namespace mlir::amdgpu {
21#define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
22#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
23} // namespace mlir::amdgpu
24
25using namespace mlir;
26using namespace mlir::amdgpu;
27
28namespace {
29struct AmdgpuEmulateAtomicsPass
30 : public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
31 AmdgpuEmulateAtomicsPass> {
32 using AmdgpuEmulateAtomicsPassBase<
33 AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
34 void runOnOperation() override;
35};
36
37template <typename AtomicOp, typename ArithOp>
38struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
39 using OpConversionPattern<AtomicOp>::OpConversionPattern;
40 using Adaptor = typename AtomicOp::Adaptor;
41
42 LogicalResult
43 matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
44 ConversionPatternRewriter &rewriter) const override;
45};
46} // namespace
47
48namespace {
49enum class DataArgAction : unsigned char {
50 Duplicate,
51 Drop,
52};
53} // namespace
54
55// Fix up the fact that, when we're migrating from a general bugffer atomic
56// to a load or to a CAS, the number of openrands, and thus the number of
57// entries needed in operandSegmentSizes, needs to change. We use this method
58// because we'd like to preserve unknown attributes on the atomic instead of
59// discarding them.
62 DataArgAction action) {
63 newAttrs.reserve(attrs.size());
64 for (NamedAttribute attr : attrs) {
65 if (attr.getName().getValue() != "operandSegmentSizes") {
66 newAttrs.push_back(attr);
67 continue;
68 }
69 auto segmentAttr = cast<DenseI32ArrayAttr>(attr.getValue());
70 MLIRContext *context = segmentAttr.getContext();
71 DenseI32ArrayAttr newSegments;
72 switch (action) {
73 case DataArgAction::Drop:
74 newSegments = DenseI32ArrayAttr::get(
75 context, segmentAttr.asArrayRef().drop_front());
76 break;
77 case DataArgAction::Duplicate: {
79 ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
80 newVals.push_back(oldVals[0]);
81 newVals.append(oldVals.begin(), oldVals.end());
82 newSegments = DenseI32ArrayAttr::get(context, newVals);
83 break;
84 }
85 }
86 newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));
87 }
88}
89
90// A helper function to flatten a vector value to a scalar containing its bits,
91// returning the value itself if othetwise.
92static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
93 Value val) {
94 auto vectorType = dyn_cast<VectorType>(val.getType());
95 if (!vectorType)
96 return val;
97
98 int64_t bitwidth =
99 vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
100 Type allBitsType = rewriter.getIntegerType(bitwidth);
101 auto allBitsVecType = VectorType::get({1}, allBitsType);
102 Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val);
103 Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0);
104 return scalar;
105}
106
107template <typename AtomicOp, typename ArithOp>
108LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
109 AtomicOp atomicOp, Adaptor adaptor,
110 ConversionPatternRewriter &rewriter) const {
111 Location loc = atomicOp.getLoc();
112
113 ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
114 ValueRange operands = adaptor.getOperands();
115 Value data = operands.take_front()[0];
116 ValueRange invariantArgs = operands.drop_front();
117 Type dataType = data.getType();
118
120 patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
121 Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType,
122 invariantArgs, loadAttrs);
123 Block *currentBlock = rewriter.getInsertionBlock();
124 Block *afterAtomic =
125 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
126 afterAtomic->addArgument(dataType, loc);
127 Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
128
129 rewriter.setInsertionPointToEnd(currentBlock);
130 cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad);
131
132 rewriter.setInsertionPointToEnd(loopBlock);
133 Value prevLoad = loopBlock->getArgument(0);
134 Value operated = ArithOp::create(rewriter, loc, data, prevLoad);
135 dataType = operated.getType();
136
137 SmallVector<NamedAttribute> cmpswapAttrs;
138 patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
139 SmallVector<Value> cmpswapArgs = {operated, prevLoad};
140 cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
141 Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType,
142 cmpswapArgs, cmpswapAttrs);
143
144 // We care about exact bitwise equality here, so do some bitcasts.
145 // These will fold away during lowering to the ROCDL dialect, where
146 // an int->float bitcast is introduced to account for the fact that cmpswap
147 // only takes integer arguments.
148
149 Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad);
150 Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes);
151 if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
152 Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
153 prevLoadForCompare =
154 arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad);
155 atomicResForCompare =
156 arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes);
157 }
158 Value canLeave =
159 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
160 atomicResForCompare, prevLoadForCompare);
161 cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic,
162 ValueRange{prevLoad}, loopBlock, atomicRes);
163 rewriter.replaceOp(atomicOp, ValueRange{afterAtomic->getArgument(0)});
164 return success();
165}
166
169 PatternBenefit benefit) {
170 // gfx10 has no atomic adds.
171 if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
172 target.addIllegalOp<RawBufferAtomicFaddOp>();
173 }
174 // gfx11 has no fp16 atomics
175 if (chipset.majorVersion == 11) {
176 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
177 [](RawBufferAtomicFaddOp op) -> bool {
178 Type elemType = getElementTypeOrSelf(op.getValue().getType());
179 return !isa<Float16Type, BFloat16Type>(elemType);
180 });
181 }
182 // gfx9 has no to a very limited support for floating-point min and max.
183 if (chipset.majorVersion == 9) {
184 if (chipset >= Chipset(9, 0, 0xa)) {
185 // gfx90a supports f64 max (and min, but we don't have a min wrapper right
186 // now) but all other types need to be emulated.
187 target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
188 [](RawBufferAtomicFmaxOp op) -> bool {
189 return op.getValue().getType().isF64();
190 });
191 } else {
192 target.addIllegalOp<RawBufferAtomicFmaxOp>();
193 }
194 // TODO(https://github.com/llvm/llvm-project/issues/129206): Refactor
195 // this to avoid hardcoding ISA version: gfx950 has bf16 atomics.
196 if (chipset < Chipset(9, 5, 0)) {
197 target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
198 [](RawBufferAtomicFaddOp op) -> bool {
199 Type elemType = getElementTypeOrSelf(op.getValue().getType());
200 return !isa<BFloat16Type>(elemType);
201 });
202 }
203 }
204 patterns.add<
205 RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
206 RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaximumFOp>,
207 RawBufferAtomicByCasPattern<RawBufferAtomicSmaxOp, arith::MaxSIOp>,
208 RawBufferAtomicByCasPattern<RawBufferAtomicUminOp, arith::MinUIOp>>(
209 patterns.getContext(), benefit);
210}
211
212void AmdgpuEmulateAtomicsPass::runOnOperation() {
213 Operation *op = getOperation();
214 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
215 if (failed(maybeChipset)) {
216 emitError(op->getLoc(), "Invalid chipset name: " + chipset);
217 return signalPassFailure();
218 }
219
220 MLIRContext &ctx = getContext();
222 RewritePatternSet patterns(&ctx);
223 target.markUnknownOpDynamicallyLegal(
224 [](Operation *op) -> bool { return true; });
225
226 populateAmdgpuEmulateAtomicsPatterns(target, patterns, *maybeChipset);
227 if (failed(applyPartialConversion(op, target, std::move(patterns))))
228 return signalPassFailure();
229}
return success()
static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, Value val)
static void patchOperandSegmentSizes(ArrayRef< NamedAttribute > attrs, SmallVectorImpl< NamedAttribute > &newAttrs, DataArgAction action)
b getContext())
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset, PatternBenefit benefit=1)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
unsigned majorVersion
Definition Chipset.h:23
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14