23 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
24 #include "mlir/Conversion/Passes.h.inc"
30 struct ArithToAMDGPUConversionPass final
31 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
32 using impl::ArithToAMDGPUConversionPassBase<
33 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
35 void runOnOperation()
override;
45 struct TruncFToFloat8RewritePattern final :
OpRewritePattern<arith::TruncFOp> {
46 bool saturateFP8 =
false;
47 TruncFToFloat8RewritePattern(
MLIRContext *ctx,
bool saturateFP8)
57 if (elementType.
isF32())
60 return rewriter.
create<arith::TruncFOp>(loc, elementType, f32);
62 return rewriter.
create<arith::ExtFOp>(loc, elementType, f32);
63 llvm_unreachable(
"The only 32-bit float type is f32");
66 LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op)
const {
67 Type inType = op.getIn().getType();
68 if (
auto inVecType = dyn_cast<VectorType>(inType)) {
69 if (inVecType.isScalable())
71 if (inVecType.getShape().size() > 1)
74 inType = inVecType.getElementType();
82 Value in = op.getIn();
84 if (!isa<VectorType>(in.
getType())) {
85 Value asFloat = rewriter.
create<amdgpu::ExtPackedFp8Op>(
90 VectorType inType = cast<VectorType>(in.
getType());
91 int64_t numElements = inType.getNumElements();
93 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
95 rewriter.
createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
96 if (inType.getShape().empty()) {
101 rewriter.
create<arith::ExtFOp>(loc, outElemType, scalarIn);
102 result = rewriter.
create<vector::InsertOp>(loc, scalarExt, zero,
106 for (int64_t i = 0; i < numElements; i += 4) {
107 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
108 Value inSlice = rewriter.
create<vector::ExtractStridedSliceOp>(
109 loc, in, i, elemsThisOp, 1);
110 for (int64_t
j = 0;
j < elemsThisOp; ++
j) {
111 Value asFloat = rewriter.
create<amdgpu::ExtPackedFp8Op>(
114 result = rewriter.
create<vector::InsertOp>(loc, asType, result, i +
j);
128 llvm_unreachable(
"The only 32-bit float type is f32");
139 const llvm::fltSemantics &sourceSem =
141 const llvm::fltSemantics &targetSem =
142 cast<FloatType>(outElemType).getFloatSemantics();
144 APFloat
min = APFloat::getLargest(targetSem,
true);
145 APFloat
max = APFloat::getLargest(targetSem,
false);
146 bool ignoredLosesInfo =
false;
150 (void)
min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
151 (void)
max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
157 rewriter, loc, sourceType,
158 APFloat::getInf(sourceSem,
false));
160 rewriter, loc, sourceType, APFloat::getInf(sourceSem,
true));
162 loc, arith::CmpFPredicate::OEQ, source, inf);
164 loc, arith::CmpFPredicate::OEQ, source, negInf);
166 loc, arith::CmpFPredicate::UNO, source, source);
168 loc, rewriter.
create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
170 Value clampedBelow = rewriter.
create<arith::MaximumFOp>(loc, source, minCst);
171 Value clamped = rewriter.
create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
173 rewriter.
create<arith::SelectOp>(loc, isNonFinite, source, clamped);
177 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op)
const {
179 if (op.getRoundingmodeAttr())
181 Type outType = op.getOut().getType();
182 if (
auto outVecType = dyn_cast<VectorType>(outType)) {
183 if (outVecType.isScalable())
185 if (outVecType.getShape().size() > 1)
188 outType = outVecType.getElementType();
191 if (inType && inType.getWidth() <= 8 && saturateFP8)
200 Value in = op.getIn();
203 in =
clampInput(rewriter, loc, outElemType, in);
205 if (!isa<VectorType>(in.
getType())) {
207 Value asF8s = rewriter.
create<amdgpu::PackedTrunc2xFp8Op>(
208 loc, truncResType, asFloat,
nullptr, 0,
210 Value result = rewriter.
create<vector::ExtractOp>(loc, asF8s, 0);
213 VectorType outType = cast<VectorType>(op.getOut().getType());
214 int64_t numElements = outType.getNumElements();
216 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
218 if (outType.getShape().empty()) {
223 rewriter.
create<arith::TruncFOp>(loc, outElemType, scalarIn);
224 result = rewriter.
create<vector::InsertOp>(loc, scalarTrunc, zero,
229 for (int64_t i = 0; i < numElements; i += 4) {
230 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
231 Value thisResult =
nullptr;
232 for (int64_t
j = 0;
j < elemsThisOp;
j += 2) {
233 Value elemA = rewriter.
create<vector::ExtractOp>(loc, in, i +
j);
235 Value asFloatB =
nullptr;
236 if (
j + 1 < elemsThisOp) {
237 Value elemB = rewriter.
create<vector::ExtractOp>(loc, in, i +
j + 1);
238 asFloatB =
castToF32(elemB, loc, rewriter);
240 thisResult = rewriter.
create<amdgpu::PackedTrunc2xFp8Op>(
241 loc, truncResType, asFloatA, asFloatB,
j / 2, thisResult);
244 thisResult = rewriter.
create<vector::ExtractStridedSliceOp>(
245 loc, thisResult, 0, elemsThisOp, 1);
246 result = rewriter.
create<vector::InsertStridedSliceOp>(loc, thisResult,
254 patterns.
add<ExtFOnFloat8RewritePattern>(patterns.
getContext());
255 patterns.
add<TruncFToFloat8RewritePattern>(patterns.
getContext(),
259 void ArithToAMDGPUConversionPass::runOnOperation() {
264 return signalPassFailure();
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
FloatAttr getFloatAttr(Type type, double value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat8E4M3FNUZ() const
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isFloat8E5M2FNUZ() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool saturateFP8TruncF)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.