28 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
29 #include "mlir/Conversion/Passes.h.inc"
40 struct ArithToAMDGPUConversionPass final
41 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
42 using impl::ArithToAMDGPUConversionPassBase<
43 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
45 void runOnOperation()
override;
56 LogicalResult matchAndRewrite(arith::ExtFOp op,
60 struct TruncFToFloat8RewritePattern final :
OpRewritePattern<arith::TruncFOp> {
61 bool saturateFP8 =
false;
62 TruncFToFloat8RewritePattern(
MLIRContext *ctx,
bool saturateFP8,
65 saturateFP8(saturateFP8), chipset(chipset) {}
68 LogicalResult matchAndRewrite(arith::TruncFOp op,
72 struct TruncfToFloat16RewritePattern final
77 LogicalResult matchAndRewrite(arith::TruncFOp op,
81 struct ScalingExtFRewritePattern final
85 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
89 struct ScalingTruncFRewritePattern final
93 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
101 return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
103 return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
110 if (elementType.
isF32())
113 return arith::TruncFOp::create(rewriter, loc, desType, f32);
115 return arith::ExtFOp::create(rewriter, loc, desType, f32);
116 llvm_unreachable(
"The only 32-bit float type is f32");
120 ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
122 Type inType = op.getIn().getType();
123 auto inVecType = dyn_cast<VectorType>(inType);
125 if (inVecType.isScalable())
127 inType = inVecType.getElementType();
133 Value in = op.getIn();
137 Value asFloat = amdgpu::ExtPackedFp8Op::create(
143 int64_t numElements = inVecType.getNumElements();
145 Value zero = arith::ConstantOp::create(
146 rewriter, loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
147 VectorType outType = cast<VectorType>(op.getOut().getType());
149 if (inVecType.getShape().empty()) {
151 rewriter.
createOrFold<vector::BroadcastOp>(loc, outType, zero);
155 arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
156 Value result = vector::InsertOp::create(rewriter, loc, scalarExt,
163 outType.getElementType());
166 if (inVecType.getRank() > 1) {
168 inVecType.getElementType());
169 in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
172 for (int64_t i = 0; i < numElements; i += 4) {
173 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
174 Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i,
176 for (int64_t
j = 0;
j < elemsThisOp;
j += 2) {
177 if (i +
j + 1 < numElements) {
178 Value asFloats = amdgpu::ExtPackedFp8Op::create(
179 rewriter, loc, extResType, inSlice,
j / 2);
182 result = vector::InsertStridedSliceOp::create(rewriter, loc, asType,
185 Value asFloat = amdgpu::ExtPackedFp8Op::create(
186 rewriter, loc, rewriter.
getF32Type(), inSlice,
j / 2 * 2);
188 result = vector::InsertOp::create(rewriter, loc, asType, result, i +
j);
193 if (inVecType.getRank() != outType.getRank()) {
194 result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
206 return arith::ExtFOp::create(rewriter, loc, rewriter.
getF32Type(), value);
208 return arith::TruncFOp::create(rewriter, loc, rewriter.
getF32Type(), value);
209 llvm_unreachable(
"The only 32-bit float type is f32");
220 const llvm::fltSemantics &sourceSem =
222 const llvm::fltSemantics &targetSem =
223 cast<FloatType>(outElemType).getFloatSemantics();
225 APFloat
min = APFloat::getLargest(targetSem,
true);
226 APFloat
max = APFloat::getLargest(targetSem,
false);
227 bool ignoredLosesInfo =
false;
231 (void)
min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
232 (void)
max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
238 rewriter, loc, sourceType,
239 APFloat::getInf(sourceSem,
false));
241 rewriter, loc, sourceType, APFloat::getInf(sourceSem,
true));
243 loc, arith::CmpFPredicate::OEQ, source, inf);
245 loc, arith::CmpFPredicate::OEQ, source, negInf);
247 loc, arith::CmpFPredicate::UNO, source, source);
248 Value isNonFinite = arith::OrIOp::create(
249 rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf),
252 Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
254 arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
256 arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
261 TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
264 if (op.getRoundingmodeAttr())
266 Type outType = op.getOut().getType();
267 auto outVecType = dyn_cast<VectorType>(outType);
269 if (outVecType.isScalable())
271 outType = outVecType.getElementType();
274 if (inType && inType.getWidth() <= 8 && saturateFP8)
282 Value in = op.getIn();
285 in =
clampInput(rewriter, loc, outElemType, in);
286 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
290 Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(
291 rewriter, loc, truncResType, asFloat,
nullptr, 0,
293 Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
298 int64_t numElements = outVecType.getNumElements();
299 Value zero = arith::ConstantOp::create(
300 rewriter, loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
301 if (outVecType.getShape().empty()) {
306 arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
307 Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
314 outVecType.getElementType());
317 if (inVectorTy.getRank() > 1) {
319 inVectorTy.getElementType());
320 in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
323 for (int64_t i = 0; i < numElements; i += 4) {
324 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
325 Value thisResult =
nullptr;
326 for (int64_t
j = 0;
j < elemsThisOp;
j += 2) {
327 Value elemA = vector::ExtractOp::create(rewriter, loc, in, i +
j);
329 Value asFloatB =
nullptr;
330 if (
j + 1 < elemsThisOp) {
331 Value elemB = vector::ExtractOp::create(rewriter, loc, in, i +
j + 1);
332 asFloatB =
castToF32(elemB, loc, rewriter);
334 thisResult = amdgpu::PackedTrunc2xFp8Op::create(
335 rewriter, loc, truncResType, asFloatA, asFloatB,
j / 2, thisResult);
338 thisResult = vector::ExtractStridedSliceOp::create(
339 rewriter, loc, thisResult, 0, elemsThisOp, 1);
340 result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
344 if (inVectorTy.getRank() != outVecType.getRank()) {
345 result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
352 LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
354 Type outType = op.getOut().getType();
356 auto outVecType = dyn_cast<VectorType>(outType);
358 if (outVecType.isScalable())
360 outType = outVecType.getElementType();
366 Value in = op.getIn();
369 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
373 auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.
getF32Type());
375 ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
376 Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
380 int64_t numElements = outVecType.getNumElements();
382 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
384 rewriter.
createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
386 if (inVectorTy.getRank() > 1) {
388 inVectorTy.getElementType());
389 in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
394 for (int64_t i = 0; i < numElements; i += 2) {
395 int64_t elemsThisOp =
std::min(numElements, i + 2) - i;
396 Value thisResult =
nullptr;
397 Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
398 Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.
getF32Type());
400 if (elemsThisOp == 2) {
401 elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
405 ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
408 thisResult = vector::ExtractStridedSliceOp::create(
409 rewriter, loc, thisResult, 0, elemsThisOp, 1);
410 result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
414 if (inVectorTy.getRank() != outVecType.getRank()) {
415 result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
424 Value current = value;
427 .Case<vector::ShapeCastOp>([¤t](
auto op) {
428 current = op.getSource();
431 .Case<vector::BroadcastOp>([¤t](
auto op) {
432 current = op.getSource();
435 .Default([](
Operation *) {
return false; });
445 ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
448 constexpr int64_t opOutWidth = 2;
450 Value in = op.getIn();
451 Value scale = op.getScale();
452 Value out = op.getOut();
461 VectorType outVecType = dyn_cast<VectorType>(out.
getType());
462 VectorType scaleVecType = dyn_cast<VectorType>(scale.
getType());
464 if (outVecType && outVecType.isScalable())
470 scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
472 scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
477 Value inCast = vector::BroadcastOp::create(rewriter, loc,
480 Value scaleExt = amdgpu::ScaledExtPackedOp::create(
481 rewriter, loc, extScaleResultType, inCast, scale, 0);
486 VectorType inVecType = cast<VectorType>(in.
getType());
488 VectorType origScaleVecType = dyn_cast<VectorType>(origScale.
getType());
492 if (origScaleVecType)
493 llvm::append_range(originalScaleShape, origScaleVecType.getShape());
495 originalScaleShape.insert(originalScaleShape.end(),
496 inShape.size() - originalScaleShape.size(), 1);
500 "failed to derive block size from broadcast or splat operation");
507 Value zero = arith::ConstantOp::create(rewriter, loc, outType,
510 rewriter.
createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
514 Value block = vector::ExtractStridedSliceOp::create(
515 rewriter, loc, in, offsets, ratio, strides);
518 vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
520 vector::ExtractOp::create(rewriter, loc, scale, offsets);
524 rewriter.
createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
526 for (int64_t i = 0, inSliceWidth =
std::min(opInWidth, blockSize - i);
528 i += inSliceWidth, inSliceWidth =
std::min(opInWidth, blockSize - i)) {
529 Value inSlice = vector::ExtractStridedSliceOp::create(
530 rewriter, loc, block1D, i, inSliceWidth, 1);
532 outSliceWidth =
std::min(opOutWidth, inSliceWidth -
j);
533 j < inSliceWidth;
j += outSliceWidth,
534 outSliceWidth =
std::min(opOutWidth, inSliceWidth -
j)) {
536 Value scaleExt = amdgpu::ScaledExtPackedOp::create(
537 rewriter, loc, extScaleResultType, inSlice, uniformScale,
539 if (outSliceWidth < opOutWidth) {
540 scaleExt = vector::ExtractStridedSliceOp::create(
541 rewriter, loc, scaleExt, 0, outSliceWidth, 1);
543 blockResult = vector::InsertStridedSliceOp::create(
544 rewriter, loc, scaleExt, blockResult, i +
j, 1);
550 vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
551 result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
561 ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
564 constexpr int64_t opInWidth = 2;
566 Value in = op.getIn();
567 Value scale = op.getScale();
568 Value out = op.getOut();
575 VectorType outVecType = dyn_cast<VectorType>(out.
getType());
576 VectorType scaleVecType = dyn_cast<VectorType>(scale.
getType());
577 if (outVecType && outVecType.isScalable())
583 scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
585 scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
587 Value zero = arith::ConstantOp::create(rewriter, loc, outType,
590 VectorType truncScaleResultType =
VectorType::get(opOutWidth, outType);
594 Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in);
596 Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
597 rewriter, loc, truncScaleResultType, inCast, scale, 0,
604 VectorType inVecType = cast<VectorType>(in.
getType());
606 VectorType origScaleVecType = dyn_cast<VectorType>(origScale.
getType());
610 if (origScaleVecType)
611 llvm::append_range(scaleShape, origScaleVecType.getShape());
613 scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
617 "failed to derive block size from broadcast or splat operation");
625 rewriter.
createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
629 Value block = vector::ExtractStridedSliceOp::create(
630 rewriter, loc, in, offsets, ratio, strides);
633 vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
635 vector::ExtractOp::create(rewriter, loc, scale, offsets);
639 rewriter.
createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
641 for (int64_t i = 0, outSliceWidth =
std::min(opOutWidth, blockSize - i);
642 i < blockSize; i += outSliceWidth,
643 outSliceWidth =
std::min(opOutWidth, blockSize - i)) {
646 if (outSliceWidth <= opInWidth) {
647 Value slice = vector::ExtractStridedSliceOp::create(
648 rewriter, loc, block1D, i, outSliceWidth, 1);
650 scaleTrunc = amdgpu::PackedScaledTruncOp::create(
651 rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
654 scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
655 truncScaleResultType, zero);
657 inSliceWidth =
std::min(opInWidth, outSliceWidth -
j);
658 j < outSliceWidth;
j += opInWidth,
659 inSliceWidth =
std::min(opInWidth, outSliceWidth -
j)) {
660 Value slice = vector::ExtractStridedSliceOp::create(
661 rewriter, loc, block1D, i +
j, inSliceWidth, 1);
662 scaleTrunc = amdgpu::PackedScaledTruncOp::create(
663 rewriter, loc, truncScaleResultType, slice, uniformScale,
664 j / opInWidth, scaleTrunc);
667 if (outSliceWidth != opOutWidth) {
668 scaleTrunc = vector::ExtractStridedSliceOp::create(
669 rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
671 blockResult = vector::InsertStridedSliceOp::create(
672 rewriter, loc, scaleTrunc, blockResult, i, 1);
677 vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
678 result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
689 bool saturateFP8Truncf,
bool allowPackedF16Rtz,
bool supportsScaledExtTrunc,
692 if (convertFP8Arithmetic) {
695 patterns.add<TruncFToFloat8RewritePattern>(
696 patterns.getContext(), saturateFP8Truncf, chipset, benefit);
698 if (allowPackedF16Rtz)
699 patterns.add<TruncfToFloat16RewritePattern>(
patterns.getContext(), benefit);
701 if (supportsScaledExtTrunc) {
707 void ArithToAMDGPUConversionPass::runOnOperation() {
712 if (
failed(maybeChipset)) {
714 return signalPassFailure();
717 bool convertFP8Arithmetic =
719 bool supportsScaledExtTrunc = *maybeChipset ==
kGfx950;
721 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
722 supportsScaledExtTrunc, *maybeChipset);
724 return signalPassFailure();
constexpr Chipset kGfx942
constexpr Chipset kGfx950
static Value getOriginalVectorValue(Value value)
Get the broadcasted / splatted value for a chain of ops.
static Value castF32To(Type desType, Value f32, Location loc, PatternRewriter &rewriter)
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static bool isSupportedF8(Type elementType, Chipset chipset)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasOcpFp8(const Chipset &chipset)
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.