25 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
26 #include "mlir/Conversion/Passes.h.inc"
33 struct ArithToAMDGPUConversionPass final
34 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
35 using impl::ArithToAMDGPUConversionPassBase<
36 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
38 void runOnOperation()
override;
44 LogicalResult match(arith::ExtFOp op)
const override;
48 struct TruncFToFloat8RewritePattern final :
OpRewritePattern<arith::TruncFOp> {
49 bool saturateFP8 =
false;
50 TruncFToFloat8RewritePattern(
MLIRContext *ctx,
bool saturateFP8,
56 LogicalResult match(arith::TruncFOp op)
const override;
60 struct TruncfToFloat16RewritePattern final
65 LogicalResult match(arith::TruncFOp op)
const override;
73 if (elementType.
isF32())
76 return rewriter.
create<arith::TruncFOp>(loc, elementType, f32);
78 return rewriter.
create<arith::ExtFOp>(loc, elementType, f32);
79 llvm_unreachable(
"The only 32-bit float type is f32");
82 LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op)
const {
83 Type inType = op.getIn().getType();
84 if (
auto inVecType = dyn_cast<VectorType>(inType)) {
85 if (inVecType.isScalable())
87 inType = inVecType.getElementType();
95 Value in = op.getIn();
97 auto inType = dyn_cast<VectorType>(in.
getType());
99 Value asFloat = rewriter.
create<amdgpu::ExtPackedFp8Op>(
104 int64_t numElements = inType.getNumElements();
106 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
107 if (inType.getShape().empty()) {
112 rewriter.
create<arith::ExtFOp>(loc, outElemType, scalarIn);
113 Value result = rewriter.
create<vector::InsertOp>(loc, scalarExt, zero,
118 VectorType outType = cast<VectorType>(op.getOut().getType());
120 outType.getElementType());
123 if (inType.getRank() > 1) {
125 inType.getElementType());
126 in = rewriter.
create<vector::ShapeCastOp>(loc, inType, in);
129 for (int64_t i = 0; i < numElements; i += 4) {
130 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
131 Value inSlice = rewriter.
create<vector::ExtractStridedSliceOp>(
132 loc, in, i, elemsThisOp, 1);
133 for (int64_t
j = 0;
j < elemsThisOp; ++
j) {
134 Value asFloat = rewriter.
create<amdgpu::ExtPackedFp8Op>(
137 result = rewriter.
create<vector::InsertOp>(loc, asType, result, i +
j);
141 if (inType.getRank() != outType.getRank()) {
142 result = rewriter.
create<vector::ShapeCastOp>(loc, outType, result);
156 llvm_unreachable(
"The only 32-bit float type is f32");
167 const llvm::fltSemantics &sourceSem =
169 const llvm::fltSemantics &targetSem =
170 cast<FloatType>(outElemType).getFloatSemantics();
172 APFloat
min = APFloat::getLargest(targetSem,
true);
173 APFloat
max = APFloat::getLargest(targetSem,
false);
174 bool ignoredLosesInfo =
false;
178 (void)
min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
179 (void)
max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
185 rewriter, loc, sourceType,
186 APFloat::getInf(sourceSem,
false));
188 rewriter, loc, sourceType, APFloat::getInf(sourceSem,
true));
190 loc, arith::CmpFPredicate::OEQ, source, inf);
192 loc, arith::CmpFPredicate::OEQ, source, negInf);
194 loc, arith::CmpFPredicate::UNO, source, source);
196 loc, rewriter.
create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
198 Value clampedBelow = rewriter.
create<arith::MaximumFOp>(loc, source, minCst);
199 Value clamped = rewriter.
create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
201 rewriter.
create<arith::SelectOp>(loc, isNonFinite, source, clamped);
205 LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op)
const {
207 if (op.getRoundingmodeAttr())
209 Type outType = op.getOut().getType();
210 if (
auto outVecType = dyn_cast<VectorType>(outType)) {
211 if (outVecType.isScalable())
213 outType = outVecType.getElementType();
216 if (inType && inType.getWidth() <= 8 && saturateFP8)
225 Value in = op.getIn();
228 in =
clampInput(rewriter, loc, outElemType, in);
229 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
233 Value asF8s = rewriter.
create<amdgpu::PackedTrunc2xFp8Op>(
234 loc, truncResType, asFloat,
nullptr, 0,
236 Value result = rewriter.
create<vector::ExtractOp>(loc, asF8s, 0);
239 VectorType outType = cast<VectorType>(op.getOut().getType());
240 int64_t numElements = outType.getNumElements();
242 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
243 if (outType.getShape().empty()) {
248 rewriter.
create<arith::TruncFOp>(loc, outElemType, scalarIn);
249 Value result = rewriter.
create<vector::InsertOp>(loc, scalarTrunc, zero,
255 outType.getElementType());
258 if (inVectorTy.getRank() > 1) {
260 inVectorTy.getElementType());
261 in = rewriter.
create<vector::ShapeCastOp>(loc, inVectorTy, in);
264 for (int64_t i = 0; i < numElements; i += 4) {
265 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
266 Value thisResult =
nullptr;
267 for (int64_t
j = 0;
j < elemsThisOp;
j += 2) {
268 Value elemA = rewriter.
create<vector::ExtractOp>(loc, in, i +
j);
270 Value asFloatB =
nullptr;
271 if (
j + 1 < elemsThisOp) {
272 Value elemB = rewriter.
create<vector::ExtractOp>(loc, in, i +
j + 1);
273 asFloatB =
castToF32(elemB, loc, rewriter);
275 thisResult = rewriter.
create<amdgpu::PackedTrunc2xFp8Op>(
276 loc, truncResType, asFloatA, asFloatB,
j / 2, thisResult);
279 thisResult = rewriter.
create<vector::ExtractStridedSliceOp>(
280 loc, thisResult, 0, elemsThisOp, 1);
281 result = rewriter.
create<vector::InsertStridedSliceOp>(loc, thisResult,
285 if (inVectorTy.getRank() != outType.getRank()) {
286 result = rewriter.
create<vector::ShapeCastOp>(loc, outType, result);
292 LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op)
const {
293 Type outType = op.getOut().getType();
295 if (
auto outVecType = dyn_cast<VectorType>(outType)) {
296 if (outVecType.isScalable())
298 outType = outVecType.getElementType();
300 return success(outType.
isF16() && inputType.
isF32());
306 Value in = op.getIn();
309 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
315 rewriter.
create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
316 Value result = rewriter.
create<vector::ExtractOp>(loc, asF16s, 0);
319 VectorType outType = cast<VectorType>(op.getOut().getType());
320 int64_t numElements = outType.getNumElements();
322 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
325 if (inVectorTy.getRank() > 1) {
327 inVectorTy.getElementType());
328 in = rewriter.
create<vector::ShapeCastOp>(loc, inVectorTy, in);
333 for (int64_t i = 0; i < numElements; i += 2) {
334 int64_t elemsThisOp =
std::min(numElements, i + 2) - i;
335 Value thisResult =
nullptr;
336 Value elemA = rewriter.
create<vector::ExtractOp>(loc, in, i);
339 if (elemsThisOp == 2) {
340 elemB = rewriter.
create<vector::ExtractOp>(loc, in, i + 1);
344 rewriter.
create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
347 thisResult = rewriter.
create<vector::ExtractStridedSliceOp>(
348 loc, thisResult, 0, elemsThisOp, 1);
349 result = rewriter.
create<vector::InsertStridedSliceOp>(loc, thisResult,
353 if (inVectorTy.getRank() != outType.getRank()) {
354 result = rewriter.
create<vector::ShapeCastOp>(loc, outType, result);
362 bool saturateFP8Truncf,
bool allowPackedF16Rtz,
Chipset chipset) {
364 if (convertFP8Arithmetic) {
365 patterns.
add<ExtFOnFloat8RewritePattern>(patterns.
getContext());
366 patterns.
add<TruncFToFloat8RewritePattern>(patterns.
getContext(),
367 saturateFP8Truncf, chipset);
369 if (allowPackedF16Rtz)
370 patterns.
add<TruncfToFloat16RewritePattern>(patterns.
getContext());
373 void ArithToAMDGPUConversionPass::runOnOperation() {
378 if (failed(maybeChipset)) {
380 return signalPassFailure();
383 bool convertFP8Arithmetic =
384 maybeChipset->majorVersion == 9 && *maybeChipset >=
Chipset(9, 4, 0);
386 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
389 return signalPassFailure();
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
FloatAttr getFloatAttr(Type type, double value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isFloat8E4M3FNUZ() const
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isFloat8E5M2FNUZ() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, amdgpu::Chipset chipset)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.