MLIR  20.0.0git
ArithToAMDGPU.cpp
Go to the documentation of this file.
1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Pass/Pass.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 
28 namespace {
29 struct ArithToAMDGPUConversionPass final
30  : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
31  using impl::ArithToAMDGPUConversionPassBase<
32  ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
33 
34  void runOnOperation() override;
35 };
36 
37 struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
39 
40  LogicalResult match(arith::ExtFOp op) const override;
41  void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
42 };
43 
44 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
45  bool saturateFP8 = false;
46  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
47  : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
48 
49  LogicalResult match(arith::TruncFOp op) const override;
50  void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
51 };
52 } // end namespace
53 
54 static Value castF32To(Type elementType, Value f32, Location loc,
55  PatternRewriter &rewriter) {
56  if (elementType.isF32())
57  return f32;
58  if (elementType.getIntOrFloatBitWidth() < 32)
59  return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
60  if (elementType.getIntOrFloatBitWidth() > 32)
61  return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
62  llvm_unreachable("The only 32-bit float type is f32");
63 }
64 
65 LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
66  Type inType = op.getIn().getType();
67  if (auto inVecType = dyn_cast<VectorType>(inType)) {
68  if (inVecType.isScalable())
69  return failure();
70  inType = inVecType.getElementType();
71  }
72  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
73 }
74 
75 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
76  PatternRewriter &rewriter) const {
77  Location loc = op.getLoc();
78  Value in = op.getIn();
79  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
80  auto inType = dyn_cast<VectorType>(in.getType());
81  if (!inType) {
82  Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
83  loc, rewriter.getF32Type(), in, 0);
84  Value result = castF32To(outElemType, asFloat, loc, rewriter);
85  return rewriter.replaceOp(op, result);
86  }
87  int64_t numElements = inType.getNumElements();
88  Value zero = rewriter.create<arith::ConstantOp>(
89  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
90  if (inType.getShape().empty()) {
91  Value scalarIn =
92  rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
93  // Recurse to send the 0-D vector case to the 1-D vector case
94  Value scalarExt =
95  rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
96  Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
98  return rewriter.replaceOp(op, result);
99  }
100 
101  VectorType outType = cast<VectorType>(op.getOut().getType());
102  VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
103  outType.getElementType());
104  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
105 
106  if (inType.getRank() > 1) {
107  inType = VectorType::get(SmallVector<int64_t>{numElements},
108  inType.getElementType());
109  in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
110  }
111 
112  for (int64_t i = 0; i < numElements; i += 4) {
113  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
114  Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
115  loc, in, i, elemsThisOp, 1);
116  for (int64_t j = 0; j < elemsThisOp; ++j) {
117  Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
118  loc, rewriter.getF32Type(), inSlice, j);
119  Value asType = castF32To(outElemType, asFloat, loc, rewriter);
120  result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
121  }
122  }
123 
124  if (inType.getRank() != outType.getRank()) {
125  result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
126  }
127 
128  rewriter.replaceOp(op, result);
129 }
130 
131 static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
132  Type type = value.getType();
133  if (type.isF32())
134  return value;
135  if (type.getIntOrFloatBitWidth() < 32)
136  return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
137  if (type.getIntOrFloatBitWidth() > 32)
138  return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
139  llvm_unreachable("The only 32-bit float type is f32");
140 }
141 
142 // If `in` is a finite value, clamp it between the maximum and minimum values
143 // of `outElemType` so that subsequent conversion instructions don't
144 // overflow those out-of-range values to NaN. These semantics are commonly
145 // used in machine-learning contexts where failure to clamp would lead to
146 // excessive NaN production.
147 static Value clampInput(PatternRewriter &rewriter, Location loc,
148  Type outElemType, Value source) {
149  Type sourceType = source.getType();
150  const llvm::fltSemantics &sourceSem =
151  cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
152  const llvm::fltSemantics &targetSem =
153  cast<FloatType>(outElemType).getFloatSemantics();
154 
155  APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
156  APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
157  bool ignoredLosesInfo = false;
158  // We can ignore conversion failures here because this conversion promotes
159  // from a smaller type to a larger one - ex. there can be no loss of precision
160  // when casting fp8 to f16.
161  (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
162  (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
163 
164  Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
165  Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
166 
168  rewriter, loc, sourceType,
169  APFloat::getInf(sourceSem, /*Negative=*/false));
171  rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
172  Value isInf = rewriter.createOrFold<arith::CmpFOp>(
173  loc, arith::CmpFPredicate::OEQ, source, inf);
174  Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
175  loc, arith::CmpFPredicate::OEQ, source, negInf);
176  Value isNan = rewriter.createOrFold<arith::CmpFOp>(
177  loc, arith::CmpFPredicate::UNO, source, source);
178  Value isNonFinite = rewriter.create<arith::OrIOp>(
179  loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
180 
181  Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
182  Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
183  Value res =
184  rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
185  return res;
186 }
187 
188 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
189  // Only supporting default rounding mode as of now.
190  if (op.getRoundingmodeAttr())
191  return failure();
192  Type outType = op.getOut().getType();
193  if (auto outVecType = dyn_cast<VectorType>(outType)) {
194  if (outVecType.isScalable())
195  return failure();
196  outType = outVecType.getElementType();
197  }
198  auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
199  if (inType && inType.getWidth() <= 8 && saturateFP8)
200  // Conversion between 8-bit floats is not supported with truncation enabled.
201  return failure();
202  return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
203 }
204 
205 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
206  PatternRewriter &rewriter) const {
207  Location loc = op.getLoc();
208  Value in = op.getIn();
209  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
210  if (saturateFP8)
211  in = clampInput(rewriter, loc, outElemType, in);
212  auto inVectorTy = dyn_cast<VectorType>(in.getType());
213  VectorType truncResType = VectorType::get(4, outElemType);
214  if (!inVectorTy) {
215  Value asFloat = castToF32(in, loc, rewriter);
216  Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
217  loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
218  /*existing=*/nullptr);
219  Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
220  return rewriter.replaceOp(op, result);
221  }
222  VectorType outType = cast<VectorType>(op.getOut().getType());
223  int64_t numElements = outType.getNumElements();
224  Value zero = rewriter.create<arith::ConstantOp>(
225  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
226  if (outType.getShape().empty()) {
227  Value scalarIn =
228  rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
229  // Recurse to send the 0-D vector case to the 1-D vector case
230  Value scalarTrunc =
231  rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
232  Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
234  return rewriter.replaceOp(op, result);
235  }
236 
237  VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
238  outType.getElementType());
239  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
240 
241  if (inVectorTy.getRank() > 1) {
242  inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
243  inVectorTy.getElementType());
244  in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
245  }
246 
247  for (int64_t i = 0; i < numElements; i += 4) {
248  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
249  Value thisResult = nullptr;
250  for (int64_t j = 0; j < elemsThisOp; j += 2) {
251  Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
252  Value asFloatA = castToF32(elemA, loc, rewriter);
253  Value asFloatB = nullptr;
254  if (j + 1 < elemsThisOp) {
255  Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
256  asFloatB = castToF32(elemB, loc, rewriter);
257  }
258  thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
259  loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
260  }
261  if (elemsThisOp < 4)
262  thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
263  loc, thisResult, 0, elemsThisOp, 1);
264  result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
265  result, i, 1);
266  }
267 
268  if (inVectorTy.getRank() != outType.getRank()) {
269  result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
270  }
271 
272  rewriter.replaceOp(op, result);
273 }
274 
276  RewritePatternSet &patterns, bool saturateFP8TruncF) {
277  patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
278  patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
279  saturateFP8TruncF);
280 }
281 
282 void ArithToAMDGPUConversionPass::runOnOperation() {
283  Operation *op = getOperation();
284  RewritePatternSet patterns(op->getContext());
285  arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
286  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
287  return signalPassFailure();
288 }
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
FloatType getF32Type()
Definition: Builders.cpp:67
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:265
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:523
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:52
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:43
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:40
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool saturateFP8TruncF)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:271
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.