25 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
26 #include "mlir/Conversion/Passes.h.inc"
36 struct ArithToAMDGPUConversionPass final
37 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
38 using impl::ArithToAMDGPUConversionPassBase<
39 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
41 void runOnOperation()
override;
51 LogicalResult matchAndRewrite(arith::ExtFOp op,
55 struct TruncFToFloat8RewritePattern final :
OpRewritePattern<arith::TruncFOp> {
56 bool saturateFP8 =
false;
57 TruncFToFloat8RewritePattern(
MLIRContext *ctx,
bool saturateFP8,
63 LogicalResult matchAndRewrite(arith::TruncFOp op,
67 struct TruncfToFloat16RewritePattern final
72 LogicalResult matchAndRewrite(arith::TruncFOp op,
80 return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
82 return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
89 if (elementType.
isF32())
92 return rewriter.
create<arith::TruncFOp>(loc, desType, f32);
94 return rewriter.
create<arith::ExtFOp>(loc, desType, f32);
95 llvm_unreachable(
"The only 32-bit float type is f32");
99 ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
101 Type inType = op.getIn().getType();
102 auto inVecType = dyn_cast<VectorType>(inType);
104 if (inVecType.isScalable())
106 inType = inVecType.getElementType();
112 Value in = op.getIn();
116 Value asFloat = rewriter.
create<amdgpu::ExtPackedFp8Op>(
122 int64_t numElements = inVecType.getNumElements();
125 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
126 VectorType outType = cast<VectorType>(op.getOut().getType());
128 if (inVecType.getShape().empty()) {
130 rewriter.
createOrFold<vector::SplatOp>(loc, outType, zero);
134 rewriter.
create<arith::ExtFOp>(loc, outElemType, scalarIn);
135 Value result = rewriter.
create<vector::InsertOp>(loc, scalarExt, zerodSplat,
142 outType.getElementType());
145 if (inVecType.getRank() > 1) {
147 inVecType.getElementType());
148 in = rewriter.
create<vector::ShapeCastOp>(loc, inVecType, in);
151 for (int64_t i = 0; i < numElements; i += 4) {
152 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
153 Value inSlice = rewriter.
create<vector::ExtractStridedSliceOp>(
154 loc, in, i, elemsThisOp, 1);
155 for (int64_t
j = 0;
j < elemsThisOp;
j += 2) {
156 if (i +
j + 1 < numElements) {
157 Value asFloats = rewriter.
create<amdgpu::ExtPackedFp8Op>(
158 loc, extResType, inSlice,
j / 2);
161 result = rewriter.
create<vector::InsertStridedSliceOp>(
162 loc, asType, result, i +
j, 1);
164 Value asFloat = rewriter.
create<amdgpu::ExtPackedFp8Op>(
167 result = rewriter.
create<vector::InsertOp>(loc, asType, result, i +
j);
172 if (inVecType.getRank() != outType.getRank()) {
173 result = rewriter.
create<vector::ShapeCastOp>(loc, outType, result);
188 llvm_unreachable(
"The only 32-bit float type is f32");
199 const llvm::fltSemantics &sourceSem =
201 const llvm::fltSemantics &targetSem =
202 cast<FloatType>(outElemType).getFloatSemantics();
204 APFloat
min = APFloat::getLargest(targetSem,
true);
205 APFloat
max = APFloat::getLargest(targetSem,
false);
206 bool ignoredLosesInfo =
false;
210 (void)
min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
211 (void)
max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
217 rewriter, loc, sourceType,
218 APFloat::getInf(sourceSem,
false));
220 rewriter, loc, sourceType, APFloat::getInf(sourceSem,
true));
222 loc, arith::CmpFPredicate::OEQ, source, inf);
224 loc, arith::CmpFPredicate::OEQ, source, negInf);
226 loc, arith::CmpFPredicate::UNO, source, source);
228 loc, rewriter.
create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
230 Value clampedBelow = rewriter.
create<arith::MaximumFOp>(loc, source, minCst);
231 Value clamped = rewriter.
create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
233 rewriter.
create<arith::SelectOp>(loc, isNonFinite, source, clamped);
238 TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
241 if (op.getRoundingmodeAttr())
243 Type outType = op.getOut().getType();
244 auto outVecType = dyn_cast<VectorType>(outType);
246 if (outVecType.isScalable())
248 outType = outVecType.getElementType();
251 if (inType && inType.getWidth() <= 8 && saturateFP8)
259 Value in = op.getIn();
262 in =
clampInput(rewriter, loc, outElemType, in);
263 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
267 Value asF8s = rewriter.
create<amdgpu::PackedTrunc2xFp8Op>(
268 loc, truncResType, asFloat,
nullptr, 0,
270 Value result = rewriter.
create<vector::ExtractOp>(loc, asF8s, 0);
275 int64_t numElements = outVecType.getNumElements();
277 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
278 if (outVecType.getShape().empty()) {
283 rewriter.
create<arith::TruncFOp>(loc, outElemType, scalarIn);
284 Value result = rewriter.
create<vector::InsertOp>(loc, scalarTrunc, zero,
291 outVecType.getElementType());
294 if (inVectorTy.getRank() > 1) {
296 inVectorTy.getElementType());
297 in = rewriter.
create<vector::ShapeCastOp>(loc, inVectorTy, in);
300 for (int64_t i = 0; i < numElements; i += 4) {
301 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
302 Value thisResult =
nullptr;
303 for (int64_t
j = 0;
j < elemsThisOp;
j += 2) {
304 Value elemA = rewriter.
create<vector::ExtractOp>(loc, in, i +
j);
306 Value asFloatB =
nullptr;
307 if (
j + 1 < elemsThisOp) {
308 Value elemB = rewriter.
create<vector::ExtractOp>(loc, in, i +
j + 1);
309 asFloatB =
castToF32(elemB, loc, rewriter);
311 thisResult = rewriter.
create<amdgpu::PackedTrunc2xFp8Op>(
312 loc, truncResType, asFloatA, asFloatB,
j / 2, thisResult);
315 thisResult = rewriter.
create<vector::ExtractStridedSliceOp>(
316 loc, thisResult, 0, elemsThisOp, 1);
317 result = rewriter.
create<vector::InsertStridedSliceOp>(loc, thisResult,
321 if (inVectorTy.getRank() != outVecType.getRank()) {
322 result = rewriter.
create<vector::ShapeCastOp>(loc, outVecType, result);
329 LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
331 Type outType = op.getOut().getType();
333 auto outVecType = dyn_cast<VectorType>(outType);
335 if (outVecType.isScalable())
337 outType = outVecType.getElementType();
343 Value in = op.getIn();
346 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
352 rewriter.
create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
353 Value result = rewriter.
create<vector::ExtractOp>(loc, asF16s, 0);
357 int64_t numElements = outVecType.getNumElements();
359 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
362 if (inVectorTy.getRank() > 1) {
364 inVectorTy.getElementType());
365 in = rewriter.
create<vector::ShapeCastOp>(loc, inVectorTy, in);
370 for (int64_t i = 0; i < numElements; i += 2) {
371 int64_t elemsThisOp =
std::min(numElements, i + 2) - i;
372 Value thisResult =
nullptr;
373 Value elemA = rewriter.
create<vector::ExtractOp>(loc, in, i);
376 if (elemsThisOp == 2) {
377 elemB = rewriter.
create<vector::ExtractOp>(loc, in, i + 1);
381 rewriter.
create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
384 thisResult = rewriter.
create<vector::ExtractStridedSliceOp>(
385 loc, thisResult, 0, elemsThisOp, 1);
386 result = rewriter.
create<vector::InsertStridedSliceOp>(loc, thisResult,
390 if (inVectorTy.getRank() != outVecType.getRank()) {
391 result = rewriter.
create<vector::ShapeCastOp>(loc, outVecType, result);
400 bool saturateFP8Truncf,
bool allowPackedF16Rtz,
Chipset chipset) {
402 if (convertFP8Arithmetic) {
405 saturateFP8Truncf, chipset);
407 if (allowPackedF16Rtz)
411 void ArithToAMDGPUConversionPass::runOnOperation() {
416 if (failed(maybeChipset)) {
418 return signalPassFailure();
421 bool convertFP8Arithmetic =
424 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
427 return signalPassFailure();
constexpr Chipset kGfx942
static Value castF32To(Type desType, Value f32, Location loc, PatternRewriter &rewriter)
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static bool isSupportedF8(Type elementType, Chipset chipset)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOcpFp8(const Chipset &chipset)
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, amdgpu::Chipset chipset)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.