22 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
23 #include "mlir/Conversion/Passes.h.inc"
29 struct ArithToAMDGPUConversionPass final
30 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
31 using impl::ArithToAMDGPUConversionPassBase<
32 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
34 void runOnOperation()
override;
40 LogicalResult match(arith::ExtFOp op)
const override;
44 struct TruncFToFloat8RewritePattern final :
OpRewritePattern<arith::TruncFOp> {
45 bool saturateFP8 =
false;
46 TruncFToFloat8RewritePattern(
MLIRContext *ctx,
bool saturateFP8)
49 LogicalResult match(arith::TruncFOp op)
const override;
56 if (elementType.
isF32())
59 return rewriter.
create<arith::TruncFOp>(loc, elementType, f32);
61 return rewriter.
create<arith::ExtFOp>(loc, elementType, f32);
62 llvm_unreachable(
"The only 32-bit float type is f32");
65 LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op)
const {
66 Type inType = op.getIn().getType();
67 if (
auto inVecType = dyn_cast<VectorType>(inType)) {
68 if (inVecType.isScalable())
70 inType = inVecType.getElementType();
78 Value in = op.getIn();
80 auto inType = dyn_cast<VectorType>(in.
getType());
82 Value asFloat = rewriter.
create<amdgpu::ExtPackedFp8Op>(
87 int64_t numElements = inType.getNumElements();
89 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
90 if (inType.getShape().empty()) {
95 rewriter.
create<arith::ExtFOp>(loc, outElemType, scalarIn);
96 Value result = rewriter.
create<vector::InsertOp>(loc, scalarExt, zero,
101 VectorType outType = cast<VectorType>(op.getOut().getType());
103 outType.getElementType());
106 if (inType.getRank() > 1) {
108 inType.getElementType());
109 in = rewriter.
create<vector::ShapeCastOp>(loc, inType, in);
112 for (int64_t i = 0; i < numElements; i += 4) {
113 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
114 Value inSlice = rewriter.
create<vector::ExtractStridedSliceOp>(
115 loc, in, i, elemsThisOp, 1);
116 for (int64_t
j = 0;
j < elemsThisOp; ++
j) {
117 Value asFloat = rewriter.
create<amdgpu::ExtPackedFp8Op>(
120 result = rewriter.
create<vector::InsertOp>(loc, asType, result, i +
j);
124 if (inType.getRank() != outType.getRank()) {
125 result = rewriter.
create<vector::ShapeCastOp>(loc, outType, result);
139 llvm_unreachable(
"The only 32-bit float type is f32");
150 const llvm::fltSemantics &sourceSem =
152 const llvm::fltSemantics &targetSem =
153 cast<FloatType>(outElemType).getFloatSemantics();
155 APFloat
min = APFloat::getLargest(targetSem,
true);
156 APFloat
max = APFloat::getLargest(targetSem,
false);
157 bool ignoredLosesInfo =
false;
161 (void)
min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
162 (void)
max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
168 rewriter, loc, sourceType,
169 APFloat::getInf(sourceSem,
false));
171 rewriter, loc, sourceType, APFloat::getInf(sourceSem,
true));
173 loc, arith::CmpFPredicate::OEQ, source, inf);
175 loc, arith::CmpFPredicate::OEQ, source, negInf);
177 loc, arith::CmpFPredicate::UNO, source, source);
179 loc, rewriter.
create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
181 Value clampedBelow = rewriter.
create<arith::MaximumFOp>(loc, source, minCst);
182 Value clamped = rewriter.
create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
184 rewriter.
create<arith::SelectOp>(loc, isNonFinite, source, clamped);
188 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op)
const {
190 if (op.getRoundingmodeAttr())
192 Type outType = op.getOut().getType();
193 if (
auto outVecType = dyn_cast<VectorType>(outType)) {
194 if (outVecType.isScalable())
196 outType = outVecType.getElementType();
199 if (inType && inType.getWidth() <= 8 && saturateFP8)
208 Value in = op.getIn();
211 in =
clampInput(rewriter, loc, outElemType, in);
212 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
216 Value asF8s = rewriter.
create<amdgpu::PackedTrunc2xFp8Op>(
217 loc, truncResType, asFloat,
nullptr, 0,
219 Value result = rewriter.
create<vector::ExtractOp>(loc, asF8s, 0);
222 VectorType outType = cast<VectorType>(op.getOut().getType());
223 int64_t numElements = outType.getNumElements();
225 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
226 if (outType.getShape().empty()) {
231 rewriter.
create<arith::TruncFOp>(loc, outElemType, scalarIn);
232 Value result = rewriter.
create<vector::InsertOp>(loc, scalarTrunc, zero,
238 outType.getElementType());
241 if (inVectorTy.getRank() > 1) {
243 inVectorTy.getElementType());
244 in = rewriter.
create<vector::ShapeCastOp>(loc, inVectorTy, in);
247 for (int64_t i = 0; i < numElements; i += 4) {
248 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
249 Value thisResult =
nullptr;
250 for (int64_t
j = 0;
j < elemsThisOp;
j += 2) {
251 Value elemA = rewriter.
create<vector::ExtractOp>(loc, in, i +
j);
253 Value asFloatB =
nullptr;
254 if (
j + 1 < elemsThisOp) {
255 Value elemB = rewriter.
create<vector::ExtractOp>(loc, in, i +
j + 1);
256 asFloatB =
castToF32(elemB, loc, rewriter);
258 thisResult = rewriter.
create<amdgpu::PackedTrunc2xFp8Op>(
259 loc, truncResType, asFloatA, asFloatB,
j / 2, thisResult);
262 thisResult = rewriter.
create<vector::ExtractStridedSliceOp>(
263 loc, thisResult, 0, elemsThisOp, 1);
264 result = rewriter.
create<vector::InsertStridedSliceOp>(loc, thisResult,
268 if (inVectorTy.getRank() != outType.getRank()) {
269 result = rewriter.
create<vector::ShapeCastOp>(loc, outType, result);
277 patterns.
add<ExtFOnFloat8RewritePattern>(patterns.
getContext());
278 patterns.
add<TruncFToFloat8RewritePattern>(patterns.
getContext(),
282 void ArithToAMDGPUConversionPass::runOnOperation() {
287 return signalPassFailure();
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
FloatAttr getFloatAttr(Type type, double value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat8E4M3FNUZ() const
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isFloat8E5M2FNUZ() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool saturateFP8TruncF)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.