MLIR  20.0.0git
ArithToAMDGPU.cpp
Go to the documentation of this file.
1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Pass/Pass.h"
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
26 #include "mlir/Conversion/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::amdgpu;
31 
32 namespace {
33 struct ArithToAMDGPUConversionPass final
34  : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
35  using impl::ArithToAMDGPUConversionPassBase<
36  ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
37 
38  void runOnOperation() override;
39 };
40 
41 struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
43 
44  LogicalResult match(arith::ExtFOp op) const override;
45  void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
46 };
47 
48 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
49  bool saturateFP8 = false;
50  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
51  Chipset chipset)
52  : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
53  chipset(chipset) {}
54  Chipset chipset;
55 
56  LogicalResult match(arith::TruncFOp op) const override;
57  void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
58 };
59 
60 struct TruncfToFloat16RewritePattern final
61  : public OpRewritePattern<arith::TruncFOp> {
62 
64 
65  LogicalResult match(arith::TruncFOp op) const override;
66  void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
67 };
68 
69 } // end namespace
70 
71 static Value castF32To(Type elementType, Value f32, Location loc,
72  PatternRewriter &rewriter) {
73  if (elementType.isF32())
74  return f32;
75  if (elementType.getIntOrFloatBitWidth() < 32)
76  return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
77  if (elementType.getIntOrFloatBitWidth() > 32)
78  return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
79  llvm_unreachable("The only 32-bit float type is f32");
80 }
81 
82 LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
83  Type inType = op.getIn().getType();
84  if (auto inVecType = dyn_cast<VectorType>(inType)) {
85  if (inVecType.isScalable())
86  return failure();
87  inType = inVecType.getElementType();
88  }
89  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
90 }
91 
92 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
93  PatternRewriter &rewriter) const {
94  Location loc = op.getLoc();
95  Value in = op.getIn();
96  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
97  auto inType = dyn_cast<VectorType>(in.getType());
98  if (!inType) {
99  Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
100  loc, rewriter.getF32Type(), in, 0);
101  Value result = castF32To(outElemType, asFloat, loc, rewriter);
102  return rewriter.replaceOp(op, result);
103  }
104  int64_t numElements = inType.getNumElements();
105  Value zero = rewriter.create<arith::ConstantOp>(
106  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
107  if (inType.getShape().empty()) {
108  Value scalarIn =
109  rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
110  // Recurse to send the 0-D vector case to the 1-D vector case
111  Value scalarExt =
112  rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
113  Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
115  return rewriter.replaceOp(op, result);
116  }
117 
118  VectorType outType = cast<VectorType>(op.getOut().getType());
119  VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
120  outType.getElementType());
121  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
122 
123  if (inType.getRank() > 1) {
124  inType = VectorType::get(SmallVector<int64_t>{numElements},
125  inType.getElementType());
126  in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
127  }
128 
129  for (int64_t i = 0; i < numElements; i += 4) {
130  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
131  Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
132  loc, in, i, elemsThisOp, 1);
133  for (int64_t j = 0; j < elemsThisOp; ++j) {
134  Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
135  loc, rewriter.getF32Type(), inSlice, j);
136  Value asType = castF32To(outElemType, asFloat, loc, rewriter);
137  result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
138  }
139  }
140 
141  if (inType.getRank() != outType.getRank()) {
142  result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
143  }
144 
145  rewriter.replaceOp(op, result);
146 }
147 
148 static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
149  Type type = value.getType();
150  if (type.isF32())
151  return value;
152  if (type.getIntOrFloatBitWidth() < 32)
153  return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
154  if (type.getIntOrFloatBitWidth() > 32)
155  return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
156  llvm_unreachable("The only 32-bit float type is f32");
157 }
158 
159 // If `in` is a finite value, clamp it between the maximum and minimum values
160 // of `outElemType` so that subsequent conversion instructions don't
161 // overflow those out-of-range values to NaN. These semantics are commonly
162 // used in machine-learning contexts where failure to clamp would lead to
163 // excessive NaN production.
164 static Value clampInput(PatternRewriter &rewriter, Location loc,
165  Type outElemType, Value source) {
166  Type sourceType = source.getType();
167  const llvm::fltSemantics &sourceSem =
168  cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
169  const llvm::fltSemantics &targetSem =
170  cast<FloatType>(outElemType).getFloatSemantics();
171 
172  APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
173  APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
174  bool ignoredLosesInfo = false;
175  // We can ignore conversion failures here because this conversion promotes
176  // from a smaller type to a larger one - ex. there can be no loss of precision
177  // when casting fp8 to f16.
178  (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
179  (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
180 
181  Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
182  Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
183 
185  rewriter, loc, sourceType,
186  APFloat::getInf(sourceSem, /*Negative=*/false));
188  rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
189  Value isInf = rewriter.createOrFold<arith::CmpFOp>(
190  loc, arith::CmpFPredicate::OEQ, source, inf);
191  Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
192  loc, arith::CmpFPredicate::OEQ, source, negInf);
193  Value isNan = rewriter.createOrFold<arith::CmpFOp>(
194  loc, arith::CmpFPredicate::UNO, source, source);
195  Value isNonFinite = rewriter.create<arith::OrIOp>(
196  loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
197 
198  Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
199  Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
200  Value res =
201  rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
202  return res;
203 }
204 
205 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
206  // Only supporting default rounding mode as of now.
207  if (op.getRoundingmodeAttr())
208  return failure();
209  Type outType = op.getOut().getType();
210  if (auto outVecType = dyn_cast<VectorType>(outType)) {
211  if (outVecType.isScalable())
212  return failure();
213  outType = outVecType.getElementType();
214  }
215  auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
216  if (inType && inType.getWidth() <= 8 && saturateFP8)
217  // Conversion between 8-bit floats is not supported with truncation enabled.
218  return failure();
219  return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
220 }
221 
222 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
223  PatternRewriter &rewriter) const {
224  Location loc = op.getLoc();
225  Value in = op.getIn();
226  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
227  if (saturateFP8)
228  in = clampInput(rewriter, loc, outElemType, in);
229  auto inVectorTy = dyn_cast<VectorType>(in.getType());
230  VectorType truncResType = VectorType::get(4, outElemType);
231  if (!inVectorTy) {
232  Value asFloat = castToF32(in, loc, rewriter);
233  Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
234  loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
235  /*existing=*/nullptr);
236  Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
237  return rewriter.replaceOp(op, result);
238  }
239  VectorType outType = cast<VectorType>(op.getOut().getType());
240  int64_t numElements = outType.getNumElements();
241  Value zero = rewriter.create<arith::ConstantOp>(
242  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
243  if (outType.getShape().empty()) {
244  Value scalarIn =
245  rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
246  // Recurse to send the 0-D vector case to the 1-D vector case
247  Value scalarTrunc =
248  rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
249  Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
251  return rewriter.replaceOp(op, result);
252  }
253 
254  VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
255  outType.getElementType());
256  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
257 
258  if (inVectorTy.getRank() > 1) {
259  inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
260  inVectorTy.getElementType());
261  in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
262  }
263 
264  for (int64_t i = 0; i < numElements; i += 4) {
265  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
266  Value thisResult = nullptr;
267  for (int64_t j = 0; j < elemsThisOp; j += 2) {
268  Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
269  Value asFloatA = castToF32(elemA, loc, rewriter);
270  Value asFloatB = nullptr;
271  if (j + 1 < elemsThisOp) {
272  Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
273  asFloatB = castToF32(elemB, loc, rewriter);
274  }
275  thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
276  loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
277  }
278  if (elemsThisOp < 4)
279  thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
280  loc, thisResult, 0, elemsThisOp, 1);
281  result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
282  result, i, 1);
283  }
284 
285  if (inVectorTy.getRank() != outType.getRank()) {
286  result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
287  }
288 
289  rewriter.replaceOp(op, result);
290 }
291 
292 LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
293  Type outType = op.getOut().getType();
294  Type inputType = getElementTypeOrSelf(op.getIn());
295  if (auto outVecType = dyn_cast<VectorType>(outType)) {
296  if (outVecType.isScalable())
297  return failure();
298  outType = outVecType.getElementType();
299  }
300  return success(outType.isF16() && inputType.isF32());
301 }
302 
303 void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
304  PatternRewriter &rewriter) const {
305  Location loc = op.getLoc();
306  Value in = op.getIn();
307  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
308  VectorType truncResType = VectorType::get(2, outElemType);
309  auto inVectorTy = dyn_cast<VectorType>(in.getType());
310 
311  // Handle the case where input type is not a vector type
312  if (!inVectorTy) {
313  auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
314  Value asF16s =
315  rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
316  Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
317  return rewriter.replaceOp(op, result);
318  }
319  VectorType outType = cast<VectorType>(op.getOut().getType());
320  int64_t numElements = outType.getNumElements();
321  Value zero = rewriter.createOrFold<arith::ConstantOp>(
322  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
323  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
324 
325  if (inVectorTy.getRank() > 1) {
326  inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
327  inVectorTy.getElementType());
328  in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
329  }
330 
331  // Handle the vector case. We also handle the (uncommon) case where the vector
332  // length is odd
333  for (int64_t i = 0; i < numElements; i += 2) {
334  int64_t elemsThisOp = std::min(numElements, i + 2) - i;
335  Value thisResult = nullptr;
336  Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
337  Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
338 
339  if (elemsThisOp == 2) {
340  elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
341  }
342 
343  thisResult =
344  rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
345  // Place back the truncated result into the possibly larger vector. If we
346  // are operating on a size 2 vector, these operations should be folded away
347  thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
348  loc, thisResult, 0, elemsThisOp, 1);
349  result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
350  result, i, 1);
351  }
352 
353  if (inVectorTy.getRank() != outType.getRank()) {
354  result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
355  }
356 
357  rewriter.replaceOp(op, result);
358 }
359 
361  RewritePatternSet &patterns, bool convertFP8Arithmetic,
362  bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
363 
364  if (convertFP8Arithmetic) {
365  patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
366  patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
367  saturateFP8Truncf, chipset);
368  }
369  if (allowPackedF16Rtz)
370  patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
371 }
372 
373 void ArithToAMDGPUConversionPass::runOnOperation() {
374  Operation *op = getOperation();
375  MLIRContext *ctx = &getContext();
376  RewritePatternSet patterns(op->getContext());
377  FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
378  if (failed(maybeChipset)) {
379  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
380  return signalPassFailure();
381  }
382 
383  bool convertFP8Arithmetic =
384  maybeChipset->majorVersion == 9 && *maybeChipset >= Chipset(9, 4, 0);
386  patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
387  *maybeChipset);
388  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
389  return signalPassFailure();
390 }
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
FloatType getF32Type()
Definition: Builders.cpp:87
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:59
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:46
bool isF16() const
Definition: Types.cpp:57
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:43
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, amdgpu::Chipset chipset)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:271
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.