MLIR  19.0.0git
ArithToAMDGPU.cpp
Go to the documentation of this file.
1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Pass/Pass.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 struct ArithToAMDGPUConversionPass final
31  : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
32  using impl::ArithToAMDGPUConversionPassBase<
33  ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
34 
35  void runOnOperation() override;
36 };
37 
38 struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
40 
41  LogicalResult match(arith::ExtFOp op) const override;
42  void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
43 };
44 
45 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
46  bool saturateFP8 = false;
47  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
48  : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
49 
50  LogicalResult match(arith::TruncFOp op) const override;
51  void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
52 };
53 } // end namespace
54 
55 static Value castF32To(Type elementType, Value f32, Location loc,
56  PatternRewriter &rewriter) {
57  if (elementType.isF32())
58  return f32;
59  if (elementType.getIntOrFloatBitWidth() < 32)
60  return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
61  if (elementType.getIntOrFloatBitWidth() > 32)
62  return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
63  llvm_unreachable("The only 32-bit float type is f32");
64 }
65 
66 LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
67  Type inType = op.getIn().getType();
68  if (auto inVecType = dyn_cast<VectorType>(inType)) {
69  if (inVecType.isScalable())
70  return failure();
71  if (inVecType.getShape().size() > 1)
72  // Multi-dimensional vectors are currently unsupported.
73  return failure();
74  inType = inVecType.getElementType();
75  }
76  return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
77 }
78 
79 void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
80  PatternRewriter &rewriter) const {
81  Location loc = op.getLoc();
82  Value in = op.getIn();
83  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
84  if (!isa<VectorType>(in.getType())) {
85  Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
86  loc, rewriter.getF32Type(), in, 0);
87  Value result = castF32To(outElemType, asFloat, loc, rewriter);
88  return rewriter.replaceOp(op, result);
89  }
90  VectorType inType = cast<VectorType>(in.getType());
91  int64_t numElements = inType.getNumElements();
92  Value zero = rewriter.create<arith::ConstantOp>(
93  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
94  Value result =
95  rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
96  if (inType.getShape().empty()) {
97  Value scalarIn =
98  rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
99  // Recurse to send the 0-D vector case to the 1-D vector case
100  Value scalarExt =
101  rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
102  result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
104  return rewriter.replaceOp(op, result);
105  }
106  for (int64_t i = 0; i < numElements; i += 4) {
107  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
108  Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
109  loc, in, i, elemsThisOp, 1);
110  for (int64_t j = 0; j < elemsThisOp; ++j) {
111  Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
112  loc, rewriter.getF32Type(), inSlice, j);
113  Value asType = castF32To(outElemType, asFloat, loc, rewriter);
114  result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
115  }
116  }
117  rewriter.replaceOp(op, result);
118 }
119 
120 static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
121  Type type = value.getType();
122  if (type.isF32())
123  return value;
124  if (type.getIntOrFloatBitWidth() < 32)
125  return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
126  if (type.getIntOrFloatBitWidth() > 32)
127  return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
128  llvm_unreachable("The only 32-bit float type is f32");
129 }
130 
131 // If `in` is a finite value, clamp it between the maximum and minimum values
132 // of `outElemType` so that subsequent conversion instructions don't
133 // overflow those out-of-range values to NaN. These semantics are commonly
134 // used in machine-learning contexts where failure to clamp would lead to
135 // excessive NaN production.
136 static Value clampInput(PatternRewriter &rewriter, Location loc,
137  Type outElemType, Value source) {
138  Type sourceType = source.getType();
139  const llvm::fltSemantics &sourceSem =
140  cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
141  const llvm::fltSemantics &targetSem =
142  cast<FloatType>(outElemType).getFloatSemantics();
143 
144  APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
145  APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
146  bool ignoredLosesInfo = false;
147  // We can ignore conversion failures here because this conversion promotes
148  // from a smaller type to a larger one - ex. there can be no loss of precision
149  // when casting fp8 to f16.
150  (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
151  (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
152 
153  Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
154  Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
155 
157  rewriter, loc, sourceType,
158  APFloat::getInf(sourceSem, /*Negative=*/false));
160  rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
161  Value isInf = rewriter.createOrFold<arith::CmpFOp>(
162  loc, arith::CmpFPredicate::OEQ, source, inf);
163  Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
164  loc, arith::CmpFPredicate::OEQ, source, negInf);
165  Value isNan = rewriter.createOrFold<arith::CmpFOp>(
166  loc, arith::CmpFPredicate::UNO, source, source);
167  Value isNonFinite = rewriter.create<arith::OrIOp>(
168  loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
169 
170  Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
171  Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
172  Value res =
173  rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
174  return res;
175 }
176 
177 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
178  // Only supporting default rounding mode as of now.
179  if (op.getRoundingmodeAttr())
180  return failure();
181  Type outType = op.getOut().getType();
182  if (auto outVecType = dyn_cast<VectorType>(outType)) {
183  if (outVecType.isScalable())
184  return failure();
185  if (outVecType.getShape().size() > 1)
186  // Multi-dimensional vectors are currently unsupported.
187  return failure();
188  outType = outVecType.getElementType();
189  }
190  auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
191  if (inType && inType.getWidth() <= 8 && saturateFP8)
192  // Conversion between 8-bit floats is not supported with truncation enabled.
193  return failure();
194  return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
195 }
196 
197 void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
198  PatternRewriter &rewriter) const {
199  Location loc = op.getLoc();
200  Value in = op.getIn();
201  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
202  if (saturateFP8)
203  in = clampInput(rewriter, loc, outElemType, in);
204  VectorType truncResType = VectorType::get(4, outElemType);
205  if (!isa<VectorType>(in.getType())) {
206  Value asFloat = castToF32(in, loc, rewriter);
207  Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
208  loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
209  /*existing=*/nullptr);
210  Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
211  return rewriter.replaceOp(op, result);
212  }
213  VectorType outType = cast<VectorType>(op.getOut().getType());
214  int64_t numElements = outType.getNumElements();
215  Value zero = rewriter.create<arith::ConstantOp>(
216  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
217  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
218  if (outType.getShape().empty()) {
219  Value scalarIn =
220  rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
221  // Recurse to send the 0-D vector case to the 1-D vector case
222  Value scalarTrunc =
223  rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
224  result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
226  return rewriter.replaceOp(op, result);
227  }
228 
229  for (int64_t i = 0; i < numElements; i += 4) {
230  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
231  Value thisResult = nullptr;
232  for (int64_t j = 0; j < elemsThisOp; j += 2) {
233  Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
234  Value asFloatA = castToF32(elemA, loc, rewriter);
235  Value asFloatB = nullptr;
236  if (j + 1 < elemsThisOp) {
237  Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
238  asFloatB = castToF32(elemB, loc, rewriter);
239  }
240  thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
241  loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
242  }
243  if (elemsThisOp < 4)
244  thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
245  loc, thisResult, 0, elemsThisOp, 1);
246  result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
247  result, i, 1);
248  }
249  rewriter.replaceOp(op, result);
250 }
251 
253  RewritePatternSet &patterns, bool saturateFP8TruncF) {
254  patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
255  patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
256  saturateFP8TruncF);
257 }
258 
259 void ArithToAMDGPUConversionPass::runOnOperation() {
260  Operation *op = getOperation();
261  RewritePatternSet patterns(op->getContext());
262  arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
263  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
264  return signalPassFailure();
265 }
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
FloatType getF32Type()
Definition: Builders.cpp:63
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:51
bool isFloat8E4M3FNUZ() const
Definition: Types.cpp:42
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
bool isFloat8E5M2FNUZ() const
Definition: Types.cpp:39
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool saturateFP8TruncF)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:201
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.