MLIR  21.0.0git
ArithToAMDGPU.cpp
Go to the documentation of this file.
1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Pass/Pass.h"
23 
24 namespace mlir {
25 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
26 #include "mlir/Conversion/Passes.h.inc"
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::amdgpu;
31 
32 namespace {
33 // Define commonly used chipsets versions for convenience.
34 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
35 
36 struct ArithToAMDGPUConversionPass final
37  : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
38  using impl::ArithToAMDGPUConversionPassBase<
39  ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
40 
41  void runOnOperation() override;
42 };
43 
44 struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
46 
47  Chipset chipset;
48  ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
49  : OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
50 
51  LogicalResult matchAndRewrite(arith::ExtFOp op,
52  PatternRewriter &rewriter) const override;
53 };
54 
55 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
56  bool saturateFP8 = false;
57  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
58  Chipset chipset)
59  : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
60  chipset(chipset) {}
61  Chipset chipset;
62 
63  LogicalResult matchAndRewrite(arith::TruncFOp op,
64  PatternRewriter &rewriter) const override;
65 };
66 
67 struct TruncfToFloat16RewritePattern final
68  : public OpRewritePattern<arith::TruncFOp> {
69 
71 
72  LogicalResult matchAndRewrite(arith::TruncFOp op,
73  PatternRewriter &rewriter) const override;
74 };
75 
76 } // end namespace
77 
78 static bool isSupportedF8(Type elementType, Chipset chipset) {
79  if (chipset == kGfx942)
80  return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
81  if (hasOcpFp8(chipset))
82  return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
83  return false;
84 }
85 
86 static Value castF32To(Type desType, Value f32, Location loc,
87  PatternRewriter &rewriter) {
88  Type elementType = getElementTypeOrSelf(desType);
89  if (elementType.isF32())
90  return f32;
91  if (elementType.getIntOrFloatBitWidth() < 32)
92  return rewriter.create<arith::TruncFOp>(loc, desType, f32);
93  if (elementType.getIntOrFloatBitWidth() > 32)
94  return rewriter.create<arith::ExtFOp>(loc, desType, f32);
95  llvm_unreachable("The only 32-bit float type is f32");
96 }
97 
98 LogicalResult
99 ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
100  PatternRewriter &rewriter) const {
101  Type inType = op.getIn().getType();
102  auto inVecType = dyn_cast<VectorType>(inType);
103  if (inVecType) {
104  if (inVecType.isScalable())
105  return failure();
106  inType = inVecType.getElementType();
107  }
108  if (!isSupportedF8(inType, chipset))
109  return failure();
110 
111  Location loc = op.getLoc();
112  Value in = op.getIn();
113  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
114  VectorType extResType = VectorType::get(2, rewriter.getF32Type());
115  if (!inVecType) {
116  Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
117  loc, rewriter.getF32Type(), in, 0);
118  Value result = castF32To(outElemType, asFloat, loc, rewriter);
119  rewriter.replaceOp(op, result);
120  return success();
121  }
122  int64_t numElements = inVecType.getNumElements();
123 
124  Value zero = rewriter.create<arith::ConstantOp>(
125  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
126  VectorType outType = cast<VectorType>(op.getOut().getType());
127 
128  if (inVecType.getShape().empty()) {
129  Value zerodSplat =
130  rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
131  Value scalarIn =
132  rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
133  Value scalarExt =
134  rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
135  Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
137  rewriter.replaceOp(op, result);
138  return success();
139  }
140 
141  VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
142  outType.getElementType());
143  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
144 
145  if (inVecType.getRank() > 1) {
146  inVecType = VectorType::get(SmallVector<int64_t>{numElements},
147  inVecType.getElementType());
148  in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
149  }
150 
151  for (int64_t i = 0; i < numElements; i += 4) {
152  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
153  Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
154  loc, in, i, elemsThisOp, 1);
155  for (int64_t j = 0; j < elemsThisOp; j += 2) {
156  if (i + j + 1 < numElements) { // Convert two 8-bit elements
157  Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
158  loc, extResType, inSlice, j / 2);
159  Type desType = VectorType::get(2, outElemType);
160  Value asType = castF32To(desType, asFloats, loc, rewriter);
161  result = rewriter.create<vector::InsertStridedSliceOp>(
162  loc, asType, result, i + j, 1);
163  } else { // Convert a 8-bit element
164  Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
165  loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
166  Value asType = castF32To(outElemType, asFloat, loc, rewriter);
167  result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
168  }
169  }
170  }
171 
172  if (inVecType.getRank() != outType.getRank()) {
173  result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
174  }
175 
176  rewriter.replaceOp(op, result);
177  return success();
178 }
179 
180 static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
181  Type type = value.getType();
182  if (type.isF32())
183  return value;
184  if (type.getIntOrFloatBitWidth() < 32)
185  return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
186  if (type.getIntOrFloatBitWidth() > 32)
187  return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
188  llvm_unreachable("The only 32-bit float type is f32");
189 }
190 
191 // If `in` is a finite value, clamp it between the maximum and minimum values
192 // of `outElemType` so that subsequent conversion instructions don't
193 // overflow those out-of-range values to NaN. These semantics are commonly
194 // used in machine-learning contexts where failure to clamp would lead to
195 // excessive NaN production.
196 static Value clampInput(PatternRewriter &rewriter, Location loc,
197  Type outElemType, Value source) {
198  Type sourceType = source.getType();
199  const llvm::fltSemantics &sourceSem =
200  cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
201  const llvm::fltSemantics &targetSem =
202  cast<FloatType>(outElemType).getFloatSemantics();
203 
204  APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
205  APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
206  bool ignoredLosesInfo = false;
207  // We can ignore conversion failures here because this conversion promotes
208  // from a smaller type to a larger one - ex. there can be no loss of precision
209  // when casting fp8 to f16.
210  (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
211  (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
212 
213  Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
214  Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
215 
217  rewriter, loc, sourceType,
218  APFloat::getInf(sourceSem, /*Negative=*/false));
220  rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
221  Value isInf = rewriter.createOrFold<arith::CmpFOp>(
222  loc, arith::CmpFPredicate::OEQ, source, inf);
223  Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
224  loc, arith::CmpFPredicate::OEQ, source, negInf);
225  Value isNan = rewriter.createOrFold<arith::CmpFOp>(
226  loc, arith::CmpFPredicate::UNO, source, source);
227  Value isNonFinite = rewriter.create<arith::OrIOp>(
228  loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
229 
230  Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
231  Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
232  Value res =
233  rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
234  return res;
235 }
236 
237 LogicalResult
238 TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
239  PatternRewriter &rewriter) const {
240  // Only supporting default rounding mode as of now.
241  if (op.getRoundingmodeAttr())
242  return failure();
243  Type outType = op.getOut().getType();
244  auto outVecType = dyn_cast<VectorType>(outType);
245  if (outVecType) {
246  if (outVecType.isScalable())
247  return failure();
248  outType = outVecType.getElementType();
249  }
250  auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
251  if (inType && inType.getWidth() <= 8 && saturateFP8)
252  // Conversion between 8-bit floats is not supported with truncation enabled.
253  return failure();
254 
255  if (!isSupportedF8(outType, chipset))
256  return failure();
257 
258  Location loc = op.getLoc();
259  Value in = op.getIn();
260  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
261  if (saturateFP8)
262  in = clampInput(rewriter, loc, outElemType, in);
263  auto inVectorTy = dyn_cast<VectorType>(in.getType());
264  VectorType truncResType = VectorType::get(4, outElemType);
265  if (!inVectorTy) {
266  Value asFloat = castToF32(in, loc, rewriter);
267  Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
268  loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
269  /*existing=*/nullptr);
270  Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
271  rewriter.replaceOp(op, result);
272  return success();
273  }
274 
275  int64_t numElements = outVecType.getNumElements();
276  Value zero = rewriter.create<arith::ConstantOp>(
277  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
278  if (outVecType.getShape().empty()) {
279  Value scalarIn =
280  rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
281  // Recurse to send the 0-D vector case to the 1-D vector case
282  Value scalarTrunc =
283  rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
284  Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
286  rewriter.replaceOp(op, result);
287  return success();
288  }
289 
290  VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
291  outVecType.getElementType());
292  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
293 
294  if (inVectorTy.getRank() > 1) {
295  inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
296  inVectorTy.getElementType());
297  in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
298  }
299 
300  for (int64_t i = 0; i < numElements; i += 4) {
301  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
302  Value thisResult = nullptr;
303  for (int64_t j = 0; j < elemsThisOp; j += 2) {
304  Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
305  Value asFloatA = castToF32(elemA, loc, rewriter);
306  Value asFloatB = nullptr;
307  if (j + 1 < elemsThisOp) {
308  Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
309  asFloatB = castToF32(elemB, loc, rewriter);
310  }
311  thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
312  loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
313  }
314  if (elemsThisOp < 4)
315  thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
316  loc, thisResult, 0, elemsThisOp, 1);
317  result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
318  result, i, 1);
319  }
320 
321  if (inVectorTy.getRank() != outVecType.getRank()) {
322  result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
323  }
324 
325  rewriter.replaceOp(op, result);
326  return success();
327 }
328 
329 LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
330  arith::TruncFOp op, PatternRewriter &rewriter) const {
331  Type outType = op.getOut().getType();
332  Type inputType = getElementTypeOrSelf(op.getIn());
333  auto outVecType = dyn_cast<VectorType>(outType);
334  if (outVecType) {
335  if (outVecType.isScalable())
336  return failure();
337  outType = outVecType.getElementType();
338  }
339  if (!(outType.isF16() && inputType.isF32()))
340  return failure();
341 
342  Location loc = op.getLoc();
343  Value in = op.getIn();
344  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
345  VectorType truncResType = VectorType::get(2, outElemType);
346  auto inVectorTy = dyn_cast<VectorType>(in.getType());
347 
348  // Handle the case where input type is not a vector type
349  if (!inVectorTy) {
350  auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
351  Value asF16s =
352  rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
353  Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
354  rewriter.replaceOp(op, result);
355  return success();
356  }
357  int64_t numElements = outVecType.getNumElements();
358  Value zero = rewriter.createOrFold<arith::ConstantOp>(
359  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
360  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
361 
362  if (inVectorTy.getRank() > 1) {
363  inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
364  inVectorTy.getElementType());
365  in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
366  }
367 
368  // Handle the vector case. We also handle the (uncommon) case where the vector
369  // length is odd
370  for (int64_t i = 0; i < numElements; i += 2) {
371  int64_t elemsThisOp = std::min(numElements, i + 2) - i;
372  Value thisResult = nullptr;
373  Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
374  Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
375 
376  if (elemsThisOp == 2) {
377  elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
378  }
379 
380  thisResult =
381  rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
382  // Place back the truncated result into the possibly larger vector. If we
383  // are operating on a size 2 vector, these operations should be folded away
384  thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
385  loc, thisResult, 0, elemsThisOp, 1);
386  result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
387  result, i, 1);
388  }
389 
390  if (inVectorTy.getRank() != outVecType.getRank()) {
391  result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
392  }
393 
394  rewriter.replaceOp(op, result);
395  return success();
396 }
397 
399  RewritePatternSet &patterns, bool convertFP8Arithmetic,
400  bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
401 
402  if (convertFP8Arithmetic) {
403  patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset);
404  patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
405  saturateFP8Truncf, chipset);
406  }
407  if (allowPackedF16Rtz)
408  patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());
409 }
410 
411 void ArithToAMDGPUConversionPass::runOnOperation() {
412  Operation *op = getOperation();
413  MLIRContext *ctx = &getContext();
415  FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
416  if (failed(maybeChipset)) {
417  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
418  return signalPassFailure();
419  }
420 
421  bool convertFP8Arithmetic =
422  *maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
424  patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
425  *maybeChipset);
426  if (failed(applyPatternsGreedily(op, std::move(patterns))))
427  return signalPassFailure();
428 }
constexpr Chipset kGfx942
static Value castF32To(Type desType, Value f32, Location loc, PatternRewriter &rewriter)
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static bool isSupportedF8(Type elementType, Chipset chipset)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatType getF32Type()
Definition: Builders.cpp:43
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:250
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOcpFp8(const Chipset &chipset)
Definition: Chipset.h:52
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, amdgpu::Chipset chipset)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:271
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:323
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.